...

Source file src/text/template/funcs.go

Documentation: text/template

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package template
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net/url"
    13  	"reflect"
    14  	"strings"
    15  	"sync"
    16  	"unicode"
    17  	"unicode/utf8"
    18  )
    19  
    20  // FuncMap is the type of the map defining the mapping from names to functions.
    21  // Each function must have either a single return value, or two return values of
    22  // which the second has type error. In that case, if the second (error)
    23  // return value evaluates to non-nil during execution, execution terminates and
    24  // Execute returns that error.
    25  //
    26  // Errors returned by Execute wrap the underlying error; call errors.As to
    27  // uncover them.
    28  //
    29  // When template execution invokes a function with an argument list, that list
    30  // must be assignable to the function's parameter types. Functions meant to
    31  // apply to arguments of arbitrary type can use parameters of type interface{} or
    32  // of type reflect.Value. Similarly, functions meant to return a result of arbitrary
    33  // type can return interface{} or reflect.Value.
    34  type FuncMap map[string]any
    35  
    36  // builtins returns the FuncMap.
    37  // It is not a global variable so the linker can dead code eliminate
    38  // more when this isn't called. See golang.org/issue/36021.
    39  // TODO: revert this back to a global map once golang.org/issue/2559 is fixed.
    40  func builtins() FuncMap {
    41  	return FuncMap{
    42  		"and":      and,
    43  		"call":     call,
    44  		"html":     HTMLEscaper,
    45  		"index":    index,
    46  		"slice":    slice,
    47  		"js":       JSEscaper,
    48  		"len":      length,
    49  		"not":      not,
    50  		"or":       or,
    51  		"print":    fmt.Sprint,
    52  		"printf":   fmt.Sprintf,
    53  		"println":  fmt.Sprintln,
    54  		"urlquery": URLQueryEscaper,
    55  
    56  		// Comparisons
    57  		"eq": eq, // ==
    58  		"ge": ge, // >=
    59  		"gt": gt, // >
    60  		"le": le, // <=
    61  		"lt": lt, // <
    62  		"ne": ne, // !=
    63  	}
    64  }
    65  
    66  var builtinFuncsOnce struct {
    67  	sync.Once
    68  	v map[string]reflect.Value
    69  }
    70  
    71  // builtinFuncsOnce lazily computes & caches the builtinFuncs map.
    72  // TODO: revert this back to a global map once golang.org/issue/2559 is fixed.
    73  func builtinFuncs() map[string]reflect.Value {
    74  	builtinFuncsOnce.Do(func() {
    75  		builtinFuncsOnce.v = createValueFuncs(builtins())
    76  	})
    77  	return builtinFuncsOnce.v
    78  }
    79  
    80  // createValueFuncs turns a FuncMap into a map[string]reflect.Value
    81  func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
    82  	m := make(map[string]reflect.Value)
    83  	addValueFuncs(m, funcMap)
    84  	return m
    85  }
    86  
    87  // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
    88  func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
    89  	for name, fn := range in {
    90  		if !goodName(name) {
    91  			panic(fmt.Errorf("function name %q is not a valid identifier", name))
    92  		}
    93  		v := reflect.ValueOf(fn)
    94  		if v.Kind() != reflect.Func {
    95  			panic("value for " + name + " not a function")
    96  		}
    97  		if !goodFunc(v.Type()) {
    98  			panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
    99  		}
   100  		out[name] = v
   101  	}
   102  }
   103  
   104  // addFuncs adds to values the functions in funcs. It does no checking of the input -
   105  // call addValueFuncs first.
   106  func addFuncs(out, in FuncMap) {
   107  	for name, fn := range in {
   108  		out[name] = fn
   109  	}
   110  }
   111  
   112  // goodFunc reports whether the function or method has the right result signature.
   113  func goodFunc(typ reflect.Type) bool {
   114  	// We allow functions with 1 result or 2 results where the second is an error.
   115  	switch {
   116  	case typ.NumOut() == 1:
   117  		return true
   118  	case typ.NumOut() == 2 && typ.Out(1) == errorType:
   119  		return true
   120  	}
   121  	return false
   122  }
   123  
   124  // goodName reports whether the function name is a valid identifier.
   125  func goodName(name string) bool {
   126  	if name == "" {
   127  		return false
   128  	}
   129  	for i, r := range name {
   130  		switch {
   131  		case r == '_':
   132  		case i == 0 && !unicode.IsLetter(r):
   133  			return false
   134  		case !unicode.IsLetter(r) && !unicode.IsDigit(r):
   135  			return false
   136  		}
   137  	}
   138  	return true
   139  }
   140  
   141  // findFunction looks for a function in the template, and global map.
   142  func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
   143  	if tmpl != nil && tmpl.common != nil {
   144  		tmpl.muFuncs.RLock()
   145  		defer tmpl.muFuncs.RUnlock()
   146  		if fn := tmpl.execFuncs[name]; fn.IsValid() {
   147  			return fn, false, true
   148  		}
   149  	}
   150  	if fn := builtinFuncs()[name]; fn.IsValid() {
   151  		return fn, true, true
   152  	}
   153  	return reflect.Value{}, false, false
   154  }
   155  
   156  // prepareArg checks if value can be used as an argument of type argType, and
   157  // converts an invalid value to appropriate zero if possible.
   158  func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
   159  	if !value.IsValid() {
   160  		if !canBeNil(argType) {
   161  			return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
   162  		}
   163  		value = reflect.Zero(argType)
   164  	}
   165  	if value.Type().AssignableTo(argType) {
   166  		return value, nil
   167  	}
   168  	if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
   169  		value = value.Convert(argType)
   170  		return value, nil
   171  	}
   172  	return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
   173  }
   174  
   175  func intLike(typ reflect.Kind) bool {
   176  	switch typ {
   177  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   178  		return true
   179  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   180  		return true
   181  	}
   182  	return false
   183  }
   184  
   185  // indexArg checks if a reflect.Value can be used as an index, and converts it to int if possible.
   186  func indexArg(index reflect.Value, cap int) (int, error) {
   187  	var x int64
   188  	switch index.Kind() {
   189  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   190  		x = index.Int()
   191  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   192  		x = int64(index.Uint())
   193  	case reflect.Invalid:
   194  		return 0, fmt.Errorf("cannot index slice/array with nil")
   195  	default:
   196  		return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
   197  	}
   198  	if x < 0 || int(x) < 0 || int(x) > cap {
   199  		return 0, fmt.Errorf("index out of range: %d", x)
   200  	}
   201  	return int(x), nil
   202  }
   203  
   204  // Indexing.
   205  
   206  // index returns the result of indexing its first argument by the following
   207  // arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
   208  // indexed item must be a map, slice, or array.
   209  func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
   210  	item = indirectInterface(item)
   211  	if !item.IsValid() {
   212  		return reflect.Value{}, fmt.Errorf("index of untyped nil")
   213  	}
   214  	for _, index := range indexes {
   215  		index = indirectInterface(index)
   216  		var isNil bool
   217  		if item, isNil = indirect(item); isNil {
   218  			return reflect.Value{}, fmt.Errorf("index of nil pointer")
   219  		}
   220  		switch item.Kind() {
   221  		case reflect.Array, reflect.Slice, reflect.String:
   222  			x, err := indexArg(index, item.Len())
   223  			if err != nil {
   224  				return reflect.Value{}, err
   225  			}
   226  			item = item.Index(x)
   227  		case reflect.Map:
   228  			index, err := prepareArg(index, item.Type().Key())
   229  			if err != nil {
   230  				return reflect.Value{}, err
   231  			}
   232  			if x := item.MapIndex(index); x.IsValid() {
   233  				item = x
   234  			} else {
   235  				item = reflect.Zero(item.Type().Elem())
   236  			}
   237  		case reflect.Invalid:
   238  			// the loop holds invariant: item.IsValid()
   239  			panic("unreachable")
   240  		default:
   241  			return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
   242  		}
   243  	}
   244  	return item, nil
   245  }
   246  
   247  // Slicing.
   248  
   249  // slice returns the result of slicing its first argument by the remaining
   250  // arguments. Thus "slice x 1 2" is, in Go syntax, x[1:2], while "slice x"
   251  // is x[:], "slice x 1" is x[1:], and "slice x 1 2 3" is x[1:2:3]. The first
   252  // argument must be a string, slice, or array.
   253  func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
   254  	item = indirectInterface(item)
   255  	if !item.IsValid() {
   256  		return reflect.Value{}, fmt.Errorf("slice of untyped nil")
   257  	}
   258  	if len(indexes) > 3 {
   259  		return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
   260  	}
   261  	var cap int
   262  	switch item.Kind() {
   263  	case reflect.String:
   264  		if len(indexes) == 3 {
   265  			return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
   266  		}
   267  		cap = item.Len()
   268  	case reflect.Array, reflect.Slice:
   269  		cap = item.Cap()
   270  	default:
   271  		return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
   272  	}
   273  	// set default values for cases item[:], item[i:].
   274  	idx := [3]int{0, item.Len()}
   275  	for i, index := range indexes {
   276  		x, err := indexArg(index, cap)
   277  		if err != nil {
   278  			return reflect.Value{}, err
   279  		}
   280  		idx[i] = x
   281  	}
   282  	// given item[i:j], make sure i <= j.
   283  	if idx[0] > idx[1] {
   284  		return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
   285  	}
   286  	if len(indexes) < 3 {
   287  		return item.Slice(idx[0], idx[1]), nil
   288  	}
   289  	// given item[i:j:k], make sure i <= j <= k.
   290  	if idx[1] > idx[2] {
   291  		return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
   292  	}
   293  	return item.Slice3(idx[0], idx[1], idx[2]), nil
   294  }
   295  
   296  // Length
   297  
   298  // length returns the length of the item, with an error if it has no defined length.
   299  func length(item reflect.Value) (int, error) {
   300  	item, isNil := indirect(item)
   301  	if isNil {
   302  		return 0, fmt.Errorf("len of nil pointer")
   303  	}
   304  	switch item.Kind() {
   305  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
   306  		return item.Len(), nil
   307  	}
   308  	return 0, fmt.Errorf("len of type %s", item.Type())
   309  }
   310  
   311  // Function invocation
   312  
   313  // call returns the result of evaluating the first argument as a function.
   314  // The function must return 1 result, or 2 results, the second of which is an error.
   315  func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
   316  	fn = indirectInterface(fn)
   317  	if !fn.IsValid() {
   318  		return reflect.Value{}, fmt.Errorf("call of nil")
   319  	}
   320  	typ := fn.Type()
   321  	if typ.Kind() != reflect.Func {
   322  		return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
   323  	}
   324  	if !goodFunc(typ) {
   325  		return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
   326  	}
   327  	numIn := typ.NumIn()
   328  	var dddType reflect.Type
   329  	if typ.IsVariadic() {
   330  		if len(args) < numIn-1 {
   331  			return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
   332  		}
   333  		dddType = typ.In(numIn - 1).Elem()
   334  	} else {
   335  		if len(args) != numIn {
   336  			return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
   337  		}
   338  	}
   339  	argv := make([]reflect.Value, len(args))
   340  	for i, arg := range args {
   341  		arg = indirectInterface(arg)
   342  		// Compute the expected type. Clumsy because of variadics.
   343  		argType := dddType
   344  		if !typ.IsVariadic() || i < numIn-1 {
   345  			argType = typ.In(i)
   346  		}
   347  
   348  		var err error
   349  		if argv[i], err = prepareArg(arg, argType); err != nil {
   350  			return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
   351  		}
   352  	}
   353  	return safeCall(fn, argv)
   354  }
   355  
   356  // safeCall runs fun.Call(args), and returns the resulting value and error, if
   357  // any. If the call panics, the panic value is returned as an error.
   358  func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
   359  	defer func() {
   360  		if r := recover(); r != nil {
   361  			if e, ok := r.(error); ok {
   362  				err = e
   363  			} else {
   364  				err = fmt.Errorf("%v", r)
   365  			}
   366  		}
   367  	}()
   368  	ret := fun.Call(args)
   369  	if len(ret) == 2 && !ret[1].IsNil() {
   370  		return ret[0], ret[1].Interface().(error)
   371  	}
   372  	return ret[0], nil
   373  }
   374  
   375  // Boolean logic.
   376  
   377  func truth(arg reflect.Value) bool {
   378  	t, _ := isTrue(indirectInterface(arg))
   379  	return t
   380  }
   381  
   382  // and computes the Boolean AND of its arguments, returning
   383  // the first false argument it encounters, or the last argument.
   384  func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   385  	panic("unreachable") // implemented as a special case in evalCall
   386  }
   387  
   388  // or computes the Boolean OR of its arguments, returning
   389  // the first true argument it encounters, or the last argument.
   390  func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   391  	panic("unreachable") // implemented as a special case in evalCall
   392  }
   393  
   394  // not returns the Boolean negation of its argument.
   395  func not(arg reflect.Value) bool {
   396  	return !truth(arg)
   397  }
   398  
   399  // Comparison.
   400  
   401  // TODO: Perhaps allow comparison between signed and unsigned integers.
   402  
   403  var (
   404  	errBadComparisonType = errors.New("invalid type for comparison")
   405  	errBadComparison     = errors.New("incompatible types for comparison")
   406  	errNoComparison      = errors.New("missing argument for comparison")
   407  )
   408  
   409  type kind int
   410  
   411  const (
   412  	invalidKind kind = iota
   413  	boolKind
   414  	complexKind
   415  	intKind
   416  	floatKind
   417  	stringKind
   418  	uintKind
   419  )
   420  
   421  func basicKind(v reflect.Value) (kind, error) {
   422  	switch v.Kind() {
   423  	case reflect.Bool:
   424  		return boolKind, nil
   425  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   426  		return intKind, nil
   427  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   428  		return uintKind, nil
   429  	case reflect.Float32, reflect.Float64:
   430  		return floatKind, nil
   431  	case reflect.Complex64, reflect.Complex128:
   432  		return complexKind, nil
   433  	case reflect.String:
   434  		return stringKind, nil
   435  	}
   436  	return invalidKind, errBadComparisonType
   437  }
   438  
   439  // isNil returns true if v is the zero reflect.Value, or nil of its type.
   440  func isNil(v reflect.Value) bool {
   441  	if v == zero {
   442  		return true
   443  	}
   444  	switch v.Kind() {
   445  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
   446  		return v.IsNil()
   447  	}
   448  	return false
   449  }
   450  
   451  // canCompare reports whether v1 and v2 are both the same kind, or one is nil.
   452  // Called only when dealing with nillable types, or there's about to be an error.
   453  func canCompare(v1, v2 reflect.Value) bool {
   454  	k1 := v1.Kind()
   455  	k2 := v2.Kind()
   456  	if k1 == k2 {
   457  		return true
   458  	}
   459  	// We know the type can be compared to nil.
   460  	return k1 == reflect.Invalid || k2 == reflect.Invalid
   461  }
   462  
   463  // eq evaluates the comparison a == b || a == c || ...
   464  func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
   465  	arg1 = indirectInterface(arg1)
   466  	if len(arg2) == 0 {
   467  		return false, errNoComparison
   468  	}
   469  	k1, _ := basicKind(arg1)
   470  	for _, arg := range arg2 {
   471  		arg = indirectInterface(arg)
   472  		k2, _ := basicKind(arg)
   473  		truth := false
   474  		if k1 != k2 {
   475  			// Special case: Can compare integer values regardless of type's sign.
   476  			switch {
   477  			case k1 == intKind && k2 == uintKind:
   478  				truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
   479  			case k1 == uintKind && k2 == intKind:
   480  				truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
   481  			default:
   482  				if arg1 != zero && arg != zero {
   483  					return false, errBadComparison
   484  				}
   485  			}
   486  		} else {
   487  			switch k1 {
   488  			case boolKind:
   489  				truth = arg1.Bool() == arg.Bool()
   490  			case complexKind:
   491  				truth = arg1.Complex() == arg.Complex()
   492  			case floatKind:
   493  				truth = arg1.Float() == arg.Float()
   494  			case intKind:
   495  				truth = arg1.Int() == arg.Int()
   496  			case stringKind:
   497  				truth = arg1.String() == arg.String()
   498  			case uintKind:
   499  				truth = arg1.Uint() == arg.Uint()
   500  			default:
   501  				if !canCompare(arg1, arg) {
   502  					return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
   503  				}
   504  				if isNil(arg1) || isNil(arg) {
   505  					truth = isNil(arg) == isNil(arg1)
   506  				} else {
   507  					if !arg.Type().Comparable() {
   508  						return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
   509  					}
   510  					truth = arg1.Interface() == arg.Interface()
   511  				}
   512  			}
   513  		}
   514  		if truth {
   515  			return true, nil
   516  		}
   517  	}
   518  	return false, nil
   519  }
   520  
   521  // ne evaluates the comparison a != b.
   522  func ne(arg1, arg2 reflect.Value) (bool, error) {
   523  	// != is the inverse of ==.
   524  	equal, err := eq(arg1, arg2)
   525  	return !equal, err
   526  }
   527  
   528  // lt evaluates the comparison a < b.
   529  func lt(arg1, arg2 reflect.Value) (bool, error) {
   530  	arg1 = indirectInterface(arg1)
   531  	k1, err := basicKind(arg1)
   532  	if err != nil {
   533  		return false, err
   534  	}
   535  	arg2 = indirectInterface(arg2)
   536  	k2, err := basicKind(arg2)
   537  	if err != nil {
   538  		return false, err
   539  	}
   540  	truth := false
   541  	if k1 != k2 {
   542  		// Special case: Can compare integer values regardless of type's sign.
   543  		switch {
   544  		case k1 == intKind && k2 == uintKind:
   545  			truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
   546  		case k1 == uintKind && k2 == intKind:
   547  			truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
   548  		default:
   549  			return false, errBadComparison
   550  		}
   551  	} else {
   552  		switch k1 {
   553  		case boolKind, complexKind:
   554  			return false, errBadComparisonType
   555  		case floatKind:
   556  			truth = arg1.Float() < arg2.Float()
   557  		case intKind:
   558  			truth = arg1.Int() < arg2.Int()
   559  		case stringKind:
   560  			truth = arg1.String() < arg2.String()
   561  		case uintKind:
   562  			truth = arg1.Uint() < arg2.Uint()
   563  		default:
   564  			panic("invalid kind")
   565  		}
   566  	}
   567  	return truth, nil
   568  }
   569  
   570  // le evaluates the comparison <= b.
   571  func le(arg1, arg2 reflect.Value) (bool, error) {
   572  	// <= is < or ==.
   573  	lessThan, err := lt(arg1, arg2)
   574  	if lessThan || err != nil {
   575  		return lessThan, err
   576  	}
   577  	return eq(arg1, arg2)
   578  }
   579  
   580  // gt evaluates the comparison a > b.
   581  func gt(arg1, arg2 reflect.Value) (bool, error) {
   582  	// > is the inverse of <=.
   583  	lessOrEqual, err := le(arg1, arg2)
   584  	if err != nil {
   585  		return false, err
   586  	}
   587  	return !lessOrEqual, nil
   588  }
   589  
   590  // ge evaluates the comparison a >= b.
   591  func ge(arg1, arg2 reflect.Value) (bool, error) {
   592  	// >= is the inverse of <.
   593  	lessThan, err := lt(arg1, arg2)
   594  	if err != nil {
   595  		return false, err
   596  	}
   597  	return !lessThan, nil
   598  }
   599  
   600  // HTML escaping.
   601  
   602  var (
   603  	htmlQuot = []byte("&#34;") // shorter than "&quot;"
   604  	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
   605  	htmlAmp  = []byte("&amp;")
   606  	htmlLt   = []byte("&lt;")
   607  	htmlGt   = []byte("&gt;")
   608  	htmlNull = []byte("\uFFFD")
   609  )
   610  
   611  // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
   612  func HTMLEscape(w io.Writer, b []byte) {
   613  	last := 0
   614  	for i, c := range b {
   615  		var html []byte
   616  		switch c {
   617  		case '\000':
   618  			html = htmlNull
   619  		case '"':
   620  			html = htmlQuot
   621  		case '\'':
   622  			html = htmlApos
   623  		case '&':
   624  			html = htmlAmp
   625  		case '<':
   626  			html = htmlLt
   627  		case '>':
   628  			html = htmlGt
   629  		default:
   630  			continue
   631  		}
   632  		w.Write(b[last:i])
   633  		w.Write(html)
   634  		last = i + 1
   635  	}
   636  	w.Write(b[last:])
   637  }
   638  
   639  // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
   640  func HTMLEscapeString(s string) string {
   641  	// Avoid allocation if we can.
   642  	if !strings.ContainsAny(s, "'\"&<>\000") {
   643  		return s
   644  	}
   645  	var b bytes.Buffer
   646  	HTMLEscape(&b, []byte(s))
   647  	return b.String()
   648  }
   649  
   650  // HTMLEscaper returns the escaped HTML equivalent of the textual
   651  // representation of its arguments.
   652  func HTMLEscaper(args ...any) string {
   653  	return HTMLEscapeString(evalArgs(args))
   654  }
   655  
   656  // JavaScript escaping.
   657  
   658  var (
   659  	jsLowUni = []byte(`\u00`)
   660  	hex      = []byte("0123456789ABCDEF")
   661  
   662  	jsBackslash = []byte(`\\`)
   663  	jsApos      = []byte(`\'`)
   664  	jsQuot      = []byte(`\"`)
   665  	jsLt        = []byte(`\u003C`)
   666  	jsGt        = []byte(`\u003E`)
   667  	jsAmp       = []byte(`\u0026`)
   668  	jsEq        = []byte(`\u003D`)
   669  )
   670  
   671  // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
   672  func JSEscape(w io.Writer, b []byte) {
   673  	last := 0
   674  	for i := 0; i < len(b); i++ {
   675  		c := b[i]
   676  
   677  		if !jsIsSpecial(rune(c)) {
   678  			// fast path: nothing to do
   679  			continue
   680  		}
   681  		w.Write(b[last:i])
   682  
   683  		if c < utf8.RuneSelf {
   684  			// Quotes, slashes and angle brackets get quoted.
   685  			// Control characters get written as \u00XX.
   686  			switch c {
   687  			case '\\':
   688  				w.Write(jsBackslash)
   689  			case '\'':
   690  				w.Write(jsApos)
   691  			case '"':
   692  				w.Write(jsQuot)
   693  			case '<':
   694  				w.Write(jsLt)
   695  			case '>':
   696  				w.Write(jsGt)
   697  			case '&':
   698  				w.Write(jsAmp)
   699  			case '=':
   700  				w.Write(jsEq)
   701  			default:
   702  				w.Write(jsLowUni)
   703  				t, b := c>>4, c&0x0f
   704  				w.Write(hex[t : t+1])
   705  				w.Write(hex[b : b+1])
   706  			}
   707  		} else {
   708  			// Unicode rune.
   709  			r, size := utf8.DecodeRune(b[i:])
   710  			if unicode.IsPrint(r) {
   711  				w.Write(b[i : i+size])
   712  			} else {
   713  				fmt.Fprintf(w, "\\u%04X", r)
   714  			}
   715  			i += size - 1
   716  		}
   717  		last = i + 1
   718  	}
   719  	w.Write(b[last:])
   720  }
   721  
   722  // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
   723  func JSEscapeString(s string) string {
   724  	// Avoid allocation if we can.
   725  	if strings.IndexFunc(s, jsIsSpecial) < 0 {
   726  		return s
   727  	}
   728  	var b bytes.Buffer
   729  	JSEscape(&b, []byte(s))
   730  	return b.String()
   731  }
   732  
   733  func jsIsSpecial(r rune) bool {
   734  	switch r {
   735  	case '\\', '\'', '"', '<', '>', '&', '=':
   736  		return true
   737  	}
   738  	return r < ' ' || utf8.RuneSelf <= r
   739  }
   740  
   741  // JSEscaper returns the escaped JavaScript equivalent of the textual
   742  // representation of its arguments.
   743  func JSEscaper(args ...any) string {
   744  	return JSEscapeString(evalArgs(args))
   745  }
   746  
   747  // URLQueryEscaper returns the escaped value of the textual representation of
   748  // its arguments in a form suitable for embedding in a URL query.
   749  func URLQueryEscaper(args ...any) string {
   750  	return url.QueryEscape(evalArgs(args))
   751  }
   752  
   753  // evalArgs formats the list of arguments into a string. It is therefore equivalent to
   754  //
   755  //	fmt.Sprint(args...)
   756  //
   757  // except that each argument is indirected (if a pointer), as required,
   758  // using the same rules as the default string evaluation during template
   759  // execution.
   760  func evalArgs(args []any) string {
   761  	ok := false
   762  	var s string
   763  	// Fast path for simple common case.
   764  	if len(args) == 1 {
   765  		s, ok = args[0].(string)
   766  	}
   767  	if !ok {
   768  		for i, arg := range args {
   769  			a, ok := printableValue(reflect.ValueOf(arg))
   770  			if ok {
   771  				args[i] = a
   772  			} // else let fmt do its thing
   773  		}
   774  		s = fmt.Sprint(args...)
   775  	}
   776  	return s
   777  }
   778  

View as plain text