Home | History | Annotate | Download | only in template
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package template
      6 
      7 import (
      8 	"bytes"
      9 	"errors"
     10 	"fmt"
     11 	"io"
     12 	"net/url"
     13 	"reflect"
     14 	"strings"
     15 	"unicode"
     16 	"unicode/utf8"
     17 )
     18 
     19 // FuncMap is the type of the map defining the mapping from names to functions.
     20 // Each function must have either a single return value, or two return values of
     21 // which the second has type error. In that case, if the second (error)
     22 // return value evaluates to non-nil during execution, execution terminates and
     23 // Execute returns that error.
     24 //
     25 // When template execution invokes a function with an argument list, that list
     26 // must be assignable to the function's parameter types. Functions meant to
     27 // apply to arguments of arbitrary type can use parameters of type interface{} or
     28 // of type reflect.Value. Similarly, functions meant to return a result of arbitrary
     29 // type can return interface{} or reflect.Value.
     30 type FuncMap map[string]interface{}
     31 
     32 var builtins = FuncMap{
     33 	"and":      and,
     34 	"call":     call,
     35 	"html":     HTMLEscaper,
     36 	"index":    index,
     37 	"js":       JSEscaper,
     38 	"len":      length,
     39 	"not":      not,
     40 	"or":       or,
     41 	"print":    fmt.Sprint,
     42 	"printf":   fmt.Sprintf,
     43 	"println":  fmt.Sprintln,
     44 	"urlquery": URLQueryEscaper,
     45 
     46 	// Comparisons
     47 	"eq": eq, // ==
     48 	"ge": ge, // >=
     49 	"gt": gt, // >
     50 	"le": le, // <=
     51 	"lt": lt, // <
     52 	"ne": ne, // !=
     53 }
     54 
     55 var builtinFuncs = createValueFuncs(builtins)
     56 
     57 // createValueFuncs turns a FuncMap into a map[string]reflect.Value
     58 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
     59 	m := make(map[string]reflect.Value)
     60 	addValueFuncs(m, funcMap)
     61 	return m
     62 }
     63 
     64 // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
     65 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
     66 	for name, fn := range in {
     67 		if !goodName(name) {
     68 			panic(fmt.Errorf("function name %s is not a valid identifier", name))
     69 		}
     70 		v := reflect.ValueOf(fn)
     71 		if v.Kind() != reflect.Func {
     72 			panic("value for " + name + " not a function")
     73 		}
     74 		if !goodFunc(v.Type()) {
     75 			panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
     76 		}
     77 		out[name] = v
     78 	}
     79 }
     80 
     81 // addFuncs adds to values the functions in funcs. It does no checking of the input -
     82 // call addValueFuncs first.
     83 func addFuncs(out, in FuncMap) {
     84 	for name, fn := range in {
     85 		out[name] = fn
     86 	}
     87 }
     88 
     89 // goodFunc reports whether the function or method has the right result signature.
     90 func goodFunc(typ reflect.Type) bool {
     91 	// We allow functions with 1 result or 2 results where the second is an error.
     92 	switch {
     93 	case typ.NumOut() == 1:
     94 		return true
     95 	case typ.NumOut() == 2 && typ.Out(1) == errorType:
     96 		return true
     97 	}
     98 	return false
     99 }
    100 
    101 // goodName reports whether the function name is a valid identifier.
    102 func goodName(name string) bool {
    103 	if name == "" {
    104 		return false
    105 	}
    106 	for i, r := range name {
    107 		switch {
    108 		case r == '_':
    109 		case i == 0 && !unicode.IsLetter(r):
    110 			return false
    111 		case !unicode.IsLetter(r) && !unicode.IsDigit(r):
    112 			return false
    113 		}
    114 	}
    115 	return true
    116 }
    117 
    118 // findFunction looks for a function in the template, and global map.
    119 func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
    120 	if tmpl != nil && tmpl.common != nil {
    121 		tmpl.muFuncs.RLock()
    122 		defer tmpl.muFuncs.RUnlock()
    123 		if fn := tmpl.execFuncs[name]; fn.IsValid() {
    124 			return fn, true
    125 		}
    126 	}
    127 	if fn := builtinFuncs[name]; fn.IsValid() {
    128 		return fn, true
    129 	}
    130 	return reflect.Value{}, false
    131 }
    132 
    133 // prepareArg checks if value can be used as an argument of type argType, and
    134 // converts an invalid value to appropriate zero if possible.
    135 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
    136 	if !value.IsValid() {
    137 		if !canBeNil(argType) {
    138 			return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
    139 		}
    140 		value = reflect.Zero(argType)
    141 	}
    142 	if !value.Type().AssignableTo(argType) {
    143 		return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
    144 	}
    145 	return value, nil
    146 }
    147 
    148 // Indexing.
    149 
    150 // index returns the result of indexing its first argument by the following
    151 // arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
    152 // indexed item must be a map, slice, or array.
    153 func index(item reflect.Value, indices ...reflect.Value) (reflect.Value, error) {
    154 	v := indirectInterface(item)
    155 	if !v.IsValid() {
    156 		return reflect.Value{}, fmt.Errorf("index of untyped nil")
    157 	}
    158 	for _, i := range indices {
    159 		index := indirectInterface(i)
    160 		var isNil bool
    161 		if v, isNil = indirect(v); isNil {
    162 			return reflect.Value{}, fmt.Errorf("index of nil pointer")
    163 		}
    164 		switch v.Kind() {
    165 		case reflect.Array, reflect.Slice, reflect.String:
    166 			var x int64
    167 			switch index.Kind() {
    168 			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    169 				x = index.Int()
    170 			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    171 				x = int64(index.Uint())
    172 			case reflect.Invalid:
    173 				return reflect.Value{}, fmt.Errorf("cannot index slice/array with nil")
    174 			default:
    175 				return reflect.Value{}, fmt.Errorf("cannot index slice/array with type %s", index.Type())
    176 			}
    177 			if x < 0 || x >= int64(v.Len()) {
    178 				return reflect.Value{}, fmt.Errorf("index out of range: %d", x)
    179 			}
    180 			v = v.Index(int(x))
    181 		case reflect.Map:
    182 			index, err := prepareArg(index, v.Type().Key())
    183 			if err != nil {
    184 				return reflect.Value{}, err
    185 			}
    186 			if x := v.MapIndex(index); x.IsValid() {
    187 				v = x
    188 			} else {
    189 				v = reflect.Zero(v.Type().Elem())
    190 			}
    191 		case reflect.Invalid:
    192 			// the loop holds invariant: v.IsValid()
    193 			panic("unreachable")
    194 		default:
    195 			return reflect.Value{}, fmt.Errorf("can't index item of type %s", v.Type())
    196 		}
    197 	}
    198 	return v, nil
    199 }
    200 
    201 // Length
    202 
    203 // length returns the length of the item, with an error if it has no defined length.
    204 func length(item interface{}) (int, error) {
    205 	v := reflect.ValueOf(item)
    206 	if !v.IsValid() {
    207 		return 0, fmt.Errorf("len of untyped nil")
    208 	}
    209 	v, isNil := indirect(v)
    210 	if isNil {
    211 		return 0, fmt.Errorf("len of nil pointer")
    212 	}
    213 	switch v.Kind() {
    214 	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
    215 		return v.Len(), nil
    216 	}
    217 	return 0, fmt.Errorf("len of type %s", v.Type())
    218 }
    219 
    220 // Function invocation
    221 
    222 // call returns the result of evaluating the first argument as a function.
    223 // The function must return 1 result, or 2 results, the second of which is an error.
    224 func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
    225 	v := indirectInterface(fn)
    226 	if !v.IsValid() {
    227 		return reflect.Value{}, fmt.Errorf("call of nil")
    228 	}
    229 	typ := v.Type()
    230 	if typ.Kind() != reflect.Func {
    231 		return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
    232 	}
    233 	if !goodFunc(typ) {
    234 		return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
    235 	}
    236 	numIn := typ.NumIn()
    237 	var dddType reflect.Type
    238 	if typ.IsVariadic() {
    239 		if len(args) < numIn-1 {
    240 			return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
    241 		}
    242 		dddType = typ.In(numIn - 1).Elem()
    243 	} else {
    244 		if len(args) != numIn {
    245 			return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
    246 		}
    247 	}
    248 	argv := make([]reflect.Value, len(args))
    249 	for i, arg := range args {
    250 		value := indirectInterface(arg)
    251 		// Compute the expected type. Clumsy because of variadics.
    252 		var argType reflect.Type
    253 		if !typ.IsVariadic() || i < numIn-1 {
    254 			argType = typ.In(i)
    255 		} else {
    256 			argType = dddType
    257 		}
    258 
    259 		var err error
    260 		if argv[i], err = prepareArg(value, argType); err != nil {
    261 			return reflect.Value{}, fmt.Errorf("arg %d: %s", i, err)
    262 		}
    263 	}
    264 	result := v.Call(argv)
    265 	if len(result) == 2 && !result[1].IsNil() {
    266 		return result[0], result[1].Interface().(error)
    267 	}
    268 	return result[0], nil
    269 }
    270 
    271 // Boolean logic.
    272 
    273 func truth(arg reflect.Value) bool {
    274 	t, _ := isTrue(indirectInterface(arg))
    275 	return t
    276 }
    277 
    278 // and computes the Boolean AND of its arguments, returning
    279 // the first false argument it encounters, or the last argument.
    280 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
    281 	if !truth(arg0) {
    282 		return arg0
    283 	}
    284 	for i := range args {
    285 		arg0 = args[i]
    286 		if !truth(arg0) {
    287 			break
    288 		}
    289 	}
    290 	return arg0
    291 }
    292 
    293 // or computes the Boolean OR of its arguments, returning
    294 // the first true argument it encounters, or the last argument.
    295 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
    296 	if truth(arg0) {
    297 		return arg0
    298 	}
    299 	for i := range args {
    300 		arg0 = args[i]
    301 		if truth(arg0) {
    302 			break
    303 		}
    304 	}
    305 	return arg0
    306 }
    307 
    308 // not returns the Boolean negation of its argument.
    309 func not(arg reflect.Value) bool {
    310 	return !truth(arg)
    311 }
    312 
    313 // Comparison.
    314 
    315 // TODO: Perhaps allow comparison between signed and unsigned integers.
    316 
    317 var (
    318 	errBadComparisonType = errors.New("invalid type for comparison")
    319 	errBadComparison     = errors.New("incompatible types for comparison")
    320 	errNoComparison      = errors.New("missing argument for comparison")
    321 )
    322 
    323 type kind int
    324 
    325 const (
    326 	invalidKind kind = iota
    327 	boolKind
    328 	complexKind
    329 	intKind
    330 	floatKind
    331 	stringKind
    332 	uintKind
    333 )
    334 
    335 func basicKind(v reflect.Value) (kind, error) {
    336 	switch v.Kind() {
    337 	case reflect.Bool:
    338 		return boolKind, nil
    339 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    340 		return intKind, nil
    341 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    342 		return uintKind, nil
    343 	case reflect.Float32, reflect.Float64:
    344 		return floatKind, nil
    345 	case reflect.Complex64, reflect.Complex128:
    346 		return complexKind, nil
    347 	case reflect.String:
    348 		return stringKind, nil
    349 	}
    350 	return invalidKind, errBadComparisonType
    351 }
    352 
    353 // eq evaluates the comparison a == b || a == c || ...
    354 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
    355 	v1 := indirectInterface(arg1)
    356 	k1, err := basicKind(v1)
    357 	if err != nil {
    358 		return false, err
    359 	}
    360 	if len(arg2) == 0 {
    361 		return false, errNoComparison
    362 	}
    363 	for _, arg := range arg2 {
    364 		v2 := indirectInterface(arg)
    365 		k2, err := basicKind(v2)
    366 		if err != nil {
    367 			return false, err
    368 		}
    369 		truth := false
    370 		if k1 != k2 {
    371 			// Special case: Can compare integer values regardless of type's sign.
    372 			switch {
    373 			case k1 == intKind && k2 == uintKind:
    374 				truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
    375 			case k1 == uintKind && k2 == intKind:
    376 				truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
    377 			default:
    378 				return false, errBadComparison
    379 			}
    380 		} else {
    381 			switch k1 {
    382 			case boolKind:
    383 				truth = v1.Bool() == v2.Bool()
    384 			case complexKind:
    385 				truth = v1.Complex() == v2.Complex()
    386 			case floatKind:
    387 				truth = v1.Float() == v2.Float()
    388 			case intKind:
    389 				truth = v1.Int() == v2.Int()
    390 			case stringKind:
    391 				truth = v1.String() == v2.String()
    392 			case uintKind:
    393 				truth = v1.Uint() == v2.Uint()
    394 			default:
    395 				panic("invalid kind")
    396 			}
    397 		}
    398 		if truth {
    399 			return true, nil
    400 		}
    401 	}
    402 	return false, nil
    403 }
    404 
    405 // ne evaluates the comparison a != b.
    406 func ne(arg1, arg2 reflect.Value) (bool, error) {
    407 	// != is the inverse of ==.
    408 	equal, err := eq(arg1, arg2)
    409 	return !equal, err
    410 }
    411 
    412 // lt evaluates the comparison a < b.
    413 func lt(arg1, arg2 reflect.Value) (bool, error) {
    414 	v1 := indirectInterface(arg1)
    415 	k1, err := basicKind(v1)
    416 	if err != nil {
    417 		return false, err
    418 	}
    419 	v2 := indirectInterface(arg2)
    420 	k2, err := basicKind(v2)
    421 	if err != nil {
    422 		return false, err
    423 	}
    424 	truth := false
    425 	if k1 != k2 {
    426 		// Special case: Can compare integer values regardless of type's sign.
    427 		switch {
    428 		case k1 == intKind && k2 == uintKind:
    429 			truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
    430 		case k1 == uintKind && k2 == intKind:
    431 			truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
    432 		default:
    433 			return false, errBadComparison
    434 		}
    435 	} else {
    436 		switch k1 {
    437 		case boolKind, complexKind:
    438 			return false, errBadComparisonType
    439 		case floatKind:
    440 			truth = v1.Float() < v2.Float()
    441 		case intKind:
    442 			truth = v1.Int() < v2.Int()
    443 		case stringKind:
    444 			truth = v1.String() < v2.String()
    445 		case uintKind:
    446 			truth = v1.Uint() < v2.Uint()
    447 		default:
    448 			panic("invalid kind")
    449 		}
    450 	}
    451 	return truth, nil
    452 }
    453 
    454 // le evaluates the comparison <= b.
    455 func le(arg1, arg2 reflect.Value) (bool, error) {
    456 	// <= is < or ==.
    457 	lessThan, err := lt(arg1, arg2)
    458 	if lessThan || err != nil {
    459 		return lessThan, err
    460 	}
    461 	return eq(arg1, arg2)
    462 }
    463 
    464 // gt evaluates the comparison a > b.
    465 func gt(arg1, arg2 reflect.Value) (bool, error) {
    466 	// > is the inverse of <=.
    467 	lessOrEqual, err := le(arg1, arg2)
    468 	if err != nil {
    469 		return false, err
    470 	}
    471 	return !lessOrEqual, nil
    472 }
    473 
    474 // ge evaluates the comparison a >= b.
    475 func ge(arg1, arg2 reflect.Value) (bool, error) {
    476 	// >= is the inverse of <.
    477 	lessThan, err := lt(arg1, arg2)
    478 	if err != nil {
    479 		return false, err
    480 	}
    481 	return !lessThan, nil
    482 }
    483 
    484 // HTML escaping.
    485 
    486 var (
    487 	htmlQuot = []byte("&#34;") // shorter than "&quot;"
    488 	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
    489 	htmlAmp  = []byte("&amp;")
    490 	htmlLt   = []byte("&lt;")
    491 	htmlGt   = []byte("&gt;")
    492 )
    493 
    494 // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
    495 func HTMLEscape(w io.Writer, b []byte) {
    496 	last := 0
    497 	for i, c := range b {
    498 		var html []byte
    499 		switch c {
    500 		case '"':
    501 			html = htmlQuot
    502 		case '\'':
    503 			html = htmlApos
    504 		case '&':
    505 			html = htmlAmp
    506 		case '<':
    507 			html = htmlLt
    508 		case '>':
    509 			html = htmlGt
    510 		default:
    511 			continue
    512 		}
    513 		w.Write(b[last:i])
    514 		w.Write(html)
    515 		last = i + 1
    516 	}
    517 	w.Write(b[last:])
    518 }
    519 
    520 // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
    521 func HTMLEscapeString(s string) string {
    522 	// Avoid allocation if we can.
    523 	if !strings.ContainsAny(s, `'"&<>`) {
    524 		return s
    525 	}
    526 	var b bytes.Buffer
    527 	HTMLEscape(&b, []byte(s))
    528 	return b.String()
    529 }
    530 
    531 // HTMLEscaper returns the escaped HTML equivalent of the textual
    532 // representation of its arguments.
    533 func HTMLEscaper(args ...interface{}) string {
    534 	return HTMLEscapeString(evalArgs(args))
    535 }
    536 
    537 // JavaScript escaping.
    538 
    539 var (
    540 	jsLowUni = []byte(`\u00`)
    541 	hex      = []byte("0123456789ABCDEF")
    542 
    543 	jsBackslash = []byte(`\\`)
    544 	jsApos      = []byte(`\'`)
    545 	jsQuot      = []byte(`\"`)
    546 	jsLt        = []byte(`\x3C`)
    547 	jsGt        = []byte(`\x3E`)
    548 )
    549 
    550 // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
    551 func JSEscape(w io.Writer, b []byte) {
    552 	last := 0
    553 	for i := 0; i < len(b); i++ {
    554 		c := b[i]
    555 
    556 		if !jsIsSpecial(rune(c)) {
    557 			// fast path: nothing to do
    558 			continue
    559 		}
    560 		w.Write(b[last:i])
    561 
    562 		if c < utf8.RuneSelf {
    563 			// Quotes, slashes and angle brackets get quoted.
    564 			// Control characters get written as \u00XX.
    565 			switch c {
    566 			case '\\':
    567 				w.Write(jsBackslash)
    568 			case '\'':
    569 				w.Write(jsApos)
    570 			case '"':
    571 				w.Write(jsQuot)
    572 			case '<':
    573 				w.Write(jsLt)
    574 			case '>':
    575 				w.Write(jsGt)
    576 			default:
    577 				w.Write(jsLowUni)
    578 				t, b := c>>4, c&0x0f
    579 				w.Write(hex[t : t+1])
    580 				w.Write(hex[b : b+1])
    581 			}
    582 		} else {
    583 			// Unicode rune.
    584 			r, size := utf8.DecodeRune(b[i:])
    585 			if unicode.IsPrint(r) {
    586 				w.Write(b[i : i+size])
    587 			} else {
    588 				fmt.Fprintf(w, "\\u%04X", r)
    589 			}
    590 			i += size - 1
    591 		}
    592 		last = i + 1
    593 	}
    594 	w.Write(b[last:])
    595 }
    596 
    597 // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
    598 func JSEscapeString(s string) string {
    599 	// Avoid allocation if we can.
    600 	if strings.IndexFunc(s, jsIsSpecial) < 0 {
    601 		return s
    602 	}
    603 	var b bytes.Buffer
    604 	JSEscape(&b, []byte(s))
    605 	return b.String()
    606 }
    607 
    608 func jsIsSpecial(r rune) bool {
    609 	switch r {
    610 	case '\\', '\'', '"', '<', '>':
    611 		return true
    612 	}
    613 	return r < ' ' || utf8.RuneSelf <= r
    614 }
    615 
    616 // JSEscaper returns the escaped JavaScript equivalent of the textual
    617 // representation of its arguments.
    618 func JSEscaper(args ...interface{}) string {
    619 	return JSEscapeString(evalArgs(args))
    620 }
    621 
    622 // URLQueryEscaper returns the escaped value of the textual representation of
    623 // its arguments in a form suitable for embedding in a URL query.
    624 func URLQueryEscaper(args ...interface{}) string {
    625 	return url.QueryEscape(evalArgs(args))
    626 }
    627 
    628 // evalArgs formats the list of arguments into a string. It is therefore equivalent to
    629 //	fmt.Sprint(args...)
    630 // except that each argument is indirected (if a pointer), as required,
    631 // using the same rules as the default string evaluation during template
    632 // execution.
    633 func evalArgs(args []interface{}) string {
    634 	ok := false
    635 	var s string
    636 	// Fast path for simple common case.
    637 	if len(args) == 1 {
    638 		s, ok = args[0].(string)
    639 	}
    640 	if !ok {
    641 		for i, arg := range args {
    642 			a, ok := printableValue(reflect.ValueOf(arg))
    643 			if ok {
    644 				args[i] = a
    645 			} // else let fmt do its thing
    646 		}
    647 		s = fmt.Sprint(args...)
    648 	}
    649 	return s
    650 }
    651