1
2
3
4
5 package eg
6
7
8
9
10
11 import (
12 "fmt"
13 "go/ast"
14 "go/token"
15 "go/types"
16 "os"
17 "reflect"
18 "sort"
19 "strconv"
20 "strings"
21
22 "golang.org/x/tools/go/ast/astutil"
23 )
24
25
26
27
28 func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
29
30 if !rv.IsValid() {
31 return reflect.Value{}, false, nil
32 }
33
34 rv, changed, newEnv := tr.apply(tr.transformItem, rv)
35
36 e := rvToExpr(rv)
37 if e == nil {
38 return rv, changed, newEnv
39 }
40
41 savedEnv := tr.env
42 tr.env = make(map[string]ast.Expr)
43
44 if tr.matchExpr(tr.before, e) {
45 if tr.verbose {
46 fmt.Fprintf(os.Stderr, "%s matches %s",
47 astString(tr.fset, tr.before), astString(tr.fset, e))
48 if len(tr.env) > 0 {
49 fmt.Fprintf(os.Stderr, " with:")
50 for name, ast := range tr.env {
51 fmt.Fprintf(os.Stderr, " %s->%s",
52 name, astString(tr.fset, ast))
53 }
54 }
55 fmt.Fprintf(os.Stderr, "\n")
56 }
57 tr.nsubsts++
58
59
60
61 rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
62 reflect.ValueOf(e.Pos()))
63 changed = true
64 newEnv = tr.env
65 }
66 tr.env = savedEnv
67
68 return rv, changed, newEnv
69 }
70
71
72
73
74
75
76
77
78
79
80 func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int {
81 if !tr.seenInfos[info] {
82 tr.seenInfos[info] = true
83 mergeTypeInfo(tr.info, info)
84 }
85 tr.currentPkg = pkg
86 tr.nsubsts = 0
87
88 if tr.verbose {
89 fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
90 fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
91 fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts)
92 }
93
94 o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file))
95 if changed {
96 panic("BUG")
97 }
98 file2 := o.Interface().(*ast.File)
99
100
101 if file != file2 {
102 panic("BUG")
103 }
104
105
106
107 if tr.nsubsts > 0 {
108 pkgs := make(map[string]*types.Package)
109 for obj := range tr.importedObjs {
110 pkgs[obj.Pkg().Path()] = obj.Pkg()
111 }
112
113 for _, imp := range file.Imports {
114 path, _ := strconv.Unquote(imp.Path.Value)
115 delete(pkgs, path)
116 }
117 delete(pkgs, pkg.Path())
118
119
120
121 var paths []string
122 for path := range pkgs {
123 paths = append(paths, path)
124 }
125 sort.Strings(paths)
126 for _, path := range paths {
127 astutil.AddImport(tr.fset, file, path)
128 }
129 }
130
131 tr.currentPkg = nil
132
133 return tr.nsubsts
134 }
135
136
137
138 func setValue(x, y reflect.Value) {
139
140 if !y.IsValid() {
141 return
142 }
143 defer func() {
144 if x := recover(); x != nil {
145 if s, ok := x.(string); ok &&
146 (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
147
148 return
149 }
150 panic(x)
151 }
152 }()
153 x.Set(y)
154 }
155
156
157 var (
158 objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
159 scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
160
161 identType = reflect.TypeOf((*ast.Ident)(nil))
162 selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
163 objectPtrType = reflect.TypeOf((*ast.Object)(nil))
164 statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem()
165 positionType = reflect.TypeOf(token.NoPos)
166 scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
167 )
168
169
170
171
172
173
174
175 func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
176 if !val.IsValid() {
177 return reflect.Value{}, false, nil
178 }
179
180
181
182 if val.Type() == objectPtrType {
183 return objectPtrNil, false, nil
184 }
185
186
187
188 if val.Type() == scopePtrType {
189 return scopePtrNil, false, nil
190 }
191
192 switch v := reflect.Indirect(val); v.Kind() {
193 case reflect.Slice:
194
195 if v.Type().Elem() != statementType {
196 changed := false
197 var envp map[string]ast.Expr
198 for i := 0; i < v.Len(); i++ {
199 e := v.Index(i)
200 o, localchanged, env := f(e)
201 if localchanged {
202 changed = true
203
204
205
206
207 envp = env
208 }
209 setValue(e, o)
210 }
211 return val, changed, envp
212 }
213
214
215 var out []ast.Stmt
216 for i := 0; i < v.Len(); i++ {
217 e := v.Index(i)
218 o, changed, env := f(e)
219 if changed {
220 for _, s := range tr.afterStmts {
221 t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface()
222 out = append(out, t.(ast.Stmt))
223 }
224 }
225 setValue(e, o)
226 out = append(out, e.Interface().(ast.Stmt))
227 }
228 return reflect.ValueOf(out), false, nil
229 case reflect.Struct:
230 changed := false
231 var envp map[string]ast.Expr
232 for i := 0; i < v.NumField(); i++ {
233 e := v.Field(i)
234 o, localchanged, env := f(e)
235 if localchanged {
236 changed = true
237 envp = env
238 }
239 setValue(e, o)
240 }
241 return val, changed, envp
242 case reflect.Interface:
243 e := v.Elem()
244 o, changed, env := f(e)
245 setValue(v, o)
246 return val, changed, env
247 }
248 return val, false, nil
249 }
250
251
252
253
254
255 func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value {
256 if !pattern.IsValid() {
257 return reflect.Value{}
258 }
259
260
261
262 if pattern.Type() == objectPtrType {
263 return objectPtrNil
264 }
265
266
267
268 if pattern.Type() == scopePtrType {
269 return scopePtrNil
270 }
271
272
273 if env != nil && pattern.Type() == identType {
274 id := pattern.Interface().(*ast.Ident)
275 if old, ok := env[id.Name]; ok {
276 return tr.subst(nil, reflect.ValueOf(old), reflect.Value{})
277 }
278 }
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296 if tr.importedObjs != nil && pattern.Type() == selectorExprType {
297 obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info)
298 if obj != nil {
299 if sel, ok := tr.importedObjs[obj]; ok {
300 var id ast.Expr
301 if obj.Pkg() == tr.currentPkg {
302 id = sel.Sel
303 } else {
304 id = sel
305 }
306
307
308 saved := tr.importedObjs
309 tr.importedObjs = nil
310 r := tr.subst(nil, reflect.ValueOf(id), pos)
311 tr.importedObjs = saved
312 return r
313 }
314 }
315 }
316
317 if pos.IsValid() && pattern.Type() == positionType {
318
319 if old := pattern.Interface().(token.Pos); !old.IsValid() {
320 return pattern
321 }
322 return pos
323 }
324
325
326 switch p := pattern; p.Kind() {
327 case reflect.Slice:
328 v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
329 for i := 0; i < p.Len(); i++ {
330 v.Index(i).Set(tr.subst(env, p.Index(i), pos))
331 }
332 return v
333
334 case reflect.Struct:
335 v := reflect.New(p.Type()).Elem()
336 for i := 0; i < p.NumField(); i++ {
337 v.Field(i).Set(tr.subst(env, p.Field(i), pos))
338 }
339 return v
340
341 case reflect.Ptr:
342 v := reflect.New(p.Type()).Elem()
343 if elem := p.Elem(); elem.IsValid() {
344 v.Set(tr.subst(env, elem, pos).Addr())
345 }
346
347
348
349
350 if e := rvToExpr(v); e != nil {
351 updateTypeInfo(tr.info, e, p.Interface().(ast.Expr))
352 }
353 return v
354
355 case reflect.Interface:
356 v := reflect.New(p.Type()).Elem()
357 if elem := p.Elem(); elem.IsValid() {
358 v.Set(tr.subst(env, elem, pos))
359 }
360 return v
361 }
362
363 return pattern
364 }
365
366
367
368 func rvToExpr(rv reflect.Value) ast.Expr {
369 if rv.CanInterface() {
370 if e, ok := rv.Interface().(ast.Expr); ok {
371 return e
372 }
373 }
374 return nil
375 }
376
377
378
379 func updateTypeInfo(info *types.Info, new, old ast.Expr) {
380 switch new := new.(type) {
381 case *ast.Ident:
382 orig := old.(*ast.Ident)
383 if obj, ok := info.Defs[orig]; ok {
384 info.Defs[new] = obj
385 }
386 if obj, ok := info.Uses[orig]; ok {
387 info.Uses[new] = obj
388 }
389
390 case *ast.SelectorExpr:
391 orig := old.(*ast.SelectorExpr)
392 if sel, ok := info.Selections[orig]; ok {
393 info.Selections[new] = sel
394 }
395 }
396
397 if tv, ok := info.Types[old]; ok {
398 info.Types[new] = tv
399 }
400 }
401
View as plain text