...
1
2
3
4
5 package vta
6
7 import (
8 "go/types"
9
10 "golang.org/x/tools/go/callgraph"
11 "golang.org/x/tools/go/ssa"
12 "golang.org/x/tools/internal/typeparams"
13 )
14
15 func canAlias(n1, n2 node) bool {
16 return isReferenceNode(n1) && isReferenceNode(n2)
17 }
18
19 func isReferenceNode(n node) bool {
20 if _, ok := n.(nestedPtrInterface); ok {
21 return true
22 }
23 if _, ok := n.(nestedPtrFunction); ok {
24 return true
25 }
26
27 if _, ok := n.Type().(*types.Pointer); ok {
28 return true
29 }
30
31 return false
32 }
33
34
35
36
37
38
39
40
41
42
43 func hasInFlow(n node) bool {
44 if _, ok := n.(panicArg); ok {
45 return true
46 }
47 if _, ok := n.(recoverReturn); ok {
48 return true
49 }
50
51 t := n.Type()
52
53 if i := interfaceUnderPtr(t); i != nil {
54 return true
55 }
56 if f := functionUnderPtr(t); f != nil {
57 return true
58 }
59
60 return types.IsInterface(t) || isFunction(t)
61 }
62
63 func isFunction(t types.Type) bool {
64 _, ok := t.Underlying().(*types.Signature)
65 return ok
66 }
67
68
69
70
71 func interfaceUnderPtr(t types.Type) types.Type {
72 seen := make(map[types.Type]bool)
73 var visit func(types.Type) types.Type
74 visit = func(t types.Type) types.Type {
75 if seen[t] {
76 return nil
77 }
78 seen[t] = true
79
80 p, ok := t.Underlying().(*types.Pointer)
81 if !ok {
82 return nil
83 }
84
85 if types.IsInterface(p.Elem()) {
86 return p.Elem()
87 }
88
89 return visit(p.Elem())
90 }
91 return visit(t)
92 }
93
94
95
96
97 func functionUnderPtr(t types.Type) types.Type {
98 seen := make(map[types.Type]bool)
99 var visit func(types.Type) types.Type
100 visit = func(t types.Type) types.Type {
101 if seen[t] {
102 return nil
103 }
104 seen[t] = true
105
106 p, ok := t.Underlying().(*types.Pointer)
107 if !ok {
108 return nil
109 }
110
111 if isFunction(p.Elem()) {
112 return p.Elem()
113 }
114
115 return visit(p.Elem())
116 }
117 return visit(t)
118 }
119
120
121
122
123 func sliceArrayElem(t types.Type) types.Type {
124 switch u := t.Underlying().(type) {
125 case *types.Pointer:
126 return u.Elem().Underlying().(*types.Array).Elem()
127 case *types.Array:
128 return u.Elem()
129 case *types.Slice:
130 return u.Elem()
131 case *types.Basic:
132 return types.Typ[types.Byte]
133 case *types.Interface:
134 terms, err := typeparams.InterfaceTermSet(u)
135 if err != nil || len(terms) == 0 {
136 panic(t)
137 }
138 return sliceArrayElem(terms[0].Type())
139 default:
140 panic(t)
141 }
142 }
143
144
145 func siteCallees(c ssa.CallInstruction, callgraph *callgraph.Graph) []*ssa.Function {
146 var matches []*ssa.Function
147
148 node := callgraph.Nodes[c.Parent()]
149 if node == nil {
150 return nil
151 }
152
153 for _, edge := range node.Out {
154 if edge.Site == c {
155 matches = append(matches, edge.Callee.Func)
156 }
157 }
158 return matches
159 }
160
161 func canHaveMethods(t types.Type) bool {
162 if _, ok := t.(*types.Named); ok {
163 return true
164 }
165
166 u := t.Underlying()
167 switch u.(type) {
168 case *types.Interface, *types.Signature, *types.Struct:
169 return true
170 default:
171 return false
172 }
173 }
174
175
176 func calls(f *ssa.Function) []ssa.CallInstruction {
177 var calls []ssa.CallInstruction
178 for _, bl := range f.Blocks {
179 for _, instr := range bl.Instrs {
180 if c, ok := instr.(ssa.CallInstruction); ok {
181 calls = append(calls, c)
182 }
183 }
184 }
185 return calls
186 }
187
188
189 func intersect(fs1, fs2 []*ssa.Function) []*ssa.Function {
190 m := make(map[*ssa.Function]bool)
191 for _, f := range fs1 {
192 m[f] = true
193 }
194
195 var res []*ssa.Function
196 for _, f := range fs2 {
197 if m[f] {
198 res = append(res, f)
199 }
200 }
201 return res
202 }
203
View as plain text