...

Source file src/golang.org/x/tools/go/analysis/passes/sigchanyzer/sigchanyzer.go

Documentation: golang.org/x/tools/go/analysis/passes/sigchanyzer

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package sigchanyzer defines an Analyzer that detects
     6  // misuse of unbuffered signal as argument to signal.Notify.
     7  package sigchanyzer
     8  
     9  import (
    10  	"bytes"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/token"
    14  	"go/types"
    15  
    16  	"golang.org/x/tools/go/analysis"
    17  	"golang.org/x/tools/go/analysis/passes/inspect"
    18  	"golang.org/x/tools/go/ast/inspector"
    19  )
    20  
    21  const Doc = `check for unbuffered channel of os.Signal
    22  
    23  This checker reports call expression of the form signal.Notify(c <-chan os.Signal, sig ...os.Signal),
    24  where c is an unbuffered channel, which can be at risk of missing the signal.`
    25  
    26  // Analyzer describes sigchanyzer analysis function detector.
    27  var Analyzer = &analysis.Analyzer{
    28  	Name:     "sigchanyzer",
    29  	Doc:      Doc,
    30  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    31  	Run:      run,
    32  }
    33  
    34  func run(pass *analysis.Pass) (interface{}, error) {
    35  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    36  
    37  	nodeFilter := []ast.Node{
    38  		(*ast.CallExpr)(nil),
    39  	}
    40  	inspect.Preorder(nodeFilter, func(n ast.Node) {
    41  		call := n.(*ast.CallExpr)
    42  		if !isSignalNotify(pass.TypesInfo, call) {
    43  			return
    44  		}
    45  		var chanDecl *ast.CallExpr
    46  		switch arg := call.Args[0].(type) {
    47  		case *ast.Ident:
    48  			if decl, ok := findDecl(arg).(*ast.CallExpr); ok {
    49  				chanDecl = decl
    50  			}
    51  		case *ast.CallExpr:
    52  			// Only signal.Notify(make(chan os.Signal), os.Interrupt) is safe,
    53  			// conservatively treat others as not safe, see golang/go#45043
    54  			if isBuiltinMake(pass.TypesInfo, arg) {
    55  				return
    56  			}
    57  			chanDecl = arg
    58  		}
    59  		if chanDecl == nil || len(chanDecl.Args) != 1 {
    60  			return
    61  		}
    62  
    63  		// Make a copy of the channel's declaration to avoid
    64  		// mutating the AST. See https://golang.org/issue/46129.
    65  		chanDeclCopy := &ast.CallExpr{}
    66  		*chanDeclCopy = *chanDecl
    67  		chanDeclCopy.Args = append([]ast.Expr(nil), chanDecl.Args...)
    68  		chanDeclCopy.Args = append(chanDeclCopy.Args, &ast.BasicLit{
    69  			Kind:  token.INT,
    70  			Value: "1",
    71  		})
    72  
    73  		var buf bytes.Buffer
    74  		if err := format.Node(&buf, token.NewFileSet(), chanDeclCopy); err != nil {
    75  			return
    76  		}
    77  		pass.Report(analysis.Diagnostic{
    78  			Pos:     call.Pos(),
    79  			End:     call.End(),
    80  			Message: "misuse of unbuffered os.Signal channel as argument to signal.Notify",
    81  			SuggestedFixes: []analysis.SuggestedFix{{
    82  				Message: "Change to buffer channel",
    83  				TextEdits: []analysis.TextEdit{{
    84  					Pos:     chanDecl.Pos(),
    85  					End:     chanDecl.End(),
    86  					NewText: buf.Bytes(),
    87  				}},
    88  			}},
    89  		})
    90  	})
    91  	return nil, nil
    92  }
    93  
    94  func isSignalNotify(info *types.Info, call *ast.CallExpr) bool {
    95  	check := func(id *ast.Ident) bool {
    96  		obj := info.ObjectOf(id)
    97  		return obj.Name() == "Notify" && obj.Pkg().Path() == "os/signal"
    98  	}
    99  	switch fun := call.Fun.(type) {
   100  	case *ast.SelectorExpr:
   101  		return check(fun.Sel)
   102  	case *ast.Ident:
   103  		if fun, ok := findDecl(fun).(*ast.SelectorExpr); ok {
   104  			return check(fun.Sel)
   105  		}
   106  		return false
   107  	default:
   108  		return false
   109  	}
   110  }
   111  
   112  func findDecl(arg *ast.Ident) ast.Node {
   113  	if arg.Obj == nil {
   114  		return nil
   115  	}
   116  	switch as := arg.Obj.Decl.(type) {
   117  	case *ast.AssignStmt:
   118  		if len(as.Lhs) != len(as.Rhs) {
   119  			return nil
   120  		}
   121  		for i, lhs := range as.Lhs {
   122  			lid, ok := lhs.(*ast.Ident)
   123  			if !ok {
   124  				continue
   125  			}
   126  			if lid.Obj == arg.Obj {
   127  				return as.Rhs[i]
   128  			}
   129  		}
   130  	case *ast.ValueSpec:
   131  		if len(as.Names) != len(as.Values) {
   132  			return nil
   133  		}
   134  		for i, name := range as.Names {
   135  			if name.Obj == arg.Obj {
   136  				return as.Values[i]
   137  			}
   138  		}
   139  	}
   140  	return nil
   141  }
   142  
   143  func isBuiltinMake(info *types.Info, call *ast.CallExpr) bool {
   144  	typVal := info.Types[call.Fun]
   145  	if !typVal.IsBuiltin() {
   146  		return false
   147  	}
   148  	switch fun := call.Fun.(type) {
   149  	case *ast.Ident:
   150  		return info.ObjectOf(fun).Name() == "make"
   151  	default:
   152  		return false
   153  	}
   154  }
   155  

View as plain text