1
2
3
4
5
6
7 package loopclosure
8
9 import (
10 "go/ast"
11 "go/types"
12
13 "golang.org/x/tools/go/analysis"
14 "golang.org/x/tools/go/analysis/passes/inspect"
15 "golang.org/x/tools/go/ast/inspector"
16 "golang.org/x/tools/go/types/typeutil"
17 )
18
19 const Doc = `check references to loop variables from within nested functions
20
21 This analyzer reports places where a function literal references the
22 iteration variable of an enclosing loop, and the loop calls the function
23 in such a way (e.g. with go or defer) that it may outlive the loop
24 iteration and possibly observe the wrong value of the variable.
25
26 In this example, all the deferred functions run after the loop has
27 completed, so all observe the final value of v.
28
29 for _, v := range list {
30 defer func() {
31 use(v) // incorrect
32 }()
33 }
34
35 One fix is to create a new variable for each iteration of the loop:
36
37 for _, v := range list {
38 v := v // new var per iteration
39 defer func() {
40 use(v) // ok
41 }()
42 }
43
44 The next example uses a go statement and has a similar problem.
45 In addition, it has a data race because the loop updates v
46 concurrent with the goroutines accessing it.
47
48 for _, v := range elem {
49 go func() {
50 use(v) // incorrect, and a data race
51 }()
52 }
53
54 A fix is the same as before. The checker also reports problems
55 in goroutines started by golang.org/x/sync/errgroup.Group.
56 A hard-to-spot variant of this form is common in parallel tests:
57
58 func Test(t *testing.T) {
59 for _, test := range tests {
60 t.Run(test.name, func(t *testing.T) {
61 t.Parallel()
62 use(test) // incorrect, and a data race
63 })
64 }
65 }
66
67 The t.Parallel() call causes the rest of the function to execute
68 concurrent with the loop.
69
70 The analyzer reports references only in the last statement,
71 as it is not deep enough to understand the effects of subsequent
72 statements that might render the reference benign.
73 ("Last statement" is defined recursively in compound
74 statements such as if, switch, and select.)
75
76 See: https://golang.org/doc/go_faq.html#closures_and_goroutines`
77
78 var Analyzer = &analysis.Analyzer{
79 Name: "loopclosure",
80 Doc: Doc,
81 Requires: []*analysis.Analyzer{inspect.Analyzer},
82 Run: run,
83 }
84
85 func run(pass *analysis.Pass) (interface{}, error) {
86 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
87
88 nodeFilter := []ast.Node{
89 (*ast.RangeStmt)(nil),
90 (*ast.ForStmt)(nil),
91 }
92 inspect.Preorder(nodeFilter, func(n ast.Node) {
93
94 var vars []types.Object
95 addVar := func(expr ast.Expr) {
96 if id, _ := expr.(*ast.Ident); id != nil {
97 if obj := pass.TypesInfo.ObjectOf(id); obj != nil {
98 vars = append(vars, obj)
99 }
100 }
101 }
102 var body *ast.BlockStmt
103 switch n := n.(type) {
104 case *ast.RangeStmt:
105 body = n.Body
106 addVar(n.Key)
107 addVar(n.Value)
108 case *ast.ForStmt:
109 body = n.Body
110 switch post := n.Post.(type) {
111 case *ast.AssignStmt:
112
113 for _, lhs := range post.Lhs {
114 addVar(lhs)
115 }
116 case *ast.IncDecStmt:
117
118 addVar(post.X)
119 }
120 }
121 if vars == nil {
122 return
123 }
124
125
126
127
128
129
130
131
132
133
134
135
136
137 forEachLastStmt(body.List, func(last ast.Stmt) {
138 var stmts []ast.Stmt
139 switch s := last.(type) {
140 case *ast.GoStmt:
141 stmts = litStmts(s.Call.Fun)
142 case *ast.DeferStmt:
143 stmts = litStmts(s.Call.Fun)
144 case *ast.ExprStmt:
145 if call, ok := s.X.(*ast.CallExpr); ok {
146 stmts = litStmts(goInvoke(pass.TypesInfo, call))
147 }
148 }
149 for _, stmt := range stmts {
150 reportCaptured(pass, vars, stmt)
151 }
152 })
153
154
155
156
157
158
159
160 for _, s := range body.List {
161 switch s := s.(type) {
162 case *ast.ExprStmt:
163 if call, ok := s.X.(*ast.CallExpr); ok {
164 for _, stmt := range parallelSubtest(pass.TypesInfo, call) {
165 reportCaptured(pass, vars, stmt)
166 }
167
168 }
169 }
170 }
171 })
172 return nil, nil
173 }
174
175
176
177
178
179 func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) {
180 ast.Inspect(checkStmt, func(n ast.Node) bool {
181 id, ok := n.(*ast.Ident)
182 if !ok {
183 return true
184 }
185 obj := pass.TypesInfo.Uses[id]
186 if obj == nil {
187 return true
188 }
189 for _, v := range vars {
190 if v == obj {
191 pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
192 }
193 }
194 return true
195 })
196 }
197
198
199
200
201
202 func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) {
203 if len(stmts) == 0 {
204 return
205 }
206
207 s := stmts[len(stmts)-1]
208 switch s := s.(type) {
209 case *ast.IfStmt:
210 loop:
211 for {
212 forEachLastStmt(s.Body.List, onLast)
213 switch e := s.Else.(type) {
214 case *ast.BlockStmt:
215 forEachLastStmt(e.List, onLast)
216 break loop
217 case *ast.IfStmt:
218 s = e
219 case nil:
220 break loop
221 }
222 }
223 case *ast.ForStmt:
224 forEachLastStmt(s.Body.List, onLast)
225 case *ast.RangeStmt:
226 forEachLastStmt(s.Body.List, onLast)
227 case *ast.SwitchStmt:
228 for _, c := range s.Body.List {
229 cc := c.(*ast.CaseClause)
230 forEachLastStmt(cc.Body, onLast)
231 }
232 case *ast.TypeSwitchStmt:
233 for _, c := range s.Body.List {
234 cc := c.(*ast.CaseClause)
235 forEachLastStmt(cc.Body, onLast)
236 }
237 case *ast.SelectStmt:
238 for _, c := range s.Body.List {
239 cc := c.(*ast.CommClause)
240 forEachLastStmt(cc.Body, onLast)
241 }
242 default:
243 onLast(s)
244 }
245 }
246
247
248
249
250
251 func litStmts(fun ast.Expr) []ast.Stmt {
252 lit, _ := fun.(*ast.FuncLit)
253 if lit == nil {
254 return nil
255 }
256 return lit.Body.List
257 }
258
259
260
261
262
263
264
265
266
267
268 func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr {
269 if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") {
270 return nil
271 }
272 return call.Args[0]
273 }
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301 func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt {
302 if !isMethodCall(info, call, "testing", "T", "Run") {
303 return nil
304 }
305
306 lit, _ := call.Args[1].(*ast.FuncLit)
307 if lit == nil {
308 return nil
309 }
310
311
312
313 if len(lit.Type.Params.List[0].Names) == 0 {
314 return nil
315 }
316
317 tObj := info.Defs[lit.Type.Params.List[0].Names[0]]
318 if tObj == nil {
319 return nil
320 }
321
322
323
324
325
326
327
328 var stmts []ast.Stmt
329 afterParallel := false
330 for _, stmt := range lit.Body.List {
331 stmt, labeled := unlabel(stmt)
332 if labeled {
333
334
335 stmts = nil
336 afterParallel = false
337 }
338
339 if afterParallel {
340 stmts = append(stmts, stmt)
341 continue
342 }
343
344
345 exprStmt, ok := stmt.(*ast.ExprStmt)
346 if !ok {
347 continue
348 }
349 expr := exprStmt.X
350 if isMethodCall(info, expr, "testing", "T", "Parallel") {
351 call, _ := expr.(*ast.CallExpr)
352 if call == nil {
353 continue
354 }
355 x, _ := call.Fun.(*ast.SelectorExpr)
356 if x == nil {
357 continue
358 }
359 id, _ := x.X.(*ast.Ident)
360 if id == nil {
361 continue
362 }
363 if info.Uses[id] == tObj {
364 afterParallel = true
365 }
366 }
367 }
368
369 return stmts
370 }
371
372
373
374
375
376 func unlabel(stmt ast.Stmt) (ast.Stmt, bool) {
377 labeled := false
378 for {
379 labelStmt, ok := stmt.(*ast.LabeledStmt)
380 if !ok {
381 return stmt, labeled
382 }
383 labeled = true
384 stmt = labelStmt.Stmt
385 }
386 }
387
388
389
390 func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool {
391 call, ok := expr.(*ast.CallExpr)
392 if !ok {
393 return false
394 }
395
396
397 f := typeutil.StaticCallee(info, call)
398 if f == nil || f.Name() != method {
399 return false
400 }
401 recv := f.Type().(*types.Signature).Recv()
402 if recv == nil {
403 return false
404 }
405
406
407
408 rtype := recv.Type()
409 if ptr, ok := recv.Type().(*types.Pointer); ok {
410 rtype = ptr.Elem()
411 }
412 named, ok := rtype.(*types.Named)
413 if !ok {
414 return false
415 }
416 if named.Obj().Name() != typeName {
417 return false
418 }
419 pkg := f.Pkg()
420 if pkg == nil {
421 return false
422 }
423 if pkg.Path() != pkgPath {
424 return false
425 }
426
427 return true
428 }
429
View as plain text