Home | History | Annotate | Download | only in template
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package template
      6 
      7 import (
      8 	"bytes"
      9 	"errors"
     10 	"fmt"
     11 	"io"
     12 	"net/url"
     13 	"reflect"
     14 	"strings"
     15 	"unicode"
     16 	"unicode/utf8"
     17 )
     18 
     19 // FuncMap is the type of the map defining the mapping from names to functions.
     20 // Each function must have either a single return value, or two return values of
     21 // which the second has type error. In that case, if the second (error)
     22 // return value evaluates to non-nil during execution, execution terminates and
     23 // Execute returns that error.
     24 //
     25 // When template execution invokes a function with an argument list, that list
     26 // must be assignable to the function's parameter types. Functions meant to
     27 // apply to arguments of arbitrary type can use parameters of type interface{} or
     28 // of type reflect.Value. Similarly, functions meant to return a result of arbitrary
     29 // type can return interface{} or reflect.Value.
     30 type FuncMap map[string]interface{}
     31 
     32 var builtins = FuncMap{
     33 	"and":      and,
     34 	"call":     call,
     35 	"html":     HTMLEscaper,
     36 	"index":    index,
     37 	"js":       JSEscaper,
     38 	"len":      length,
     39 	"not":      not,
     40 	"or":       or,
     41 	"print":    fmt.Sprint,
     42 	"printf":   fmt.Sprintf,
     43 	"println":  fmt.Sprintln,
     44 	"urlquery": URLQueryEscaper,
     45 
     46 	// Comparisons
     47 	"eq": eq, // ==
     48 	"ge": ge, // >=
     49 	"gt": gt, // >
     50 	"le": le, // <=
     51 	"lt": lt, // <
     52 	"ne": ne, // !=
     53 }
     54 
     55 var builtinFuncs = createValueFuncs(builtins)
     56 
     57 // createValueFuncs turns a FuncMap into a map[string]reflect.Value
     58 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
     59 	m := make(map[string]reflect.Value)
     60 	addValueFuncs(m, funcMap)
     61 	return m
     62 }
     63 
     64 // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
     65 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
     66 	for name, fn := range in {
     67 		if !goodName(name) {
     68 			panic(fmt.Errorf("function name %s is not a valid identifier", name))
     69 		}
     70 		v := reflect.ValueOf(fn)
     71 		if v.Kind() != reflect.Func {
     72 			panic("value for " + name + " not a function")
     73 		}
     74 		if !goodFunc(v.Type()) {
     75 			panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
     76 		}
     77 		out[name] = v
     78 	}
     79 }
     80 
     81 // addFuncs adds to values the functions in funcs. It does no checking of the input -
     82 // call addValueFuncs first.
     83 func addFuncs(out, in FuncMap) {
     84 	for name, fn := range in {
     85 		out[name] = fn
     86 	}
     87 }
     88 
     89 // goodFunc reports whether the function or method has the right result signature.
     90 func goodFunc(typ reflect.Type) bool {
     91 	// We allow functions with 1 result or 2 results where the second is an error.
     92 	switch {
     93 	case typ.NumOut() == 1:
     94 		return true
     95 	case typ.NumOut() == 2 && typ.Out(1) == errorType:
     96 		return true
     97 	}
     98 	return false
     99 }
    100 
    101 // goodName reports whether the function name is a valid identifier.
    102 func goodName(name string) bool {
    103 	if name == "" {
    104 		return false
    105 	}
    106 	for i, r := range name {
    107 		switch {
    108 		case r == '_':
    109 		case i == 0 && !unicode.IsLetter(r):
    110 			return false
    111 		case !unicode.IsLetter(r) && !unicode.IsDigit(r):
    112 			return false
    113 		}
    114 	}
    115 	return true
    116 }
    117 
    118 // findFunction looks for a function in the template, and global map.
    119 func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
    120 	if tmpl != nil && tmpl.common != nil {
    121 		tmpl.muFuncs.RLock()
    122 		defer tmpl.muFuncs.RUnlock()
    123 		if fn := tmpl.execFuncs[name]; fn.IsValid() {
    124 			return fn, true
    125 		}
    126 	}
    127 	if fn := builtinFuncs[name]; fn.IsValid() {
    128 		return fn, true
    129 	}
    130 	return reflect.Value{}, false
    131 }
    132 
    133 // prepareArg checks if value can be used as an argument of type argType, and
    134 // converts an invalid value to appropriate zero if possible.
    135 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
    136 	if !value.IsValid() {
    137 		if !canBeNil(argType) {
    138 			return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
    139 		}
    140 		value = reflect.Zero(argType)
    141 	}
    142 	if value.Type().AssignableTo(argType) {
    143 		return value, nil
    144 	}
    145 	if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
    146 		value = value.Convert(argType)
    147 		return value, nil
    148 	}
    149 	return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
    150 }
    151 
    152 func intLike(typ reflect.Kind) bool {
    153 	switch typ {
    154 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    155 		return true
    156 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    157 		return true
    158 	}
    159 	return false
    160 }
    161 
    162 // Indexing.
    163 
    164 // index returns the result of indexing its first argument by the following
    165 // arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
    166 // indexed item must be a map, slice, or array.
    167 func index(item reflect.Value, indices ...reflect.Value) (reflect.Value, error) {
    168 	v := indirectInterface(item)
    169 	if !v.IsValid() {
    170 		return reflect.Value{}, fmt.Errorf("index of untyped nil")
    171 	}
    172 	for _, i := range indices {
    173 		index := indirectInterface(i)
    174 		var isNil bool
    175 		if v, isNil = indirect(v); isNil {
    176 			return reflect.Value{}, fmt.Errorf("index of nil pointer")
    177 		}
    178 		switch v.Kind() {
    179 		case reflect.Array, reflect.Slice, reflect.String:
    180 			var x int64
    181 			switch index.Kind() {
    182 			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    183 				x = index.Int()
    184 			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    185 				x = int64(index.Uint())
    186 			case reflect.Invalid:
    187 				return reflect.Value{}, fmt.Errorf("cannot index slice/array with nil")
    188 			default:
    189 				return reflect.Value{}, fmt.Errorf("cannot index slice/array with type %s", index.Type())
    190 			}
    191 			if x < 0 || x >= int64(v.Len()) {
    192 				return reflect.Value{}, fmt.Errorf("index out of range: %d", x)
    193 			}
    194 			v = v.Index(int(x))
    195 		case reflect.Map:
    196 			index, err := prepareArg(index, v.Type().Key())
    197 			if err != nil {
    198 				return reflect.Value{}, err
    199 			}
    200 			if x := v.MapIndex(index); x.IsValid() {
    201 				v = x
    202 			} else {
    203 				v = reflect.Zero(v.Type().Elem())
    204 			}
    205 		case reflect.Invalid:
    206 			// the loop holds invariant: v.IsValid()
    207 			panic("unreachable")
    208 		default:
    209 			return reflect.Value{}, fmt.Errorf("can't index item of type %s", v.Type())
    210 		}
    211 	}
    212 	return v, nil
    213 }
    214 
    215 // Length
    216 
    217 // length returns the length of the item, with an error if it has no defined length.
    218 func length(item interface{}) (int, error) {
    219 	v := reflect.ValueOf(item)
    220 	if !v.IsValid() {
    221 		return 0, fmt.Errorf("len of untyped nil")
    222 	}
    223 	v, isNil := indirect(v)
    224 	if isNil {
    225 		return 0, fmt.Errorf("len of nil pointer")
    226 	}
    227 	switch v.Kind() {
    228 	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
    229 		return v.Len(), nil
    230 	}
    231 	return 0, fmt.Errorf("len of type %s", v.Type())
    232 }
    233 
    234 // Function invocation
    235 
    236 // call returns the result of evaluating the first argument as a function.
    237 // The function must return 1 result, or 2 results, the second of which is an error.
    238 func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
    239 	v := indirectInterface(fn)
    240 	if !v.IsValid() {
    241 		return reflect.Value{}, fmt.Errorf("call of nil")
    242 	}
    243 	typ := v.Type()
    244 	if typ.Kind() != reflect.Func {
    245 		return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
    246 	}
    247 	if !goodFunc(typ) {
    248 		return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
    249 	}
    250 	numIn := typ.NumIn()
    251 	var dddType reflect.Type
    252 	if typ.IsVariadic() {
    253 		if len(args) < numIn-1 {
    254 			return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
    255 		}
    256 		dddType = typ.In(numIn - 1).Elem()
    257 	} else {
    258 		if len(args) != numIn {
    259 			return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
    260 		}
    261 	}
    262 	argv := make([]reflect.Value, len(args))
    263 	for i, arg := range args {
    264 		value := indirectInterface(arg)
    265 		// Compute the expected type. Clumsy because of variadics.
    266 		var argType reflect.Type
    267 		if !typ.IsVariadic() || i < numIn-1 {
    268 			argType = typ.In(i)
    269 		} else {
    270 			argType = dddType
    271 		}
    272 
    273 		var err error
    274 		if argv[i], err = prepareArg(value, argType); err != nil {
    275 			return reflect.Value{}, fmt.Errorf("arg %d: %s", i, err)
    276 		}
    277 	}
    278 	result := v.Call(argv)
    279 	if len(result) == 2 && !result[1].IsNil() {
    280 		return result[0], result[1].Interface().(error)
    281 	}
    282 	return result[0], nil
    283 }
    284 
    285 // Boolean logic.
    286 
    287 func truth(arg reflect.Value) bool {
    288 	t, _ := isTrue(indirectInterface(arg))
    289 	return t
    290 }
    291 
    292 // and computes the Boolean AND of its arguments, returning
    293 // the first false argument it encounters, or the last argument.
    294 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
    295 	if !truth(arg0) {
    296 		return arg0
    297 	}
    298 	for i := range args {
    299 		arg0 = args[i]
    300 		if !truth(arg0) {
    301 			break
    302 		}
    303 	}
    304 	return arg0
    305 }
    306 
    307 // or computes the Boolean OR of its arguments, returning
    308 // the first true argument it encounters, or the last argument.
    309 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
    310 	if truth(arg0) {
    311 		return arg0
    312 	}
    313 	for i := range args {
    314 		arg0 = args[i]
    315 		if truth(arg0) {
    316 			break
    317 		}
    318 	}
    319 	return arg0
    320 }
    321 
    322 // not returns the Boolean negation of its argument.
    323 func not(arg reflect.Value) bool {
    324 	return !truth(arg)
    325 }
    326 
    327 // Comparison.
    328 
    329 // TODO: Perhaps allow comparison between signed and unsigned integers.
    330 
    331 var (
    332 	errBadComparisonType = errors.New("invalid type for comparison")
    333 	errBadComparison     = errors.New("incompatible types for comparison")
    334 	errNoComparison      = errors.New("missing argument for comparison")
    335 )
    336 
    337 type kind int
    338 
    339 const (
    340 	invalidKind kind = iota
    341 	boolKind
    342 	complexKind
    343 	intKind
    344 	floatKind
    345 	stringKind
    346 	uintKind
    347 )
    348 
    349 func basicKind(v reflect.Value) (kind, error) {
    350 	switch v.Kind() {
    351 	case reflect.Bool:
    352 		return boolKind, nil
    353 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    354 		return intKind, nil
    355 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    356 		return uintKind, nil
    357 	case reflect.Float32, reflect.Float64:
    358 		return floatKind, nil
    359 	case reflect.Complex64, reflect.Complex128:
    360 		return complexKind, nil
    361 	case reflect.String:
    362 		return stringKind, nil
    363 	}
    364 	return invalidKind, errBadComparisonType
    365 }
    366 
    367 // eq evaluates the comparison a == b || a == c || ...
    368 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
    369 	v1 := indirectInterface(arg1)
    370 	k1, err := basicKind(v1)
    371 	if err != nil {
    372 		return false, err
    373 	}
    374 	if len(arg2) == 0 {
    375 		return false, errNoComparison
    376 	}
    377 	for _, arg := range arg2 {
    378 		v2 := indirectInterface(arg)
    379 		k2, err := basicKind(v2)
    380 		if err != nil {
    381 			return false, err
    382 		}
    383 		truth := false
    384 		if k1 != k2 {
    385 			// Special case: Can compare integer values regardless of type's sign.
    386 			switch {
    387 			case k1 == intKind && k2 == uintKind:
    388 				truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
    389 			case k1 == uintKind && k2 == intKind:
    390 				truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
    391 			default:
    392 				return false, errBadComparison
    393 			}
    394 		} else {
    395 			switch k1 {
    396 			case boolKind:
    397 				truth = v1.Bool() == v2.Bool()
    398 			case complexKind:
    399 				truth = v1.Complex() == v2.Complex()
    400 			case floatKind:
    401 				truth = v1.Float() == v2.Float()
    402 			case intKind:
    403 				truth = v1.Int() == v2.Int()
    404 			case stringKind:
    405 				truth = v1.String() == v2.String()
    406 			case uintKind:
    407 				truth = v1.Uint() == v2.Uint()
    408 			default:
    409 				panic("invalid kind")
    410 			}
    411 		}
    412 		if truth {
    413 			return true, nil
    414 		}
    415 	}
    416 	return false, nil
    417 }
    418 
    419 // ne evaluates the comparison a != b.
    420 func ne(arg1, arg2 reflect.Value) (bool, error) {
    421 	// != is the inverse of ==.
    422 	equal, err := eq(arg1, arg2)
    423 	return !equal, err
    424 }
    425 
    426 // lt evaluates the comparison a < b.
    427 func lt(arg1, arg2 reflect.Value) (bool, error) {
    428 	v1 := indirectInterface(arg1)
    429 	k1, err := basicKind(v1)
    430 	if err != nil {
    431 		return false, err
    432 	}
    433 	v2 := indirectInterface(arg2)
    434 	k2, err := basicKind(v2)
    435 	if err != nil {
    436 		return false, err
    437 	}
    438 	truth := false
    439 	if k1 != k2 {
    440 		// Special case: Can compare integer values regardless of type's sign.
    441 		switch {
    442 		case k1 == intKind && k2 == uintKind:
    443 			truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
    444 		case k1 == uintKind && k2 == intKind:
    445 			truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
    446 		default:
    447 			return false, errBadComparison
    448 		}
    449 	} else {
    450 		switch k1 {
    451 		case boolKind, complexKind:
    452 			return false, errBadComparisonType
    453 		case floatKind:
    454 			truth = v1.Float() < v2.Float()
    455 		case intKind:
    456 			truth = v1.Int() < v2.Int()
    457 		case stringKind:
    458 			truth = v1.String() < v2.String()
    459 		case uintKind:
    460 			truth = v1.Uint() < v2.Uint()
    461 		default:
    462 			panic("invalid kind")
    463 		}
    464 	}
    465 	return truth, nil
    466 }
    467 
    468 // le evaluates the comparison <= b.
    469 func le(arg1, arg2 reflect.Value) (bool, error) {
    470 	// <= is < or ==.
    471 	lessThan, err := lt(arg1, arg2)
    472 	if lessThan || err != nil {
    473 		return lessThan, err
    474 	}
    475 	return eq(arg1, arg2)
    476 }
    477 
    478 // gt evaluates the comparison a > b.
    479 func gt(arg1, arg2 reflect.Value) (bool, error) {
    480 	// > is the inverse of <=.
    481 	lessOrEqual, err := le(arg1, arg2)
    482 	if err != nil {
    483 		return false, err
    484 	}
    485 	return !lessOrEqual, nil
    486 }
    487 
    488 // ge evaluates the comparison a >= b.
    489 func ge(arg1, arg2 reflect.Value) (bool, error) {
    490 	// >= is the inverse of <.
    491 	lessThan, err := lt(arg1, arg2)
    492 	if err != nil {
    493 		return false, err
    494 	}
    495 	return !lessThan, nil
    496 }
    497 
    498 // HTML escaping.
    499 
    500 var (
    501 	htmlQuot = []byte("&#34;") // shorter than "&quot;"
    502 	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
    503 	htmlAmp  = []byte("&amp;")
    504 	htmlLt   = []byte("&lt;")
    505 	htmlGt   = []byte("&gt;")
    506 	htmlNull = []byte("\uFFFD")
    507 )
    508 
    509 // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
    510 func HTMLEscape(w io.Writer, b []byte) {
    511 	last := 0
    512 	for i, c := range b {
    513 		var html []byte
    514 		switch c {
    515 		case '\000':
    516 			html = htmlNull
    517 		case '"':
    518 			html = htmlQuot
    519 		case '\'':
    520 			html = htmlApos
    521 		case '&':
    522 			html = htmlAmp
    523 		case '<':
    524 			html = htmlLt
    525 		case '>':
    526 			html = htmlGt
    527 		default:
    528 			continue
    529 		}
    530 		w.Write(b[last:i])
    531 		w.Write(html)
    532 		last = i + 1
    533 	}
    534 	w.Write(b[last:])
    535 }
    536 
    537 // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
    538 func HTMLEscapeString(s string) string {
    539 	// Avoid allocation if we can.
    540 	if !strings.ContainsAny(s, "'\"&<>\000") {
    541 		return s
    542 	}
    543 	var b bytes.Buffer
    544 	HTMLEscape(&b, []byte(s))
    545 	return b.String()
    546 }
    547 
    548 // HTMLEscaper returns the escaped HTML equivalent of the textual
    549 // representation of its arguments.
    550 func HTMLEscaper(args ...interface{}) string {
    551 	return HTMLEscapeString(evalArgs(args))
    552 }
    553 
    554 // JavaScript escaping.
    555 
    556 var (
    557 	jsLowUni = []byte(`\u00`)
    558 	hex      = []byte("0123456789ABCDEF")
    559 
    560 	jsBackslash = []byte(`\\`)
    561 	jsApos      = []byte(`\'`)
    562 	jsQuot      = []byte(`\"`)
    563 	jsLt        = []byte(`\x3C`)
    564 	jsGt        = []byte(`\x3E`)
    565 )
    566 
    567 // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
    568 func JSEscape(w io.Writer, b []byte) {
    569 	last := 0
    570 	for i := 0; i < len(b); i++ {
    571 		c := b[i]
    572 
    573 		if !jsIsSpecial(rune(c)) {
    574 			// fast path: nothing to do
    575 			continue
    576 		}
    577 		w.Write(b[last:i])
    578 
    579 		if c < utf8.RuneSelf {
    580 			// Quotes, slashes and angle brackets get quoted.
    581 			// Control characters get written as \u00XX.
    582 			switch c {
    583 			case '\\':
    584 				w.Write(jsBackslash)
    585 			case '\'':
    586 				w.Write(jsApos)
    587 			case '"':
    588 				w.Write(jsQuot)
    589 			case '<':
    590 				w.Write(jsLt)
    591 			case '>':
    592 				w.Write(jsGt)
    593 			default:
    594 				w.Write(jsLowUni)
    595 				t, b := c>>4, c&0x0f
    596 				w.Write(hex[t : t+1])
    597 				w.Write(hex[b : b+1])
    598 			}
    599 		} else {
    600 			// Unicode rune.
    601 			r, size := utf8.DecodeRune(b[i:])
    602 			if unicode.IsPrint(r) {
    603 				w.Write(b[i : i+size])
    604 			} else {
    605 				fmt.Fprintf(w, "\\u%04X", r)
    606 			}
    607 			i += size - 1
    608 		}
    609 		last = i + 1
    610 	}
    611 	w.Write(b[last:])
    612 }
    613 
    614 // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
    615 func JSEscapeString(s string) string {
    616 	// Avoid allocation if we can.
    617 	if strings.IndexFunc(s, jsIsSpecial) < 0 {
    618 		return s
    619 	}
    620 	var b bytes.Buffer
    621 	JSEscape(&b, []byte(s))
    622 	return b.String()
    623 }
    624 
    625 func jsIsSpecial(r rune) bool {
    626 	switch r {
    627 	case '\\', '\'', '"', '<', '>':
    628 		return true
    629 	}
    630 	return r < ' ' || utf8.RuneSelf <= r
    631 }
    632 
    633 // JSEscaper returns the escaped JavaScript equivalent of the textual
    634 // representation of its arguments.
    635 func JSEscaper(args ...interface{}) string {
    636 	return JSEscapeString(evalArgs(args))
    637 }
    638 
    639 // URLQueryEscaper returns the escaped value of the textual representation of
    640 // its arguments in a form suitable for embedding in a URL query.
    641 func URLQueryEscaper(args ...interface{}) string {
    642 	return url.QueryEscape(evalArgs(args))
    643 }
    644 
    645 // evalArgs formats the list of arguments into a string. It is therefore equivalent to
    646 //	fmt.Sprint(args...)
    647 // except that each argument is indirected (if a pointer), as required,
    648 // using the same rules as the default string evaluation during template
    649 // execution.
    650 func evalArgs(args []interface{}) string {
    651 	ok := false
    652 	var s string
    653 	// Fast path for simple common case.
    654 	if len(args) == 1 {
    655 		s, ok = args[0].(string)
    656 	}
    657 	if !ok {
    658 		for i, arg := range args {
    659 			a, ok := printableValue(reflect.ValueOf(arg))
    660 			if ok {
    661 				args[i] = a
    662 			} // else let fmt do its thing
    663 		}
    664 		s = fmt.Sprint(args...)
    665 	}
    666 	return s
    667 }
    668