1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gotestmain 16 17 import ( 18 "bytes" 19 "flag" 20 "fmt" 21 "go/ast" 22 "go/parser" 23 "go/token" 24 "io/ioutil" 25 "os" 26 "strings" 27 "text/template" 28 ) 29 30 var ( 31 output = flag.String("o", "", "output filename") 32 pkg = flag.String("pkg", "", "test package") 33 exitCode = 0 34 ) 35 36 type data struct { 37 Package string 38 Tests []string 39 } 40 41 func findTests(srcs []string) (tests []string) { 42 for _, src := range srcs { 43 f, err := parser.ParseFile(token.NewFileSet(), src, nil, 0) 44 if err != nil { 45 panic(err) 46 } 47 for _, obj := range f.Scope.Objects { 48 if obj.Kind != ast.Fun || !strings.HasPrefix(obj.Name, "Test") { 49 continue 50 } 51 tests = append(tests, obj.Name) 52 } 53 } 54 return 55 } 56 57 func main() { 58 flag.Parse() 59 60 if flag.NArg() == 0 { 61 fmt.Fprintln(os.Stderr, "error: must pass at least one input") 62 exitCode = 1 63 return 64 } 65 66 buf := &bytes.Buffer{} 67 68 d := data{ 69 Package: *pkg, 70 Tests: findTests(flag.Args()), 71 } 72 73 err := testMainTmpl.Execute(buf, d) 74 if err != nil { 75 panic(err) 76 } 77 78 err = ioutil.WriteFile(*output, buf.Bytes(), 0666) 79 if err != nil { 80 panic(err) 81 } 82 } 83 84 var testMainTmpl = template.Must(template.New("testMain").Parse(` 85 package main 86 87 import ( 88 "testing" 89 90 pkg "{{.Package}}" 91 ) 92 93 var t = []testing.InternalTest{ 94 {{range .Tests}} 95 {"{{.}}", pkg.{{.}}}, 96 {{end}} 97 } 98 99 func matchString(pat, str string) (bool, error) { 100 return true, nil 101 } 102 103 func main() { 104 testing.Main(matchString, t, nil, nil) 105 } 106 `)) 107