1
2
3
4
5 package astutil
6
7 import (
8 "fmt"
9 "go/ast"
10 "reflect"
11 "sort"
12
13 "golang.org/x/tools/internal/typeparams"
14 )
15
16
17
18
19
20
21
22 type ApplyFunc func(*Cursor) bool
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44 func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
45 parent := &struct{ ast.Node }{root}
46 defer func() {
47 if r := recover(); r != nil && r != abort {
48 panic(r)
49 }
50 result = parent.Node
51 }()
52 a := &application{pre: pre, post: post}
53 a.apply(parent, "Node", nil, root)
54 return
55 }
56
57 var abort = new(int)
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72 type Cursor struct {
73 parent ast.Node
74 name string
75 iter *iterator
76 node ast.Node
77 }
78
79
80 func (c *Cursor) Node() ast.Node { return c.node }
81
82
83 func (c *Cursor) Parent() ast.Node { return c.parent }
84
85
86
87
88 func (c *Cursor) Name() string { return c.name }
89
90
91
92
93
94 func (c *Cursor) Index() int {
95 if c.iter != nil {
96 return c.iter.index
97 }
98 return -1
99 }
100
101
102 func (c *Cursor) field() reflect.Value {
103 return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
104 }
105
106
107
108 func (c *Cursor) Replace(n ast.Node) {
109 if _, ok := c.node.(*ast.File); ok {
110 file, ok := n.(*ast.File)
111 if !ok {
112 panic("attempt to replace *ast.File with non-*ast.File")
113 }
114 c.parent.(*ast.Package).Files[c.name] = file
115 return
116 }
117
118 v := c.field()
119 if i := c.Index(); i >= 0 {
120 v = v.Index(i)
121 }
122 v.Set(reflect.ValueOf(n))
123 }
124
125
126
127
128
129 func (c *Cursor) Delete() {
130 if _, ok := c.node.(*ast.File); ok {
131 delete(c.parent.(*ast.Package).Files, c.name)
132 return
133 }
134
135 i := c.Index()
136 if i < 0 {
137 panic("Delete node not contained in slice")
138 }
139 v := c.field()
140 l := v.Len()
141 reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
142 v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
143 v.SetLen(l - 1)
144 c.iter.step--
145 }
146
147
148
149
150 func (c *Cursor) InsertAfter(n ast.Node) {
151 i := c.Index()
152 if i < 0 {
153 panic("InsertAfter node not contained in slice")
154 }
155 v := c.field()
156 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
157 l := v.Len()
158 reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
159 v.Index(i + 1).Set(reflect.ValueOf(n))
160 c.iter.step++
161 }
162
163
164
165
166 func (c *Cursor) InsertBefore(n ast.Node) {
167 i := c.Index()
168 if i < 0 {
169 panic("InsertBefore node not contained in slice")
170 }
171 v := c.field()
172 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
173 l := v.Len()
174 reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
175 v.Index(i).Set(reflect.ValueOf(n))
176 c.iter.index++
177 }
178
179
180 type application struct {
181 pre, post ApplyFunc
182 cursor Cursor
183 iter iterator
184 }
185
186 func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
187
188 if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
189 n = nil
190 }
191
192
193 saved := a.cursor
194 a.cursor.parent = parent
195 a.cursor.name = name
196 a.cursor.iter = iter
197 a.cursor.node = n
198
199 if a.pre != nil && !a.pre(&a.cursor) {
200 a.cursor = saved
201 return
202 }
203
204
205
206 switch n := n.(type) {
207 case nil:
208
209
210
211 case *ast.Comment:
212
213
214 case *ast.CommentGroup:
215 if n != nil {
216 a.applyList(n, "List")
217 }
218
219 case *ast.Field:
220 a.apply(n, "Doc", nil, n.Doc)
221 a.applyList(n, "Names")
222 a.apply(n, "Type", nil, n.Type)
223 a.apply(n, "Tag", nil, n.Tag)
224 a.apply(n, "Comment", nil, n.Comment)
225
226 case *ast.FieldList:
227 a.applyList(n, "List")
228
229
230 case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
231
232
233 case *ast.Ellipsis:
234 a.apply(n, "Elt", nil, n.Elt)
235
236 case *ast.FuncLit:
237 a.apply(n, "Type", nil, n.Type)
238 a.apply(n, "Body", nil, n.Body)
239
240 case *ast.CompositeLit:
241 a.apply(n, "Type", nil, n.Type)
242 a.applyList(n, "Elts")
243
244 case *ast.ParenExpr:
245 a.apply(n, "X", nil, n.X)
246
247 case *ast.SelectorExpr:
248 a.apply(n, "X", nil, n.X)
249 a.apply(n, "Sel", nil, n.Sel)
250
251 case *ast.IndexExpr:
252 a.apply(n, "X", nil, n.X)
253 a.apply(n, "Index", nil, n.Index)
254
255 case *typeparams.IndexListExpr:
256 a.apply(n, "X", nil, n.X)
257 a.applyList(n, "Indices")
258
259 case *ast.SliceExpr:
260 a.apply(n, "X", nil, n.X)
261 a.apply(n, "Low", nil, n.Low)
262 a.apply(n, "High", nil, n.High)
263 a.apply(n, "Max", nil, n.Max)
264
265 case *ast.TypeAssertExpr:
266 a.apply(n, "X", nil, n.X)
267 a.apply(n, "Type", nil, n.Type)
268
269 case *ast.CallExpr:
270 a.apply(n, "Fun", nil, n.Fun)
271 a.applyList(n, "Args")
272
273 case *ast.StarExpr:
274 a.apply(n, "X", nil, n.X)
275
276 case *ast.UnaryExpr:
277 a.apply(n, "X", nil, n.X)
278
279 case *ast.BinaryExpr:
280 a.apply(n, "X", nil, n.X)
281 a.apply(n, "Y", nil, n.Y)
282
283 case *ast.KeyValueExpr:
284 a.apply(n, "Key", nil, n.Key)
285 a.apply(n, "Value", nil, n.Value)
286
287
288 case *ast.ArrayType:
289 a.apply(n, "Len", nil, n.Len)
290 a.apply(n, "Elt", nil, n.Elt)
291
292 case *ast.StructType:
293 a.apply(n, "Fields", nil, n.Fields)
294
295 case *ast.FuncType:
296 if tparams := typeparams.ForFuncType(n); tparams != nil {
297 a.apply(n, "TypeParams", nil, tparams)
298 }
299 a.apply(n, "Params", nil, n.Params)
300 a.apply(n, "Results", nil, n.Results)
301
302 case *ast.InterfaceType:
303 a.apply(n, "Methods", nil, n.Methods)
304
305 case *ast.MapType:
306 a.apply(n, "Key", nil, n.Key)
307 a.apply(n, "Value", nil, n.Value)
308
309 case *ast.ChanType:
310 a.apply(n, "Value", nil, n.Value)
311
312
313 case *ast.BadStmt:
314
315
316 case *ast.DeclStmt:
317 a.apply(n, "Decl", nil, n.Decl)
318
319 case *ast.EmptyStmt:
320
321
322 case *ast.LabeledStmt:
323 a.apply(n, "Label", nil, n.Label)
324 a.apply(n, "Stmt", nil, n.Stmt)
325
326 case *ast.ExprStmt:
327 a.apply(n, "X", nil, n.X)
328
329 case *ast.SendStmt:
330 a.apply(n, "Chan", nil, n.Chan)
331 a.apply(n, "Value", nil, n.Value)
332
333 case *ast.IncDecStmt:
334 a.apply(n, "X", nil, n.X)
335
336 case *ast.AssignStmt:
337 a.applyList(n, "Lhs")
338 a.applyList(n, "Rhs")
339
340 case *ast.GoStmt:
341 a.apply(n, "Call", nil, n.Call)
342
343 case *ast.DeferStmt:
344 a.apply(n, "Call", nil, n.Call)
345
346 case *ast.ReturnStmt:
347 a.applyList(n, "Results")
348
349 case *ast.BranchStmt:
350 a.apply(n, "Label", nil, n.Label)
351
352 case *ast.BlockStmt:
353 a.applyList(n, "List")
354
355 case *ast.IfStmt:
356 a.apply(n, "Init", nil, n.Init)
357 a.apply(n, "Cond", nil, n.Cond)
358 a.apply(n, "Body", nil, n.Body)
359 a.apply(n, "Else", nil, n.Else)
360
361 case *ast.CaseClause:
362 a.applyList(n, "List")
363 a.applyList(n, "Body")
364
365 case *ast.SwitchStmt:
366 a.apply(n, "Init", nil, n.Init)
367 a.apply(n, "Tag", nil, n.Tag)
368 a.apply(n, "Body", nil, n.Body)
369
370 case *ast.TypeSwitchStmt:
371 a.apply(n, "Init", nil, n.Init)
372 a.apply(n, "Assign", nil, n.Assign)
373 a.apply(n, "Body", nil, n.Body)
374
375 case *ast.CommClause:
376 a.apply(n, "Comm", nil, n.Comm)
377 a.applyList(n, "Body")
378
379 case *ast.SelectStmt:
380 a.apply(n, "Body", nil, n.Body)
381
382 case *ast.ForStmt:
383 a.apply(n, "Init", nil, n.Init)
384 a.apply(n, "Cond", nil, n.Cond)
385 a.apply(n, "Post", nil, n.Post)
386 a.apply(n, "Body", nil, n.Body)
387
388 case *ast.RangeStmt:
389 a.apply(n, "Key", nil, n.Key)
390 a.apply(n, "Value", nil, n.Value)
391 a.apply(n, "X", nil, n.X)
392 a.apply(n, "Body", nil, n.Body)
393
394
395 case *ast.ImportSpec:
396 a.apply(n, "Doc", nil, n.Doc)
397 a.apply(n, "Name", nil, n.Name)
398 a.apply(n, "Path", nil, n.Path)
399 a.apply(n, "Comment", nil, n.Comment)
400
401 case *ast.ValueSpec:
402 a.apply(n, "Doc", nil, n.Doc)
403 a.applyList(n, "Names")
404 a.apply(n, "Type", nil, n.Type)
405 a.applyList(n, "Values")
406 a.apply(n, "Comment", nil, n.Comment)
407
408 case *ast.TypeSpec:
409 a.apply(n, "Doc", nil, n.Doc)
410 a.apply(n, "Name", nil, n.Name)
411 if tparams := typeparams.ForTypeSpec(n); tparams != nil {
412 a.apply(n, "TypeParams", nil, tparams)
413 }
414 a.apply(n, "Type", nil, n.Type)
415 a.apply(n, "Comment", nil, n.Comment)
416
417 case *ast.BadDecl:
418
419
420 case *ast.GenDecl:
421 a.apply(n, "Doc", nil, n.Doc)
422 a.applyList(n, "Specs")
423
424 case *ast.FuncDecl:
425 a.apply(n, "Doc", nil, n.Doc)
426 a.apply(n, "Recv", nil, n.Recv)
427 a.apply(n, "Name", nil, n.Name)
428 a.apply(n, "Type", nil, n.Type)
429 a.apply(n, "Body", nil, n.Body)
430
431
432 case *ast.File:
433 a.apply(n, "Doc", nil, n.Doc)
434 a.apply(n, "Name", nil, n.Name)
435 a.applyList(n, "Decls")
436
437
438
439 case *ast.Package:
440
441 var names []string
442 for name := range n.Files {
443 names = append(names, name)
444 }
445 sort.Strings(names)
446 for _, name := range names {
447 a.apply(n, name, nil, n.Files[name])
448 }
449
450 default:
451 panic(fmt.Sprintf("Apply: unexpected node type %T", n))
452 }
453
454 if a.post != nil && !a.post(&a.cursor) {
455 panic(abort)
456 }
457
458 a.cursor = saved
459 }
460
461
462 type iterator struct {
463 index, step int
464 }
465
466 func (a *application) applyList(parent ast.Node, name string) {
467
468 saved := a.iter
469 a.iter.index = 0
470 for {
471
472 v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
473 if a.iter.index >= v.Len() {
474 break
475 }
476
477
478 var x ast.Node
479 if e := v.Index(a.iter.index); e.IsValid() {
480 x = e.Interface().(ast.Node)
481 }
482
483 a.iter.step = 1
484 a.apply(parent, name, &a.iter, x)
485 a.iter.index += a.iter.step
486 }
487 a.iter = saved
488 }
489
View as plain text