...

Source file src/golang.org/x/tools/go/ast/astutil/rewrite.go

Documentation: golang.org/x/tools/go/ast/astutil

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package astutil
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"reflect"
    11  	"sort"
    12  
    13  	"golang.org/x/tools/internal/typeparams"
    14  )
    15  
    16  // An ApplyFunc is invoked by Apply for each node n, even if n is nil,
    17  // before and/or after the node's children, using a Cursor describing
    18  // the current node and providing operations on it.
    19  //
    20  // The return value of ApplyFunc controls the syntax tree traversal.
    21  // See Apply for details.
    22  type ApplyFunc func(*Cursor) bool
    23  
    24  // Apply traverses a syntax tree recursively, starting with root,
    25  // and calling pre and post for each node as described below.
    26  // Apply returns the syntax tree, possibly modified.
    27  //
    28  // If pre is not nil, it is called for each node before the node's
    29  // children are traversed (pre-order). If pre returns false, no
    30  // children are traversed, and post is not called for that node.
    31  //
    32  // If post is not nil, and a prior call of pre didn't return false,
    33  // post is called for each node after its children are traversed
    34  // (post-order). If post returns false, traversal is terminated and
    35  // Apply returns immediately.
    36  //
    37  // Only fields that refer to AST nodes are considered children;
    38  // i.e., token.Pos, Scopes, Objects, and fields of basic types
    39  // (strings, etc.) are ignored.
    40  //
    41  // Children are traversed in the order in which they appear in the
    42  // respective node's struct definition. A package's files are
    43  // traversed in the filenames' alphabetical order.
    44  func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
    45  	parent := &struct{ ast.Node }{root}
    46  	defer func() {
    47  		if r := recover(); r != nil && r != abort {
    48  			panic(r)
    49  		}
    50  		result = parent.Node
    51  	}()
    52  	a := &application{pre: pre, post: post}
    53  	a.apply(parent, "Node", nil, root)
    54  	return
    55  }
    56  
    57  var abort = new(int) // singleton, to signal termination of Apply
    58  
    59  // A Cursor describes a node encountered during Apply.
    60  // Information about the node and its parent is available
    61  // from the Node, Parent, Name, and Index methods.
    62  //
    63  // If p is a variable of type and value of the current parent node
    64  // c.Parent(), and f is the field identifier with name c.Name(),
    65  // the following invariants hold:
    66  //
    67  //	p.f            == c.Node()  if c.Index() <  0
    68  //	p.f[c.Index()] == c.Node()  if c.Index() >= 0
    69  //
    70  // The methods Replace, Delete, InsertBefore, and InsertAfter
    71  // can be used to change the AST without disrupting Apply.
    72  type Cursor struct {
    73  	parent ast.Node
    74  	name   string
    75  	iter   *iterator // valid if non-nil
    76  	node   ast.Node
    77  }
    78  
    79  // Node returns the current Node.
    80  func (c *Cursor) Node() ast.Node { return c.node }
    81  
    82  // Parent returns the parent of the current Node.
    83  func (c *Cursor) Parent() ast.Node { return c.parent }
    84  
    85  // Name returns the name of the parent Node field that contains the current Node.
    86  // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
    87  // the filename for the current Node.
    88  func (c *Cursor) Name() string { return c.name }
    89  
    90  // Index reports the index >= 0 of the current Node in the slice of Nodes that
    91  // contains it, or a value < 0 if the current Node is not part of a slice.
    92  // The index of the current node changes if InsertBefore is called while
    93  // processing the current node.
    94  func (c *Cursor) Index() int {
    95  	if c.iter != nil {
    96  		return c.iter.index
    97  	}
    98  	return -1
    99  }
   100  
   101  // field returns the current node's parent field value.
   102  func (c *Cursor) field() reflect.Value {
   103  	return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
   104  }
   105  
   106  // Replace replaces the current Node with n.
   107  // The replacement node is not walked by Apply.
   108  func (c *Cursor) Replace(n ast.Node) {
   109  	if _, ok := c.node.(*ast.File); ok {
   110  		file, ok := n.(*ast.File)
   111  		if !ok {
   112  			panic("attempt to replace *ast.File with non-*ast.File")
   113  		}
   114  		c.parent.(*ast.Package).Files[c.name] = file
   115  		return
   116  	}
   117  
   118  	v := c.field()
   119  	if i := c.Index(); i >= 0 {
   120  		v = v.Index(i)
   121  	}
   122  	v.Set(reflect.ValueOf(n))
   123  }
   124  
   125  // Delete deletes the current Node from its containing slice.
   126  // If the current Node is not part of a slice, Delete panics.
   127  // As a special case, if the current node is a package file,
   128  // Delete removes it from the package's Files map.
   129  func (c *Cursor) Delete() {
   130  	if _, ok := c.node.(*ast.File); ok {
   131  		delete(c.parent.(*ast.Package).Files, c.name)
   132  		return
   133  	}
   134  
   135  	i := c.Index()
   136  	if i < 0 {
   137  		panic("Delete node not contained in slice")
   138  	}
   139  	v := c.field()
   140  	l := v.Len()
   141  	reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
   142  	v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
   143  	v.SetLen(l - 1)
   144  	c.iter.step--
   145  }
   146  
   147  // InsertAfter inserts n after the current Node in its containing slice.
   148  // If the current Node is not part of a slice, InsertAfter panics.
   149  // Apply does not walk n.
   150  func (c *Cursor) InsertAfter(n ast.Node) {
   151  	i := c.Index()
   152  	if i < 0 {
   153  		panic("InsertAfter node not contained in slice")
   154  	}
   155  	v := c.field()
   156  	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
   157  	l := v.Len()
   158  	reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
   159  	v.Index(i + 1).Set(reflect.ValueOf(n))
   160  	c.iter.step++
   161  }
   162  
   163  // InsertBefore inserts n before the current Node in its containing slice.
   164  // If the current Node is not part of a slice, InsertBefore panics.
   165  // Apply will not walk n.
   166  func (c *Cursor) InsertBefore(n ast.Node) {
   167  	i := c.Index()
   168  	if i < 0 {
   169  		panic("InsertBefore node not contained in slice")
   170  	}
   171  	v := c.field()
   172  	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
   173  	l := v.Len()
   174  	reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
   175  	v.Index(i).Set(reflect.ValueOf(n))
   176  	c.iter.index++
   177  }
   178  
   179  // application carries all the shared data so we can pass it around cheaply.
   180  type application struct {
   181  	pre, post ApplyFunc
   182  	cursor    Cursor
   183  	iter      iterator
   184  }
   185  
   186  func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
   187  	// convert typed nil into untyped nil
   188  	if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
   189  		n = nil
   190  	}
   191  
   192  	// avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
   193  	saved := a.cursor
   194  	a.cursor.parent = parent
   195  	a.cursor.name = name
   196  	a.cursor.iter = iter
   197  	a.cursor.node = n
   198  
   199  	if a.pre != nil && !a.pre(&a.cursor) {
   200  		a.cursor = saved
   201  		return
   202  	}
   203  
   204  	// walk children
   205  	// (the order of the cases matches the order of the corresponding node types in go/ast)
   206  	switch n := n.(type) {
   207  	case nil:
   208  		// nothing to do
   209  
   210  	// Comments and fields
   211  	case *ast.Comment:
   212  		// nothing to do
   213  
   214  	case *ast.CommentGroup:
   215  		if n != nil {
   216  			a.applyList(n, "List")
   217  		}
   218  
   219  	case *ast.Field:
   220  		a.apply(n, "Doc", nil, n.Doc)
   221  		a.applyList(n, "Names")
   222  		a.apply(n, "Type", nil, n.Type)
   223  		a.apply(n, "Tag", nil, n.Tag)
   224  		a.apply(n, "Comment", nil, n.Comment)
   225  
   226  	case *ast.FieldList:
   227  		a.applyList(n, "List")
   228  
   229  	// Expressions
   230  	case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
   231  		// nothing to do
   232  
   233  	case *ast.Ellipsis:
   234  		a.apply(n, "Elt", nil, n.Elt)
   235  
   236  	case *ast.FuncLit:
   237  		a.apply(n, "Type", nil, n.Type)
   238  		a.apply(n, "Body", nil, n.Body)
   239  
   240  	case *ast.CompositeLit:
   241  		a.apply(n, "Type", nil, n.Type)
   242  		a.applyList(n, "Elts")
   243  
   244  	case *ast.ParenExpr:
   245  		a.apply(n, "X", nil, n.X)
   246  
   247  	case *ast.SelectorExpr:
   248  		a.apply(n, "X", nil, n.X)
   249  		a.apply(n, "Sel", nil, n.Sel)
   250  
   251  	case *ast.IndexExpr:
   252  		a.apply(n, "X", nil, n.X)
   253  		a.apply(n, "Index", nil, n.Index)
   254  
   255  	case *typeparams.IndexListExpr:
   256  		a.apply(n, "X", nil, n.X)
   257  		a.applyList(n, "Indices")
   258  
   259  	case *ast.SliceExpr:
   260  		a.apply(n, "X", nil, n.X)
   261  		a.apply(n, "Low", nil, n.Low)
   262  		a.apply(n, "High", nil, n.High)
   263  		a.apply(n, "Max", nil, n.Max)
   264  
   265  	case *ast.TypeAssertExpr:
   266  		a.apply(n, "X", nil, n.X)
   267  		a.apply(n, "Type", nil, n.Type)
   268  
   269  	case *ast.CallExpr:
   270  		a.apply(n, "Fun", nil, n.Fun)
   271  		a.applyList(n, "Args")
   272  
   273  	case *ast.StarExpr:
   274  		a.apply(n, "X", nil, n.X)
   275  
   276  	case *ast.UnaryExpr:
   277  		a.apply(n, "X", nil, n.X)
   278  
   279  	case *ast.BinaryExpr:
   280  		a.apply(n, "X", nil, n.X)
   281  		a.apply(n, "Y", nil, n.Y)
   282  
   283  	case *ast.KeyValueExpr:
   284  		a.apply(n, "Key", nil, n.Key)
   285  		a.apply(n, "Value", nil, n.Value)
   286  
   287  	// Types
   288  	case *ast.ArrayType:
   289  		a.apply(n, "Len", nil, n.Len)
   290  		a.apply(n, "Elt", nil, n.Elt)
   291  
   292  	case *ast.StructType:
   293  		a.apply(n, "Fields", nil, n.Fields)
   294  
   295  	case *ast.FuncType:
   296  		if tparams := typeparams.ForFuncType(n); tparams != nil {
   297  			a.apply(n, "TypeParams", nil, tparams)
   298  		}
   299  		a.apply(n, "Params", nil, n.Params)
   300  		a.apply(n, "Results", nil, n.Results)
   301  
   302  	case *ast.InterfaceType:
   303  		a.apply(n, "Methods", nil, n.Methods)
   304  
   305  	case *ast.MapType:
   306  		a.apply(n, "Key", nil, n.Key)
   307  		a.apply(n, "Value", nil, n.Value)
   308  
   309  	case *ast.ChanType:
   310  		a.apply(n, "Value", nil, n.Value)
   311  
   312  	// Statements
   313  	case *ast.BadStmt:
   314  		// nothing to do
   315  
   316  	case *ast.DeclStmt:
   317  		a.apply(n, "Decl", nil, n.Decl)
   318  
   319  	case *ast.EmptyStmt:
   320  		// nothing to do
   321  
   322  	case *ast.LabeledStmt:
   323  		a.apply(n, "Label", nil, n.Label)
   324  		a.apply(n, "Stmt", nil, n.Stmt)
   325  
   326  	case *ast.ExprStmt:
   327  		a.apply(n, "X", nil, n.X)
   328  
   329  	case *ast.SendStmt:
   330  		a.apply(n, "Chan", nil, n.Chan)
   331  		a.apply(n, "Value", nil, n.Value)
   332  
   333  	case *ast.IncDecStmt:
   334  		a.apply(n, "X", nil, n.X)
   335  
   336  	case *ast.AssignStmt:
   337  		a.applyList(n, "Lhs")
   338  		a.applyList(n, "Rhs")
   339  
   340  	case *ast.GoStmt:
   341  		a.apply(n, "Call", nil, n.Call)
   342  
   343  	case *ast.DeferStmt:
   344  		a.apply(n, "Call", nil, n.Call)
   345  
   346  	case *ast.ReturnStmt:
   347  		a.applyList(n, "Results")
   348  
   349  	case *ast.BranchStmt:
   350  		a.apply(n, "Label", nil, n.Label)
   351  
   352  	case *ast.BlockStmt:
   353  		a.applyList(n, "List")
   354  
   355  	case *ast.IfStmt:
   356  		a.apply(n, "Init", nil, n.Init)
   357  		a.apply(n, "Cond", nil, n.Cond)
   358  		a.apply(n, "Body", nil, n.Body)
   359  		a.apply(n, "Else", nil, n.Else)
   360  
   361  	case *ast.CaseClause:
   362  		a.applyList(n, "List")
   363  		a.applyList(n, "Body")
   364  
   365  	case *ast.SwitchStmt:
   366  		a.apply(n, "Init", nil, n.Init)
   367  		a.apply(n, "Tag", nil, n.Tag)
   368  		a.apply(n, "Body", nil, n.Body)
   369  
   370  	case *ast.TypeSwitchStmt:
   371  		a.apply(n, "Init", nil, n.Init)
   372  		a.apply(n, "Assign", nil, n.Assign)
   373  		a.apply(n, "Body", nil, n.Body)
   374  
   375  	case *ast.CommClause:
   376  		a.apply(n, "Comm", nil, n.Comm)
   377  		a.applyList(n, "Body")
   378  
   379  	case *ast.SelectStmt:
   380  		a.apply(n, "Body", nil, n.Body)
   381  
   382  	case *ast.ForStmt:
   383  		a.apply(n, "Init", nil, n.Init)
   384  		a.apply(n, "Cond", nil, n.Cond)
   385  		a.apply(n, "Post", nil, n.Post)
   386  		a.apply(n, "Body", nil, n.Body)
   387  
   388  	case *ast.RangeStmt:
   389  		a.apply(n, "Key", nil, n.Key)
   390  		a.apply(n, "Value", nil, n.Value)
   391  		a.apply(n, "X", nil, n.X)
   392  		a.apply(n, "Body", nil, n.Body)
   393  
   394  	// Declarations
   395  	case *ast.ImportSpec:
   396  		a.apply(n, "Doc", nil, n.Doc)
   397  		a.apply(n, "Name", nil, n.Name)
   398  		a.apply(n, "Path", nil, n.Path)
   399  		a.apply(n, "Comment", nil, n.Comment)
   400  
   401  	case *ast.ValueSpec:
   402  		a.apply(n, "Doc", nil, n.Doc)
   403  		a.applyList(n, "Names")
   404  		a.apply(n, "Type", nil, n.Type)
   405  		a.applyList(n, "Values")
   406  		a.apply(n, "Comment", nil, n.Comment)
   407  
   408  	case *ast.TypeSpec:
   409  		a.apply(n, "Doc", nil, n.Doc)
   410  		a.apply(n, "Name", nil, n.Name)
   411  		if tparams := typeparams.ForTypeSpec(n); tparams != nil {
   412  			a.apply(n, "TypeParams", nil, tparams)
   413  		}
   414  		a.apply(n, "Type", nil, n.Type)
   415  		a.apply(n, "Comment", nil, n.Comment)
   416  
   417  	case *ast.BadDecl:
   418  		// nothing to do
   419  
   420  	case *ast.GenDecl:
   421  		a.apply(n, "Doc", nil, n.Doc)
   422  		a.applyList(n, "Specs")
   423  
   424  	case *ast.FuncDecl:
   425  		a.apply(n, "Doc", nil, n.Doc)
   426  		a.apply(n, "Recv", nil, n.Recv)
   427  		a.apply(n, "Name", nil, n.Name)
   428  		a.apply(n, "Type", nil, n.Type)
   429  		a.apply(n, "Body", nil, n.Body)
   430  
   431  	// Files and packages
   432  	case *ast.File:
   433  		a.apply(n, "Doc", nil, n.Doc)
   434  		a.apply(n, "Name", nil, n.Name)
   435  		a.applyList(n, "Decls")
   436  		// Don't walk n.Comments; they have either been walked already if
   437  		// they are Doc comments, or they can be easily walked explicitly.
   438  
   439  	case *ast.Package:
   440  		// collect and sort names for reproducible behavior
   441  		var names []string
   442  		for name := range n.Files {
   443  			names = append(names, name)
   444  		}
   445  		sort.Strings(names)
   446  		for _, name := range names {
   447  			a.apply(n, name, nil, n.Files[name])
   448  		}
   449  
   450  	default:
   451  		panic(fmt.Sprintf("Apply: unexpected node type %T", n))
   452  	}
   453  
   454  	if a.post != nil && !a.post(&a.cursor) {
   455  		panic(abort)
   456  	}
   457  
   458  	a.cursor = saved
   459  }
   460  
   461  // An iterator controls iteration over a slice of nodes.
   462  type iterator struct {
   463  	index, step int
   464  }
   465  
   466  func (a *application) applyList(parent ast.Node, name string) {
   467  	// avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
   468  	saved := a.iter
   469  	a.iter.index = 0
   470  	for {
   471  		// must reload parent.name each time, since cursor modifications might change it
   472  		v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
   473  		if a.iter.index >= v.Len() {
   474  			break
   475  		}
   476  
   477  		// element x may be nil in a bad AST - be cautious
   478  		var x ast.Node
   479  		if e := v.Index(a.iter.index); e.IsValid() {
   480  			x = e.Interface().(ast.Node)
   481  		}
   482  
   483  		a.iter.step = 1
   484  		a.apply(parent, name, &a.iter, x)
   485  		a.iter.index += a.iter.step
   486  	}
   487  	a.iter = saved
   488  }
   489  

View as plain text