Home | History | Annotate | Download | only in gotestmain
      1 // Copyright 2015 Google Inc. All rights reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 package gotestmain
     16 
     17 import (
     18 	"bytes"
     19 	"flag"
     20 	"fmt"
     21 	"go/ast"
     22 	"go/parser"
     23 	"go/token"
     24 	"io/ioutil"
     25 	"os"
     26 	"strings"
     27 	"text/template"
     28 )
     29 
     30 var (
     31 	output   = flag.String("o", "", "output filename")
     32 	pkg      = flag.String("pkg", "", "test package")
     33 	exitCode = 0
     34 )
     35 
     36 type data struct {
     37 	Package string
     38 	Tests   []string
     39 }
     40 
     41 func findTests(srcs []string) (tests []string) {
     42 	for _, src := range srcs {
     43 		f, err := parser.ParseFile(token.NewFileSet(), src, nil, 0)
     44 		if err != nil {
     45 			panic(err)
     46 		}
     47 		for _, obj := range f.Scope.Objects {
     48 			if obj.Kind != ast.Fun || !strings.HasPrefix(obj.Name, "Test") {
     49 				continue
     50 			}
     51 			tests = append(tests, obj.Name)
     52 		}
     53 	}
     54 	return
     55 }
     56 
     57 func main() {
     58 	flag.Parse()
     59 
     60 	if flag.NArg() == 0 {
     61 		fmt.Fprintln(os.Stderr, "error: must pass at least one input")
     62 		exitCode = 1
     63 		return
     64 	}
     65 
     66 	buf := &bytes.Buffer{}
     67 
     68 	d := data{
     69 		Package: *pkg,
     70 		Tests:   findTests(flag.Args()),
     71 	}
     72 
     73 	err := testMainTmpl.Execute(buf, d)
     74 	if err != nil {
     75 		panic(err)
     76 	}
     77 
     78 	err = ioutil.WriteFile(*output, buf.Bytes(), 0666)
     79 	if err != nil {
     80 		panic(err)
     81 	}
     82 }
     83 
     84 var testMainTmpl = template.Must(template.New("testMain").Parse(`
     85 package main
     86 
     87 import (
     88 	"testing"
     89 
     90 	pkg "{{.Package}}"
     91 )
     92 
     93 var t = []testing.InternalTest{
     94 {{range .Tests}}
     95 	{"{{.}}", pkg.{{.}}},
     96 {{end}}
     97 }
     98 
     99 func matchString(pat, str string) (bool, error) {
    100 	return true, nil
    101 }
    102 
    103 func main() {
    104 	testing.Main(matchString, t, nil, nil)
    105 }
    106 `))
    107