1
2
3
4 package ssa
5
6 import (
7 "fmt"
8 "go/types"
9
10 "golang.org/x/tools/internal/typeparams"
11 )
12
13
14
15
16
17
18
19
20 type subster struct {
21
22
23 replacements map[*typeparams.TypeParam]types.Type
24 cache map[types.Type]types.Type
25 ctxt *typeparams.Context
26 debug bool
27
28 }
29
30
31
32 func makeSubster(ctxt *typeparams.Context, tparams *typeparams.TypeParamList, targs []types.Type, debug bool) *subster {
33 assert(tparams.Len() == len(targs), "makeSubster argument count must match")
34
35 subst := &subster{
36 replacements: make(map[*typeparams.TypeParam]types.Type, tparams.Len()),
37 cache: make(map[types.Type]types.Type),
38 ctxt: ctxt,
39 debug: debug,
40 }
41 for i := 0; i < tparams.Len(); i++ {
42 subst.replacements[tparams.At(i)] = targs[i]
43 }
44 if subst.debug {
45 if err := subst.wellFormed(); err != nil {
46 panic(err)
47 }
48 }
49 return subst
50 }
51
52
53 func (subst *subster) wellFormed() error {
54 if subst == nil || len(subst.replacements) == 0 {
55 return nil
56 }
57
58 s := make(map[types.Type]bool, len(subst.replacements))
59 for tparam := range subst.replacements {
60 s[tparam] = true
61 }
62 for _, r := range subst.replacements {
63 if reaches(r, s) {
64 return fmt.Errorf("\n‰r %s s %v replacements %v\n", r, s, subst.replacements)
65 }
66 }
67 return nil
68 }
69
70
71
72 func (subst *subster) typ(t types.Type) (res types.Type) {
73 if subst == nil {
74 return t
75 }
76 if r, ok := subst.cache[t]; ok {
77 return r
78 }
79 defer func() {
80 subst.cache[t] = res
81 }()
82
83
84 switch t := t.(type) {
85 case *typeparams.TypeParam:
86 r := subst.replacements[t]
87 assert(r != nil, "type param without replacement encountered")
88 return r
89
90 case *types.Basic:
91 return t
92
93 case *types.Array:
94 if r := subst.typ(t.Elem()); r != t.Elem() {
95 return types.NewArray(r, t.Len())
96 }
97 return t
98
99 case *types.Slice:
100 if r := subst.typ(t.Elem()); r != t.Elem() {
101 return types.NewSlice(r)
102 }
103 return t
104
105 case *types.Pointer:
106 if r := subst.typ(t.Elem()); r != t.Elem() {
107 return types.NewPointer(r)
108 }
109 return t
110
111 case *types.Tuple:
112 return subst.tuple(t)
113
114 case *types.Struct:
115 return subst.struct_(t)
116
117 case *types.Map:
118 key := subst.typ(t.Key())
119 elem := subst.typ(t.Elem())
120 if key != t.Key() || elem != t.Elem() {
121 return types.NewMap(key, elem)
122 }
123 return t
124
125 case *types.Chan:
126 if elem := subst.typ(t.Elem()); elem != t.Elem() {
127 return types.NewChan(t.Dir(), elem)
128 }
129 return t
130
131 case *types.Signature:
132 return subst.signature(t)
133
134 case *typeparams.Union:
135 return subst.union(t)
136
137 case *types.Interface:
138 return subst.interface_(t)
139
140 case *types.Named:
141 return subst.named(t)
142
143 default:
144 panic("unreachable")
145 }
146 }
147
148
149 func (subst *subster) types(ts []types.Type) []types.Type {
150 res := make([]types.Type, len(ts))
151 for i := range ts {
152 res[i] = subst.typ(ts[i])
153 }
154 return res
155 }
156
157 func (subst *subster) tuple(t *types.Tuple) *types.Tuple {
158 if t != nil {
159 if vars := subst.varlist(t); vars != nil {
160 return types.NewTuple(vars...)
161 }
162 }
163 return t
164 }
165
166 type varlist interface {
167 At(i int) *types.Var
168 Len() int
169 }
170
171
172 type fieldlist struct {
173 str *types.Struct
174 }
175
176 func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) }
177 func (fl fieldlist) Len() int { return fl.str.NumFields() }
178
179 func (subst *subster) struct_(t *types.Struct) *types.Struct {
180 if t != nil {
181 if fields := subst.varlist(fieldlist{t}); fields != nil {
182 tags := make([]string, t.NumFields())
183 for i, n := 0, t.NumFields(); i < n; i++ {
184 tags[i] = t.Tag(i)
185 }
186 return types.NewStruct(fields, tags)
187 }
188 }
189 return t
190 }
191
192
193 func (subst *subster) varlist(in varlist) []*types.Var {
194 var out []*types.Var
195 for i, n := 0, in.Len(); i < n; i++ {
196 v := in.At(i)
197 w := subst.var_(v)
198 if v != w && out == nil {
199 out = make([]*types.Var, n)
200 for j := 0; j < i; j++ {
201 out[j] = in.At(j)
202 }
203 }
204 if out != nil {
205 out[i] = w
206 }
207 }
208 return out
209 }
210
211 func (subst *subster) var_(v *types.Var) *types.Var {
212 if v != nil {
213 if typ := subst.typ(v.Type()); typ != v.Type() {
214 if v.IsField() {
215 return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded())
216 }
217 return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ)
218 }
219 }
220 return v
221 }
222
223 func (subst *subster) union(u *typeparams.Union) *typeparams.Union {
224 var out []*typeparams.Term
225
226 for i, n := 0, u.Len(); i < n; i++ {
227 t := u.Term(i)
228 r := subst.typ(t.Type())
229 if r != t.Type() && out == nil {
230 out = make([]*typeparams.Term, n)
231 for j := 0; j < i; j++ {
232 out[j] = u.Term(j)
233 }
234 }
235 if out != nil {
236 out[i] = typeparams.NewTerm(t.Tilde(), r)
237 }
238 }
239
240 if out != nil {
241 return typeparams.NewUnion(out)
242 }
243 return u
244 }
245
246 func (subst *subster) interface_(iface *types.Interface) *types.Interface {
247 if iface == nil {
248 return nil
249 }
250
251
252
253 var methods []*types.Func
254 initMethods := func(n int) {
255 methods = make([]*types.Func, iface.NumExplicitMethods())
256 for i := 0; i < n; i++ {
257 f := iface.ExplicitMethod(i)
258 norecv := changeRecv(f.Type().(*types.Signature), nil)
259 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
260 }
261 }
262 for i := 0; i < iface.NumExplicitMethods(); i++ {
263 f := iface.ExplicitMethod(i)
264
265
266
267 norecv := changeRecv(f.Type().(*types.Signature), nil)
268 sig := subst.typ(norecv)
269 if sig != norecv && methods == nil {
270 initMethods(i)
271 }
272 if methods != nil {
273 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
274 }
275 }
276
277 var embeds []types.Type
278 initEmbeds := func(n int) {
279 embeds = make([]types.Type, iface.NumEmbeddeds())
280 for i := 0; i < n; i++ {
281 embeds[i] = iface.EmbeddedType(i)
282 }
283 }
284 for i := 0; i < iface.NumEmbeddeds(); i++ {
285 e := iface.EmbeddedType(i)
286 r := subst.typ(e)
287 if e != r && embeds == nil {
288 initEmbeds(i)
289 }
290 if embeds != nil {
291 embeds[i] = r
292 }
293 }
294
295 if methods == nil && embeds == nil {
296 return iface
297 }
298 if methods == nil {
299 initMethods(iface.NumExplicitMethods())
300 }
301 if embeds == nil {
302 initEmbeds(iface.NumEmbeddeds())
303 }
304 return types.NewInterfaceType(methods, embeds).Complete()
305 }
306
307 func (subst *subster) named(t *types.Named) types.Type {
308
309
310
311
312 tparams := typeparams.ForNamed(t)
313 if tparams.Len() == 0 {
314
315
316
317
318 return t
319 }
320 targs := typeparams.NamedTypeArgs(t)
321
322
323 insts := make([]types.Type, tparams.Len())
324
325
326
327
328 assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported")
329
330
331
332
333
334
335
336
337
338 assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present")
339 for i, n := 0, targs.Len(); i < n; i++ {
340 inst := subst.typ(targs.At(i))
341 insts[i] = inst
342 }
343 r, err := typeparams.Instantiate(subst.ctxt, typeparams.NamedTypeOrigin(t), insts, false)
344 assert(err == nil, "failed to Instantiate Named type")
345 return r
346 }
347
348 func (subst *subster) signature(t *types.Signature) types.Type {
349 tparams := typeparams.ForSignature(t)
350
351
352
353
354
355
356
357 assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.")
358
359
360
361
362
363
364
365
366
367
368
369
370 recv := subst.var_(t.Recv())
371 params := subst.tuple(t.Params())
372 results := subst.tuple(t.Results())
373 if recv != t.Recv() || params != t.Params() || results != t.Results() {
374 return typeparams.NewSignatureType(recv, nil, nil, params, results, t.Variadic())
375 }
376 return t
377 }
378
379
380
381 func reaches(t types.Type, c map[types.Type]bool) (res bool) {
382 if c, ok := c[t]; ok {
383 return c
384 }
385 c[t] = false
386 defer func() {
387 c[t] = res
388 }()
389
390 switch t := t.(type) {
391 case *typeparams.TypeParam, *types.Basic:
392
393 case *types.Array:
394 return reaches(t.Elem(), c)
395 case *types.Slice:
396 return reaches(t.Elem(), c)
397 case *types.Pointer:
398 return reaches(t.Elem(), c)
399 case *types.Tuple:
400 for i := 0; i < t.Len(); i++ {
401 if reaches(t.At(i).Type(), c) {
402 return true
403 }
404 }
405 case *types.Struct:
406 for i := 0; i < t.NumFields(); i++ {
407 if reaches(t.Field(i).Type(), c) {
408 return true
409 }
410 }
411 case *types.Map:
412 return reaches(t.Key(), c) || reaches(t.Elem(), c)
413 case *types.Chan:
414 return reaches(t.Elem(), c)
415 case *types.Signature:
416 if t.Recv() != nil && reaches(t.Recv().Type(), c) {
417 return true
418 }
419 return reaches(t.Params(), c) || reaches(t.Results(), c)
420 case *typeparams.Union:
421 for i := 0; i < t.Len(); i++ {
422 if reaches(t.Term(i).Type(), c) {
423 return true
424 }
425 }
426 case *types.Interface:
427 for i := 0; i < t.NumEmbeddeds(); i++ {
428 if reaches(t.Embedded(i), c) {
429 return true
430 }
431 }
432 for i := 0; i < t.NumExplicitMethods(); i++ {
433 if reaches(t.ExplicitMethod(i).Type(), c) {
434 return true
435 }
436 }
437 case *types.Named:
438 return reaches(t.Underlying(), c)
439 default:
440 panic("unreachable")
441 }
442 return false
443 }
444
View as plain text