Home | History | Annotate | Download | only in subtle
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package subtle
      6 
      7 import (
      8 	"testing"
      9 	"testing/quick"
     10 )
     11 
     12 type TestConstantTimeCompareStruct struct {
     13 	a, b []byte
     14 	out  int
     15 }
     16 
     17 var testConstantTimeCompareData = []TestConstantTimeCompareStruct{
     18 	{[]byte{}, []byte{}, 1},
     19 	{[]byte{0x11}, []byte{0x11}, 1},
     20 	{[]byte{0x12}, []byte{0x11}, 0},
     21 	{[]byte{0x11}, []byte{0x11, 0x12}, 0},
     22 	{[]byte{0x11, 0x12}, []byte{0x11}, 0},
     23 }
     24 
     25 func TestConstantTimeCompare(t *testing.T) {
     26 	for i, test := range testConstantTimeCompareData {
     27 		if r := ConstantTimeCompare(test.a, test.b); r != test.out {
     28 			t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out)
     29 		}
     30 	}
     31 }
     32 
     33 type TestConstantTimeByteEqStruct struct {
     34 	a, b uint8
     35 	out  int
     36 }
     37 
     38 var testConstandTimeByteEqData = []TestConstantTimeByteEqStruct{
     39 	{0, 0, 1},
     40 	{0, 1, 0},
     41 	{1, 0, 0},
     42 	{0xff, 0xff, 1},
     43 	{0xff, 0xfe, 0},
     44 }
     45 
     46 func byteEq(a, b uint8) int {
     47 	if a == b {
     48 		return 1
     49 	}
     50 	return 0
     51 }
     52 
     53 func TestConstantTimeByteEq(t *testing.T) {
     54 	for i, test := range testConstandTimeByteEqData {
     55 		if r := ConstantTimeByteEq(test.a, test.b); r != test.out {
     56 			t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out)
     57 		}
     58 	}
     59 	err := quick.CheckEqual(ConstantTimeByteEq, byteEq, nil)
     60 	if err != nil {
     61 		t.Error(err)
     62 	}
     63 }
     64 
     65 func eq(a, b int32) int {
     66 	if a == b {
     67 		return 1
     68 	}
     69 	return 0
     70 }
     71 
     72 func TestConstantTimeEq(t *testing.T) {
     73 	err := quick.CheckEqual(ConstantTimeEq, eq, nil)
     74 	if err != nil {
     75 		t.Error(err)
     76 	}
     77 }
     78 
     79 func makeCopy(v int, x, y []byte) []byte {
     80 	if len(x) > len(y) {
     81 		x = x[0:len(y)]
     82 	} else {
     83 		y = y[0:len(x)]
     84 	}
     85 	if v == 1 {
     86 		copy(x, y)
     87 	}
     88 	return x
     89 }
     90 
     91 func constantTimeCopyWrapper(v int, x, y []byte) []byte {
     92 	if len(x) > len(y) {
     93 		x = x[0:len(y)]
     94 	} else {
     95 		y = y[0:len(x)]
     96 	}
     97 	v &= 1
     98 	ConstantTimeCopy(v, x, y)
     99 	return x
    100 }
    101 
    102 func TestConstantTimeCopy(t *testing.T) {
    103 	err := quick.CheckEqual(constantTimeCopyWrapper, makeCopy, nil)
    104 	if err != nil {
    105 		t.Error(err)
    106 	}
    107 }
    108 
    109 var lessOrEqTests = []struct {
    110 	x, y, result int
    111 }{
    112 	{0, 0, 1},
    113 	{1, 0, 0},
    114 	{0, 1, 1},
    115 	{10, 20, 1},
    116 	{20, 10, 0},
    117 	{10, 10, 1},
    118 }
    119 
    120 func TestConstantTimeLessOrEq(t *testing.T) {
    121 	for i, test := range lessOrEqTests {
    122 		result := ConstantTimeLessOrEq(test.x, test.y)
    123 		if result != test.result {
    124 			t.Errorf("#%d: %d <= %d gave %d, expected %d", i, test.x, test.y, result, test.result)
    125 		}
    126 	}
    127 }
    128