Home | History | Annotate | Download | only in vet
      1 // Copyright 2016 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package main
      6 
      7 import (
      8 	"cmd/vet/internal/cfg"
      9 	"fmt"
     10 	"go/ast"
     11 	"go/types"
     12 	"strconv"
     13 )
     14 
     15 func init() {
     16 	register("lostcancel",
     17 		"check for failure to call cancelation function returned by context.WithCancel",
     18 		checkLostCancel,
     19 		funcDecl, funcLit)
     20 }
     21 
     22 const debugLostCancel = false
     23 
     24 var contextPackage = "context"
     25 
     26 // checkLostCancel reports a failure to the call the cancel function
     27 // returned by context.WithCancel, either because the variable was
     28 // assigned to the blank identifier, or because there exists a
     29 // control-flow path from the call to a return statement and that path
     30 // does not "use" the cancel function.  Any reference to the variable
     31 // counts as a use, even within a nested function literal.
     32 //
     33 // checkLostCancel analyzes a single named or literal function.
     34 func checkLostCancel(f *File, node ast.Node) {
     35 	// Fast path: bypass check if file doesn't use context.WithCancel.
     36 	if !hasImport(f.file, contextPackage) {
     37 		return
     38 	}
     39 
     40 	// Maps each cancel variable to its defining ValueSpec/AssignStmt.
     41 	cancelvars := make(map[*types.Var]ast.Node)
     42 
     43 	// Find the set of cancel vars to analyze.
     44 	stack := make([]ast.Node, 0, 32)
     45 	ast.Inspect(node, func(n ast.Node) bool {
     46 		switch n.(type) {
     47 		case *ast.FuncLit:
     48 			if len(stack) > 0 {
     49 				return false // don't stray into nested functions
     50 			}
     51 		case nil:
     52 			stack = stack[:len(stack)-1] // pop
     53 			return true
     54 		}
     55 		stack = append(stack, n) // push
     56 
     57 		// Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]:
     58 		//
     59 		//   ctx, cancel    := context.WithCancel(...)
     60 		//   ctx, cancel     = context.WithCancel(...)
     61 		//   var ctx, cancel = context.WithCancel(...)
     62 		//
     63 		if isContextWithCancel(f, n) && isCall(stack[len(stack)-2]) {
     64 			var id *ast.Ident // id of cancel var
     65 			stmt := stack[len(stack)-3]
     66 			switch stmt := stmt.(type) {
     67 			case *ast.ValueSpec:
     68 				if len(stmt.Names) > 1 {
     69 					id = stmt.Names[1]
     70 				}
     71 			case *ast.AssignStmt:
     72 				if len(stmt.Lhs) > 1 {
     73 					id, _ = stmt.Lhs[1].(*ast.Ident)
     74 				}
     75 			}
     76 			if id != nil {
     77 				if id.Name == "_" {
     78 					f.Badf(id.Pos(), "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
     79 						n.(*ast.SelectorExpr).Sel.Name)
     80 				} else if v, ok := f.pkg.uses[id].(*types.Var); ok {
     81 					cancelvars[v] = stmt
     82 				} else if v, ok := f.pkg.defs[id].(*types.Var); ok {
     83 					cancelvars[v] = stmt
     84 				}
     85 			}
     86 		}
     87 
     88 		return true
     89 	})
     90 
     91 	if len(cancelvars) == 0 {
     92 		return // no need to build CFG
     93 	}
     94 
     95 	// Tell the CFG builder which functions never return.
     96 	info := &types.Info{Uses: f.pkg.uses, Selections: f.pkg.selectors}
     97 	mayReturn := func(call *ast.CallExpr) bool {
     98 		name := callName(info, call)
     99 		return !noReturnFuncs[name]
    100 	}
    101 
    102 	// Build the CFG.
    103 	var g *cfg.CFG
    104 	var sig *types.Signature
    105 	switch node := node.(type) {
    106 	case *ast.FuncDecl:
    107 		sig, _ = f.pkg.defs[node.Name].Type().(*types.Signature)
    108 		g = cfg.New(node.Body, mayReturn)
    109 	case *ast.FuncLit:
    110 		sig, _ = f.pkg.types[node.Type].Type.(*types.Signature)
    111 		g = cfg.New(node.Body, mayReturn)
    112 	}
    113 
    114 	// Print CFG.
    115 	if debugLostCancel {
    116 		fmt.Println(g.Format(f.fset))
    117 	}
    118 
    119 	// Examine the CFG for each variable in turn.
    120 	// (It would be more efficient to analyze all cancelvars in a
    121 	// single pass over the AST, but seldom is there more than one.)
    122 	for v, stmt := range cancelvars {
    123 		if ret := lostCancelPath(f, g, v, stmt, sig); ret != nil {
    124 			lineno := f.fset.Position(stmt.Pos()).Line
    125 			f.Badf(stmt.Pos(), "the %s function is not used on all paths (possible context leak)", v.Name())
    126 			f.Badf(ret.Pos(), "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno)
    127 		}
    128 	}
    129 }
    130 
    131 func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
    132 
    133 func hasImport(f *ast.File, path string) bool {
    134 	for _, imp := range f.Imports {
    135 		v, _ := strconv.Unquote(imp.Path.Value)
    136 		if v == path {
    137 			return true
    138 		}
    139 	}
    140 	return false
    141 }
    142 
    143 // isContextWithCancel reports whether n is one of the qualified identifiers
    144 // context.With{Cancel,Timeout,Deadline}.
    145 func isContextWithCancel(f *File, n ast.Node) bool {
    146 	if sel, ok := n.(*ast.SelectorExpr); ok {
    147 		switch sel.Sel.Name {
    148 		case "WithCancel", "WithTimeout", "WithDeadline":
    149 			if x, ok := sel.X.(*ast.Ident); ok {
    150 				if pkgname, ok := f.pkg.uses[x].(*types.PkgName); ok {
    151 					return pkgname.Imported().Path() == contextPackage
    152 				}
    153 				// Import failed, so we can't check package path.
    154 				// Just check the local package name (heuristic).
    155 				return x.Name == "context"
    156 			}
    157 		}
    158 	}
    159 	return false
    160 }
    161 
    162 // lostCancelPath finds a path through the CFG, from stmt (which defines
    163 // the 'cancel' variable v) to a return statement, that doesn't "use" v.
    164 // If it finds one, it returns the return statement (which may be synthetic).
    165 // sig is the function's type, if known.
    166 func lostCancelPath(f *File, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
    167 	vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
    168 
    169 	// uses reports whether stmts contain a "use" of variable v.
    170 	uses := func(f *File, v *types.Var, stmts []ast.Node) bool {
    171 		found := false
    172 		for _, stmt := range stmts {
    173 			ast.Inspect(stmt, func(n ast.Node) bool {
    174 				switch n := n.(type) {
    175 				case *ast.Ident:
    176 					if f.pkg.uses[n] == v {
    177 						found = true
    178 					}
    179 				case *ast.ReturnStmt:
    180 					// A naked return statement counts as a use
    181 					// of the named result variables.
    182 					if n.Results == nil && vIsNamedResult {
    183 						found = true
    184 					}
    185 				}
    186 				return !found
    187 			})
    188 		}
    189 		return found
    190 	}
    191 
    192 	// blockUses computes "uses" for each block, caching the result.
    193 	memo := make(map[*cfg.Block]bool)
    194 	blockUses := func(f *File, v *types.Var, b *cfg.Block) bool {
    195 		res, ok := memo[b]
    196 		if !ok {
    197 			res = uses(f, v, b.Nodes)
    198 			memo[b] = res
    199 		}
    200 		return res
    201 	}
    202 
    203 	// Find the var's defining block in the CFG,
    204 	// plus the rest of the statements of that block.
    205 	var defblock *cfg.Block
    206 	var rest []ast.Node
    207 outer:
    208 	for _, b := range g.Blocks {
    209 		for i, n := range b.Nodes {
    210 			if n == stmt {
    211 				defblock = b
    212 				rest = b.Nodes[i+1:]
    213 				break outer
    214 			}
    215 		}
    216 	}
    217 	if defblock == nil {
    218 		panic("internal error: can't find defining block for cancel var")
    219 	}
    220 
    221 	// Is v "used" in the remainder of its defining block?
    222 	if uses(f, v, rest) {
    223 		return nil
    224 	}
    225 
    226 	// Does the defining block return without using v?
    227 	if ret := defblock.Return(); ret != nil {
    228 		return ret
    229 	}
    230 
    231 	// Search the CFG depth-first for a path, from defblock to a
    232 	// return block, in which v is never "used".
    233 	seen := make(map[*cfg.Block]bool)
    234 	var search func(blocks []*cfg.Block) *ast.ReturnStmt
    235 	search = func(blocks []*cfg.Block) *ast.ReturnStmt {
    236 		for _, b := range blocks {
    237 			if !seen[b] {
    238 				seen[b] = true
    239 
    240 				// Prune the search if the block uses v.
    241 				if blockUses(f, v, b) {
    242 					continue
    243 				}
    244 
    245 				// Found path to return statement?
    246 				if ret := b.Return(); ret != nil {
    247 					if debugLostCancel {
    248 						fmt.Printf("found path to return in block %s\n", b)
    249 					}
    250 					return ret // found
    251 				}
    252 
    253 				// Recur
    254 				if ret := search(b.Succs); ret != nil {
    255 					if debugLostCancel {
    256 						fmt.Printf(" from block %s\n", b)
    257 					}
    258 					return ret
    259 				}
    260 			}
    261 		}
    262 		return nil
    263 	}
    264 	return search(defblock.Succs)
    265 }
    266 
    267 func tupleContains(tuple *types.Tuple, v *types.Var) bool {
    268 	for i := 0; i < tuple.Len(); i++ {
    269 		if tuple.At(i) == v {
    270 			return true
    271 		}
    272 	}
    273 	return false
    274 }
    275 
    276 var noReturnFuncs = map[string]bool{
    277 	"(*testing.common).FailNow": true,
    278 	"(*testing.common).Fatal":   true,
    279 	"(*testing.common).Fatalf":  true,
    280 	"(*testing.common).Skip":    true,
    281 	"(*testing.common).SkipNow": true,
    282 	"(*testing.common).Skipf":   true,
    283 	"log.Fatal":                 true,
    284 	"log.Fatalf":                true,
    285 	"log.Fatalln":               true,
    286 	"os.Exit":                   true,
    287 	"panic":                     true,
    288 	"runtime.Goexit":            true,
    289 }
    290 
    291 // callName returns the canonical name of the builtin, method, or
    292 // function called by call, if known.
    293 func callName(info *types.Info, call *ast.CallExpr) string {
    294 	switch fun := call.Fun.(type) {
    295 	case *ast.Ident:
    296 		// builtin, e.g. "panic"
    297 		if obj, ok := info.Uses[fun].(*types.Builtin); ok {
    298 			return obj.Name()
    299 		}
    300 	case *ast.SelectorExpr:
    301 		if sel, ok := info.Selections[fun]; ok && sel.Kind() == types.MethodVal {
    302 			// method call, e.g. "(*testing.common).Fatal"
    303 			meth := sel.Obj()
    304 			return fmt.Sprintf("(%s).%s",
    305 				meth.Type().(*types.Signature).Recv().Type(),
    306 				meth.Name())
    307 		}
    308 		if obj, ok := info.Uses[fun.Sel]; ok {
    309 			// qualified identifier, e.g. "os.Exit"
    310 			return fmt.Sprintf("%s.%s",
    311 				obj.Pkg().Path(),
    312 				obj.Name())
    313 		}
    314 	}
    315 
    316 	// function with no name, or defined in missing imported package
    317 	return ""
    318 }
    319