...
1
2
3
4
5
6
7 package sigchanyzer
8
9 import (
10 "bytes"
11 "go/ast"
12 "go/format"
13 "go/token"
14 "go/types"
15
16 "golang.org/x/tools/go/analysis"
17 "golang.org/x/tools/go/analysis/passes/inspect"
18 "golang.org/x/tools/go/ast/inspector"
19 )
20
21 const Doc = `check for unbuffered channel of os.Signal
22
23 This checker reports call expression of the form signal.Notify(c <-chan os.Signal, sig ...os.Signal),
24 where c is an unbuffered channel, which can be at risk of missing the signal.`
25
26
27 var Analyzer = &analysis.Analyzer{
28 Name: "sigchanyzer",
29 Doc: Doc,
30 Requires: []*analysis.Analyzer{inspect.Analyzer},
31 Run: run,
32 }
33
34 func run(pass *analysis.Pass) (interface{}, error) {
35 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
36
37 nodeFilter := []ast.Node{
38 (*ast.CallExpr)(nil),
39 }
40 inspect.Preorder(nodeFilter, func(n ast.Node) {
41 call := n.(*ast.CallExpr)
42 if !isSignalNotify(pass.TypesInfo, call) {
43 return
44 }
45 var chanDecl *ast.CallExpr
46 switch arg := call.Args[0].(type) {
47 case *ast.Ident:
48 if decl, ok := findDecl(arg).(*ast.CallExpr); ok {
49 chanDecl = decl
50 }
51 case *ast.CallExpr:
52
53
54 if isBuiltinMake(pass.TypesInfo, arg) {
55 return
56 }
57 chanDecl = arg
58 }
59 if chanDecl == nil || len(chanDecl.Args) != 1 {
60 return
61 }
62
63
64
65 chanDeclCopy := &ast.CallExpr{}
66 *chanDeclCopy = *chanDecl
67 chanDeclCopy.Args = append([]ast.Expr(nil), chanDecl.Args...)
68 chanDeclCopy.Args = append(chanDeclCopy.Args, &ast.BasicLit{
69 Kind: token.INT,
70 Value: "1",
71 })
72
73 var buf bytes.Buffer
74 if err := format.Node(&buf, token.NewFileSet(), chanDeclCopy); err != nil {
75 return
76 }
77 pass.Report(analysis.Diagnostic{
78 Pos: call.Pos(),
79 End: call.End(),
80 Message: "misuse of unbuffered os.Signal channel as argument to signal.Notify",
81 SuggestedFixes: []analysis.SuggestedFix{{
82 Message: "Change to buffer channel",
83 TextEdits: []analysis.TextEdit{{
84 Pos: chanDecl.Pos(),
85 End: chanDecl.End(),
86 NewText: buf.Bytes(),
87 }},
88 }},
89 })
90 })
91 return nil, nil
92 }
93
94 func isSignalNotify(info *types.Info, call *ast.CallExpr) bool {
95 check := func(id *ast.Ident) bool {
96 obj := info.ObjectOf(id)
97 return obj.Name() == "Notify" && obj.Pkg().Path() == "os/signal"
98 }
99 switch fun := call.Fun.(type) {
100 case *ast.SelectorExpr:
101 return check(fun.Sel)
102 case *ast.Ident:
103 if fun, ok := findDecl(fun).(*ast.SelectorExpr); ok {
104 return check(fun.Sel)
105 }
106 return false
107 default:
108 return false
109 }
110 }
111
112 func findDecl(arg *ast.Ident) ast.Node {
113 if arg.Obj == nil {
114 return nil
115 }
116 switch as := arg.Obj.Decl.(type) {
117 case *ast.AssignStmt:
118 if len(as.Lhs) != len(as.Rhs) {
119 return nil
120 }
121 for i, lhs := range as.Lhs {
122 lid, ok := lhs.(*ast.Ident)
123 if !ok {
124 continue
125 }
126 if lid.Obj == arg.Obj {
127 return as.Rhs[i]
128 }
129 }
130 case *ast.ValueSpec:
131 if len(as.Names) != len(as.Values) {
132 return nil
133 }
134 for i, name := range as.Names {
135 if name.Obj == arg.Obj {
136 return as.Values[i]
137 }
138 }
139 }
140 return nil
141 }
142
143 func isBuiltinMake(info *types.Info, call *ast.CallExpr) bool {
144 typVal := info.Types[call.Fun]
145 if !typVal.IsBuiltin() {
146 return false
147 }
148 switch fun := call.Fun.(type) {
149 case *ast.Ident:
150 return info.ObjectOf(fun).Name() == "make"
151 default:
152 return false
153 }
154 }
155
View as plain text