...
1
2
3
4
5
6
7 package deepequalerrors
8
9 import (
10 "go/ast"
11 "go/types"
12
13 "golang.org/x/tools/go/analysis"
14 "golang.org/x/tools/go/analysis/passes/inspect"
15 "golang.org/x/tools/go/ast/inspector"
16 "golang.org/x/tools/go/types/typeutil"
17 )
18
19 const Doc = `check for calls of reflect.DeepEqual on error values
20
21 The deepequalerrors checker looks for calls of the form:
22
23 reflect.DeepEqual(err1, err2)
24
25 where err1 and err2 are errors. Using reflect.DeepEqual to compare
26 errors is discouraged.`
27
28 var Analyzer = &analysis.Analyzer{
29 Name: "deepequalerrors",
30 Doc: Doc,
31 Requires: []*analysis.Analyzer{inspect.Analyzer},
32 Run: run,
33 }
34
35 func run(pass *analysis.Pass) (interface{}, error) {
36 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
37
38 nodeFilter := []ast.Node{
39 (*ast.CallExpr)(nil),
40 }
41 inspect.Preorder(nodeFilter, func(n ast.Node) {
42 call := n.(*ast.CallExpr)
43 fn, ok := typeutil.Callee(pass.TypesInfo, call).(*types.Func)
44 if !ok {
45 return
46 }
47 if fn.FullName() == "reflect.DeepEqual" && hasError(pass, call.Args[0]) && hasError(pass, call.Args[1]) {
48 pass.ReportRangef(call, "avoid using reflect.DeepEqual with errors")
49 }
50 })
51 return nil, nil
52 }
53
54 var errorType = types.Universe.Lookup("error").Type().Underlying().(*types.Interface)
55
56
57
58 func hasError(pass *analysis.Pass, e ast.Expr) bool {
59 tv, ok := pass.TypesInfo.Types[e]
60 if !ok {
61 return false
62 }
63 return containsError(tv.Type)
64 }
65
66
67
68
69
70
71 func containsError(typ types.Type) bool {
72
73
74
75 inProgress := make(map[types.Type]bool)
76
77 var check func(t types.Type) bool
78 check = func(t types.Type) bool {
79 if t == errorType {
80 return true
81 }
82 if inProgress[t] {
83 return false
84 }
85 inProgress[t] = true
86 switch t := t.(type) {
87 case *types.Pointer:
88 return check(t.Elem())
89 case *types.Slice:
90 return check(t.Elem())
91 case *types.Array:
92 return check(t.Elem())
93 case *types.Map:
94 return check(t.Key()) || check(t.Elem())
95 case *types.Struct:
96 for i := 0; i < t.NumFields(); i++ {
97 if check(t.Field(i).Type()) {
98 return true
99 }
100 }
101 case *types.Named:
102 return check(t.Underlying())
103
104
105 case *types.Basic:
106 case *types.Chan:
107 case *types.Signature:
108 case *types.Tuple:
109 case *types.Interface:
110 }
111 return false
112 }
113
114 return check(typ)
115 }
116
View as plain text