1
2
3
4
5 package ssautil
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 import (
22 "bytes"
23 "fmt"
24 "go/token"
25 "go/types"
26
27 "golang.org/x/tools/go/ssa"
28 )
29
30
31
32 type ConstCase struct {
33 Block *ssa.BasicBlock
34 Body *ssa.BasicBlock
35 Value *ssa.Const
36 }
37
38
39
40 type TypeCase struct {
41 Block *ssa.BasicBlock
42 Body *ssa.BasicBlock
43 Type types.Type
44 Binding ssa.Value
45 }
46
47
48
49
50
51
52
53
54
55
56
57
58 type Switch struct {
59 Start *ssa.BasicBlock
60 X ssa.Value
61 ConstCases []ConstCase
62 TypeCases []TypeCase
63 Default *ssa.BasicBlock
64 }
65
66 func (sw *Switch) String() string {
67
68
69 var buf bytes.Buffer
70 if sw.ConstCases != nil {
71 fmt.Fprintf(&buf, "switch %s {\n", sw.X.Name())
72 for _, c := range sw.ConstCases {
73 fmt.Fprintf(&buf, "case %s: %s\n", c.Value, c.Body.Instrs[0])
74 }
75 } else {
76 fmt.Fprintf(&buf, "switch %s.(type) {\n", sw.X.Name())
77 for _, c := range sw.TypeCases {
78 fmt.Fprintf(&buf, "case %s %s: %s\n",
79 c.Binding.Name(), c.Type, c.Body.Instrs[0])
80 }
81 }
82 if sw.Default != nil {
83 fmt.Fprintf(&buf, "default: %s\n", sw.Default.Instrs[0])
84 }
85 fmt.Fprintf(&buf, "}")
86 return buf.String()
87 }
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105 func Switches(fn *ssa.Function) []Switch {
106
107
108 var switches []Switch
109 seen := make(map[*ssa.BasicBlock]bool)
110 for _, b := range fn.DomPreorder() {
111 if x, k := isComparisonBlock(b); x != nil {
112
113 sw := Switch{Start: b, X: x}
114 valueSwitch(&sw, k, seen)
115 if len(sw.ConstCases) > 1 {
116 switches = append(switches, sw)
117 }
118 }
119
120 if y, x, T := isTypeAssertBlock(b); y != nil {
121
122 sw := Switch{Start: b, X: x}
123 typeSwitch(&sw, y, T, seen)
124 if len(sw.TypeCases) > 1 {
125 switches = append(switches, sw)
126 }
127 }
128 }
129 return switches
130 }
131
132 func valueSwitch(sw *Switch, k *ssa.Const, seen map[*ssa.BasicBlock]bool) {
133 b := sw.Start
134 x := sw.X
135 for x == sw.X {
136 if seen[b] {
137 break
138 }
139 seen[b] = true
140
141 sw.ConstCases = append(sw.ConstCases, ConstCase{
142 Block: b,
143 Body: b.Succs[0],
144 Value: k,
145 })
146 b = b.Succs[1]
147 if len(b.Instrs) > 2 {
148
149
150
151 break
152 }
153 if len(b.Preds) != 1 {
154
155
156 break
157 }
158 x, k = isComparisonBlock(b)
159 }
160 sw.Default = b
161 }
162
163 func typeSwitch(sw *Switch, y ssa.Value, T types.Type, seen map[*ssa.BasicBlock]bool) {
164 b := sw.Start
165 x := sw.X
166 for x == sw.X {
167 if seen[b] {
168 break
169 }
170 seen[b] = true
171
172 sw.TypeCases = append(sw.TypeCases, TypeCase{
173 Block: b,
174 Body: b.Succs[0],
175 Type: T,
176 Binding: y,
177 })
178 b = b.Succs[1]
179 if len(b.Instrs) > 4 {
180
181
182
183
184 break
185 }
186 if len(b.Preds) != 1 {
187
188
189 break
190 }
191 y, x, T = isTypeAssertBlock(b)
192 }
193 sw.Default = b
194 }
195
196
197
198 func isComparisonBlock(b *ssa.BasicBlock) (v ssa.Value, k *ssa.Const) {
199 if n := len(b.Instrs); n >= 2 {
200 if i, ok := b.Instrs[n-1].(*ssa.If); ok {
201 if binop, ok := i.Cond.(*ssa.BinOp); ok && binop.Block() == b && binop.Op == token.EQL {
202 if k, ok := binop.Y.(*ssa.Const); ok {
203 return binop.X, k
204 }
205 if k, ok := binop.X.(*ssa.Const); ok {
206 return binop.Y, k
207 }
208 }
209 }
210 }
211 return
212 }
213
214
215
216 func isTypeAssertBlock(b *ssa.BasicBlock) (y, x ssa.Value, T types.Type) {
217 if n := len(b.Instrs); n >= 4 {
218 if i, ok := b.Instrs[n-1].(*ssa.If); ok {
219 if ext1, ok := i.Cond.(*ssa.Extract); ok && ext1.Block() == b && ext1.Index == 1 {
220 if ta, ok := ext1.Tuple.(*ssa.TypeAssert); ok && ta.Block() == b {
221
222 if ext0, ok := b.Instrs[n-3].(*ssa.Extract); ok {
223 return ext0, ta.X, ta.AssertedType
224 }
225 }
226 }
227 }
228 }
229 return
230 }
231
View as plain text