1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package template 6 7 import ( 8 "bytes" 9 "errors" 10 "fmt" 11 "io" 12 "net/url" 13 "reflect" 14 "strings" 15 "unicode" 16 "unicode/utf8" 17 ) 18 19 // FuncMap is the type of the map defining the mapping from names to functions. 20 // Each function must have either a single return value, or two return values of 21 // which the second has type error. In that case, if the second (error) 22 // return value evaluates to non-nil during execution, execution terminates and 23 // Execute returns that error. 24 type FuncMap map[string]interface{} 25 26 var builtins = FuncMap{ 27 "and": and, 28 "call": call, 29 "html": HTMLEscaper, 30 "index": index, 31 "js": JSEscaper, 32 "len": length, 33 "not": not, 34 "or": or, 35 "print": fmt.Sprint, 36 "printf": fmt.Sprintf, 37 "println": fmt.Sprintln, 38 "urlquery": URLQueryEscaper, 39 40 // Comparisons 41 "eq": eq, // == 42 "ge": ge, // >= 43 "gt": gt, // > 44 "le": le, // <= 45 "lt": lt, // < 46 "ne": ne, // != 47 } 48 49 var builtinFuncs = createValueFuncs(builtins) 50 51 // createValueFuncs turns a FuncMap into a map[string]reflect.Value 52 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value { 53 m := make(map[string]reflect.Value) 54 addValueFuncs(m, funcMap) 55 return m 56 } 57 58 // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values. 59 func addValueFuncs(out map[string]reflect.Value, in FuncMap) { 60 for name, fn := range in { 61 v := reflect.ValueOf(fn) 62 if v.Kind() != reflect.Func { 63 panic("value for " + name + " not a function") 64 } 65 if !goodFunc(v.Type()) { 66 panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut())) 67 } 68 out[name] = v 69 } 70 } 71 72 // addFuncs adds to values the functions in funcs. It does no checking of the input - 73 // call addValueFuncs first. 74 func addFuncs(out, in FuncMap) { 75 for name, fn := range in { 76 out[name] = fn 77 } 78 } 79 80 // goodFunc checks that the function or method has the right result signature. 81 func goodFunc(typ reflect.Type) bool { 82 // We allow functions with 1 result or 2 results where the second is an error. 83 switch { 84 case typ.NumOut() == 1: 85 return true 86 case typ.NumOut() == 2 && typ.Out(1) == errorType: 87 return true 88 } 89 return false 90 } 91 92 // findFunction looks for a function in the template, and global map. 93 func findFunction(name string, tmpl *Template) (reflect.Value, bool) { 94 if tmpl != nil && tmpl.common != nil { 95 tmpl.muFuncs.RLock() 96 defer tmpl.muFuncs.RUnlock() 97 if fn := tmpl.execFuncs[name]; fn.IsValid() { 98 return fn, true 99 } 100 } 101 if fn := builtinFuncs[name]; fn.IsValid() { 102 return fn, true 103 } 104 return reflect.Value{}, false 105 } 106 107 // Indexing. 108 109 // index returns the result of indexing its first argument by the following 110 // arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each 111 // indexed item must be a map, slice, or array. 112 func index(item interface{}, indices ...interface{}) (interface{}, error) { 113 v := reflect.ValueOf(item) 114 for _, i := range indices { 115 index := reflect.ValueOf(i) 116 var isNil bool 117 if v, isNil = indirect(v); isNil { 118 return nil, fmt.Errorf("index of nil pointer") 119 } 120 switch v.Kind() { 121 case reflect.Array, reflect.Slice, reflect.String: 122 var x int64 123 switch index.Kind() { 124 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 125 x = index.Int() 126 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 127 x = int64(index.Uint()) 128 default: 129 return nil, fmt.Errorf("cannot index slice/array with type %s", index.Type()) 130 } 131 if x < 0 || x >= int64(v.Len()) { 132 return nil, fmt.Errorf("index out of range: %d", x) 133 } 134 v = v.Index(int(x)) 135 case reflect.Map: 136 if !index.IsValid() { 137 index = reflect.Zero(v.Type().Key()) 138 } 139 if !index.Type().AssignableTo(v.Type().Key()) { 140 return nil, fmt.Errorf("%s is not index type for %s", index.Type(), v.Type()) 141 } 142 if x := v.MapIndex(index); x.IsValid() { 143 v = x 144 } else { 145 v = reflect.Zero(v.Type().Elem()) 146 } 147 default: 148 return nil, fmt.Errorf("can't index item of type %s", v.Type()) 149 } 150 } 151 return v.Interface(), nil 152 } 153 154 // Length 155 156 // length returns the length of the item, with an error if it has no defined length. 157 func length(item interface{}) (int, error) { 158 v, isNil := indirect(reflect.ValueOf(item)) 159 if isNil { 160 return 0, fmt.Errorf("len of nil pointer") 161 } 162 switch v.Kind() { 163 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: 164 return v.Len(), nil 165 } 166 return 0, fmt.Errorf("len of type %s", v.Type()) 167 } 168 169 // Function invocation 170 171 // call returns the result of evaluating the first argument as a function. 172 // The function must return 1 result, or 2 results, the second of which is an error. 173 func call(fn interface{}, args ...interface{}) (interface{}, error) { 174 v := reflect.ValueOf(fn) 175 typ := v.Type() 176 if typ.Kind() != reflect.Func { 177 return nil, fmt.Errorf("non-function of type %s", typ) 178 } 179 if !goodFunc(typ) { 180 return nil, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut()) 181 } 182 numIn := typ.NumIn() 183 var dddType reflect.Type 184 if typ.IsVariadic() { 185 if len(args) < numIn-1 { 186 return nil, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1) 187 } 188 dddType = typ.In(numIn - 1).Elem() 189 } else { 190 if len(args) != numIn { 191 return nil, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn) 192 } 193 } 194 argv := make([]reflect.Value, len(args)) 195 for i, arg := range args { 196 value := reflect.ValueOf(arg) 197 // Compute the expected type. Clumsy because of variadics. 198 var argType reflect.Type 199 if !typ.IsVariadic() || i < numIn-1 { 200 argType = typ.In(i) 201 } else { 202 argType = dddType 203 } 204 if !value.IsValid() && canBeNil(argType) { 205 value = reflect.Zero(argType) 206 } 207 if !value.Type().AssignableTo(argType) { 208 return nil, fmt.Errorf("arg %d has type %s; should be %s", i, value.Type(), argType) 209 } 210 argv[i] = value 211 } 212 result := v.Call(argv) 213 if len(result) == 2 && !result[1].IsNil() { 214 return result[0].Interface(), result[1].Interface().(error) 215 } 216 return result[0].Interface(), nil 217 } 218 219 // Boolean logic. 220 221 func truth(a interface{}) bool { 222 t, _ := isTrue(reflect.ValueOf(a)) 223 return t 224 } 225 226 // and computes the Boolean AND of its arguments, returning 227 // the first false argument it encounters, or the last argument. 228 func and(arg0 interface{}, args ...interface{}) interface{} { 229 if !truth(arg0) { 230 return arg0 231 } 232 for i := range args { 233 arg0 = args[i] 234 if !truth(arg0) { 235 break 236 } 237 } 238 return arg0 239 } 240 241 // or computes the Boolean OR of its arguments, returning 242 // the first true argument it encounters, or the last argument. 243 func or(arg0 interface{}, args ...interface{}) interface{} { 244 if truth(arg0) { 245 return arg0 246 } 247 for i := range args { 248 arg0 = args[i] 249 if truth(arg0) { 250 break 251 } 252 } 253 return arg0 254 } 255 256 // not returns the Boolean negation of its argument. 257 func not(arg interface{}) (truth bool) { 258 truth, _ = isTrue(reflect.ValueOf(arg)) 259 return !truth 260 } 261 262 // Comparison. 263 264 // TODO: Perhaps allow comparison between signed and unsigned integers. 265 266 var ( 267 errBadComparisonType = errors.New("invalid type for comparison") 268 errBadComparison = errors.New("incompatible types for comparison") 269 errNoComparison = errors.New("missing argument for comparison") 270 ) 271 272 type kind int 273 274 const ( 275 invalidKind kind = iota 276 boolKind 277 complexKind 278 intKind 279 floatKind 280 integerKind 281 stringKind 282 uintKind 283 ) 284 285 func basicKind(v reflect.Value) (kind, error) { 286 switch v.Kind() { 287 case reflect.Bool: 288 return boolKind, nil 289 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 290 return intKind, nil 291 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 292 return uintKind, nil 293 case reflect.Float32, reflect.Float64: 294 return floatKind, nil 295 case reflect.Complex64, reflect.Complex128: 296 return complexKind, nil 297 case reflect.String: 298 return stringKind, nil 299 } 300 return invalidKind, errBadComparisonType 301 } 302 303 // eq evaluates the comparison a == b || a == c || ... 304 func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { 305 v1 := reflect.ValueOf(arg1) 306 k1, err := basicKind(v1) 307 if err != nil { 308 return false, err 309 } 310 if len(arg2) == 0 { 311 return false, errNoComparison 312 } 313 for _, arg := range arg2 { 314 v2 := reflect.ValueOf(arg) 315 k2, err := basicKind(v2) 316 if err != nil { 317 return false, err 318 } 319 truth := false 320 if k1 != k2 { 321 // Special case: Can compare integer values regardless of type's sign. 322 switch { 323 case k1 == intKind && k2 == uintKind: 324 truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint() 325 case k1 == uintKind && k2 == intKind: 326 truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int()) 327 default: 328 return false, errBadComparison 329 } 330 } else { 331 switch k1 { 332 case boolKind: 333 truth = v1.Bool() == v2.Bool() 334 case complexKind: 335 truth = v1.Complex() == v2.Complex() 336 case floatKind: 337 truth = v1.Float() == v2.Float() 338 case intKind: 339 truth = v1.Int() == v2.Int() 340 case stringKind: 341 truth = v1.String() == v2.String() 342 case uintKind: 343 truth = v1.Uint() == v2.Uint() 344 default: 345 panic("invalid kind") 346 } 347 } 348 if truth { 349 return true, nil 350 } 351 } 352 return false, nil 353 } 354 355 // ne evaluates the comparison a != b. 356 func ne(arg1, arg2 interface{}) (bool, error) { 357 // != is the inverse of ==. 358 equal, err := eq(arg1, arg2) 359 return !equal, err 360 } 361 362 // lt evaluates the comparison a < b. 363 func lt(arg1, arg2 interface{}) (bool, error) { 364 v1 := reflect.ValueOf(arg1) 365 k1, err := basicKind(v1) 366 if err != nil { 367 return false, err 368 } 369 v2 := reflect.ValueOf(arg2) 370 k2, err := basicKind(v2) 371 if err != nil { 372 return false, err 373 } 374 truth := false 375 if k1 != k2 { 376 // Special case: Can compare integer values regardless of type's sign. 377 switch { 378 case k1 == intKind && k2 == uintKind: 379 truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint() 380 case k1 == uintKind && k2 == intKind: 381 truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int()) 382 default: 383 return false, errBadComparison 384 } 385 } else { 386 switch k1 { 387 case boolKind, complexKind: 388 return false, errBadComparisonType 389 case floatKind: 390 truth = v1.Float() < v2.Float() 391 case intKind: 392 truth = v1.Int() < v2.Int() 393 case stringKind: 394 truth = v1.String() < v2.String() 395 case uintKind: 396 truth = v1.Uint() < v2.Uint() 397 default: 398 panic("invalid kind") 399 } 400 } 401 return truth, nil 402 } 403 404 // le evaluates the comparison <= b. 405 func le(arg1, arg2 interface{}) (bool, error) { 406 // <= is < or ==. 407 lessThan, err := lt(arg1, arg2) 408 if lessThan || err != nil { 409 return lessThan, err 410 } 411 return eq(arg1, arg2) 412 } 413 414 // gt evaluates the comparison a > b. 415 func gt(arg1, arg2 interface{}) (bool, error) { 416 // > is the inverse of <=. 417 lessOrEqual, err := le(arg1, arg2) 418 if err != nil { 419 return false, err 420 } 421 return !lessOrEqual, nil 422 } 423 424 // ge evaluates the comparison a >= b. 425 func ge(arg1, arg2 interface{}) (bool, error) { 426 // >= is the inverse of <. 427 lessThan, err := lt(arg1, arg2) 428 if err != nil { 429 return false, err 430 } 431 return !lessThan, nil 432 } 433 434 // HTML escaping. 435 436 var ( 437 htmlQuot = []byte(""") // shorter than """ 438 htmlApos = []byte("'") // shorter than "'" and apos was not in HTML until HTML5 439 htmlAmp = []byte("&") 440 htmlLt = []byte("<") 441 htmlGt = []byte(">") 442 ) 443 444 // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b. 445 func HTMLEscape(w io.Writer, b []byte) { 446 last := 0 447 for i, c := range b { 448 var html []byte 449 switch c { 450 case '"': 451 html = htmlQuot 452 case '\'': 453 html = htmlApos 454 case '&': 455 html = htmlAmp 456 case '<': 457 html = htmlLt 458 case '>': 459 html = htmlGt 460 default: 461 continue 462 } 463 w.Write(b[last:i]) 464 w.Write(html) 465 last = i + 1 466 } 467 w.Write(b[last:]) 468 } 469 470 // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s. 471 func HTMLEscapeString(s string) string { 472 // Avoid allocation if we can. 473 if strings.IndexAny(s, `'"&<>`) < 0 { 474 return s 475 } 476 var b bytes.Buffer 477 HTMLEscape(&b, []byte(s)) 478 return b.String() 479 } 480 481 // HTMLEscaper returns the escaped HTML equivalent of the textual 482 // representation of its arguments. 483 func HTMLEscaper(args ...interface{}) string { 484 return HTMLEscapeString(evalArgs(args)) 485 } 486 487 // JavaScript escaping. 488 489 var ( 490 jsLowUni = []byte(`\u00`) 491 hex = []byte("0123456789ABCDEF") 492 493 jsBackslash = []byte(`\\`) 494 jsApos = []byte(`\'`) 495 jsQuot = []byte(`\"`) 496 jsLt = []byte(`\x3C`) 497 jsGt = []byte(`\x3E`) 498 ) 499 500 // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b. 501 func JSEscape(w io.Writer, b []byte) { 502 last := 0 503 for i := 0; i < len(b); i++ { 504 c := b[i] 505 506 if !jsIsSpecial(rune(c)) { 507 // fast path: nothing to do 508 continue 509 } 510 w.Write(b[last:i]) 511 512 if c < utf8.RuneSelf { 513 // Quotes, slashes and angle brackets get quoted. 514 // Control characters get written as \u00XX. 515 switch c { 516 case '\\': 517 w.Write(jsBackslash) 518 case '\'': 519 w.Write(jsApos) 520 case '"': 521 w.Write(jsQuot) 522 case '<': 523 w.Write(jsLt) 524 case '>': 525 w.Write(jsGt) 526 default: 527 w.Write(jsLowUni) 528 t, b := c>>4, c&0x0f 529 w.Write(hex[t : t+1]) 530 w.Write(hex[b : b+1]) 531 } 532 } else { 533 // Unicode rune. 534 r, size := utf8.DecodeRune(b[i:]) 535 if unicode.IsPrint(r) { 536 w.Write(b[i : i+size]) 537 } else { 538 fmt.Fprintf(w, "\\u%04X", r) 539 } 540 i += size - 1 541 } 542 last = i + 1 543 } 544 w.Write(b[last:]) 545 } 546 547 // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s. 548 func JSEscapeString(s string) string { 549 // Avoid allocation if we can. 550 if strings.IndexFunc(s, jsIsSpecial) < 0 { 551 return s 552 } 553 var b bytes.Buffer 554 JSEscape(&b, []byte(s)) 555 return b.String() 556 } 557 558 func jsIsSpecial(r rune) bool { 559 switch r { 560 case '\\', '\'', '"', '<', '>': 561 return true 562 } 563 return r < ' ' || utf8.RuneSelf <= r 564 } 565 566 // JSEscaper returns the escaped JavaScript equivalent of the textual 567 // representation of its arguments. 568 func JSEscaper(args ...interface{}) string { 569 return JSEscapeString(evalArgs(args)) 570 } 571 572 // URLQueryEscaper returns the escaped value of the textual representation of 573 // its arguments in a form suitable for embedding in a URL query. 574 func URLQueryEscaper(args ...interface{}) string { 575 return url.QueryEscape(evalArgs(args)) 576 } 577 578 // evalArgs formats the list of arguments into a string. It is therefore equivalent to 579 // fmt.Sprint(args...) 580 // except that each argument is indirected (if a pointer), as required, 581 // using the same rules as the default string evaluation during template 582 // execution. 583 func evalArgs(args []interface{}) string { 584 ok := false 585 var s string 586 // Fast path for simple common case. 587 if len(args) == 1 { 588 s, ok = args[0].(string) 589 } 590 if !ok { 591 for i, arg := range args { 592 a, ok := printableValue(reflect.ValueOf(arg)) 593 if ok { 594 args[i] = a 595 } // else let fmt do its thing 596 } 597 s = fmt.Sprint(args...) 598 } 599 return s 600 } 601