Home | History | Annotate | Download | only in bn
      1 // Copyright (c) 2016, Google Inc.
      2 //
      3 // Permission to use, copy, modify, and/or distribute this software for any
      4 // purpose with or without fee is hereby granted, provided that the above
      5 // copyright notice and this permission notice appear in all copies.
      6 //
      7 // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
      8 // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
      9 // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
     10 // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
     11 // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
     12 // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
     13 // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
     14 
     15 package main
     16 
     17 import (
     18 	"bufio"
     19 	"errors"
     20 	"fmt"
     21 	"io"
     22 	"math/big"
     23 	"os"
     24 	"strings"
     25 )
     26 
     27 type test struct {
     28 	LineNumber int
     29 	Type       string
     30 	Values     map[string]*big.Int
     31 }
     32 
     33 type testScanner struct {
     34 	scanner *bufio.Scanner
     35 	lineNo  int
     36 	err     error
     37 	test    test
     38 }
     39 
     40 func newTestScanner(r io.Reader) *testScanner {
     41 	return &testScanner{scanner: bufio.NewScanner(r)}
     42 }
     43 
     44 func (s *testScanner) scanLine() bool {
     45 	if !s.scanner.Scan() {
     46 		return false
     47 	}
     48 	s.lineNo++
     49 	return true
     50 }
     51 
     52 func (s *testScanner) addAttribute(line string) (key string, ok bool) {
     53 	fields := strings.SplitN(line, "=", 2)
     54 	if len(fields) != 2 {
     55 		s.setError(errors.New("invalid syntax"))
     56 		return "", false
     57 	}
     58 
     59 	key = strings.TrimSpace(fields[0])
     60 	value := strings.TrimSpace(fields[1])
     61 
     62 	valueInt, ok := new(big.Int).SetString(value, 16)
     63 	if !ok {
     64 		s.setError(fmt.Errorf("could not parse %q", value))
     65 		return "", false
     66 	}
     67 	if _, dup := s.test.Values[key]; dup {
     68 		s.setError(fmt.Errorf("duplicate key %q", key))
     69 		return "", false
     70 	}
     71 	s.test.Values[key] = valueInt
     72 	return key, true
     73 }
     74 
     75 func (s *testScanner) Scan() bool {
     76 	s.test = test{
     77 		Values: make(map[string]*big.Int),
     78 	}
     79 
     80 	// Scan until the first attribute.
     81 	for {
     82 		if !s.scanLine() {
     83 			return false
     84 		}
     85 		if len(s.scanner.Text()) != 0 && s.scanner.Text()[0] != '#' {
     86 			break
     87 		}
     88 	}
     89 
     90 	var ok bool
     91 	s.test.Type, ok = s.addAttribute(s.scanner.Text())
     92 	if !ok {
     93 		return false
     94 	}
     95 	s.test.LineNumber = s.lineNo
     96 
     97 	for s.scanLine() {
     98 		if len(s.scanner.Text()) == 0 {
     99 			break
    100 		}
    101 
    102 		if s.scanner.Text()[0] == '#' {
    103 			continue
    104 		}
    105 
    106 		if _, ok := s.addAttribute(s.scanner.Text()); !ok {
    107 			return false
    108 		}
    109 	}
    110 	return s.scanner.Err() == nil
    111 }
    112 
    113 func (s *testScanner) Test() test {
    114 	return s.test
    115 }
    116 
    117 func (s *testScanner) Err() error {
    118 	if s.err != nil {
    119 		return s.err
    120 	}
    121 	return s.scanner.Err()
    122 }
    123 
    124 func (s *testScanner) setError(err error) {
    125 	s.err = fmt.Errorf("line %d: %s", s.lineNo, err)
    126 }
    127 
    128 func checkKeys(t test, keys ...string) bool {
    129 	var foundErrors bool
    130 
    131 	for _, k := range keys {
    132 		if _, ok := t.Values[k]; !ok {
    133 			fmt.Fprintf(os.Stderr, "Line %d: missing key %q.\n", t.LineNumber, k)
    134 			foundErrors = true
    135 		}
    136 	}
    137 
    138 	for k, _ := range t.Values {
    139 		var found bool
    140 		for _, k2 := range keys {
    141 			if k == k2 {
    142 				found = true
    143 				break
    144 			}
    145 		}
    146 		if !found {
    147 			fmt.Fprintf(os.Stderr, "Line %d: unexpected key %q.\n", t.LineNumber, k)
    148 			foundErrors = true
    149 		}
    150 	}
    151 
    152 	return !foundErrors
    153 }
    154 
    155 func checkResult(t test, expr, key string, r *big.Int) {
    156 	if t.Values[key].Cmp(r) != 0 {
    157 		fmt.Fprintf(os.Stderr, "Line %d: %s did not match %s.\n\tGot %s\n", t.LineNumber, expr, key, r.Text(16))
    158 	}
    159 }
    160 
    161 func main() {
    162 	if len(os.Args) != 2 {
    163 		fmt.Fprintf(os.Stderr, "Usage: %s bn_tests.txt\n", os.Args[0])
    164 		os.Exit(1)
    165 	}
    166 
    167 	in, err := os.Open(os.Args[1])
    168 	if err != nil {
    169 		fmt.Fprintf(os.Stderr, "Error opening %s: %s.\n", os.Args[0], err)
    170 		os.Exit(1)
    171 	}
    172 	defer in.Close()
    173 
    174 	scanner := newTestScanner(in)
    175 	for scanner.Scan() {
    176 		test := scanner.Test()
    177 		switch test.Type {
    178 		case "Sum":
    179 			if checkKeys(test, "A", "B", "Sum") {
    180 				r := new(big.Int).Add(test.Values["A"], test.Values["B"])
    181 				checkResult(test, "A + B", "Sum", r)
    182 			}
    183 		case "LShift1":
    184 			if checkKeys(test, "A", "LShift1") {
    185 				r := new(big.Int).Add(test.Values["A"], test.Values["A"])
    186 				checkResult(test, "A + A", "LShift1", r)
    187 			}
    188 		case "LShift":
    189 			if checkKeys(test, "A", "N", "LShift") {
    190 				r := new(big.Int).Lsh(test.Values["A"], uint(test.Values["N"].Uint64()))
    191 				checkResult(test, "A << N", "LShift", r)
    192 			}
    193 		case "RShift":
    194 			if checkKeys(test, "A", "N", "RShift") {
    195 				r := new(big.Int).Rsh(test.Values["A"], uint(test.Values["N"].Uint64()))
    196 				checkResult(test, "A >> N", "RShift", r)
    197 			}
    198 		case "Square":
    199 			if checkKeys(test, "A", "Square") {
    200 				r := new(big.Int).Mul(test.Values["A"], test.Values["A"])
    201 				checkResult(test, "A * A", "Square", r)
    202 			}
    203 		case "Product":
    204 			if checkKeys(test, "A", "B", "Product") {
    205 				r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
    206 				checkResult(test, "A * B", "Product", r)
    207 			}
    208 		case "Quotient":
    209 			if checkKeys(test, "A", "B", "Quotient", "Remainder") {
    210 				q, r := new(big.Int).QuoRem(test.Values["A"], test.Values["B"], new(big.Int))
    211 				checkResult(test, "A / B", "Quotient", q)
    212 				checkResult(test, "A % B", "Remainder", r)
    213 			}
    214 		case "ModMul":
    215 			if checkKeys(test, "A", "B", "M", "ModMul") {
    216 				r := new(big.Int).Mul(test.Values["A"], test.Values["B"])
    217 				r = r.Mod(r, test.Values["M"])
    218 				checkResult(test, "A * B (mod M)", "ModMul", r)
    219 			}
    220 		case "ModExp":
    221 			if checkKeys(test, "A", "E", "M", "ModExp") {
    222 				r := new(big.Int).Exp(test.Values["A"], test.Values["E"], test.Values["M"])
    223 				checkResult(test, "A ^ E (mod M)", "ModExp", r)
    224 			}
    225 		case "Exp":
    226 			if checkKeys(test, "A", "E", "Exp") {
    227 				r := new(big.Int).Exp(test.Values["A"], test.Values["E"], nil)
    228 				checkResult(test, "A ^ E", "Exp", r)
    229 			}
    230 		case "ModSqrt":
    231 			bigOne := new(big.Int).SetInt64(1)
    232 			bigTwo := new(big.Int).SetInt64(2)
    233 
    234 			if checkKeys(test, "A", "P", "ModSqrt") {
    235 				test.Values["A"].Mod(test.Values["A"], test.Values["P"])
    236 
    237 				r := new(big.Int).Mul(test.Values["ModSqrt"], test.Values["ModSqrt"])
    238 				r = r.Mod(r, test.Values["P"])
    239 				checkResult(test, "ModSqrt ^ 2 (mod P)", "A", r)
    240 
    241 				if test.Values["P"].Cmp(bigTwo) > 0 {
    242 					pMinus1Over2 := new(big.Int).Sub(test.Values["P"], bigOne)
    243 					pMinus1Over2.Rsh(pMinus1Over2, 1)
    244 
    245 					if test.Values["ModSqrt"].Cmp(pMinus1Over2) > 0 {
    246 						fmt.Fprintf(os.Stderr, "Line %d: ModSqrt should be minimal.\n", test.LineNumber)
    247 					}
    248 				}
    249 			}
    250 		case "ModInv":
    251 			if checkKeys(test, "A", "M", "ModInv") {
    252 				r := new(big.Int).ModInverse(test.Values["A"], test.Values["M"])
    253 				checkResult(test, "A ^ -1 (mod M)", "ModInv", r)
    254 			}
    255 		default:
    256 			fmt.Fprintf(os.Stderr, "Line %d: unknown test type %q.\n", test.LineNumber, test.Type)
    257 		}
    258 	}
    259 	if scanner.Err() != nil {
    260 		fmt.Fprintf(os.Stderr, "Error reading tests: %s.\n", scanner.Err())
    261 	}
    262 }
    263