1
2
3
4
5 package template
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "net/url"
13 "reflect"
14 "strings"
15 "sync"
16 "unicode"
17 "unicode/utf8"
18 )
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 type FuncMap map[string]any
35
36
37
38
39
40 func builtins() FuncMap {
41 return FuncMap{
42 "and": and,
43 "call": call,
44 "html": HTMLEscaper,
45 "index": index,
46 "slice": slice,
47 "js": JSEscaper,
48 "len": length,
49 "not": not,
50 "or": or,
51 "print": fmt.Sprint,
52 "printf": fmt.Sprintf,
53 "println": fmt.Sprintln,
54 "urlquery": URLQueryEscaper,
55
56
57 "eq": eq,
58 "ge": ge,
59 "gt": gt,
60 "le": le,
61 "lt": lt,
62 "ne": ne,
63 }
64 }
65
66 var builtinFuncsOnce struct {
67 sync.Once
68 v map[string]reflect.Value
69 }
70
71
72
73 func builtinFuncs() map[string]reflect.Value {
74 builtinFuncsOnce.Do(func() {
75 builtinFuncsOnce.v = createValueFuncs(builtins())
76 })
77 return builtinFuncsOnce.v
78 }
79
80
81 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
82 m := make(map[string]reflect.Value)
83 addValueFuncs(m, funcMap)
84 return m
85 }
86
87
88 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
89 for name, fn := range in {
90 if !goodName(name) {
91 panic(fmt.Errorf("function name %q is not a valid identifier", name))
92 }
93 v := reflect.ValueOf(fn)
94 if v.Kind() != reflect.Func {
95 panic("value for " + name + " not a function")
96 }
97 if !goodFunc(v.Type()) {
98 panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
99 }
100 out[name] = v
101 }
102 }
103
104
105
106 func addFuncs(out, in FuncMap) {
107 for name, fn := range in {
108 out[name] = fn
109 }
110 }
111
112
113 func goodFunc(typ reflect.Type) bool {
114
115 switch {
116 case typ.NumOut() == 1:
117 return true
118 case typ.NumOut() == 2 && typ.Out(1) == errorType:
119 return true
120 }
121 return false
122 }
123
124
125 func goodName(name string) bool {
126 if name == "" {
127 return false
128 }
129 for i, r := range name {
130 switch {
131 case r == '_':
132 case i == 0 && !unicode.IsLetter(r):
133 return false
134 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
135 return false
136 }
137 }
138 return true
139 }
140
141
142 func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
143 if tmpl != nil && tmpl.common != nil {
144 tmpl.muFuncs.RLock()
145 defer tmpl.muFuncs.RUnlock()
146 if fn := tmpl.execFuncs[name]; fn.IsValid() {
147 return fn, false, true
148 }
149 }
150 if fn := builtinFuncs()[name]; fn.IsValid() {
151 return fn, true, true
152 }
153 return reflect.Value{}, false, false
154 }
155
156
157
158 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
159 if !value.IsValid() {
160 if !canBeNil(argType) {
161 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
162 }
163 value = reflect.Zero(argType)
164 }
165 if value.Type().AssignableTo(argType) {
166 return value, nil
167 }
168 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
169 value = value.Convert(argType)
170 return value, nil
171 }
172 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
173 }
174
175 func intLike(typ reflect.Kind) bool {
176 switch typ {
177 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
178 return true
179 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
180 return true
181 }
182 return false
183 }
184
185
186 func indexArg(index reflect.Value, cap int) (int, error) {
187 var x int64
188 switch index.Kind() {
189 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
190 x = index.Int()
191 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
192 x = int64(index.Uint())
193 case reflect.Invalid:
194 return 0, fmt.Errorf("cannot index slice/array with nil")
195 default:
196 return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
197 }
198 if x < 0 || int(x) < 0 || int(x) > cap {
199 return 0, fmt.Errorf("index out of range: %d", x)
200 }
201 return int(x), nil
202 }
203
204
205
206
207
208
209 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
210 item = indirectInterface(item)
211 if !item.IsValid() {
212 return reflect.Value{}, fmt.Errorf("index of untyped nil")
213 }
214 for _, index := range indexes {
215 index = indirectInterface(index)
216 var isNil bool
217 if item, isNil = indirect(item); isNil {
218 return reflect.Value{}, fmt.Errorf("index of nil pointer")
219 }
220 switch item.Kind() {
221 case reflect.Array, reflect.Slice, reflect.String:
222 x, err := indexArg(index, item.Len())
223 if err != nil {
224 return reflect.Value{}, err
225 }
226 item = item.Index(x)
227 case reflect.Map:
228 index, err := prepareArg(index, item.Type().Key())
229 if err != nil {
230 return reflect.Value{}, err
231 }
232 if x := item.MapIndex(index); x.IsValid() {
233 item = x
234 } else {
235 item = reflect.Zero(item.Type().Elem())
236 }
237 case reflect.Invalid:
238
239 panic("unreachable")
240 default:
241 return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
242 }
243 }
244 return item, nil
245 }
246
247
248
249
250
251
252
253 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
254 item = indirectInterface(item)
255 if !item.IsValid() {
256 return reflect.Value{}, fmt.Errorf("slice of untyped nil")
257 }
258 if len(indexes) > 3 {
259 return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
260 }
261 var cap int
262 switch item.Kind() {
263 case reflect.String:
264 if len(indexes) == 3 {
265 return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
266 }
267 cap = item.Len()
268 case reflect.Array, reflect.Slice:
269 cap = item.Cap()
270 default:
271 return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
272 }
273
274 idx := [3]int{0, item.Len()}
275 for i, index := range indexes {
276 x, err := indexArg(index, cap)
277 if err != nil {
278 return reflect.Value{}, err
279 }
280 idx[i] = x
281 }
282
283 if idx[0] > idx[1] {
284 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
285 }
286 if len(indexes) < 3 {
287 return item.Slice(idx[0], idx[1]), nil
288 }
289
290 if idx[1] > idx[2] {
291 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
292 }
293 return item.Slice3(idx[0], idx[1], idx[2]), nil
294 }
295
296
297
298
299 func length(item reflect.Value) (int, error) {
300 item, isNil := indirect(item)
301 if isNil {
302 return 0, fmt.Errorf("len of nil pointer")
303 }
304 switch item.Kind() {
305 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
306 return item.Len(), nil
307 }
308 return 0, fmt.Errorf("len of type %s", item.Type())
309 }
310
311
312
313
314
315 func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
316 fn = indirectInterface(fn)
317 if !fn.IsValid() {
318 return reflect.Value{}, fmt.Errorf("call of nil")
319 }
320 typ := fn.Type()
321 if typ.Kind() != reflect.Func {
322 return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
323 }
324 if !goodFunc(typ) {
325 return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
326 }
327 numIn := typ.NumIn()
328 var dddType reflect.Type
329 if typ.IsVariadic() {
330 if len(args) < numIn-1 {
331 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
332 }
333 dddType = typ.In(numIn - 1).Elem()
334 } else {
335 if len(args) != numIn {
336 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
337 }
338 }
339 argv := make([]reflect.Value, len(args))
340 for i, arg := range args {
341 arg = indirectInterface(arg)
342
343 argType := dddType
344 if !typ.IsVariadic() || i < numIn-1 {
345 argType = typ.In(i)
346 }
347
348 var err error
349 if argv[i], err = prepareArg(arg, argType); err != nil {
350 return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
351 }
352 }
353 return safeCall(fn, argv)
354 }
355
356
357
358 func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
359 defer func() {
360 if r := recover(); r != nil {
361 if e, ok := r.(error); ok {
362 err = e
363 } else {
364 err = fmt.Errorf("%v", r)
365 }
366 }
367 }()
368 ret := fun.Call(args)
369 if len(ret) == 2 && !ret[1].IsNil() {
370 return ret[0], ret[1].Interface().(error)
371 }
372 return ret[0], nil
373 }
374
375
376
377 func truth(arg reflect.Value) bool {
378 t, _ := isTrue(indirectInterface(arg))
379 return t
380 }
381
382
383
384 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
385 panic("unreachable")
386 }
387
388
389
390 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
391 panic("unreachable")
392 }
393
394
395 func not(arg reflect.Value) bool {
396 return !truth(arg)
397 }
398
399
400
401
402
403 var (
404 errBadComparisonType = errors.New("invalid type for comparison")
405 errBadComparison = errors.New("incompatible types for comparison")
406 errNoComparison = errors.New("missing argument for comparison")
407 )
408
409 type kind int
410
411 const (
412 invalidKind kind = iota
413 boolKind
414 complexKind
415 intKind
416 floatKind
417 stringKind
418 uintKind
419 )
420
421 func basicKind(v reflect.Value) (kind, error) {
422 switch v.Kind() {
423 case reflect.Bool:
424 return boolKind, nil
425 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
426 return intKind, nil
427 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
428 return uintKind, nil
429 case reflect.Float32, reflect.Float64:
430 return floatKind, nil
431 case reflect.Complex64, reflect.Complex128:
432 return complexKind, nil
433 case reflect.String:
434 return stringKind, nil
435 }
436 return invalidKind, errBadComparisonType
437 }
438
439
440 func isNil(v reflect.Value) bool {
441 if v == zero {
442 return true
443 }
444 switch v.Kind() {
445 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
446 return v.IsNil()
447 }
448 return false
449 }
450
451
452
453 func canCompare(v1, v2 reflect.Value) bool {
454 k1 := v1.Kind()
455 k2 := v2.Kind()
456 if k1 == k2 {
457 return true
458 }
459
460 return k1 == reflect.Invalid || k2 == reflect.Invalid
461 }
462
463
464 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
465 arg1 = indirectInterface(arg1)
466 if len(arg2) == 0 {
467 return false, errNoComparison
468 }
469 k1, _ := basicKind(arg1)
470 for _, arg := range arg2 {
471 arg = indirectInterface(arg)
472 k2, _ := basicKind(arg)
473 truth := false
474 if k1 != k2 {
475
476 switch {
477 case k1 == intKind && k2 == uintKind:
478 truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
479 case k1 == uintKind && k2 == intKind:
480 truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
481 default:
482 if arg1 != zero && arg != zero {
483 return false, errBadComparison
484 }
485 }
486 } else {
487 switch k1 {
488 case boolKind:
489 truth = arg1.Bool() == arg.Bool()
490 case complexKind:
491 truth = arg1.Complex() == arg.Complex()
492 case floatKind:
493 truth = arg1.Float() == arg.Float()
494 case intKind:
495 truth = arg1.Int() == arg.Int()
496 case stringKind:
497 truth = arg1.String() == arg.String()
498 case uintKind:
499 truth = arg1.Uint() == arg.Uint()
500 default:
501 if !canCompare(arg1, arg) {
502 return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
503 }
504 if isNil(arg1) || isNil(arg) {
505 truth = isNil(arg) == isNil(arg1)
506 } else {
507 if !arg.Type().Comparable() {
508 return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
509 }
510 truth = arg1.Interface() == arg.Interface()
511 }
512 }
513 }
514 if truth {
515 return true, nil
516 }
517 }
518 return false, nil
519 }
520
521
522 func ne(arg1, arg2 reflect.Value) (bool, error) {
523
524 equal, err := eq(arg1, arg2)
525 return !equal, err
526 }
527
528
529 func lt(arg1, arg2 reflect.Value) (bool, error) {
530 arg1 = indirectInterface(arg1)
531 k1, err := basicKind(arg1)
532 if err != nil {
533 return false, err
534 }
535 arg2 = indirectInterface(arg2)
536 k2, err := basicKind(arg2)
537 if err != nil {
538 return false, err
539 }
540 truth := false
541 if k1 != k2 {
542
543 switch {
544 case k1 == intKind && k2 == uintKind:
545 truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
546 case k1 == uintKind && k2 == intKind:
547 truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
548 default:
549 return false, errBadComparison
550 }
551 } else {
552 switch k1 {
553 case boolKind, complexKind:
554 return false, errBadComparisonType
555 case floatKind:
556 truth = arg1.Float() < arg2.Float()
557 case intKind:
558 truth = arg1.Int() < arg2.Int()
559 case stringKind:
560 truth = arg1.String() < arg2.String()
561 case uintKind:
562 truth = arg1.Uint() < arg2.Uint()
563 default:
564 panic("invalid kind")
565 }
566 }
567 return truth, nil
568 }
569
570
571 func le(arg1, arg2 reflect.Value) (bool, error) {
572
573 lessThan, err := lt(arg1, arg2)
574 if lessThan || err != nil {
575 return lessThan, err
576 }
577 return eq(arg1, arg2)
578 }
579
580
581 func gt(arg1, arg2 reflect.Value) (bool, error) {
582
583 lessOrEqual, err := le(arg1, arg2)
584 if err != nil {
585 return false, err
586 }
587 return !lessOrEqual, nil
588 }
589
590
591 func ge(arg1, arg2 reflect.Value) (bool, error) {
592
593 lessThan, err := lt(arg1, arg2)
594 if err != nil {
595 return false, err
596 }
597 return !lessThan, nil
598 }
599
600
601
602 var (
603 htmlQuot = []byte(""")
604 htmlApos = []byte("'")
605 htmlAmp = []byte("&")
606 htmlLt = []byte("<")
607 htmlGt = []byte(">")
608 htmlNull = []byte("\uFFFD")
609 )
610
611
612 func HTMLEscape(w io.Writer, b []byte) {
613 last := 0
614 for i, c := range b {
615 var html []byte
616 switch c {
617 case '\000':
618 html = htmlNull
619 case '"':
620 html = htmlQuot
621 case '\'':
622 html = htmlApos
623 case '&':
624 html = htmlAmp
625 case '<':
626 html = htmlLt
627 case '>':
628 html = htmlGt
629 default:
630 continue
631 }
632 w.Write(b[last:i])
633 w.Write(html)
634 last = i + 1
635 }
636 w.Write(b[last:])
637 }
638
639
640 func HTMLEscapeString(s string) string {
641
642 if !strings.ContainsAny(s, "'\"&<>\000") {
643 return s
644 }
645 var b bytes.Buffer
646 HTMLEscape(&b, []byte(s))
647 return b.String()
648 }
649
650
651
652 func HTMLEscaper(args ...any) string {
653 return HTMLEscapeString(evalArgs(args))
654 }
655
656
657
658 var (
659 jsLowUni = []byte(`\u00`)
660 hex = []byte("0123456789ABCDEF")
661
662 jsBackslash = []byte(`\\`)
663 jsApos = []byte(`\'`)
664 jsQuot = []byte(`\"`)
665 jsLt = []byte(`\u003C`)
666 jsGt = []byte(`\u003E`)
667 jsAmp = []byte(`\u0026`)
668 jsEq = []byte(`\u003D`)
669 )
670
671
672 func JSEscape(w io.Writer, b []byte) {
673 last := 0
674 for i := 0; i < len(b); i++ {
675 c := b[i]
676
677 if !jsIsSpecial(rune(c)) {
678
679 continue
680 }
681 w.Write(b[last:i])
682
683 if c < utf8.RuneSelf {
684
685
686 switch c {
687 case '\\':
688 w.Write(jsBackslash)
689 case '\'':
690 w.Write(jsApos)
691 case '"':
692 w.Write(jsQuot)
693 case '<':
694 w.Write(jsLt)
695 case '>':
696 w.Write(jsGt)
697 case '&':
698 w.Write(jsAmp)
699 case '=':
700 w.Write(jsEq)
701 default:
702 w.Write(jsLowUni)
703 t, b := c>>4, c&0x0f
704 w.Write(hex[t : t+1])
705 w.Write(hex[b : b+1])
706 }
707 } else {
708
709 r, size := utf8.DecodeRune(b[i:])
710 if unicode.IsPrint(r) {
711 w.Write(b[i : i+size])
712 } else {
713 fmt.Fprintf(w, "\\u%04X", r)
714 }
715 i += size - 1
716 }
717 last = i + 1
718 }
719 w.Write(b[last:])
720 }
721
722
723 func JSEscapeString(s string) string {
724
725 if strings.IndexFunc(s, jsIsSpecial) < 0 {
726 return s
727 }
728 var b bytes.Buffer
729 JSEscape(&b, []byte(s))
730 return b.String()
731 }
732
733 func jsIsSpecial(r rune) bool {
734 switch r {
735 case '\\', '\'', '"', '<', '>', '&', '=':
736 return true
737 }
738 return r < ' ' || utf8.RuneSelf <= r
739 }
740
741
742
743 func JSEscaper(args ...any) string {
744 return JSEscapeString(evalArgs(args))
745 }
746
747
748
749 func URLQueryEscaper(args ...any) string {
750 return url.QueryEscape(evalArgs(args))
751 }
752
753
754
755
756
757
758
759
760 func evalArgs(args []any) string {
761 ok := false
762 var s string
763
764 if len(args) == 1 {
765 s, ok = args[0].(string)
766 }
767 if !ok {
768 for i, arg := range args {
769 a, ok := printableValue(reflect.ValueOf(arg))
770 if ok {
771 args[i] = a
772 }
773 }
774 s = fmt.Sprint(args...)
775 }
776 return s
777 }
778
View as plain text