...

Source file src/golang.org/x/tools/go/ssa/subst.go

Documentation: golang.org/x/tools/go/ssa

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  package ssa
     5  
     6  import (
     7  	"fmt"
     8  	"go/types"
     9  
    10  	"golang.org/x/tools/internal/typeparams"
    11  )
    12  
    13  // Type substituter for a fixed set of replacement types.
    14  //
    15  // A nil *subster is an valid, empty substitution map. It always acts as
    16  // the identity function. This allows for treating parameterized and
    17  // non-parameterized functions identically while compiling to ssa.
    18  //
    19  // Not concurrency-safe.
    20  type subster struct {
    21  	// TODO(zpavlinovic): replacements can contain type params
    22  	// when generating instances inside of a generic function body.
    23  	replacements map[*typeparams.TypeParam]types.Type // values should contain no type params
    24  	cache        map[types.Type]types.Type            // cache of subst results
    25  	ctxt         *typeparams.Context
    26  	debug        bool // perform extra debugging checks
    27  	// TODO(taking): consider adding Pos
    28  }
    29  
    30  // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
    31  // targs should not contain any types in tparams.
    32  func makeSubster(ctxt *typeparams.Context, tparams *typeparams.TypeParamList, targs []types.Type, debug bool) *subster {
    33  	assert(tparams.Len() == len(targs), "makeSubster argument count must match")
    34  
    35  	subst := &subster{
    36  		replacements: make(map[*typeparams.TypeParam]types.Type, tparams.Len()),
    37  		cache:        make(map[types.Type]types.Type),
    38  		ctxt:         ctxt,
    39  		debug:        debug,
    40  	}
    41  	for i := 0; i < tparams.Len(); i++ {
    42  		subst.replacements[tparams.At(i)] = targs[i]
    43  	}
    44  	if subst.debug {
    45  		if err := subst.wellFormed(); err != nil {
    46  			panic(err)
    47  		}
    48  	}
    49  	return subst
    50  }
    51  
    52  // wellFormed returns an error if subst was not properly initialized.
    53  func (subst *subster) wellFormed() error {
    54  	if subst == nil || len(subst.replacements) == 0 {
    55  		return nil
    56  	}
    57  	// Check that all of the type params do not appear in the arguments.
    58  	s := make(map[types.Type]bool, len(subst.replacements))
    59  	for tparam := range subst.replacements {
    60  		s[tparam] = true
    61  	}
    62  	for _, r := range subst.replacements {
    63  		if reaches(r, s) {
    64  			return fmt.Errorf("\n‰r %s s %v replacements %v\n", r, s, subst.replacements)
    65  		}
    66  	}
    67  	return nil
    68  }
    69  
    70  // typ returns the type of t with the type parameter tparams[i] substituted
    71  // for the type targs[i] where subst was created using tparams and targs.
    72  func (subst *subster) typ(t types.Type) (res types.Type) {
    73  	if subst == nil {
    74  		return t // A nil subst is type preserving.
    75  	}
    76  	if r, ok := subst.cache[t]; ok {
    77  		return r
    78  	}
    79  	defer func() {
    80  		subst.cache[t] = res
    81  	}()
    82  
    83  	// fall through if result r will be identical to t, types.Identical(r, t).
    84  	switch t := t.(type) {
    85  	case *typeparams.TypeParam:
    86  		r := subst.replacements[t]
    87  		assert(r != nil, "type param without replacement encountered")
    88  		return r
    89  
    90  	case *types.Basic:
    91  		return t
    92  
    93  	case *types.Array:
    94  		if r := subst.typ(t.Elem()); r != t.Elem() {
    95  			return types.NewArray(r, t.Len())
    96  		}
    97  		return t
    98  
    99  	case *types.Slice:
   100  		if r := subst.typ(t.Elem()); r != t.Elem() {
   101  			return types.NewSlice(r)
   102  		}
   103  		return t
   104  
   105  	case *types.Pointer:
   106  		if r := subst.typ(t.Elem()); r != t.Elem() {
   107  			return types.NewPointer(r)
   108  		}
   109  		return t
   110  
   111  	case *types.Tuple:
   112  		return subst.tuple(t)
   113  
   114  	case *types.Struct:
   115  		return subst.struct_(t)
   116  
   117  	case *types.Map:
   118  		key := subst.typ(t.Key())
   119  		elem := subst.typ(t.Elem())
   120  		if key != t.Key() || elem != t.Elem() {
   121  			return types.NewMap(key, elem)
   122  		}
   123  		return t
   124  
   125  	case *types.Chan:
   126  		if elem := subst.typ(t.Elem()); elem != t.Elem() {
   127  			return types.NewChan(t.Dir(), elem)
   128  		}
   129  		return t
   130  
   131  	case *types.Signature:
   132  		return subst.signature(t)
   133  
   134  	case *typeparams.Union:
   135  		return subst.union(t)
   136  
   137  	case *types.Interface:
   138  		return subst.interface_(t)
   139  
   140  	case *types.Named:
   141  		return subst.named(t)
   142  
   143  	default:
   144  		panic("unreachable")
   145  	}
   146  }
   147  
   148  // types returns the result of {subst.typ(ts[i])}.
   149  func (subst *subster) types(ts []types.Type) []types.Type {
   150  	res := make([]types.Type, len(ts))
   151  	for i := range ts {
   152  		res[i] = subst.typ(ts[i])
   153  	}
   154  	return res
   155  }
   156  
   157  func (subst *subster) tuple(t *types.Tuple) *types.Tuple {
   158  	if t != nil {
   159  		if vars := subst.varlist(t); vars != nil {
   160  			return types.NewTuple(vars...)
   161  		}
   162  	}
   163  	return t
   164  }
   165  
   166  type varlist interface {
   167  	At(i int) *types.Var
   168  	Len() int
   169  }
   170  
   171  // fieldlist is an adapter for structs for the varlist interface.
   172  type fieldlist struct {
   173  	str *types.Struct
   174  }
   175  
   176  func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) }
   177  func (fl fieldlist) Len() int            { return fl.str.NumFields() }
   178  
   179  func (subst *subster) struct_(t *types.Struct) *types.Struct {
   180  	if t != nil {
   181  		if fields := subst.varlist(fieldlist{t}); fields != nil {
   182  			tags := make([]string, t.NumFields())
   183  			for i, n := 0, t.NumFields(); i < n; i++ {
   184  				tags[i] = t.Tag(i)
   185  			}
   186  			return types.NewStruct(fields, tags)
   187  		}
   188  	}
   189  	return t
   190  }
   191  
   192  // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
   193  func (subst *subster) varlist(in varlist) []*types.Var {
   194  	var out []*types.Var // nil => no updates
   195  	for i, n := 0, in.Len(); i < n; i++ {
   196  		v := in.At(i)
   197  		w := subst.var_(v)
   198  		if v != w && out == nil {
   199  			out = make([]*types.Var, n)
   200  			for j := 0; j < i; j++ {
   201  				out[j] = in.At(j)
   202  			}
   203  		}
   204  		if out != nil {
   205  			out[i] = w
   206  		}
   207  	}
   208  	return out
   209  }
   210  
   211  func (subst *subster) var_(v *types.Var) *types.Var {
   212  	if v != nil {
   213  		if typ := subst.typ(v.Type()); typ != v.Type() {
   214  			if v.IsField() {
   215  				return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded())
   216  			}
   217  			return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ)
   218  		}
   219  	}
   220  	return v
   221  }
   222  
   223  func (subst *subster) union(u *typeparams.Union) *typeparams.Union {
   224  	var out []*typeparams.Term // nil => no updates
   225  
   226  	for i, n := 0, u.Len(); i < n; i++ {
   227  		t := u.Term(i)
   228  		r := subst.typ(t.Type())
   229  		if r != t.Type() && out == nil {
   230  			out = make([]*typeparams.Term, n)
   231  			for j := 0; j < i; j++ {
   232  				out[j] = u.Term(j)
   233  			}
   234  		}
   235  		if out != nil {
   236  			out[i] = typeparams.NewTerm(t.Tilde(), r)
   237  		}
   238  	}
   239  
   240  	if out != nil {
   241  		return typeparams.NewUnion(out)
   242  	}
   243  	return u
   244  }
   245  
   246  func (subst *subster) interface_(iface *types.Interface) *types.Interface {
   247  	if iface == nil {
   248  		return nil
   249  	}
   250  
   251  	// methods for the interface. Initially nil if there is no known change needed.
   252  	// Signatures for the method where recv is nil. NewInterfaceType fills in the recievers.
   253  	var methods []*types.Func
   254  	initMethods := func(n int) { // copy first n explicit methods
   255  		methods = make([]*types.Func, iface.NumExplicitMethods())
   256  		for i := 0; i < n; i++ {
   257  			f := iface.ExplicitMethod(i)
   258  			norecv := changeRecv(f.Type().(*types.Signature), nil)
   259  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
   260  		}
   261  	}
   262  	for i := 0; i < iface.NumExplicitMethods(); i++ {
   263  		f := iface.ExplicitMethod(i)
   264  		// On interfaces, we need to cycle break on anonymous interface types
   265  		// being in a cycle with their signatures being in cycles with their recievers
   266  		// that do not go through a Named.
   267  		norecv := changeRecv(f.Type().(*types.Signature), nil)
   268  		sig := subst.typ(norecv)
   269  		if sig != norecv && methods == nil {
   270  			initMethods(i)
   271  		}
   272  		if methods != nil {
   273  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
   274  		}
   275  	}
   276  
   277  	var embeds []types.Type
   278  	initEmbeds := func(n int) { // copy first n embedded types
   279  		embeds = make([]types.Type, iface.NumEmbeddeds())
   280  		for i := 0; i < n; i++ {
   281  			embeds[i] = iface.EmbeddedType(i)
   282  		}
   283  	}
   284  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   285  		e := iface.EmbeddedType(i)
   286  		r := subst.typ(e)
   287  		if e != r && embeds == nil {
   288  			initEmbeds(i)
   289  		}
   290  		if embeds != nil {
   291  			embeds[i] = r
   292  		}
   293  	}
   294  
   295  	if methods == nil && embeds == nil {
   296  		return iface
   297  	}
   298  	if methods == nil {
   299  		initMethods(iface.NumExplicitMethods())
   300  	}
   301  	if embeds == nil {
   302  		initEmbeds(iface.NumEmbeddeds())
   303  	}
   304  	return types.NewInterfaceType(methods, embeds).Complete()
   305  }
   306  
   307  func (subst *subster) named(t *types.Named) types.Type {
   308  	// A name type may be:
   309  	// (1) ordinary (no type parameters, no type arguments),
   310  	// (2) generic (type parameters but no type arguments), or
   311  	// (3) instantiated (type parameters and type arguments).
   312  	tparams := typeparams.ForNamed(t)
   313  	if tparams.Len() == 0 {
   314  		// case (1) ordinary
   315  
   316  		// Note: If Go allows for local type declarations in generic
   317  		// functions we may need to descend into underlying as well.
   318  		return t
   319  	}
   320  	targs := typeparams.NamedTypeArgs(t)
   321  
   322  	// insts are arguments to instantiate using.
   323  	insts := make([]types.Type, tparams.Len())
   324  
   325  	// case (2) generic ==> targs.Len() == 0
   326  	// Instantiating a generic with no type arguments should be unreachable.
   327  	// Please report a bug if you encounter this.
   328  	assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported")
   329  
   330  	// case (3) instantiated.
   331  	// Substitute into the type arguments and instantiate the replacements/
   332  	// Example:
   333  	//    type N[A any] func() A
   334  	//    func Foo[T](g N[T]) {}
   335  	//  To instantiate Foo[string], one goes through {T->string}. To get the type of g
   336  	//  one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} }
   337  	//  to get {N with TypeArgs == {string} and typeparams == {A} }.
   338  	assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present")
   339  	for i, n := 0, targs.Len(); i < n; i++ {
   340  		inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion
   341  		insts[i] = inst
   342  	}
   343  	r, err := typeparams.Instantiate(subst.ctxt, typeparams.NamedTypeOrigin(t), insts, false)
   344  	assert(err == nil, "failed to Instantiate Named type")
   345  	return r
   346  }
   347  
   348  func (subst *subster) signature(t *types.Signature) types.Type {
   349  	tparams := typeparams.ForSignature(t)
   350  
   351  	// We are choosing not to support tparams.Len() > 0 until a need has been observed in practice.
   352  	//
   353  	// There are some known usages for types.Types coming from types.{Eval,CheckExpr}.
   354  	// To support tparams.Len() > 0, we just need to do the following [psuedocode]:
   355  	//   targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false)
   356  
   357  	assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.")
   358  
   359  	// Either:
   360  	// (1)non-generic function.
   361  	//    no type params to substitute
   362  	// (2)generic method and recv needs to be substituted.
   363  
   364  	// Recievers can be either:
   365  	// named
   366  	// pointer to named
   367  	// interface
   368  	// nil
   369  	// interface is the problematic case. We need to cycle break there!
   370  	recv := subst.var_(t.Recv())
   371  	params := subst.tuple(t.Params())
   372  	results := subst.tuple(t.Results())
   373  	if recv != t.Recv() || params != t.Params() || results != t.Results() {
   374  		return typeparams.NewSignatureType(recv, nil, nil, params, results, t.Variadic())
   375  	}
   376  	return t
   377  }
   378  
   379  // reaches returns true if a type t reaches any type t' s.t. c[t'] == true.
   380  // Updates c to cache results.
   381  func reaches(t types.Type, c map[types.Type]bool) (res bool) {
   382  	if c, ok := c[t]; ok {
   383  		return c
   384  	}
   385  	c[t] = false // prevent cycles
   386  	defer func() {
   387  		c[t] = res
   388  	}()
   389  
   390  	switch t := t.(type) {
   391  	case *typeparams.TypeParam, *types.Basic:
   392  		// no-op => c == false
   393  	case *types.Array:
   394  		return reaches(t.Elem(), c)
   395  	case *types.Slice:
   396  		return reaches(t.Elem(), c)
   397  	case *types.Pointer:
   398  		return reaches(t.Elem(), c)
   399  	case *types.Tuple:
   400  		for i := 0; i < t.Len(); i++ {
   401  			if reaches(t.At(i).Type(), c) {
   402  				return true
   403  			}
   404  		}
   405  	case *types.Struct:
   406  		for i := 0; i < t.NumFields(); i++ {
   407  			if reaches(t.Field(i).Type(), c) {
   408  				return true
   409  			}
   410  		}
   411  	case *types.Map:
   412  		return reaches(t.Key(), c) || reaches(t.Elem(), c)
   413  	case *types.Chan:
   414  		return reaches(t.Elem(), c)
   415  	case *types.Signature:
   416  		if t.Recv() != nil && reaches(t.Recv().Type(), c) {
   417  			return true
   418  		}
   419  		return reaches(t.Params(), c) || reaches(t.Results(), c)
   420  	case *typeparams.Union:
   421  		for i := 0; i < t.Len(); i++ {
   422  			if reaches(t.Term(i).Type(), c) {
   423  				return true
   424  			}
   425  		}
   426  	case *types.Interface:
   427  		for i := 0; i < t.NumEmbeddeds(); i++ {
   428  			if reaches(t.Embedded(i), c) {
   429  				return true
   430  			}
   431  		}
   432  		for i := 0; i < t.NumExplicitMethods(); i++ {
   433  			if reaches(t.ExplicitMethod(i).Type(), c) {
   434  				return true
   435  			}
   436  		}
   437  	case *types.Named:
   438  		return reaches(t.Underlying(), c)
   439  	default:
   440  		panic("unreachable")
   441  	}
   442  	return false
   443  }
   444  

View as plain text