Home | History | Annotate | Download | only in gen
      1 // Copyright 2016 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // This program generates a test to verify that the standard arithmetic
      6 // operators properly handle constant folding. The test file should be
      7 // generated with a known working version of go.
      8 // launch with `go run constFoldGen.go` a file called constFold_test.go
      9 // will be written into the grandparent directory containing the tests.
     10 
     11 package main
     12 
     13 import (
     14 	"bytes"
     15 	"fmt"
     16 	"go/format"
     17 	"io/ioutil"
     18 	"log"
     19 )
     20 
     21 type op struct {
     22 	name, symbol string
     23 }
     24 type szD struct {
     25 	name string
     26 	sn   string
     27 	u    []uint64
     28 	i    []int64
     29 }
     30 
     31 var szs []szD = []szD{
     32 	szD{name: "uint64", sn: "64", u: []uint64{0, 1, 4294967296, 0xffffFFFFffffFFFF}},
     33 	szD{name: "int64", sn: "64", i: []int64{-0x8000000000000000, -0x7FFFFFFFFFFFFFFF,
     34 		-4294967296, -1, 0, 1, 4294967296, 0x7FFFFFFFFFFFFFFE, 0x7FFFFFFFFFFFFFFF}},
     35 
     36 	szD{name: "uint32", sn: "32", u: []uint64{0, 1, 4294967295}},
     37 	szD{name: "int32", sn: "32", i: []int64{-0x80000000, -0x7FFFFFFF, -1, 0,
     38 		1, 0x7FFFFFFF}},
     39 
     40 	szD{name: "uint16", sn: "16", u: []uint64{0, 1, 65535}},
     41 	szD{name: "int16", sn: "16", i: []int64{-32768, -32767, -1, 0, 1, 32766, 32767}},
     42 
     43 	szD{name: "uint8", sn: "8", u: []uint64{0, 1, 255}},
     44 	szD{name: "int8", sn: "8", i: []int64{-128, -127, -1, 0, 1, 126, 127}},
     45 }
     46 
     47 var ops = []op{
     48 	op{"add", "+"}, op{"sub", "-"}, op{"div", "/"}, op{"mul", "*"},
     49 	op{"lsh", "<<"}, op{"rsh", ">>"}, op{"mod", "%"},
     50 }
     51 
     52 // compute the result of i op j, cast as type t.
     53 func ansU(i, j uint64, t, op string) string {
     54 	var ans uint64
     55 	switch op {
     56 	case "+":
     57 		ans = i + j
     58 	case "-":
     59 		ans = i - j
     60 	case "*":
     61 		ans = i * j
     62 	case "/":
     63 		if j != 0 {
     64 			ans = i / j
     65 		}
     66 	case "%":
     67 		if j != 0 {
     68 			ans = i % j
     69 		}
     70 	case "<<":
     71 		ans = i << j
     72 	case ">>":
     73 		ans = i >> j
     74 	}
     75 	switch t {
     76 	case "uint32":
     77 		ans = uint64(uint32(ans))
     78 	case "uint16":
     79 		ans = uint64(uint16(ans))
     80 	case "uint8":
     81 		ans = uint64(uint8(ans))
     82 	}
     83 	return fmt.Sprintf("%d", ans)
     84 }
     85 
     86 // compute the result of i op j, cast as type t.
     87 func ansS(i, j int64, t, op string) string {
     88 	var ans int64
     89 	switch op {
     90 	case "+":
     91 		ans = i + j
     92 	case "-":
     93 		ans = i - j
     94 	case "*":
     95 		ans = i * j
     96 	case "/":
     97 		if j != 0 {
     98 			ans = i / j
     99 		}
    100 	case "%":
    101 		if j != 0 {
    102 			ans = i % j
    103 		}
    104 	case "<<":
    105 		ans = i << uint64(j)
    106 	case ">>":
    107 		ans = i >> uint64(j)
    108 	}
    109 	switch t {
    110 	case "int32":
    111 		ans = int64(int32(ans))
    112 	case "int16":
    113 		ans = int64(int16(ans))
    114 	case "int8":
    115 		ans = int64(int8(ans))
    116 	}
    117 	return fmt.Sprintf("%d", ans)
    118 }
    119 
    120 func main() {
    121 
    122 	w := new(bytes.Buffer)
    123 
    124 	fmt.Fprintf(w, "package gc\n")
    125 	fmt.Fprintf(w, "import \"testing\"\n")
    126 
    127 	for _, s := range szs {
    128 		for _, o := range ops {
    129 			if o.symbol == "<<" || o.symbol == ">>" {
    130 				// shifts handled separately below, as they can have
    131 				// different types on the LHS and RHS.
    132 				continue
    133 			}
    134 			fmt.Fprintf(w, "func TestConstFold%s%s(t *testing.T) {\n", s.name, o.name)
    135 			fmt.Fprintf(w, "\tvar x, y, r %s\n", s.name)
    136 			// unsigned test cases
    137 			for _, c := range s.u {
    138 				fmt.Fprintf(w, "\tx = %d\n", c)
    139 				for _, d := range s.u {
    140 					if d == 0 && (o.symbol == "/" || o.symbol == "%") {
    141 						continue
    142 					}
    143 					fmt.Fprintf(w, "\ty = %d\n", d)
    144 					fmt.Fprintf(w, "\tr = x %s y\n", o.symbol)
    145 					want := ansU(c, d, s.name, o.symbol)
    146 					fmt.Fprintf(w, "\tif r != %s {\n", want)
    147 					fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol)
    148 					fmt.Fprintf(w, "\t}\n")
    149 				}
    150 			}
    151 			// signed test cases
    152 			for _, c := range s.i {
    153 				fmt.Fprintf(w, "\tx = %d\n", c)
    154 				for _, d := range s.i {
    155 					if d == 0 && (o.symbol == "/" || o.symbol == "%") {
    156 						continue
    157 					}
    158 					fmt.Fprintf(w, "\ty = %d\n", d)
    159 					fmt.Fprintf(w, "\tr = x %s y\n", o.symbol)
    160 					want := ansS(c, d, s.name, o.symbol)
    161 					fmt.Fprintf(w, "\tif r != %s {\n", want)
    162 					fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol)
    163 					fmt.Fprintf(w, "\t}\n")
    164 				}
    165 			}
    166 			fmt.Fprintf(w, "}\n")
    167 		}
    168 	}
    169 
    170 	// Special signed/unsigned cases for shifts
    171 	for _, ls := range szs {
    172 		for _, rs := range szs {
    173 			if rs.name[0] != 'u' {
    174 				continue
    175 			}
    176 			for _, o := range ops {
    177 				if o.symbol != "<<" && o.symbol != ">>" {
    178 					continue
    179 				}
    180 				fmt.Fprintf(w, "func TestConstFold%s%s%s(t *testing.T) {\n", ls.name, rs.name, o.name)
    181 				fmt.Fprintf(w, "\tvar x, r %s\n", ls.name)
    182 				fmt.Fprintf(w, "\tvar y %s\n", rs.name)
    183 				// unsigned LHS
    184 				for _, c := range ls.u {
    185 					fmt.Fprintf(w, "\tx = %d\n", c)
    186 					for _, d := range rs.u {
    187 						fmt.Fprintf(w, "\ty = %d\n", d)
    188 						fmt.Fprintf(w, "\tr = x %s y\n", o.symbol)
    189 						want := ansU(c, d, ls.name, o.symbol)
    190 						fmt.Fprintf(w, "\tif r != %s {\n", want)
    191 						fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol)
    192 						fmt.Fprintf(w, "\t}\n")
    193 					}
    194 				}
    195 				// signed LHS
    196 				for _, c := range ls.i {
    197 					fmt.Fprintf(w, "\tx = %d\n", c)
    198 					for _, d := range rs.u {
    199 						fmt.Fprintf(w, "\ty = %d\n", d)
    200 						fmt.Fprintf(w, "\tr = x %s y\n", o.symbol)
    201 						want := ansS(c, int64(d), ls.name, o.symbol)
    202 						fmt.Fprintf(w, "\tif r != %s {\n", want)
    203 						fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol)
    204 						fmt.Fprintf(w, "\t}\n")
    205 					}
    206 				}
    207 				fmt.Fprintf(w, "}\n")
    208 			}
    209 		}
    210 	}
    211 
    212 	// Constant folding for comparisons
    213 	for _, s := range szs {
    214 		fmt.Fprintf(w, "func TestConstFoldCompare%s(t *testing.T) {\n", s.name)
    215 		for _, x := range s.i {
    216 			for _, y := range s.i {
    217 				fmt.Fprintf(w, "\t{\n")
    218 				fmt.Fprintf(w, "\t\tvar x %s = %d\n", s.name, x)
    219 				fmt.Fprintf(w, "\t\tvar y %s = %d\n", s.name, y)
    220 				if x == y {
    221 					fmt.Fprintf(w, "\t\tif !(x == y) { t.Errorf(\"!(%%d == %%d)\", x, y) }\n")
    222 				} else {
    223 					fmt.Fprintf(w, "\t\tif x == y { t.Errorf(\"%%d == %%d\", x, y) }\n")
    224 				}
    225 				if x != y {
    226 					fmt.Fprintf(w, "\t\tif !(x != y) { t.Errorf(\"!(%%d != %%d)\", x, y) }\n")
    227 				} else {
    228 					fmt.Fprintf(w, "\t\tif x != y { t.Errorf(\"%%d != %%d\", x, y) }\n")
    229 				}
    230 				if x < y {
    231 					fmt.Fprintf(w, "\t\tif !(x < y) { t.Errorf(\"!(%%d < %%d)\", x, y) }\n")
    232 				} else {
    233 					fmt.Fprintf(w, "\t\tif x < y { t.Errorf(\"%%d < %%d\", x, y) }\n")
    234 				}
    235 				if x > y {
    236 					fmt.Fprintf(w, "\t\tif !(x > y) { t.Errorf(\"!(%%d > %%d)\", x, y) }\n")
    237 				} else {
    238 					fmt.Fprintf(w, "\t\tif x > y { t.Errorf(\"%%d > %%d\", x, y) }\n")
    239 				}
    240 				if x <= y {
    241 					fmt.Fprintf(w, "\t\tif !(x <= y) { t.Errorf(\"!(%%d <= %%d)\", x, y) }\n")
    242 				} else {
    243 					fmt.Fprintf(w, "\t\tif x <= y { t.Errorf(\"%%d <= %%d\", x, y) }\n")
    244 				}
    245 				if x >= y {
    246 					fmt.Fprintf(w, "\t\tif !(x >= y) { t.Errorf(\"!(%%d >= %%d)\", x, y) }\n")
    247 				} else {
    248 					fmt.Fprintf(w, "\t\tif x >= y { t.Errorf(\"%%d >= %%d\", x, y) }\n")
    249 				}
    250 				fmt.Fprintf(w, "\t}\n")
    251 			}
    252 		}
    253 		for _, x := range s.u {
    254 			for _, y := range s.u {
    255 				fmt.Fprintf(w, "\t{\n")
    256 				fmt.Fprintf(w, "\t\tvar x %s = %d\n", s.name, x)
    257 				fmt.Fprintf(w, "\t\tvar y %s = %d\n", s.name, y)
    258 				if x == y {
    259 					fmt.Fprintf(w, "\t\tif !(x == y) { t.Errorf(\"!(%%d == %%d)\", x, y) }\n")
    260 				} else {
    261 					fmt.Fprintf(w, "\t\tif x == y { t.Errorf(\"%%d == %%d\", x, y) }\n")
    262 				}
    263 				if x != y {
    264 					fmt.Fprintf(w, "\t\tif !(x != y) { t.Errorf(\"!(%%d != %%d)\", x, y) }\n")
    265 				} else {
    266 					fmt.Fprintf(w, "\t\tif x != y { t.Errorf(\"%%d != %%d\", x, y) }\n")
    267 				}
    268 				if x < y {
    269 					fmt.Fprintf(w, "\t\tif !(x < y) { t.Errorf(\"!(%%d < %%d)\", x, y) }\n")
    270 				} else {
    271 					fmt.Fprintf(w, "\t\tif x < y { t.Errorf(\"%%d < %%d\", x, y) }\n")
    272 				}
    273 				if x > y {
    274 					fmt.Fprintf(w, "\t\tif !(x > y) { t.Errorf(\"!(%%d > %%d)\", x, y) }\n")
    275 				} else {
    276 					fmt.Fprintf(w, "\t\tif x > y { t.Errorf(\"%%d > %%d\", x, y) }\n")
    277 				}
    278 				if x <= y {
    279 					fmt.Fprintf(w, "\t\tif !(x <= y) { t.Errorf(\"!(%%d <= %%d)\", x, y) }\n")
    280 				} else {
    281 					fmt.Fprintf(w, "\t\tif x <= y { t.Errorf(\"%%d <= %%d\", x, y) }\n")
    282 				}
    283 				if x >= y {
    284 					fmt.Fprintf(w, "\t\tif !(x >= y) { t.Errorf(\"!(%%d >= %%d)\", x, y) }\n")
    285 				} else {
    286 					fmt.Fprintf(w, "\t\tif x >= y { t.Errorf(\"%%d >= %%d\", x, y) }\n")
    287 				}
    288 				fmt.Fprintf(w, "\t}\n")
    289 			}
    290 		}
    291 		fmt.Fprintf(w, "}\n")
    292 	}
    293 
    294 	// gofmt result
    295 	b := w.Bytes()
    296 	src, err := format.Source(b)
    297 	if err != nil {
    298 		fmt.Printf("%s\n", b)
    299 		panic(err)
    300 	}
    301 
    302 	// write to file
    303 	err = ioutil.WriteFile("../../constFold_test.go", src, 0666)
    304 	if err != nil {
    305 		log.Fatalf("can't write output: %v\n", err)
    306 	}
    307 }
    308