Home | History | Annotate | Download | only in tls
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package tls
      6 
      7 import (
      8 	"math/rand"
      9 	"reflect"
     10 	"testing"
     11 	"testing/quick"
     12 )
     13 
     14 var tests = []interface{}{
     15 	&clientHelloMsg{},
     16 	&serverHelloMsg{},
     17 	&finishedMsg{},
     18 
     19 	&certificateMsg{},
     20 	&certificateRequestMsg{},
     21 	&certificateVerifyMsg{},
     22 	&certificateStatusMsg{},
     23 	&clientKeyExchangeMsg{},
     24 	&nextProtoMsg{},
     25 	&newSessionTicketMsg{},
     26 	&sessionState{},
     27 }
     28 
     29 type testMessage interface {
     30 	marshal() []byte
     31 	unmarshal([]byte) bool
     32 	equal(interface{}) bool
     33 }
     34 
     35 func TestMarshalUnmarshal(t *testing.T) {
     36 	rand := rand.New(rand.NewSource(0))
     37 
     38 	for i, iface := range tests {
     39 		ty := reflect.ValueOf(iface).Type()
     40 
     41 		n := 100
     42 		if testing.Short() {
     43 			n = 5
     44 		}
     45 		for j := 0; j < n; j++ {
     46 			v, ok := quick.Value(ty, rand)
     47 			if !ok {
     48 				t.Errorf("#%d: failed to create value", i)
     49 				break
     50 			}
     51 
     52 			m1 := v.Interface().(testMessage)
     53 			marshaled := m1.marshal()
     54 			m2 := iface.(testMessage)
     55 			if !m2.unmarshal(marshaled) {
     56 				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
     57 				break
     58 			}
     59 			m2.marshal() // to fill any marshal cache in the message
     60 
     61 			if !m1.equal(m2) {
     62 				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
     63 				break
     64 			}
     65 
     66 			if i >= 3 {
     67 				// The first three message types (ClientHello,
     68 				// ServerHello and Finished) are allowed to
     69 				// have parsable prefixes because the extension
     70 				// data is optional and the length of the
     71 				// Finished varies across versions.
     72 				for j := 0; j < len(marshaled); j++ {
     73 					if m2.unmarshal(marshaled[0:j]) {
     74 						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
     75 						break
     76 					}
     77 				}
     78 			}
     79 		}
     80 	}
     81 }
     82 
     83 func TestFuzz(t *testing.T) {
     84 	rand := rand.New(rand.NewSource(0))
     85 	for _, iface := range tests {
     86 		m := iface.(testMessage)
     87 
     88 		for j := 0; j < 1000; j++ {
     89 			len := rand.Intn(100)
     90 			bytes := randomBytes(len, rand)
     91 			// This just looks for crashes due to bounds errors etc.
     92 			m.unmarshal(bytes)
     93 		}
     94 	}
     95 }
     96 
     97 func randomBytes(n int, rand *rand.Rand) []byte {
     98 	r := make([]byte, n)
     99 	for i := 0; i < n; i++ {
    100 		r[i] = byte(rand.Int31())
    101 	}
    102 	return r
    103 }
    104 
    105 func randomString(n int, rand *rand.Rand) string {
    106 	b := randomBytes(n, rand)
    107 	return string(b)
    108 }
    109 
    110 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    111 	m := &clientHelloMsg{}
    112 	m.vers = uint16(rand.Intn(65536))
    113 	m.random = randomBytes(32, rand)
    114 	m.sessionId = randomBytes(rand.Intn(32), rand)
    115 	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
    116 	for i := 0; i < len(m.cipherSuites); i++ {
    117 		m.cipherSuites[i] = uint16(rand.Int31())
    118 	}
    119 	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
    120 	if rand.Intn(10) > 5 {
    121 		m.nextProtoNeg = true
    122 	}
    123 	if rand.Intn(10) > 5 {
    124 		m.serverName = randomString(rand.Intn(255), rand)
    125 	}
    126 	m.ocspStapling = rand.Intn(10) > 5
    127 	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
    128 	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
    129 	for i := range m.supportedCurves {
    130 		m.supportedCurves[i] = CurveID(rand.Intn(30000))
    131 	}
    132 	if rand.Intn(10) > 5 {
    133 		m.ticketSupported = true
    134 		if rand.Intn(10) > 5 {
    135 			m.sessionTicket = randomBytes(rand.Intn(300), rand)
    136 		}
    137 	}
    138 	if rand.Intn(10) > 5 {
    139 		m.signatureAndHashes = supportedSignatureAlgorithms
    140 	}
    141 	m.alpnProtocols = make([]string, rand.Intn(5))
    142 	for i := range m.alpnProtocols {
    143 		m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
    144 	}
    145 	if rand.Intn(10) > 5 {
    146 		m.scts = true
    147 	}
    148 
    149 	return reflect.ValueOf(m)
    150 }
    151 
    152 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    153 	m := &serverHelloMsg{}
    154 	m.vers = uint16(rand.Intn(65536))
    155 	m.random = randomBytes(32, rand)
    156 	m.sessionId = randomBytes(rand.Intn(32), rand)
    157 	m.cipherSuite = uint16(rand.Int31())
    158 	m.compressionMethod = uint8(rand.Intn(256))
    159 
    160 	if rand.Intn(10) > 5 {
    161 		m.nextProtoNeg = true
    162 
    163 		n := rand.Intn(10)
    164 		m.nextProtos = make([]string, n)
    165 		for i := 0; i < n; i++ {
    166 			m.nextProtos[i] = randomString(20, rand)
    167 		}
    168 	}
    169 
    170 	if rand.Intn(10) > 5 {
    171 		m.ocspStapling = true
    172 	}
    173 	if rand.Intn(10) > 5 {
    174 		m.ticketSupported = true
    175 	}
    176 	m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
    177 
    178 	if rand.Intn(10) > 5 {
    179 		numSCTs := rand.Intn(4)
    180 		m.scts = make([][]byte, numSCTs)
    181 		for i := range m.scts {
    182 			m.scts[i] = randomBytes(rand.Intn(500), rand)
    183 		}
    184 	}
    185 
    186 	return reflect.ValueOf(m)
    187 }
    188 
    189 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    190 	m := &certificateMsg{}
    191 	numCerts := rand.Intn(20)
    192 	m.certificates = make([][]byte, numCerts)
    193 	for i := 0; i < numCerts; i++ {
    194 		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
    195 	}
    196 	return reflect.ValueOf(m)
    197 }
    198 
    199 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    200 	m := &certificateRequestMsg{}
    201 	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
    202 	numCAs := rand.Intn(100)
    203 	m.certificateAuthorities = make([][]byte, numCAs)
    204 	for i := 0; i < numCAs; i++ {
    205 		m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
    206 	}
    207 	return reflect.ValueOf(m)
    208 }
    209 
    210 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    211 	m := &certificateVerifyMsg{}
    212 	m.signature = randomBytes(rand.Intn(15)+1, rand)
    213 	return reflect.ValueOf(m)
    214 }
    215 
    216 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    217 	m := &certificateStatusMsg{}
    218 	if rand.Intn(10) > 5 {
    219 		m.statusType = statusTypeOCSP
    220 		m.response = randomBytes(rand.Intn(10)+1, rand)
    221 	} else {
    222 		m.statusType = 42
    223 	}
    224 	return reflect.ValueOf(m)
    225 }
    226 
    227 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    228 	m := &clientKeyExchangeMsg{}
    229 	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
    230 	return reflect.ValueOf(m)
    231 }
    232 
    233 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    234 	m := &finishedMsg{}
    235 	m.verifyData = randomBytes(12, rand)
    236 	return reflect.ValueOf(m)
    237 }
    238 
    239 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    240 	m := &nextProtoMsg{}
    241 	m.proto = randomString(rand.Intn(255), rand)
    242 	return reflect.ValueOf(m)
    243 }
    244 
    245 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    246 	m := &newSessionTicketMsg{}
    247 	m.ticket = randomBytes(rand.Intn(4), rand)
    248 	return reflect.ValueOf(m)
    249 }
    250 
    251 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
    252 	s := &sessionState{}
    253 	s.vers = uint16(rand.Intn(10000))
    254 	s.cipherSuite = uint16(rand.Intn(10000))
    255 	s.masterSecret = randomBytes(rand.Intn(100), rand)
    256 	numCerts := rand.Intn(20)
    257 	s.certificates = make([][]byte, numCerts)
    258 	for i := 0; i < numCerts; i++ {
    259 		s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
    260 	}
    261 	return reflect.ValueOf(s)
    262 }
    263