Home | History | Annotate | Download | only in gen
      1 // Copyright 2016 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // This program generates a test to verify that the standard arithmetic
      6 // operators properly handle const cases. The test file should be
      7 // generated with a known working version of go.
      8 // launch with `go run arithConstGen.go` a file called arithConst.go
      9 // will be written into the parent directory containing the tests
     10 
     11 package main
     12 
     13 import (
     14 	"bytes"
     15 	"fmt"
     16 	"go/format"
     17 	"io/ioutil"
     18 	"log"
     19 	"strings"
     20 	"text/template"
     21 )
     22 
     23 type op struct {
     24 	name, symbol string
     25 }
     26 type szD struct {
     27 	name string
     28 	sn   string
     29 	u    []uint64
     30 	i    []int64
     31 }
     32 
     33 var szs []szD = []szD{
     34 	szD{name: "uint64", sn: "64", u: []uint64{0, 1, 4294967296, 0xffffFFFFffffFFFF}},
     35 	szD{name: "int64", sn: "64", i: []int64{-0x8000000000000000, -0x7FFFFFFFFFFFFFFF,
     36 		-4294967296, -1, 0, 1, 4294967296, 0x7FFFFFFFFFFFFFFE, 0x7FFFFFFFFFFFFFFF}},
     37 
     38 	szD{name: "uint32", sn: "32", u: []uint64{0, 1, 4294967295}},
     39 	szD{name: "int32", sn: "32", i: []int64{-0x80000000, -0x7FFFFFFF, -1, 0,
     40 		1, 0x7FFFFFFF}},
     41 
     42 	szD{name: "uint16", sn: "16", u: []uint64{0, 1, 65535}},
     43 	szD{name: "int16", sn: "16", i: []int64{-32768, -32767, -1, 0, 1, 32766, 32767}},
     44 
     45 	szD{name: "uint8", sn: "8", u: []uint64{0, 1, 255}},
     46 	szD{name: "int8", sn: "8", i: []int64{-128, -127, -1, 0, 1, 126, 127}},
     47 }
     48 
     49 var ops []op = []op{op{"add", "+"}, op{"sub", "-"}, op{"div", "/"}, op{"mul", "*"},
     50 	op{"lsh", "<<"}, op{"rsh", ">>"}, op{"mod", "%"}}
     51 
     52 // compute the result of i op j, cast as type t.
     53 func ansU(i, j uint64, t, op string) string {
     54 	var ans uint64
     55 	switch op {
     56 	case "+":
     57 		ans = i + j
     58 	case "-":
     59 		ans = i - j
     60 	case "*":
     61 		ans = i * j
     62 	case "/":
     63 		if j != 0 {
     64 			ans = i / j
     65 		}
     66 	case "%":
     67 		if j != 0 {
     68 			ans = i % j
     69 		}
     70 	case "<<":
     71 		ans = i << j
     72 	case ">>":
     73 		ans = i >> j
     74 	}
     75 	switch t {
     76 	case "uint32":
     77 		ans = uint64(uint32(ans))
     78 	case "uint16":
     79 		ans = uint64(uint16(ans))
     80 	case "uint8":
     81 		ans = uint64(uint8(ans))
     82 	}
     83 	return fmt.Sprintf("%d", ans)
     84 }
     85 
     86 // compute the result of i op j, cast as type t.
     87 func ansS(i, j int64, t, op string) string {
     88 	var ans int64
     89 	switch op {
     90 	case "+":
     91 		ans = i + j
     92 	case "-":
     93 		ans = i - j
     94 	case "*":
     95 		ans = i * j
     96 	case "/":
     97 		if j != 0 {
     98 			ans = i / j
     99 		}
    100 	case "%":
    101 		if j != 0 {
    102 			ans = i % j
    103 		}
    104 	case "<<":
    105 		ans = i << uint64(j)
    106 	case ">>":
    107 		ans = i >> uint64(j)
    108 	}
    109 	switch t {
    110 	case "int32":
    111 		ans = int64(int32(ans))
    112 	case "int16":
    113 		ans = int64(int16(ans))
    114 	case "int8":
    115 		ans = int64(int8(ans))
    116 	}
    117 	return fmt.Sprintf("%d", ans)
    118 }
    119 
    120 func main() {
    121 
    122 	w := new(bytes.Buffer)
    123 
    124 	fmt.Fprintf(w, "package main;\n")
    125 	fmt.Fprintf(w, "import \"fmt\"\n")
    126 
    127 	fncCnst1, err := template.New("fnc").Parse(
    128 		`//go:noinline
    129 		func {{.Name}}_{{.Type_}}_{{.FNumber}}_ssa(a {{.Type_}}) {{.Type_}} {
    130 	return a {{.Symbol}} {{.Number}}
    131 }
    132 `)
    133 	if err != nil {
    134 		panic(err)
    135 	}
    136 	fncCnst2, err := template.New("fnc").Parse(
    137 		`//go:noinline
    138 		func {{.Name}}_{{.FNumber}}_{{.Type_}}_ssa(a {{.Type_}}) {{.Type_}} {
    139 	return {{.Number}} {{.Symbol}} a
    140 }
    141 
    142 `)
    143 	if err != nil {
    144 		panic(err)
    145 	}
    146 
    147 	type fncData struct {
    148 		Name, Type_, Symbol, FNumber, Number string
    149 	}
    150 
    151 	for _, s := range szs {
    152 		for _, o := range ops {
    153 			fd := fncData{o.name, s.name, o.symbol, "", ""}
    154 
    155 			// unsigned test cases
    156 			if len(s.u) > 0 {
    157 				for _, i := range s.u {
    158 					fd.Number = fmt.Sprintf("%d", i)
    159 					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
    160 
    161 					// avoid division by zero
    162 					if o.name != "mod" && o.name != "div" || i != 0 {
    163 						fncCnst1.Execute(w, fd)
    164 					}
    165 
    166 					fncCnst2.Execute(w, fd)
    167 				}
    168 			}
    169 
    170 			// signed test cases
    171 			if len(s.i) > 0 {
    172 				// don't generate tests for shifts by signed integers
    173 				if o.name == "lsh" || o.name == "rsh" {
    174 					continue
    175 				}
    176 				for _, i := range s.i {
    177 					fd.Number = fmt.Sprintf("%d", i)
    178 					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
    179 
    180 					// avoid division by zero
    181 					if o.name != "mod" && o.name != "div" || i != 0 {
    182 						fncCnst1.Execute(w, fd)
    183 					}
    184 					fncCnst2.Execute(w, fd)
    185 				}
    186 			}
    187 		}
    188 	}
    189 
    190 	fmt.Fprintf(w, "var failed bool\n\n")
    191 	fmt.Fprintf(w, "func main() {\n\n")
    192 
    193 	vrf1, _ := template.New("vrf1").Parse(`
    194   if got := {{.Name}}_{{.FNumber}}_{{.Type_}}_ssa({{.Input}}); got != {{.Ans}} {
    195   	fmt.Printf("{{.Name}}_{{.Type_}} {{.Number}}%s{{.Input}} = %d, wanted {{.Ans}}\n", ` + "`{{.Symbol}}`" + `, got)
    196   	failed = true
    197   }
    198 `)
    199 
    200 	vrf2, _ := template.New("vrf2").Parse(`
    201   if got := {{.Name}}_{{.Type_}}_{{.FNumber}}_ssa({{.Input}}); got != {{.Ans}} {
    202     fmt.Printf("{{.Name}}_{{.Type_}} {{.Input}}%s{{.Number}} = %d, wanted {{.Ans}}\n", ` + "`{{.Symbol}}`" + `, got)
    203     failed = true
    204   }
    205 `)
    206 
    207 	type cfncData struct {
    208 		Name, Type_, Symbol, FNumber, Number string
    209 		Ans, Input                           string
    210 	}
    211 	for _, s := range szs {
    212 		if len(s.u) > 0 {
    213 			for _, o := range ops {
    214 				fd := cfncData{o.name, s.name, o.symbol, "", "", "", ""}
    215 				for _, i := range s.u {
    216 					fd.Number = fmt.Sprintf("%d", i)
    217 					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
    218 
    219 					// unsigned
    220 					for _, j := range s.u {
    221 
    222 						if o.name != "mod" && o.name != "div" || j != 0 {
    223 							fd.Ans = ansU(i, j, s.name, o.symbol)
    224 							fd.Input = fmt.Sprintf("%d", j)
    225 							err = vrf1.Execute(w, fd)
    226 							if err != nil {
    227 								panic(err)
    228 							}
    229 						}
    230 
    231 						if o.name != "mod" && o.name != "div" || i != 0 {
    232 							fd.Ans = ansU(j, i, s.name, o.symbol)
    233 							fd.Input = fmt.Sprintf("%d", j)
    234 							err = vrf2.Execute(w, fd)
    235 							if err != nil {
    236 								panic(err)
    237 							}
    238 						}
    239 
    240 					}
    241 				}
    242 
    243 			}
    244 		}
    245 
    246 		// signed
    247 		if len(s.i) > 0 {
    248 			for _, o := range ops {
    249 				// don't generate tests for shifts by signed integers
    250 				if o.name == "lsh" || o.name == "rsh" {
    251 					continue
    252 				}
    253 				fd := cfncData{o.name, s.name, o.symbol, "", "", "", ""}
    254 				for _, i := range s.i {
    255 					fd.Number = fmt.Sprintf("%d", i)
    256 					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
    257 					for _, j := range s.i {
    258 						if o.name != "mod" && o.name != "div" || j != 0 {
    259 							fd.Ans = ansS(i, j, s.name, o.symbol)
    260 							fd.Input = fmt.Sprintf("%d", j)
    261 							err = vrf1.Execute(w, fd)
    262 							if err != nil {
    263 								panic(err)
    264 							}
    265 						}
    266 
    267 						if o.name != "mod" && o.name != "div" || i != 0 {
    268 							fd.Ans = ansS(j, i, s.name, o.symbol)
    269 							fd.Input = fmt.Sprintf("%d", j)
    270 							err = vrf2.Execute(w, fd)
    271 							if err != nil {
    272 								panic(err)
    273 							}
    274 						}
    275 
    276 					}
    277 				}
    278 
    279 			}
    280 		}
    281 	}
    282 
    283 	fmt.Fprintf(w, `if failed {
    284         panic("tests failed")
    285     }
    286 `)
    287 	fmt.Fprintf(w, "}\n")
    288 
    289 	// gofmt result
    290 	b := w.Bytes()
    291 	src, err := format.Source(b)
    292 	if err != nil {
    293 		fmt.Printf("%s\n", b)
    294 		panic(err)
    295 	}
    296 
    297 	// write to file
    298 	err = ioutil.WriteFile("../arithConst.go", src, 0666)
    299 	if err != nil {
    300 		log.Fatalf("can't write output: %v\n", err)
    301 	}
    302 }
    303