Home | History | Annotate | Download | only in fix
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package main
      6 
      7 import (
      8 	"fmt"
      9 	"go/ast"
     10 	"go/token"
     11 	"os"
     12 	"reflect"
     13 	"strings"
     14 )
     15 
     16 // Partial type checker.
     17 //
     18 // The fact that it is partial is very important: the input is
     19 // an AST and a description of some type information to
     20 // assume about one or more packages, but not all the
     21 // packages that the program imports. The checker is
     22 // expected to do as much as it can with what it has been
     23 // given. There is not enough information supplied to do
     24 // a full type check, but the type checker is expected to
     25 // apply information that can be derived from variable
     26 // declarations, function and method returns, and type switches
     27 // as far as it can, so that the caller can still tell the types
     28 // of expression relevant to a particular fix.
     29 //
     30 // TODO(rsc,gri): Replace with go/typechecker.
     31 // Doing that could be an interesting test case for go/typechecker:
     32 // the constraints about working with partial information will
     33 // likely exercise it in interesting ways. The ideal interface would
     34 // be to pass typecheck a map from importpath to package API text
     35 // (Go source code), but for now we use data structures (TypeConfig, Type).
     36 //
     37 // The strings mostly use gofmt form.
     38 //
     39 // A Field or FieldList has as its type a comma-separated list
     40 // of the types of the fields. For example, the field list
     41 //	x, y, z int
     42 // has type "int, int, int".
     43 
     44 // The prefix "type " is the type of a type.
     45 // For example, given
     46 //	var x int
     47 //	type T int
     48 // x's type is "int" but T's type is "type int".
     49 // mkType inserts the "type " prefix.
     50 // getType removes it.
     51 // isType tests for it.
     52 
     53 func mkType(t string) string {
     54 	return "type " + t
     55 }
     56 
     57 func getType(t string) string {
     58 	if !isType(t) {
     59 		return ""
     60 	}
     61 	return t[len("type "):]
     62 }
     63 
     64 func isType(t string) bool {
     65 	return strings.HasPrefix(t, "type ")
     66 }
     67 
     68 // TypeConfig describes the universe of relevant types.
     69 // For ease of creation, the types are all referred to by string
     70 // name (e.g., "reflect.Value").  TypeByName is the only place
     71 // where the strings are resolved.
     72 
     73 type TypeConfig struct {
     74 	Type map[string]*Type
     75 	Var  map[string]string
     76 	Func map[string]string
     77 }
     78 
     79 // typeof returns the type of the given name, which may be of
     80 // the form "x" or "p.X".
     81 func (cfg *TypeConfig) typeof(name string) string {
     82 	if cfg.Var != nil {
     83 		if t := cfg.Var[name]; t != "" {
     84 			return t
     85 		}
     86 	}
     87 	if cfg.Func != nil {
     88 		if t := cfg.Func[name]; t != "" {
     89 			return "func()" + t
     90 		}
     91 	}
     92 	return ""
     93 }
     94 
     95 // Type describes the Fields and Methods of a type.
     96 // If the field or method cannot be found there, it is next
     97 // looked for in the Embed list.
     98 type Type struct {
     99 	Field  map[string]string // map field name to type
    100 	Method map[string]string // map method name to comma-separated return types (should start with "func ")
    101 	Embed  []string          // list of types this type embeds (for extra methods)
    102 	Def    string            // definition of named type
    103 }
    104 
    105 // dot returns the type of "typ.name", making its decision
    106 // using the type information in cfg.
    107 func (typ *Type) dot(cfg *TypeConfig, name string) string {
    108 	if typ.Field != nil {
    109 		if t := typ.Field[name]; t != "" {
    110 			return t
    111 		}
    112 	}
    113 	if typ.Method != nil {
    114 		if t := typ.Method[name]; t != "" {
    115 			return t
    116 		}
    117 	}
    118 
    119 	for _, e := range typ.Embed {
    120 		etyp := cfg.Type[e]
    121 		if etyp != nil {
    122 			if t := etyp.dot(cfg, name); t != "" {
    123 				return t
    124 			}
    125 		}
    126 	}
    127 
    128 	return ""
    129 }
    130 
    131 // typecheck type checks the AST f assuming the information in cfg.
    132 // It returns two maps with type information:
    133 // typeof maps AST nodes to type information in gofmt string form.
    134 // assign maps type strings to lists of expressions that were assigned
    135 // to values of another type that were assigned to that type.
    136 func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) {
    137 	typeof = make(map[interface{}]string)
    138 	assign = make(map[string][]interface{})
    139 	cfg1 := &TypeConfig{}
    140 	*cfg1 = *cfg // make copy so we can add locally
    141 	copied := false
    142 
    143 	// gather function declarations
    144 	for _, decl := range f.Decls {
    145 		fn, ok := decl.(*ast.FuncDecl)
    146 		if !ok {
    147 			continue
    148 		}
    149 		typecheck1(cfg, fn.Type, typeof, assign)
    150 		t := typeof[fn.Type]
    151 		if fn.Recv != nil {
    152 			// The receiver must be a type.
    153 			rcvr := typeof[fn.Recv]
    154 			if !isType(rcvr) {
    155 				if len(fn.Recv.List) != 1 {
    156 					continue
    157 				}
    158 				rcvr = mkType(gofmt(fn.Recv.List[0].Type))
    159 				typeof[fn.Recv.List[0].Type] = rcvr
    160 			}
    161 			rcvr = getType(rcvr)
    162 			if rcvr != "" && rcvr[0] == '*' {
    163 				rcvr = rcvr[1:]
    164 			}
    165 			typeof[rcvr+"."+fn.Name.Name] = t
    166 		} else {
    167 			if isType(t) {
    168 				t = getType(t)
    169 			} else {
    170 				t = gofmt(fn.Type)
    171 			}
    172 			typeof[fn.Name] = t
    173 
    174 			// Record typeof[fn.Name.Obj] for future references to fn.Name.
    175 			typeof[fn.Name.Obj] = t
    176 		}
    177 	}
    178 
    179 	// gather struct declarations
    180 	for _, decl := range f.Decls {
    181 		d, ok := decl.(*ast.GenDecl)
    182 		if ok {
    183 			for _, s := range d.Specs {
    184 				switch s := s.(type) {
    185 				case *ast.TypeSpec:
    186 					if cfg1.Type[s.Name.Name] != nil {
    187 						break
    188 					}
    189 					if !copied {
    190 						copied = true
    191 						// Copy map lazily: it's time.
    192 						cfg1.Type = make(map[string]*Type)
    193 						for k, v := range cfg.Type {
    194 							cfg1.Type[k] = v
    195 						}
    196 					}
    197 					t := &Type{Field: map[string]string{}}
    198 					cfg1.Type[s.Name.Name] = t
    199 					switch st := s.Type.(type) {
    200 					case *ast.StructType:
    201 						for _, f := range st.Fields.List {
    202 							for _, n := range f.Names {
    203 								t.Field[n.Name] = gofmt(f.Type)
    204 							}
    205 						}
    206 					case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
    207 						t.Def = gofmt(st)
    208 					}
    209 				}
    210 			}
    211 		}
    212 	}
    213 
    214 	typecheck1(cfg1, f, typeof, assign)
    215 	return typeof, assign
    216 }
    217 
    218 func makeExprList(a []*ast.Ident) []ast.Expr {
    219 	var b []ast.Expr
    220 	for _, x := range a {
    221 		b = append(b, x)
    222 	}
    223 	return b
    224 }
    225 
    226 // Typecheck1 is the recursive form of typecheck.
    227 // It is like typecheck but adds to the information in typeof
    228 // instead of allocating a new map.
    229 func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) {
    230 	// set sets the type of n to typ.
    231 	// If isDecl is true, n is being declared.
    232 	set := func(n ast.Expr, typ string, isDecl bool) {
    233 		if typeof[n] != "" || typ == "" {
    234 			if typeof[n] != typ {
    235 				assign[typ] = append(assign[typ], n)
    236 			}
    237 			return
    238 		}
    239 		typeof[n] = typ
    240 
    241 		// If we obtained typ from the declaration of x
    242 		// propagate the type to all the uses.
    243 		// The !isDecl case is a cheat here, but it makes
    244 		// up in some cases for not paying attention to
    245 		// struct fields. The real type checker will be
    246 		// more accurate so we won't need the cheat.
    247 		if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
    248 			typeof[id.Obj] = typ
    249 		}
    250 	}
    251 
    252 	// Type-check an assignment lhs = rhs.
    253 	// If isDecl is true, this is := so we can update
    254 	// the types of the objects that lhs refers to.
    255 	typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
    256 		if len(lhs) > 1 && len(rhs) == 1 {
    257 			if _, ok := rhs[0].(*ast.CallExpr); ok {
    258 				t := split(typeof[rhs[0]])
    259 				// Lists should have same length but may not; pair what can be paired.
    260 				for i := 0; i < len(lhs) && i < len(t); i++ {
    261 					set(lhs[i], t[i], isDecl)
    262 				}
    263 				return
    264 			}
    265 		}
    266 		if len(lhs) == 1 && len(rhs) == 2 {
    267 			// x = y, ok
    268 			rhs = rhs[:1]
    269 		} else if len(lhs) == 2 && len(rhs) == 1 {
    270 			// x, ok = y
    271 			lhs = lhs[:1]
    272 		}
    273 
    274 		// Match as much as we can.
    275 		for i := 0; i < len(lhs) && i < len(rhs); i++ {
    276 			x, y := lhs[i], rhs[i]
    277 			if typeof[y] != "" {
    278 				set(x, typeof[y], isDecl)
    279 			} else {
    280 				set(y, typeof[x], false)
    281 			}
    282 		}
    283 	}
    284 
    285 	expand := func(s string) string {
    286 		typ := cfg.Type[s]
    287 		if typ != nil && typ.Def != "" {
    288 			return typ.Def
    289 		}
    290 		return s
    291 	}
    292 
    293 	// The main type check is a recursive algorithm implemented
    294 	// by walkBeforeAfter(n, before, after).
    295 	// Most of it is bottom-up, but in a few places we need
    296 	// to know the type of the function we are checking.
    297 	// The before function records that information on
    298 	// the curfn stack.
    299 	var curfn []*ast.FuncType
    300 
    301 	before := func(n interface{}) {
    302 		// push function type on stack
    303 		switch n := n.(type) {
    304 		case *ast.FuncDecl:
    305 			curfn = append(curfn, n.Type)
    306 		case *ast.FuncLit:
    307 			curfn = append(curfn, n.Type)
    308 		}
    309 	}
    310 
    311 	// After is the real type checker.
    312 	after := func(n interface{}) {
    313 		if n == nil {
    314 			return
    315 		}
    316 		if false && reflect.TypeOf(n).Kind() == reflect.Ptr { // debugging trace
    317 			defer func() {
    318 				if t := typeof[n]; t != "" {
    319 					pos := fset.Position(n.(ast.Node).Pos())
    320 					fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
    321 				}
    322 			}()
    323 		}
    324 
    325 		switch n := n.(type) {
    326 		case *ast.FuncDecl, *ast.FuncLit:
    327 			// pop function type off stack
    328 			curfn = curfn[:len(curfn)-1]
    329 
    330 		case *ast.FuncType:
    331 			typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
    332 
    333 		case *ast.FieldList:
    334 			// Field list is concatenation of sub-lists.
    335 			t := ""
    336 			for _, field := range n.List {
    337 				if t != "" {
    338 					t += ", "
    339 				}
    340 				t += typeof[field]
    341 			}
    342 			typeof[n] = t
    343 
    344 		case *ast.Field:
    345 			// Field is one instance of the type per name.
    346 			all := ""
    347 			t := typeof[n.Type]
    348 			if !isType(t) {
    349 				// Create a type, because it is typically *T or *p.T
    350 				// and we might care about that type.
    351 				t = mkType(gofmt(n.Type))
    352 				typeof[n.Type] = t
    353 			}
    354 			t = getType(t)
    355 			if len(n.Names) == 0 {
    356 				all = t
    357 			} else {
    358 				for _, id := range n.Names {
    359 					if all != "" {
    360 						all += ", "
    361 					}
    362 					all += t
    363 					typeof[id.Obj] = t
    364 					typeof[id] = t
    365 				}
    366 			}
    367 			typeof[n] = all
    368 
    369 		case *ast.ValueSpec:
    370 			// var declaration. Use type if present.
    371 			if n.Type != nil {
    372 				t := typeof[n.Type]
    373 				if !isType(t) {
    374 					t = mkType(gofmt(n.Type))
    375 					typeof[n.Type] = t
    376 				}
    377 				t = getType(t)
    378 				for _, id := range n.Names {
    379 					set(id, t, true)
    380 				}
    381 			}
    382 			// Now treat same as assignment.
    383 			typecheckAssign(makeExprList(n.Names), n.Values, true)
    384 
    385 		case *ast.AssignStmt:
    386 			typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
    387 
    388 		case *ast.Ident:
    389 			// Identifier can take its type from underlying object.
    390 			if t := typeof[n.Obj]; t != "" {
    391 				typeof[n] = t
    392 			}
    393 
    394 		case *ast.SelectorExpr:
    395 			// Field or method.
    396 			name := n.Sel.Name
    397 			if t := typeof[n.X]; t != "" {
    398 				t = strings.TrimPrefix(t, "*") // implicit *
    399 				if typ := cfg.Type[t]; typ != nil {
    400 					if t := typ.dot(cfg, name); t != "" {
    401 						typeof[n] = t
    402 						return
    403 					}
    404 				}
    405 				tt := typeof[t+"."+name]
    406 				if isType(tt) {
    407 					typeof[n] = getType(tt)
    408 					return
    409 				}
    410 			}
    411 			// Package selector.
    412 			if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
    413 				str := x.Name + "." + name
    414 				if cfg.Type[str] != nil {
    415 					typeof[n] = mkType(str)
    416 					return
    417 				}
    418 				if t := cfg.typeof(x.Name + "." + name); t != "" {
    419 					typeof[n] = t
    420 					return
    421 				}
    422 			}
    423 
    424 		case *ast.CallExpr:
    425 			// make(T) has type T.
    426 			if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
    427 				typeof[n] = gofmt(n.Args[0])
    428 				return
    429 			}
    430 			// new(T) has type *T
    431 			if isTopName(n.Fun, "new") && len(n.Args) == 1 {
    432 				typeof[n] = "*" + gofmt(n.Args[0])
    433 				return
    434 			}
    435 			// Otherwise, use type of function to determine arguments.
    436 			t := typeof[n.Fun]
    437 			in, out := splitFunc(t)
    438 			if in == nil && out == nil {
    439 				return
    440 			}
    441 			typeof[n] = join(out)
    442 			for i, arg := range n.Args {
    443 				if i >= len(in) {
    444 					break
    445 				}
    446 				if typeof[arg] == "" {
    447 					typeof[arg] = in[i]
    448 				}
    449 			}
    450 
    451 		case *ast.TypeAssertExpr:
    452 			// x.(type) has type of x.
    453 			if n.Type == nil {
    454 				typeof[n] = typeof[n.X]
    455 				return
    456 			}
    457 			// x.(T) has type T.
    458 			if t := typeof[n.Type]; isType(t) {
    459 				typeof[n] = getType(t)
    460 			} else {
    461 				typeof[n] = gofmt(n.Type)
    462 			}
    463 
    464 		case *ast.SliceExpr:
    465 			// x[i:j] has type of x.
    466 			typeof[n] = typeof[n.X]
    467 
    468 		case *ast.IndexExpr:
    469 			// x[i] has key type of x's type.
    470 			t := expand(typeof[n.X])
    471 			if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
    472 				// Lazy: assume there are no nested [] in the array
    473 				// length or map key type.
    474 				if i := strings.Index(t, "]"); i >= 0 {
    475 					typeof[n] = t[i+1:]
    476 				}
    477 			}
    478 
    479 		case *ast.StarExpr:
    480 			// *x for x of type *T has type T when x is an expr.
    481 			// We don't use the result when *x is a type, but
    482 			// compute it anyway.
    483 			t := expand(typeof[n.X])
    484 			if isType(t) {
    485 				typeof[n] = "type *" + getType(t)
    486 			} else if strings.HasPrefix(t, "*") {
    487 				typeof[n] = t[len("*"):]
    488 			}
    489 
    490 		case *ast.UnaryExpr:
    491 			// &x for x of type T has type *T.
    492 			t := typeof[n.X]
    493 			if t != "" && n.Op == token.AND {
    494 				typeof[n] = "*" + t
    495 			}
    496 
    497 		case *ast.CompositeLit:
    498 			// T{...} has type T.
    499 			typeof[n] = gofmt(n.Type)
    500 
    501 		case *ast.ParenExpr:
    502 			// (x) has type of x.
    503 			typeof[n] = typeof[n.X]
    504 
    505 		case *ast.RangeStmt:
    506 			t := expand(typeof[n.X])
    507 			if t == "" {
    508 				return
    509 			}
    510 			var key, value string
    511 			if t == "string" {
    512 				key, value = "int", "rune"
    513 			} else if strings.HasPrefix(t, "[") {
    514 				key = "int"
    515 				if i := strings.Index(t, "]"); i >= 0 {
    516 					value = t[i+1:]
    517 				}
    518 			} else if strings.HasPrefix(t, "map[") {
    519 				if i := strings.Index(t, "]"); i >= 0 {
    520 					key, value = t[4:i], t[i+1:]
    521 				}
    522 			}
    523 			changed := false
    524 			if n.Key != nil && key != "" {
    525 				changed = true
    526 				set(n.Key, key, n.Tok == token.DEFINE)
    527 			}
    528 			if n.Value != nil && value != "" {
    529 				changed = true
    530 				set(n.Value, value, n.Tok == token.DEFINE)
    531 			}
    532 			// Ugly failure of vision: already type-checked body.
    533 			// Do it again now that we have that type info.
    534 			if changed {
    535 				typecheck1(cfg, n.Body, typeof, assign)
    536 			}
    537 
    538 		case *ast.TypeSwitchStmt:
    539 			// Type of variable changes for each case in type switch,
    540 			// but go/parser generates just one variable.
    541 			// Repeat type check for each case with more precise
    542 			// type information.
    543 			as, ok := n.Assign.(*ast.AssignStmt)
    544 			if !ok {
    545 				return
    546 			}
    547 			varx, ok := as.Lhs[0].(*ast.Ident)
    548 			if !ok {
    549 				return
    550 			}
    551 			t := typeof[varx]
    552 			for _, cas := range n.Body.List {
    553 				cas := cas.(*ast.CaseClause)
    554 				if len(cas.List) == 1 {
    555 					// Variable has specific type only when there is
    556 					// exactly one type in the case list.
    557 					if tt := typeof[cas.List[0]]; isType(tt) {
    558 						tt = getType(tt)
    559 						typeof[varx] = tt
    560 						typeof[varx.Obj] = tt
    561 						typecheck1(cfg, cas.Body, typeof, assign)
    562 					}
    563 				}
    564 			}
    565 			// Restore t.
    566 			typeof[varx] = t
    567 			typeof[varx.Obj] = t
    568 
    569 		case *ast.ReturnStmt:
    570 			if len(curfn) == 0 {
    571 				// Probably can't happen.
    572 				return
    573 			}
    574 			f := curfn[len(curfn)-1]
    575 			res := n.Results
    576 			if f.Results != nil {
    577 				t := split(typeof[f.Results])
    578 				for i := 0; i < len(res) && i < len(t); i++ {
    579 					set(res[i], t[i], false)
    580 				}
    581 			}
    582 		}
    583 	}
    584 	walkBeforeAfter(f, before, after)
    585 }
    586 
    587 // Convert between function type strings and lists of types.
    588 // Using strings makes this a little harder, but it makes
    589 // a lot of the rest of the code easier. This will all go away
    590 // when we can use go/typechecker directly.
    591 
    592 // splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"].
    593 func splitFunc(s string) (in, out []string) {
    594 	if !strings.HasPrefix(s, "func(") {
    595 		return nil, nil
    596 	}
    597 
    598 	i := len("func(") // index of beginning of 'in' arguments
    599 	nparen := 0
    600 	for j := i; j < len(s); j++ {
    601 		switch s[j] {
    602 		case '(':
    603 			nparen++
    604 		case ')':
    605 			nparen--
    606 			if nparen < 0 {
    607 				// found end of parameter list
    608 				out := strings.TrimSpace(s[j+1:])
    609 				if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
    610 					out = out[1 : len(out)-1]
    611 				}
    612 				return split(s[i:j]), split(out)
    613 			}
    614 		}
    615 	}
    616 	return nil, nil
    617 }
    618 
    619 // joinFunc is the inverse of splitFunc.
    620 func joinFunc(in, out []string) string {
    621 	outs := ""
    622 	if len(out) == 1 {
    623 		outs = " " + out[0]
    624 	} else if len(out) > 1 {
    625 		outs = " (" + join(out) + ")"
    626 	}
    627 	return "func(" + join(in) + ")" + outs
    628 }
    629 
    630 // split splits "int, float" into ["int", "float"] and splits "" into [].
    631 func split(s string) []string {
    632 	out := []string{}
    633 	i := 0 // current type being scanned is s[i:j].
    634 	nparen := 0
    635 	for j := 0; j < len(s); j++ {
    636 		switch s[j] {
    637 		case ' ':
    638 			if i == j {
    639 				i++
    640 			}
    641 		case '(':
    642 			nparen++
    643 		case ')':
    644 			nparen--
    645 			if nparen < 0 {
    646 				// probably can't happen
    647 				return nil
    648 			}
    649 		case ',':
    650 			if nparen == 0 {
    651 				if i < j {
    652 					out = append(out, s[i:j])
    653 				}
    654 				i = j + 1
    655 			}
    656 		}
    657 	}
    658 	if nparen != 0 {
    659 		// probably can't happen
    660 		return nil
    661 	}
    662 	if i < len(s) {
    663 		out = append(out, s[i:])
    664 	}
    665 	return out
    666 }
    667 
    668 // join is the inverse of split.
    669 func join(x []string) string {
    670 	return strings.Join(x, ", ")
    671 }
    672