1
2
3
4
5
6
7
8
9
10
11
12
13
14 package satisfy
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40 import (
41 "fmt"
42 "go/ast"
43 "go/token"
44 "go/types"
45
46 "golang.org/x/tools/go/ast/astutil"
47 "golang.org/x/tools/go/types/typeutil"
48 "golang.org/x/tools/internal/typeparams"
49 )
50
51
52
53
54
55
56
57 type Constraint struct {
58 LHS, RHS types.Type
59 }
60
61
62
63
64
65
66
67
68
69 type Finder struct {
70 Result map[Constraint]bool
71 msetcache typeutil.MethodSetCache
72
73
74 info *types.Info
75 sig *types.Signature
76 }
77
78
79
80
81
82
83
84
85
86
87 func (f *Finder) Find(info *types.Info, files []*ast.File) {
88 if f.Result == nil {
89 f.Result = make(map[Constraint]bool)
90 }
91
92 f.info = info
93 for _, file := range files {
94 for _, d := range file.Decls {
95 switch d := d.(type) {
96 case *ast.GenDecl:
97 if d.Tok == token.VAR {
98 for _, spec := range d.Specs {
99 f.valueSpec(spec.(*ast.ValueSpec))
100 }
101 }
102
103 case *ast.FuncDecl:
104 if d.Body != nil {
105 f.sig = f.info.Defs[d.Name].Type().(*types.Signature)
106 f.stmt(d.Body)
107 f.sig = nil
108 }
109 }
110 }
111 }
112 f.info = nil
113 }
114
115 var (
116 tInvalid = types.Typ[types.Invalid]
117 tUntypedBool = types.Typ[types.UntypedBool]
118 tUntypedNil = types.Typ[types.UntypedNil]
119 )
120
121
122 func (f *Finder) exprN(e ast.Expr) types.Type {
123 typ := f.info.Types[e].Type.(*types.Tuple)
124 switch e := e.(type) {
125 case *ast.ParenExpr:
126 return f.exprN(e.X)
127
128 case *ast.CallExpr:
129
130 sig := coreType(f.expr(e.Fun)).(*types.Signature)
131 f.call(sig, e.Args)
132
133 case *ast.IndexExpr:
134
135 x := f.expr(e.X)
136 f.assign(f.expr(e.Index), coreType(x).(*types.Map).Key())
137
138 case *ast.TypeAssertExpr:
139
140 f.typeAssert(f.expr(e.X), typ.At(0).Type())
141
142 case *ast.UnaryExpr:
143
144 f.expr(e.X)
145
146 default:
147 panic(e)
148 }
149 return typ
150 }
151
152 func (f *Finder) call(sig *types.Signature, args []ast.Expr) {
153 if len(args) == 0 {
154 return
155 }
156
157
158 if _, ok := args[len(args)-1].(*ast.Ellipsis); ok {
159 for i, arg := range args {
160
161 f.assign(sig.Params().At(i).Type(), f.expr(arg))
162 }
163 return
164 }
165
166 var argtypes []types.Type
167
168
169 if tuple, ok := f.info.Types[args[0]].Type.(*types.Tuple); ok {
170
171 f.expr(args[0])
172
173 for i := 0; i < tuple.Len(); i++ {
174 argtypes = append(argtypes, tuple.At(i).Type())
175 }
176 } else {
177 for _, arg := range args {
178 argtypes = append(argtypes, f.expr(arg))
179 }
180 }
181
182
183 if !sig.Variadic() {
184 for i, argtype := range argtypes {
185 f.assign(sig.Params().At(i).Type(), argtype)
186 }
187 } else {
188
189 nnormals := sig.Params().Len() - 1
190 for i, argtype := range argtypes[:nnormals] {
191 f.assign(sig.Params().At(i).Type(), argtype)
192 }
193
194 tElem := sig.Params().At(nnormals).Type().(*types.Slice).Elem()
195 for i := nnormals; i < len(argtypes); i++ {
196 f.assign(tElem, argtypes[i])
197 }
198 }
199 }
200
201
202 func (f *Finder) builtin(obj *types.Builtin, sig *types.Signature, args []ast.Expr) {
203 switch obj.Name() {
204 case "make", "new":
205
206 for _, arg := range args[1:] {
207 f.expr(arg)
208 }
209
210 case "append":
211 s := f.expr(args[0])
212 if _, ok := args[len(args)-1].(*ast.Ellipsis); ok && len(args) == 2 {
213
214 f.expr(args[1])
215 } else {
216
217 tElem := coreType(s).(*types.Slice).Elem()
218 for _, arg := range args[1:] {
219 f.assign(tElem, f.expr(arg))
220 }
221 }
222
223 case "delete":
224 m := f.expr(args[0])
225 k := f.expr(args[1])
226 f.assign(coreType(m).(*types.Map).Key(), k)
227
228 default:
229
230 f.call(sig, args)
231 }
232 }
233
234 func (f *Finder) extract(tuple types.Type, i int) types.Type {
235 if tuple, ok := tuple.(*types.Tuple); ok && i < tuple.Len() {
236 return tuple.At(i).Type()
237 }
238 return tInvalid
239 }
240
241 func (f *Finder) valueSpec(spec *ast.ValueSpec) {
242 var T types.Type
243 if spec.Type != nil {
244 T = f.info.Types[spec.Type].Type
245 }
246 switch len(spec.Values) {
247 case len(spec.Names):
248 for _, value := range spec.Values {
249 v := f.expr(value)
250 if T != nil {
251 f.assign(T, v)
252 }
253 }
254
255 case 1:
256 tuple := f.exprN(spec.Values[0])
257 for i := range spec.Names {
258 if T != nil {
259 f.assign(T, f.extract(tuple, i))
260 }
261 }
262 }
263 }
264
265
266
267
268
269
270
271
272
273 func (f *Finder) assign(lhs, rhs types.Type) {
274 if types.Identical(lhs, rhs) {
275 return
276 }
277 if !isInterface(lhs) {
278 return
279 }
280
281 if f.msetcache.MethodSet(lhs).Len() == 0 {
282 return
283 }
284 if f.msetcache.MethodSet(rhs).Len() == 0 {
285 return
286 }
287
288 f.Result[Constraint{lhs, rhs}] = true
289 }
290
291
292
293 func (f *Finder) typeAssert(I, T types.Type) {
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308 if types.AssignableTo(T, I) {
309 f.assign(I, T)
310 }
311 }
312
313
314 func (f *Finder) compare(x, y types.Type) {
315 if types.AssignableTo(x, y) {
316 f.assign(y, x)
317 } else if types.AssignableTo(y, x) {
318 f.assign(x, y)
319 }
320 }
321
322
323
324 func (f *Finder) expr(e ast.Expr) types.Type {
325 tv := f.info.Types[e]
326 if tv.Value != nil {
327 return tv.Type
328 }
329
330
331
332 switch e := e.(type) {
333 case *ast.BadExpr, *ast.BasicLit:
334
335
336 case *ast.Ident:
337
338 if obj, ok := f.info.Uses[e]; ok {
339 return obj.Type()
340 }
341 if e.Name == "_" {
342 return tInvalid
343 }
344 panic("undefined ident: " + e.Name)
345
346 case *ast.Ellipsis:
347 if e.Elt != nil {
348 f.expr(e.Elt)
349 }
350
351 case *ast.FuncLit:
352 saved := f.sig
353 f.sig = tv.Type.(*types.Signature)
354 f.stmt(e.Body)
355 f.sig = saved
356
357 case *ast.CompositeLit:
358
359 switch T := deref(tv.Type).Underlying().(type) {
360 case *types.Struct:
361 for i, elem := range e.Elts {
362 if kv, ok := elem.(*ast.KeyValueExpr); ok {
363 f.assign(f.info.Uses[kv.Key.(*ast.Ident)].Type(), f.expr(kv.Value))
364 } else {
365 f.assign(T.Field(i).Type(), f.expr(elem))
366 }
367 }
368
369 case *types.Map:
370 for _, elem := range e.Elts {
371 elem := elem.(*ast.KeyValueExpr)
372 f.assign(T.Key(), f.expr(elem.Key))
373 f.assign(T.Elem(), f.expr(elem.Value))
374 }
375
376 case *types.Array, *types.Slice:
377 tElem := T.(interface {
378 Elem() types.Type
379 }).Elem()
380 for _, elem := range e.Elts {
381 if kv, ok := elem.(*ast.KeyValueExpr); ok {
382
383 f.assign(tElem, f.expr(kv.Value))
384 } else {
385 f.assign(tElem, f.expr(elem))
386 }
387 }
388
389 default:
390 panic("unexpected composite literal type: " + tv.Type.String())
391 }
392
393 case *ast.ParenExpr:
394 f.expr(e.X)
395
396 case *ast.SelectorExpr:
397 if _, ok := f.info.Selections[e]; ok {
398 f.expr(e.X)
399 } else {
400 return f.info.Uses[e.Sel].Type()
401 }
402
403 case *ast.IndexExpr:
404 if instance(f.info, e.X) {
405
406 } else {
407
408 x := f.expr(e.X)
409 i := f.expr(e.Index)
410 if ux, ok := coreType(x).(*types.Map); ok {
411 f.assign(ux.Key(), i)
412 }
413 }
414
415 case *typeparams.IndexListExpr:
416
417
418 case *ast.SliceExpr:
419 f.expr(e.X)
420 if e.Low != nil {
421 f.expr(e.Low)
422 }
423 if e.High != nil {
424 f.expr(e.High)
425 }
426 if e.Max != nil {
427 f.expr(e.Max)
428 }
429
430 case *ast.TypeAssertExpr:
431 x := f.expr(e.X)
432 f.typeAssert(x, f.info.Types[e.Type].Type)
433
434 case *ast.CallExpr:
435 if tvFun := f.info.Types[e.Fun]; tvFun.IsType() {
436
437 arg0 := f.expr(e.Args[0])
438 f.assign(tvFun.Type, arg0)
439 } else {
440
441
442
443
444
445 if s, ok := unparen(e.Fun).(*ast.SelectorExpr); ok {
446 if obj, ok := f.info.Uses[s.Sel].(*types.Builtin); ok && obj.Pkg().Path() == "unsafe" {
447 sig := f.info.Types[e.Fun].Type.(*types.Signature)
448 f.call(sig, e.Args)
449 return tv.Type
450 }
451 }
452
453
454 if id, ok := unparen(e.Fun).(*ast.Ident); ok {
455 if obj, ok := f.info.Uses[id].(*types.Builtin); ok {
456 sig := f.info.Types[id].Type.(*types.Signature)
457 f.builtin(obj, sig, e.Args)
458 return tv.Type
459 }
460 }
461
462
463 f.call(coreType(f.expr(e.Fun)).(*types.Signature), e.Args)
464 }
465
466 case *ast.StarExpr:
467 f.expr(e.X)
468
469 case *ast.UnaryExpr:
470 f.expr(e.X)
471
472 case *ast.BinaryExpr:
473 x := f.expr(e.X)
474 y := f.expr(e.Y)
475 if e.Op == token.EQL || e.Op == token.NEQ {
476 f.compare(x, y)
477 }
478
479 case *ast.KeyValueExpr:
480 f.expr(e.Key)
481 f.expr(e.Value)
482
483 case *ast.ArrayType,
484 *ast.StructType,
485 *ast.FuncType,
486 *ast.InterfaceType,
487 *ast.MapType,
488 *ast.ChanType:
489 panic(e)
490 }
491
492 if tv.Type == nil {
493 panic(fmt.Sprintf("no type for %T", e))
494 }
495
496 return tv.Type
497 }
498
499 func (f *Finder) stmt(s ast.Stmt) {
500 switch s := s.(type) {
501 case *ast.BadStmt,
502 *ast.EmptyStmt,
503 *ast.BranchStmt:
504
505
506 case *ast.DeclStmt:
507 d := s.Decl.(*ast.GenDecl)
508 if d.Tok == token.VAR {
509 for _, spec := range d.Specs {
510 f.valueSpec(spec.(*ast.ValueSpec))
511 }
512 }
513
514 case *ast.LabeledStmt:
515 f.stmt(s.Stmt)
516
517 case *ast.ExprStmt:
518 f.expr(s.X)
519
520 case *ast.SendStmt:
521 ch := f.expr(s.Chan)
522 val := f.expr(s.Value)
523 f.assign(coreType(ch).(*types.Chan).Elem(), val)
524
525 case *ast.IncDecStmt:
526 f.expr(s.X)
527
528 case *ast.AssignStmt:
529 switch s.Tok {
530 case token.ASSIGN, token.DEFINE:
531
532 var rhsTuple types.Type
533 if len(s.Lhs) != len(s.Rhs) {
534 rhsTuple = f.exprN(s.Rhs[0])
535 }
536 for i := range s.Lhs {
537 var lhs, rhs types.Type
538 if rhsTuple == nil {
539 rhs = f.expr(s.Rhs[i])
540 } else {
541 rhs = f.extract(rhsTuple, i)
542 }
543
544 if id, ok := s.Lhs[i].(*ast.Ident); ok {
545 if id.Name != "_" {
546 if obj, ok := f.info.Defs[id]; ok {
547 lhs = obj.Type()
548 }
549 }
550 }
551 if lhs == nil {
552 lhs = f.expr(s.Lhs[i])
553 }
554 f.assign(lhs, rhs)
555 }
556
557 default:
558
559 f.expr(s.Lhs[0])
560 f.expr(s.Rhs[0])
561 }
562
563 case *ast.GoStmt:
564 f.expr(s.Call)
565
566 case *ast.DeferStmt:
567 f.expr(s.Call)
568
569 case *ast.ReturnStmt:
570 formals := f.sig.Results()
571 switch len(s.Results) {
572 case formals.Len():
573 for i, result := range s.Results {
574 f.assign(formals.At(i).Type(), f.expr(result))
575 }
576
577 case 1:
578 tuple := f.exprN(s.Results[0])
579 for i := 0; i < formals.Len(); i++ {
580 f.assign(formals.At(i).Type(), f.extract(tuple, i))
581 }
582 }
583
584 case *ast.SelectStmt:
585 f.stmt(s.Body)
586
587 case *ast.BlockStmt:
588 for _, s := range s.List {
589 f.stmt(s)
590 }
591
592 case *ast.IfStmt:
593 if s.Init != nil {
594 f.stmt(s.Init)
595 }
596 f.expr(s.Cond)
597 f.stmt(s.Body)
598 if s.Else != nil {
599 f.stmt(s.Else)
600 }
601
602 case *ast.SwitchStmt:
603 if s.Init != nil {
604 f.stmt(s.Init)
605 }
606 var tag types.Type = tUntypedBool
607 if s.Tag != nil {
608 tag = f.expr(s.Tag)
609 }
610 for _, cc := range s.Body.List {
611 cc := cc.(*ast.CaseClause)
612 for _, cond := range cc.List {
613 f.compare(tag, f.info.Types[cond].Type)
614 }
615 for _, s := range cc.Body {
616 f.stmt(s)
617 }
618 }
619
620 case *ast.TypeSwitchStmt:
621 if s.Init != nil {
622 f.stmt(s.Init)
623 }
624 var I types.Type
625 switch ass := s.Assign.(type) {
626 case *ast.ExprStmt:
627 I = f.expr(unparen(ass.X).(*ast.TypeAssertExpr).X)
628 case *ast.AssignStmt:
629 I = f.expr(unparen(ass.Rhs[0]).(*ast.TypeAssertExpr).X)
630 }
631 for _, cc := range s.Body.List {
632 cc := cc.(*ast.CaseClause)
633 for _, cond := range cc.List {
634 tCase := f.info.Types[cond].Type
635 if tCase != tUntypedNil {
636 f.typeAssert(I, tCase)
637 }
638 }
639 for _, s := range cc.Body {
640 f.stmt(s)
641 }
642 }
643
644 case *ast.CommClause:
645 if s.Comm != nil {
646 f.stmt(s.Comm)
647 }
648 for _, s := range s.Body {
649 f.stmt(s)
650 }
651
652 case *ast.ForStmt:
653 if s.Init != nil {
654 f.stmt(s.Init)
655 }
656 if s.Cond != nil {
657 f.expr(s.Cond)
658 }
659 if s.Post != nil {
660 f.stmt(s.Post)
661 }
662 f.stmt(s.Body)
663
664 case *ast.RangeStmt:
665 x := f.expr(s.X)
666
667 if s.Tok == token.ASSIGN {
668 if s.Key != nil {
669 k := f.expr(s.Key)
670 var xelem types.Type
671
672
673 switch ux := coreType(x).(type) {
674 case *types.Chan:
675 xelem = ux.Elem()
676 case *types.Map:
677 xelem = ux.Key()
678 }
679 if xelem != nil {
680 f.assign(k, xelem)
681 }
682 }
683 if s.Value != nil {
684 val := f.expr(s.Value)
685 var xelem types.Type
686
687
688 switch ux := coreType(x).(type) {
689 case *types.Array:
690 xelem = ux.Elem()
691 case *types.Map:
692 xelem = ux.Elem()
693 case *types.Pointer:
694 xelem = coreType(deref(ux)).(*types.Array).Elem()
695 case *types.Slice:
696 xelem = ux.Elem()
697 }
698 if xelem != nil {
699 f.assign(val, xelem)
700 }
701 }
702 }
703 f.stmt(s.Body)
704
705 default:
706 panic(s)
707 }
708 }
709
710
711
712
713 func deref(typ types.Type) types.Type {
714 if p, ok := coreType(typ).(*types.Pointer); ok {
715 return p.Elem()
716 }
717 return typ
718 }
719
720 func unparen(e ast.Expr) ast.Expr { return astutil.Unparen(e) }
721
722 func isInterface(T types.Type) bool { return types.IsInterface(T) }
723
724 func coreType(T types.Type) types.Type { return typeparams.CoreType(T) }
725
726 func instance(info *types.Info, expr ast.Expr) bool {
727 var id *ast.Ident
728 switch x := expr.(type) {
729 case *ast.Ident:
730 id = x
731 case *ast.SelectorExpr:
732 id = x.Sel
733 default:
734 return false
735 }
736 _, ok := typeparams.GetInstances(info)[id]
737 return ok
738 }
739
View as plain text