Home | History | Annotate | Download | only in gofmt
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package main
      6 
      7 import (
      8 	"fmt"
      9 	"go/ast"
     10 	"go/parser"
     11 	"go/token"
     12 	"os"
     13 	"reflect"
     14 	"strings"
     15 	"unicode"
     16 	"unicode/utf8"
     17 )
     18 
     19 func initRewrite() {
     20 	if *rewriteRule == "" {
     21 		rewrite = nil // disable any previous rewrite
     22 		return
     23 	}
     24 	f := strings.Split(*rewriteRule, "->")
     25 	if len(f) != 2 {
     26 		fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
     27 		os.Exit(2)
     28 	}
     29 	pattern := parseExpr(f[0], "pattern")
     30 	replace := parseExpr(f[1], "replacement")
     31 	rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }
     32 }
     33 
     34 // parseExpr parses s as an expression.
     35 // It might make sense to expand this to allow statement patterns,
     36 // but there are problems with preserving formatting and also
     37 // with what a wildcard for a statement looks like.
     38 func parseExpr(s, what string) ast.Expr {
     39 	x, err := parser.ParseExpr(s)
     40 	if err != nil {
     41 		fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
     42 		os.Exit(2)
     43 	}
     44 	return x
     45 }
     46 
     47 // Keep this function for debugging.
     48 /*
     49 func dump(msg string, val reflect.Value) {
     50 	fmt.Printf("%s:\n", msg)
     51 	ast.Print(fileSet, val.Interface())
     52 	fmt.Println()
     53 }
     54 */
     55 
     56 // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
     57 func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
     58 	cmap := ast.NewCommentMap(fileSet, p, p.Comments)
     59 	m := make(map[string]reflect.Value)
     60 	pat := reflect.ValueOf(pattern)
     61 	repl := reflect.ValueOf(replace)
     62 
     63 	var rewriteVal func(val reflect.Value) reflect.Value
     64 	rewriteVal = func(val reflect.Value) reflect.Value {
     65 		// don't bother if val is invalid to start with
     66 		if !val.IsValid() {
     67 			return reflect.Value{}
     68 		}
     69 		val = apply(rewriteVal, val)
     70 		for k := range m {
     71 			delete(m, k)
     72 		}
     73 		if match(m, pat, val) {
     74 			val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
     75 		}
     76 		return val
     77 	}
     78 
     79 	r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
     80 	r.Comments = cmap.Filter(r).Comments() // recreate comments list
     81 	return r
     82 }
     83 
     84 // set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
     85 func set(x, y reflect.Value) {
     86 	// don't bother if x cannot be set or y is invalid
     87 	if !x.CanSet() || !y.IsValid() {
     88 		return
     89 	}
     90 	defer func() {
     91 		if x := recover(); x != nil {
     92 			if s, ok := x.(string); ok &&
     93 				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
     94 				// x cannot be set to y - ignore this rewrite
     95 				return
     96 			}
     97 			panic(x)
     98 		}
     99 	}()
    100 	x.Set(y)
    101 }
    102 
    103 // Values/types for special cases.
    104 var (
    105 	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
    106 	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
    107 
    108 	identType     = reflect.TypeOf((*ast.Ident)(nil))
    109 	objectPtrType = reflect.TypeOf((*ast.Object)(nil))
    110 	positionType  = reflect.TypeOf(token.NoPos)
    111 	callExprType  = reflect.TypeOf((*ast.CallExpr)(nil))
    112 	scopePtrType  = reflect.TypeOf((*ast.Scope)(nil))
    113 )
    114 
    115 // apply replaces each AST field x in val with f(x), returning val.
    116 // To avoid extra conversions, f operates on the reflect.Value form.
    117 func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
    118 	if !val.IsValid() {
    119 		return reflect.Value{}
    120 	}
    121 
    122 	// *ast.Objects introduce cycles and are likely incorrect after
    123 	// rewrite; don't follow them but replace with nil instead
    124 	if val.Type() == objectPtrType {
    125 		return objectPtrNil
    126 	}
    127 
    128 	// similarly for scopes: they are likely incorrect after a rewrite;
    129 	// replace them with nil
    130 	if val.Type() == scopePtrType {
    131 		return scopePtrNil
    132 	}
    133 
    134 	switch v := reflect.Indirect(val); v.Kind() {
    135 	case reflect.Slice:
    136 		for i := 0; i < v.Len(); i++ {
    137 			e := v.Index(i)
    138 			set(e, f(e))
    139 		}
    140 	case reflect.Struct:
    141 		for i := 0; i < v.NumField(); i++ {
    142 			e := v.Field(i)
    143 			set(e, f(e))
    144 		}
    145 	case reflect.Interface:
    146 		e := v.Elem()
    147 		set(v, f(e))
    148 	}
    149 	return val
    150 }
    151 
    152 func isWildcard(s string) bool {
    153 	rune, size := utf8.DecodeRuneInString(s)
    154 	return size == len(s) && unicode.IsLower(rune)
    155 }
    156 
    157 // match reports whether pattern matches val,
    158 // recording wildcard submatches in m.
    159 // If m == nil, match checks whether pattern == val.
    160 func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
    161 	// Wildcard matches any expression. If it appears multiple
    162 	// times in the pattern, it must match the same expression
    163 	// each time.
    164 	if m != nil && pattern.IsValid() && pattern.Type() == identType {
    165 		name := pattern.Interface().(*ast.Ident).Name
    166 		if isWildcard(name) && val.IsValid() {
    167 			// wildcards only match valid (non-nil) expressions.
    168 			if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
    169 				if old, ok := m[name]; ok {
    170 					return match(nil, old, val)
    171 				}
    172 				m[name] = val
    173 				return true
    174 			}
    175 		}
    176 	}
    177 
    178 	// Otherwise, pattern and val must match recursively.
    179 	if !pattern.IsValid() || !val.IsValid() {
    180 		return !pattern.IsValid() && !val.IsValid()
    181 	}
    182 	if pattern.Type() != val.Type() {
    183 		return false
    184 	}
    185 
    186 	// Special cases.
    187 	switch pattern.Type() {
    188 	case identType:
    189 		// For identifiers, only the names need to match
    190 		// (and none of the other *ast.Object information).
    191 		// This is a common case, handle it all here instead
    192 		// of recursing down any further via reflection.
    193 		p := pattern.Interface().(*ast.Ident)
    194 		v := val.Interface().(*ast.Ident)
    195 		return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
    196 	case objectPtrType, positionType:
    197 		// object pointers and token positions always match
    198 		return true
    199 	case callExprType:
    200 		// For calls, the Ellipsis fields (token.Position) must
    201 		// match since that is how f(x) and f(x...) are different.
    202 		// Check them here but fall through for the remaining fields.
    203 		p := pattern.Interface().(*ast.CallExpr)
    204 		v := val.Interface().(*ast.CallExpr)
    205 		if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
    206 			return false
    207 		}
    208 	}
    209 
    210 	p := reflect.Indirect(pattern)
    211 	v := reflect.Indirect(val)
    212 	if !p.IsValid() || !v.IsValid() {
    213 		return !p.IsValid() && !v.IsValid()
    214 	}
    215 
    216 	switch p.Kind() {
    217 	case reflect.Slice:
    218 		if p.Len() != v.Len() {
    219 			return false
    220 		}
    221 		for i := 0; i < p.Len(); i++ {
    222 			if !match(m, p.Index(i), v.Index(i)) {
    223 				return false
    224 			}
    225 		}
    226 		return true
    227 
    228 	case reflect.Struct:
    229 		for i := 0; i < p.NumField(); i++ {
    230 			if !match(m, p.Field(i), v.Field(i)) {
    231 				return false
    232 			}
    233 		}
    234 		return true
    235 
    236 	case reflect.Interface:
    237 		return match(m, p.Elem(), v.Elem())
    238 	}
    239 
    240 	// Handle token integers, etc.
    241 	return p.Interface() == v.Interface()
    242 }
    243 
    244 // subst returns a copy of pattern with values from m substituted in place
    245 // of wildcards and pos used as the position of tokens from the pattern.
    246 // if m == nil, subst returns a copy of pattern and doesn't change the line
    247 // number information.
    248 func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
    249 	if !pattern.IsValid() {
    250 		return reflect.Value{}
    251 	}
    252 
    253 	// Wildcard gets replaced with map value.
    254 	if m != nil && pattern.Type() == identType {
    255 		name := pattern.Interface().(*ast.Ident).Name
    256 		if isWildcard(name) {
    257 			if old, ok := m[name]; ok {
    258 				return subst(nil, old, reflect.Value{})
    259 			}
    260 		}
    261 	}
    262 
    263 	if pos.IsValid() && pattern.Type() == positionType {
    264 		// use new position only if old position was valid in the first place
    265 		if old := pattern.Interface().(token.Pos); !old.IsValid() {
    266 			return pattern
    267 		}
    268 		return pos
    269 	}
    270 
    271 	// Otherwise copy.
    272 	switch p := pattern; p.Kind() {
    273 	case reflect.Slice:
    274 		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
    275 		for i := 0; i < p.Len(); i++ {
    276 			v.Index(i).Set(subst(m, p.Index(i), pos))
    277 		}
    278 		return v
    279 
    280 	case reflect.Struct:
    281 		v := reflect.New(p.Type()).Elem()
    282 		for i := 0; i < p.NumField(); i++ {
    283 			v.Field(i).Set(subst(m, p.Field(i), pos))
    284 		}
    285 		return v
    286 
    287 	case reflect.Ptr:
    288 		v := reflect.New(p.Type()).Elem()
    289 		if elem := p.Elem(); elem.IsValid() {
    290 			v.Set(subst(m, elem, pos).Addr())
    291 		}
    292 		return v
    293 
    294 	case reflect.Interface:
    295 		v := reflect.New(p.Type()).Elem()
    296 		if elem := p.Elem(); elem.IsValid() {
    297 			v.Set(subst(m, elem, pos))
    298 		}
    299 		return v
    300 	}
    301 
    302 	return pattern
    303 }
    304