Home | History | Annotate | Download | only in template
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package template
      6 
      7 import (
      8 	"bytes"
      9 	"errors"
     10 	"fmt"
     11 	"io"
     12 	"net/url"
     13 	"reflect"
     14 	"strings"
     15 	"unicode"
     16 	"unicode/utf8"
     17 )
     18 
     19 // FuncMap is the type of the map defining the mapping from names to functions.
     20 // Each function must have either a single return value, or two return values of
     21 // which the second has type error. In that case, if the second (error)
     22 // return value evaluates to non-nil during execution, execution terminates and
     23 // Execute returns that error.
     24 type FuncMap map[string]interface{}
     25 
     26 var builtins = FuncMap{
     27 	"and":      and,
     28 	"call":     call,
     29 	"html":     HTMLEscaper,
     30 	"index":    index,
     31 	"js":       JSEscaper,
     32 	"len":      length,
     33 	"not":      not,
     34 	"or":       or,
     35 	"print":    fmt.Sprint,
     36 	"printf":   fmt.Sprintf,
     37 	"println":  fmt.Sprintln,
     38 	"urlquery": URLQueryEscaper,
     39 
     40 	// Comparisons
     41 	"eq": eq, // ==
     42 	"ge": ge, // >=
     43 	"gt": gt, // >
     44 	"le": le, // <=
     45 	"lt": lt, // <
     46 	"ne": ne, // !=
     47 }
     48 
     49 var builtinFuncs = createValueFuncs(builtins)
     50 
     51 // createValueFuncs turns a FuncMap into a map[string]reflect.Value
     52 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
     53 	m := make(map[string]reflect.Value)
     54 	addValueFuncs(m, funcMap)
     55 	return m
     56 }
     57 
     58 // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
     59 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
     60 	for name, fn := range in {
     61 		v := reflect.ValueOf(fn)
     62 		if v.Kind() != reflect.Func {
     63 			panic("value for " + name + " not a function")
     64 		}
     65 		if !goodFunc(v.Type()) {
     66 			panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
     67 		}
     68 		out[name] = v
     69 	}
     70 }
     71 
     72 // addFuncs adds to values the functions in funcs. It does no checking of the input -
     73 // call addValueFuncs first.
     74 func addFuncs(out, in FuncMap) {
     75 	for name, fn := range in {
     76 		out[name] = fn
     77 	}
     78 }
     79 
     80 // goodFunc checks that the function or method has the right result signature.
     81 func goodFunc(typ reflect.Type) bool {
     82 	// We allow functions with 1 result or 2 results where the second is an error.
     83 	switch {
     84 	case typ.NumOut() == 1:
     85 		return true
     86 	case typ.NumOut() == 2 && typ.Out(1) == errorType:
     87 		return true
     88 	}
     89 	return false
     90 }
     91 
     92 // findFunction looks for a function in the template, and global map.
     93 func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
     94 	if tmpl != nil && tmpl.common != nil {
     95 		tmpl.muFuncs.RLock()
     96 		defer tmpl.muFuncs.RUnlock()
     97 		if fn := tmpl.execFuncs[name]; fn.IsValid() {
     98 			return fn, true
     99 		}
    100 	}
    101 	if fn := builtinFuncs[name]; fn.IsValid() {
    102 		return fn, true
    103 	}
    104 	return reflect.Value{}, false
    105 }
    106 
    107 // Indexing.
    108 
    109 // index returns the result of indexing its first argument by the following
    110 // arguments.  Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
    111 // indexed item must be a map, slice, or array.
    112 func index(item interface{}, indices ...interface{}) (interface{}, error) {
    113 	v := reflect.ValueOf(item)
    114 	for _, i := range indices {
    115 		index := reflect.ValueOf(i)
    116 		var isNil bool
    117 		if v, isNil = indirect(v); isNil {
    118 			return nil, fmt.Errorf("index of nil pointer")
    119 		}
    120 		switch v.Kind() {
    121 		case reflect.Array, reflect.Slice, reflect.String:
    122 			var x int64
    123 			switch index.Kind() {
    124 			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    125 				x = index.Int()
    126 			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    127 				x = int64(index.Uint())
    128 			default:
    129 				return nil, fmt.Errorf("cannot index slice/array with type %s", index.Type())
    130 			}
    131 			if x < 0 || x >= int64(v.Len()) {
    132 				return nil, fmt.Errorf("index out of range: %d", x)
    133 			}
    134 			v = v.Index(int(x))
    135 		case reflect.Map:
    136 			if !index.IsValid() {
    137 				index = reflect.Zero(v.Type().Key())
    138 			}
    139 			if !index.Type().AssignableTo(v.Type().Key()) {
    140 				return nil, fmt.Errorf("%s is not index type for %s", index.Type(), v.Type())
    141 			}
    142 			if x := v.MapIndex(index); x.IsValid() {
    143 				v = x
    144 			} else {
    145 				v = reflect.Zero(v.Type().Elem())
    146 			}
    147 		default:
    148 			return nil, fmt.Errorf("can't index item of type %s", v.Type())
    149 		}
    150 	}
    151 	return v.Interface(), nil
    152 }
    153 
    154 // Length
    155 
    156 // length returns the length of the item, with an error if it has no defined length.
    157 func length(item interface{}) (int, error) {
    158 	v, isNil := indirect(reflect.ValueOf(item))
    159 	if isNil {
    160 		return 0, fmt.Errorf("len of nil pointer")
    161 	}
    162 	switch v.Kind() {
    163 	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
    164 		return v.Len(), nil
    165 	}
    166 	return 0, fmt.Errorf("len of type %s", v.Type())
    167 }
    168 
    169 // Function invocation
    170 
    171 // call returns the result of evaluating the first argument as a function.
    172 // The function must return 1 result, or 2 results, the second of which is an error.
    173 func call(fn interface{}, args ...interface{}) (interface{}, error) {
    174 	v := reflect.ValueOf(fn)
    175 	typ := v.Type()
    176 	if typ.Kind() != reflect.Func {
    177 		return nil, fmt.Errorf("non-function of type %s", typ)
    178 	}
    179 	if !goodFunc(typ) {
    180 		return nil, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
    181 	}
    182 	numIn := typ.NumIn()
    183 	var dddType reflect.Type
    184 	if typ.IsVariadic() {
    185 		if len(args) < numIn-1 {
    186 			return nil, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
    187 		}
    188 		dddType = typ.In(numIn - 1).Elem()
    189 	} else {
    190 		if len(args) != numIn {
    191 			return nil, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
    192 		}
    193 	}
    194 	argv := make([]reflect.Value, len(args))
    195 	for i, arg := range args {
    196 		value := reflect.ValueOf(arg)
    197 		// Compute the expected type. Clumsy because of variadics.
    198 		var argType reflect.Type
    199 		if !typ.IsVariadic() || i < numIn-1 {
    200 			argType = typ.In(i)
    201 		} else {
    202 			argType = dddType
    203 		}
    204 		if !value.IsValid() && canBeNil(argType) {
    205 			value = reflect.Zero(argType)
    206 		}
    207 		if !value.Type().AssignableTo(argType) {
    208 			return nil, fmt.Errorf("arg %d has type %s; should be %s", i, value.Type(), argType)
    209 		}
    210 		argv[i] = value
    211 	}
    212 	result := v.Call(argv)
    213 	if len(result) == 2 && !result[1].IsNil() {
    214 		return result[0].Interface(), result[1].Interface().(error)
    215 	}
    216 	return result[0].Interface(), nil
    217 }
    218 
    219 // Boolean logic.
    220 
    221 func truth(a interface{}) bool {
    222 	t, _ := isTrue(reflect.ValueOf(a))
    223 	return t
    224 }
    225 
    226 // and computes the Boolean AND of its arguments, returning
    227 // the first false argument it encounters, or the last argument.
    228 func and(arg0 interface{}, args ...interface{}) interface{} {
    229 	if !truth(arg0) {
    230 		return arg0
    231 	}
    232 	for i := range args {
    233 		arg0 = args[i]
    234 		if !truth(arg0) {
    235 			break
    236 		}
    237 	}
    238 	return arg0
    239 }
    240 
    241 // or computes the Boolean OR of its arguments, returning
    242 // the first true argument it encounters, or the last argument.
    243 func or(arg0 interface{}, args ...interface{}) interface{} {
    244 	if truth(arg0) {
    245 		return arg0
    246 	}
    247 	for i := range args {
    248 		arg0 = args[i]
    249 		if truth(arg0) {
    250 			break
    251 		}
    252 	}
    253 	return arg0
    254 }
    255 
    256 // not returns the Boolean negation of its argument.
    257 func not(arg interface{}) (truth bool) {
    258 	truth, _ = isTrue(reflect.ValueOf(arg))
    259 	return !truth
    260 }
    261 
    262 // Comparison.
    263 
    264 // TODO: Perhaps allow comparison between signed and unsigned integers.
    265 
    266 var (
    267 	errBadComparisonType = errors.New("invalid type for comparison")
    268 	errBadComparison     = errors.New("incompatible types for comparison")
    269 	errNoComparison      = errors.New("missing argument for comparison")
    270 )
    271 
    272 type kind int
    273 
    274 const (
    275 	invalidKind kind = iota
    276 	boolKind
    277 	complexKind
    278 	intKind
    279 	floatKind
    280 	integerKind
    281 	stringKind
    282 	uintKind
    283 )
    284 
    285 func basicKind(v reflect.Value) (kind, error) {
    286 	switch v.Kind() {
    287 	case reflect.Bool:
    288 		return boolKind, nil
    289 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    290 		return intKind, nil
    291 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    292 		return uintKind, nil
    293 	case reflect.Float32, reflect.Float64:
    294 		return floatKind, nil
    295 	case reflect.Complex64, reflect.Complex128:
    296 		return complexKind, nil
    297 	case reflect.String:
    298 		return stringKind, nil
    299 	}
    300 	return invalidKind, errBadComparisonType
    301 }
    302 
    303 // eq evaluates the comparison a == b || a == c || ...
    304 func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
    305 	v1 := reflect.ValueOf(arg1)
    306 	k1, err := basicKind(v1)
    307 	if err != nil {
    308 		return false, err
    309 	}
    310 	if len(arg2) == 0 {
    311 		return false, errNoComparison
    312 	}
    313 	for _, arg := range arg2 {
    314 		v2 := reflect.ValueOf(arg)
    315 		k2, err := basicKind(v2)
    316 		if err != nil {
    317 			return false, err
    318 		}
    319 		truth := false
    320 		if k1 != k2 {
    321 			// Special case: Can compare integer values regardless of type's sign.
    322 			switch {
    323 			case k1 == intKind && k2 == uintKind:
    324 				truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
    325 			case k1 == uintKind && k2 == intKind:
    326 				truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
    327 			default:
    328 				return false, errBadComparison
    329 			}
    330 		} else {
    331 			switch k1 {
    332 			case boolKind:
    333 				truth = v1.Bool() == v2.Bool()
    334 			case complexKind:
    335 				truth = v1.Complex() == v2.Complex()
    336 			case floatKind:
    337 				truth = v1.Float() == v2.Float()
    338 			case intKind:
    339 				truth = v1.Int() == v2.Int()
    340 			case stringKind:
    341 				truth = v1.String() == v2.String()
    342 			case uintKind:
    343 				truth = v1.Uint() == v2.Uint()
    344 			default:
    345 				panic("invalid kind")
    346 			}
    347 		}
    348 		if truth {
    349 			return true, nil
    350 		}
    351 	}
    352 	return false, nil
    353 }
    354 
    355 // ne evaluates the comparison a != b.
    356 func ne(arg1, arg2 interface{}) (bool, error) {
    357 	// != is the inverse of ==.
    358 	equal, err := eq(arg1, arg2)
    359 	return !equal, err
    360 }
    361 
    362 // lt evaluates the comparison a < b.
    363 func lt(arg1, arg2 interface{}) (bool, error) {
    364 	v1 := reflect.ValueOf(arg1)
    365 	k1, err := basicKind(v1)
    366 	if err != nil {
    367 		return false, err
    368 	}
    369 	v2 := reflect.ValueOf(arg2)
    370 	k2, err := basicKind(v2)
    371 	if err != nil {
    372 		return false, err
    373 	}
    374 	truth := false
    375 	if k1 != k2 {
    376 		// Special case: Can compare integer values regardless of type's sign.
    377 		switch {
    378 		case k1 == intKind && k2 == uintKind:
    379 			truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
    380 		case k1 == uintKind && k2 == intKind:
    381 			truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
    382 		default:
    383 			return false, errBadComparison
    384 		}
    385 	} else {
    386 		switch k1 {
    387 		case boolKind, complexKind:
    388 			return false, errBadComparisonType
    389 		case floatKind:
    390 			truth = v1.Float() < v2.Float()
    391 		case intKind:
    392 			truth = v1.Int() < v2.Int()
    393 		case stringKind:
    394 			truth = v1.String() < v2.String()
    395 		case uintKind:
    396 			truth = v1.Uint() < v2.Uint()
    397 		default:
    398 			panic("invalid kind")
    399 		}
    400 	}
    401 	return truth, nil
    402 }
    403 
    404 // le evaluates the comparison <= b.
    405 func le(arg1, arg2 interface{}) (bool, error) {
    406 	// <= is < or ==.
    407 	lessThan, err := lt(arg1, arg2)
    408 	if lessThan || err != nil {
    409 		return lessThan, err
    410 	}
    411 	return eq(arg1, arg2)
    412 }
    413 
    414 // gt evaluates the comparison a > b.
    415 func gt(arg1, arg2 interface{}) (bool, error) {
    416 	// > is the inverse of <=.
    417 	lessOrEqual, err := le(arg1, arg2)
    418 	if err != nil {
    419 		return false, err
    420 	}
    421 	return !lessOrEqual, nil
    422 }
    423 
    424 // ge evaluates the comparison a >= b.
    425 func ge(arg1, arg2 interface{}) (bool, error) {
    426 	// >= is the inverse of <.
    427 	lessThan, err := lt(arg1, arg2)
    428 	if err != nil {
    429 		return false, err
    430 	}
    431 	return !lessThan, nil
    432 }
    433 
    434 // HTML escaping.
    435 
    436 var (
    437 	htmlQuot = []byte("&#34;") // shorter than "&quot;"
    438 	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
    439 	htmlAmp  = []byte("&amp;")
    440 	htmlLt   = []byte("&lt;")
    441 	htmlGt   = []byte("&gt;")
    442 )
    443 
    444 // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
    445 func HTMLEscape(w io.Writer, b []byte) {
    446 	last := 0
    447 	for i, c := range b {
    448 		var html []byte
    449 		switch c {
    450 		case '"':
    451 			html = htmlQuot
    452 		case '\'':
    453 			html = htmlApos
    454 		case '&':
    455 			html = htmlAmp
    456 		case '<':
    457 			html = htmlLt
    458 		case '>':
    459 			html = htmlGt
    460 		default:
    461 			continue
    462 		}
    463 		w.Write(b[last:i])
    464 		w.Write(html)
    465 		last = i + 1
    466 	}
    467 	w.Write(b[last:])
    468 }
    469 
    470 // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
    471 func HTMLEscapeString(s string) string {
    472 	// Avoid allocation if we can.
    473 	if strings.IndexAny(s, `'"&<>`) < 0 {
    474 		return s
    475 	}
    476 	var b bytes.Buffer
    477 	HTMLEscape(&b, []byte(s))
    478 	return b.String()
    479 }
    480 
    481 // HTMLEscaper returns the escaped HTML equivalent of the textual
    482 // representation of its arguments.
    483 func HTMLEscaper(args ...interface{}) string {
    484 	return HTMLEscapeString(evalArgs(args))
    485 }
    486 
    487 // JavaScript escaping.
    488 
    489 var (
    490 	jsLowUni = []byte(`\u00`)
    491 	hex      = []byte("0123456789ABCDEF")
    492 
    493 	jsBackslash = []byte(`\\`)
    494 	jsApos      = []byte(`\'`)
    495 	jsQuot      = []byte(`\"`)
    496 	jsLt        = []byte(`\x3C`)
    497 	jsGt        = []byte(`\x3E`)
    498 )
    499 
    500 // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
    501 func JSEscape(w io.Writer, b []byte) {
    502 	last := 0
    503 	for i := 0; i < len(b); i++ {
    504 		c := b[i]
    505 
    506 		if !jsIsSpecial(rune(c)) {
    507 			// fast path: nothing to do
    508 			continue
    509 		}
    510 		w.Write(b[last:i])
    511 
    512 		if c < utf8.RuneSelf {
    513 			// Quotes, slashes and angle brackets get quoted.
    514 			// Control characters get written as \u00XX.
    515 			switch c {
    516 			case '\\':
    517 				w.Write(jsBackslash)
    518 			case '\'':
    519 				w.Write(jsApos)
    520 			case '"':
    521 				w.Write(jsQuot)
    522 			case '<':
    523 				w.Write(jsLt)
    524 			case '>':
    525 				w.Write(jsGt)
    526 			default:
    527 				w.Write(jsLowUni)
    528 				t, b := c>>4, c&0x0f
    529 				w.Write(hex[t : t+1])
    530 				w.Write(hex[b : b+1])
    531 			}
    532 		} else {
    533 			// Unicode rune.
    534 			r, size := utf8.DecodeRune(b[i:])
    535 			if unicode.IsPrint(r) {
    536 				w.Write(b[i : i+size])
    537 			} else {
    538 				fmt.Fprintf(w, "\\u%04X", r)
    539 			}
    540 			i += size - 1
    541 		}
    542 		last = i + 1
    543 	}
    544 	w.Write(b[last:])
    545 }
    546 
    547 // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
    548 func JSEscapeString(s string) string {
    549 	// Avoid allocation if we can.
    550 	if strings.IndexFunc(s, jsIsSpecial) < 0 {
    551 		return s
    552 	}
    553 	var b bytes.Buffer
    554 	JSEscape(&b, []byte(s))
    555 	return b.String()
    556 }
    557 
    558 func jsIsSpecial(r rune) bool {
    559 	switch r {
    560 	case '\\', '\'', '"', '<', '>':
    561 		return true
    562 	}
    563 	return r < ' ' || utf8.RuneSelf <= r
    564 }
    565 
    566 // JSEscaper returns the escaped JavaScript equivalent of the textual
    567 // representation of its arguments.
    568 func JSEscaper(args ...interface{}) string {
    569 	return JSEscapeString(evalArgs(args))
    570 }
    571 
    572 // URLQueryEscaper returns the escaped value of the textual representation of
    573 // its arguments in a form suitable for embedding in a URL query.
    574 func URLQueryEscaper(args ...interface{}) string {
    575 	return url.QueryEscape(evalArgs(args))
    576 }
    577 
    578 // evalArgs formats the list of arguments into a string. It is therefore equivalent to
    579 //	fmt.Sprint(args...)
    580 // except that each argument is indirected (if a pointer), as required,
    581 // using the same rules as the default string evaluation during template
    582 // execution.
    583 func evalArgs(args []interface{}) string {
    584 	ok := false
    585 	var s string
    586 	// Fast path for simple common case.
    587 	if len(args) == 1 {
    588 		s, ok = args[0].(string)
    589 	}
    590 	if !ok {
    591 		for i, arg := range args {
    592 			a, ok := printableValue(reflect.ValueOf(arg))
    593 			if ok {
    594 				args[i] = a
    595 			} // else let fmt do its thing
    596 		}
    597 		s = fmt.Sprint(args...)
    598 	}
    599 	return s
    600 }
    601