...

Source file src/golang.org/x/tools/go/analysis/passes/loopclosure/loopclosure.go

Documentation: golang.org/x/tools/go/analysis/passes/loopclosure

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package loopclosure defines an Analyzer that checks for references to
     6  // enclosing loop variables from within nested functions.
     7  package loopclosure
     8  
     9  import (
    10  	"go/ast"
    11  	"go/types"
    12  
    13  	"golang.org/x/tools/go/analysis"
    14  	"golang.org/x/tools/go/analysis/passes/inspect"
    15  	"golang.org/x/tools/go/ast/inspector"
    16  	"golang.org/x/tools/go/types/typeutil"
    17  )
    18  
    19  const Doc = `check references to loop variables from within nested functions
    20  
    21  This analyzer reports places where a function literal references the
    22  iteration variable of an enclosing loop, and the loop calls the function
    23  in such a way (e.g. with go or defer) that it may outlive the loop
    24  iteration and possibly observe the wrong value of the variable.
    25  
    26  In this example, all the deferred functions run after the loop has
    27  completed, so all observe the final value of v.
    28  
    29      for _, v := range list {
    30          defer func() {
    31              use(v) // incorrect
    32          }()
    33      }
    34  
    35  One fix is to create a new variable for each iteration of the loop:
    36  
    37      for _, v := range list {
    38          v := v // new var per iteration
    39          defer func() {
    40              use(v) // ok
    41          }()
    42      }
    43  
    44  The next example uses a go statement and has a similar problem.
    45  In addition, it has a data race because the loop updates v
    46  concurrent with the goroutines accessing it.
    47  
    48      for _, v := range elem {
    49          go func() {
    50              use(v)  // incorrect, and a data race
    51          }()
    52      }
    53  
    54  A fix is the same as before. The checker also reports problems
    55  in goroutines started by golang.org/x/sync/errgroup.Group.
    56  A hard-to-spot variant of this form is common in parallel tests:
    57  
    58      func Test(t *testing.T) {
    59          for _, test := range tests {
    60              t.Run(test.name, func(t *testing.T) {
    61                  t.Parallel()
    62                  use(test) // incorrect, and a data race
    63              })
    64          }
    65      }
    66  
    67  The t.Parallel() call causes the rest of the function to execute
    68  concurrent with the loop.
    69  
    70  The analyzer reports references only in the last statement,
    71  as it is not deep enough to understand the effects of subsequent
    72  statements that might render the reference benign.
    73  ("Last statement" is defined recursively in compound
    74  statements such as if, switch, and select.)
    75  
    76  See: https://golang.org/doc/go_faq.html#closures_and_goroutines`
    77  
    78  var Analyzer = &analysis.Analyzer{
    79  	Name:     "loopclosure",
    80  	Doc:      Doc,
    81  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    82  	Run:      run,
    83  }
    84  
    85  func run(pass *analysis.Pass) (interface{}, error) {
    86  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    87  
    88  	nodeFilter := []ast.Node{
    89  		(*ast.RangeStmt)(nil),
    90  		(*ast.ForStmt)(nil),
    91  	}
    92  	inspect.Preorder(nodeFilter, func(n ast.Node) {
    93  		// Find the variables updated by the loop statement.
    94  		var vars []types.Object
    95  		addVar := func(expr ast.Expr) {
    96  			if id, _ := expr.(*ast.Ident); id != nil {
    97  				if obj := pass.TypesInfo.ObjectOf(id); obj != nil {
    98  					vars = append(vars, obj)
    99  				}
   100  			}
   101  		}
   102  		var body *ast.BlockStmt
   103  		switch n := n.(type) {
   104  		case *ast.RangeStmt:
   105  			body = n.Body
   106  			addVar(n.Key)
   107  			addVar(n.Value)
   108  		case *ast.ForStmt:
   109  			body = n.Body
   110  			switch post := n.Post.(type) {
   111  			case *ast.AssignStmt:
   112  				// e.g. for p = head; p != nil; p = p.next
   113  				for _, lhs := range post.Lhs {
   114  					addVar(lhs)
   115  				}
   116  			case *ast.IncDecStmt:
   117  				// e.g. for i := 0; i < n; i++
   118  				addVar(post.X)
   119  			}
   120  		}
   121  		if vars == nil {
   122  			return
   123  		}
   124  
   125  		// Inspect statements to find function literals that may be run outside of
   126  		// the current loop iteration.
   127  		//
   128  		// For go, defer, and errgroup.Group.Go, we ignore all but the last
   129  		// statement, because it's hard to prove go isn't followed by wait, or
   130  		// defer by return. "Last" is defined recursively.
   131  		//
   132  		// TODO: consider allowing the "last" go/defer/Go statement to be followed by
   133  		// N "trivial" statements, possibly under a recursive definition of "trivial"
   134  		// so that that checker could, for example, conclude that a go statement is
   135  		// followed by an if statement made of only trivial statements and trivial expressions,
   136  		// and hence the go statement could still be checked.
   137  		forEachLastStmt(body.List, func(last ast.Stmt) {
   138  			var stmts []ast.Stmt
   139  			switch s := last.(type) {
   140  			case *ast.GoStmt:
   141  				stmts = litStmts(s.Call.Fun)
   142  			case *ast.DeferStmt:
   143  				stmts = litStmts(s.Call.Fun)
   144  			case *ast.ExprStmt: // check for errgroup.Group.Go
   145  				if call, ok := s.X.(*ast.CallExpr); ok {
   146  					stmts = litStmts(goInvoke(pass.TypesInfo, call))
   147  				}
   148  			}
   149  			for _, stmt := range stmts {
   150  				reportCaptured(pass, vars, stmt)
   151  			}
   152  		})
   153  
   154  		// Also check for testing.T.Run (with T.Parallel).
   155  		// We consider every t.Run statement in the loop body, because there is
   156  		// no commonly used mechanism for synchronizing parallel subtests.
   157  		// It is of course theoretically possible to synchronize parallel subtests,
   158  		// though such a pattern is likely to be exceedingly rare as it would be
   159  		// fighting against the test runner.
   160  		for _, s := range body.List {
   161  			switch s := s.(type) {
   162  			case *ast.ExprStmt:
   163  				if call, ok := s.X.(*ast.CallExpr); ok {
   164  					for _, stmt := range parallelSubtest(pass.TypesInfo, call) {
   165  						reportCaptured(pass, vars, stmt)
   166  					}
   167  
   168  				}
   169  			}
   170  		}
   171  	})
   172  	return nil, nil
   173  }
   174  
   175  // reportCaptured reports a diagnostic stating a loop variable
   176  // has been captured by a func literal if checkStmt has escaping
   177  // references to vars. vars is expected to be variables updated by a loop statement,
   178  // and checkStmt is expected to be a statements from the body of a func literal in the loop.
   179  func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) {
   180  	ast.Inspect(checkStmt, func(n ast.Node) bool {
   181  		id, ok := n.(*ast.Ident)
   182  		if !ok {
   183  			return true
   184  		}
   185  		obj := pass.TypesInfo.Uses[id]
   186  		if obj == nil {
   187  			return true
   188  		}
   189  		for _, v := range vars {
   190  			if v == obj {
   191  				pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
   192  			}
   193  		}
   194  		return true
   195  	})
   196  }
   197  
   198  // forEachLastStmt calls onLast on each "last" statement in a list of statements.
   199  // "Last" is defined recursively so, for example, if the last statement is
   200  // a switch statement, then each switch case is also visited to examine
   201  // its last statements.
   202  func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) {
   203  	if len(stmts) == 0 {
   204  		return
   205  	}
   206  
   207  	s := stmts[len(stmts)-1]
   208  	switch s := s.(type) {
   209  	case *ast.IfStmt:
   210  	loop:
   211  		for {
   212  			forEachLastStmt(s.Body.List, onLast)
   213  			switch e := s.Else.(type) {
   214  			case *ast.BlockStmt:
   215  				forEachLastStmt(e.List, onLast)
   216  				break loop
   217  			case *ast.IfStmt:
   218  				s = e
   219  			case nil:
   220  				break loop
   221  			}
   222  		}
   223  	case *ast.ForStmt:
   224  		forEachLastStmt(s.Body.List, onLast)
   225  	case *ast.RangeStmt:
   226  		forEachLastStmt(s.Body.List, onLast)
   227  	case *ast.SwitchStmt:
   228  		for _, c := range s.Body.List {
   229  			cc := c.(*ast.CaseClause)
   230  			forEachLastStmt(cc.Body, onLast)
   231  		}
   232  	case *ast.TypeSwitchStmt:
   233  		for _, c := range s.Body.List {
   234  			cc := c.(*ast.CaseClause)
   235  			forEachLastStmt(cc.Body, onLast)
   236  		}
   237  	case *ast.SelectStmt:
   238  		for _, c := range s.Body.List {
   239  			cc := c.(*ast.CommClause)
   240  			forEachLastStmt(cc.Body, onLast)
   241  		}
   242  	default:
   243  		onLast(s)
   244  	}
   245  }
   246  
   247  // litStmts returns all statements from the function body of a function
   248  // literal.
   249  //
   250  // If fun is not a function literal, it returns nil.
   251  func litStmts(fun ast.Expr) []ast.Stmt {
   252  	lit, _ := fun.(*ast.FuncLit)
   253  	if lit == nil {
   254  		return nil
   255  	}
   256  	return lit.Body.List
   257  }
   258  
   259  // goInvoke returns a function expression that would be called asynchronously
   260  // (but not awaited) in another goroutine as a consequence of the call.
   261  // For example, given the g.Go call below, it returns the function literal expression.
   262  //
   263  //	import "sync/errgroup"
   264  //	var g errgroup.Group
   265  //	g.Go(func() error { ... })
   266  //
   267  // Currently only "golang.org/x/sync/errgroup.Group()" is considered.
   268  func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr {
   269  	if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") {
   270  		return nil
   271  	}
   272  	return call.Args[0]
   273  }
   274  
   275  // parallelSubtest returns statements that can be easily proven to execute
   276  // concurrently via the go test runner, as t.Run has been invoked with a
   277  // function literal that calls t.Parallel.
   278  //
   279  // In practice, users rely on the fact that statements before the call to
   280  // t.Parallel are synchronous. For example by declaring test := test inside the
   281  // function literal, but before the call to t.Parallel.
   282  //
   283  // Therefore, we only flag references in statements that are obviously
   284  // dominated by a call to t.Parallel. As a simple heuristic, we only consider
   285  // statements following the final labeled statement in the function body, to
   286  // avoid scenarios where a jump would cause either the call to t.Parallel or
   287  // the problematic reference to be skipped.
   288  //
   289  //	import "testing"
   290  //
   291  //	func TestFoo(t *testing.T) {
   292  //		tests := []int{0, 1, 2}
   293  //		for i, test := range tests {
   294  //			t.Run("subtest", func(t *testing.T) {
   295  //				println(i, test) // OK
   296  //		 		t.Parallel()
   297  //				println(i, test) // Not OK
   298  //			})
   299  //		}
   300  //	}
   301  func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt {
   302  	if !isMethodCall(info, call, "testing", "T", "Run") {
   303  		return nil
   304  	}
   305  
   306  	lit, _ := call.Args[1].(*ast.FuncLit)
   307  	if lit == nil {
   308  		return nil
   309  	}
   310  
   311  	// Capture the *testing.T object for the first argument to the function
   312  	// literal.
   313  	if len(lit.Type.Params.List[0].Names) == 0 {
   314  		return nil
   315  	}
   316  
   317  	tObj := info.Defs[lit.Type.Params.List[0].Names[0]]
   318  	if tObj == nil {
   319  		return nil
   320  	}
   321  
   322  	// Match statements that occur after a call to t.Parallel following the final
   323  	// labeled statement in the function body.
   324  	//
   325  	// We iterate over lit.Body.List to have a simple, fast and "frequent enough"
   326  	// dominance relationship for t.Parallel(): lit.Body.List[i] dominates
   327  	// lit.Body.List[j] for i < j unless there is a jump.
   328  	var stmts []ast.Stmt
   329  	afterParallel := false
   330  	for _, stmt := range lit.Body.List {
   331  		stmt, labeled := unlabel(stmt)
   332  		if labeled {
   333  			// Reset: naively we don't know if a jump could have caused the
   334  			// previously considered statements to be skipped.
   335  			stmts = nil
   336  			afterParallel = false
   337  		}
   338  
   339  		if afterParallel {
   340  			stmts = append(stmts, stmt)
   341  			continue
   342  		}
   343  
   344  		// Check if stmt is a call to t.Parallel(), for the correct t.
   345  		exprStmt, ok := stmt.(*ast.ExprStmt)
   346  		if !ok {
   347  			continue
   348  		}
   349  		expr := exprStmt.X
   350  		if isMethodCall(info, expr, "testing", "T", "Parallel") {
   351  			call, _ := expr.(*ast.CallExpr)
   352  			if call == nil {
   353  				continue
   354  			}
   355  			x, _ := call.Fun.(*ast.SelectorExpr)
   356  			if x == nil {
   357  				continue
   358  			}
   359  			id, _ := x.X.(*ast.Ident)
   360  			if id == nil {
   361  				continue
   362  			}
   363  			if info.Uses[id] == tObj {
   364  				afterParallel = true
   365  			}
   366  		}
   367  	}
   368  
   369  	return stmts
   370  }
   371  
   372  // unlabel returns the inner statement for the possibly labeled statement stmt,
   373  // stripping any (possibly nested) *ast.LabeledStmt wrapper.
   374  //
   375  // The second result reports whether stmt was an *ast.LabeledStmt.
   376  func unlabel(stmt ast.Stmt) (ast.Stmt, bool) {
   377  	labeled := false
   378  	for {
   379  		labelStmt, ok := stmt.(*ast.LabeledStmt)
   380  		if !ok {
   381  			return stmt, labeled
   382  		}
   383  		labeled = true
   384  		stmt = labelStmt.Stmt
   385  	}
   386  }
   387  
   388  // isMethodCall reports whether expr is a method call of
   389  // <pkgPath>.<typeName>.<method>.
   390  func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool {
   391  	call, ok := expr.(*ast.CallExpr)
   392  	if !ok {
   393  		return false
   394  	}
   395  
   396  	// Check that we are calling a method <method>
   397  	f := typeutil.StaticCallee(info, call)
   398  	if f == nil || f.Name() != method {
   399  		return false
   400  	}
   401  	recv := f.Type().(*types.Signature).Recv()
   402  	if recv == nil {
   403  		return false
   404  	}
   405  
   406  	// Check that the receiver is a <pkgPath>.<typeName> or
   407  	// *<pkgPath>.<typeName>.
   408  	rtype := recv.Type()
   409  	if ptr, ok := recv.Type().(*types.Pointer); ok {
   410  		rtype = ptr.Elem()
   411  	}
   412  	named, ok := rtype.(*types.Named)
   413  	if !ok {
   414  		return false
   415  	}
   416  	if named.Obj().Name() != typeName {
   417  		return false
   418  	}
   419  	pkg := f.Pkg()
   420  	if pkg == nil {
   421  		return false
   422  	}
   423  	if pkg.Path() != pkgPath {
   424  		return false
   425  	}
   426  
   427  	return true
   428  }
   429  

View as plain text