Home | History | Annotate | Download | only in tls
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package tls
      6 
      7 import (
      8 	"bytes"
      9 	"math/rand"
     10 	"reflect"
     11 	"testing"
     12 	"testing/quick"
     13 )
     14 
     15 var tests = []interface{}{
     16 	&clientHelloMsg{},
     17 	&serverHelloMsg{},
     18 	&finishedMsg{},
     19 
     20 	&certificateMsg{},
     21 	&certificateRequestMsg{},
     22 	&certificateVerifyMsg{},
     23 	&certificateStatusMsg{},
     24 	&clientKeyExchangeMsg{},
     25 	&nextProtoMsg{},
     26 	&newSessionTicketMsg{},
     27 	&sessionState{},
     28 }
     29 
     30 type testMessage interface {
     31 	marshal() []byte
     32 	unmarshal([]byte) bool
     33 	equal(interface{}) bool
     34 }
     35 
     36 func TestMarshalUnmarshal(t *testing.T) {
     37 	rand := rand.New(rand.NewSource(0))
     38 
     39 	for i, iface := range tests {
     40 		ty := reflect.ValueOf(iface).Type()
     41 
     42 		n := 100
     43 		if testing.Short() {
     44 			n = 5
     45 		}
     46 		for j := 0; j < n; j++ {
     47 			v, ok := quick.Value(ty, rand)
     48 			if !ok {
     49 				t.Errorf("#%d: failed to create value", i)
     50 				break
     51 			}
     52 
     53 			m1 := v.Interface().(testMessage)
     54 			marshaled := m1.marshal()
     55 			m2 := iface.(testMessage)
     56 			if !m2.unmarshal(marshaled) {
     57 				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
     58 				break
     59 			}
     60 			m2.marshal() // to fill any marshal cache in the message
     61 
     62 			if !m1.equal(m2) {
     63 				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
     64 				break
     65 			}
     66 
     67 			if i >= 3 {
     68 				// The first three message types (ClientHello,
     69 				// ServerHello and Finished) are allowed to
     70 				// have parsable prefixes because the extension
     71 				// data is optional and the length of the
     72 				// Finished varies across versions.
     73 				for j := 0; j < len(marshaled); j++ {
     74 					if m2.unmarshal(marshaled[0:j]) {
     75 						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
     76 						break
     77 					}
     78 				}
     79 			}
     80 		}
     81 	}
     82 }
     83 
     84 func TestFuzz(t *testing.T) {
     85 	rand := rand.New(rand.NewSource(0))
     86 	for _, iface := range tests {
     87 		m := iface.(testMessage)
     88 
     89 		for j := 0; j < 1000; j++ {
     90 			len := rand.Intn(100)
     91 			bytes := randomBytes(len, rand)
     92 			// This just looks for crashes due to bounds errors etc.
     93 			m.unmarshal(bytes)
     94 		}
     95 	}
     96 }
     97 
     98 func randomBytes(n int, rand *rand.Rand) []byte {
     99 	r := make([]byte, n)
    100 	for i := 0; i < n; i++ {
    101 		r[i] = byte(rand.Int31())
    102 	}
    103 	return r
    104 }
    105 
    106 func randomString(n int, rand *rand.Rand) string {
    107 	b := randomBytes(n, rand)
    108 	return string(b)
    109 }
    110 
    111 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    112 	m := &clientHelloMsg{}
    113 	m.vers = uint16(rand.Intn(65536))
    114 	m.random = randomBytes(32, rand)
    115 	m.sessionId = randomBytes(rand.Intn(32), rand)
    116 	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
    117 	for i := 0; i < len(m.cipherSuites); i++ {
    118 		m.cipherSuites[i] = uint16(rand.Int31())
    119 	}
    120 	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
    121 	if rand.Intn(10) > 5 {
    122 		m.nextProtoNeg = true
    123 	}
    124 	if rand.Intn(10) > 5 {
    125 		m.serverName = randomString(rand.Intn(255), rand)
    126 	}
    127 	m.ocspStapling = rand.Intn(10) > 5
    128 	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
    129 	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
    130 	for i := range m.supportedCurves {
    131 		m.supportedCurves[i] = CurveID(rand.Intn(30000))
    132 	}
    133 	if rand.Intn(10) > 5 {
    134 		m.ticketSupported = true
    135 		if rand.Intn(10) > 5 {
    136 			m.sessionTicket = randomBytes(rand.Intn(300), rand)
    137 		}
    138 	}
    139 	if rand.Intn(10) > 5 {
    140 		m.signatureAndHashes = supportedSignatureAlgorithms
    141 	}
    142 	m.alpnProtocols = make([]string, rand.Intn(5))
    143 	for i := range m.alpnProtocols {
    144 		m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
    145 	}
    146 	if rand.Intn(10) > 5 {
    147 		m.scts = true
    148 	}
    149 
    150 	return reflect.ValueOf(m)
    151 }
    152 
    153 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    154 	m := &serverHelloMsg{}
    155 	m.vers = uint16(rand.Intn(65536))
    156 	m.random = randomBytes(32, rand)
    157 	m.sessionId = randomBytes(rand.Intn(32), rand)
    158 	m.cipherSuite = uint16(rand.Int31())
    159 	m.compressionMethod = uint8(rand.Intn(256))
    160 
    161 	if rand.Intn(10) > 5 {
    162 		m.nextProtoNeg = true
    163 
    164 		n := rand.Intn(10)
    165 		m.nextProtos = make([]string, n)
    166 		for i := 0; i < n; i++ {
    167 			m.nextProtos[i] = randomString(20, rand)
    168 		}
    169 	}
    170 
    171 	if rand.Intn(10) > 5 {
    172 		m.ocspStapling = true
    173 	}
    174 	if rand.Intn(10) > 5 {
    175 		m.ticketSupported = true
    176 	}
    177 	m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
    178 
    179 	if rand.Intn(10) > 5 {
    180 		numSCTs := rand.Intn(4)
    181 		m.scts = make([][]byte, numSCTs)
    182 		for i := range m.scts {
    183 			m.scts[i] = randomBytes(rand.Intn(500), rand)
    184 		}
    185 	}
    186 
    187 	return reflect.ValueOf(m)
    188 }
    189 
    190 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    191 	m := &certificateMsg{}
    192 	numCerts := rand.Intn(20)
    193 	m.certificates = make([][]byte, numCerts)
    194 	for i := 0; i < numCerts; i++ {
    195 		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
    196 	}
    197 	return reflect.ValueOf(m)
    198 }
    199 
    200 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    201 	m := &certificateRequestMsg{}
    202 	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
    203 	numCAs := rand.Intn(100)
    204 	m.certificateAuthorities = make([][]byte, numCAs)
    205 	for i := 0; i < numCAs; i++ {
    206 		m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
    207 	}
    208 	return reflect.ValueOf(m)
    209 }
    210 
    211 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    212 	m := &certificateVerifyMsg{}
    213 	m.signature = randomBytes(rand.Intn(15)+1, rand)
    214 	return reflect.ValueOf(m)
    215 }
    216 
    217 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    218 	m := &certificateStatusMsg{}
    219 	if rand.Intn(10) > 5 {
    220 		m.statusType = statusTypeOCSP
    221 		m.response = randomBytes(rand.Intn(10)+1, rand)
    222 	} else {
    223 		m.statusType = 42
    224 	}
    225 	return reflect.ValueOf(m)
    226 }
    227 
    228 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    229 	m := &clientKeyExchangeMsg{}
    230 	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
    231 	return reflect.ValueOf(m)
    232 }
    233 
    234 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    235 	m := &finishedMsg{}
    236 	m.verifyData = randomBytes(12, rand)
    237 	return reflect.ValueOf(m)
    238 }
    239 
    240 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    241 	m := &nextProtoMsg{}
    242 	m.proto = randomString(rand.Intn(255), rand)
    243 	return reflect.ValueOf(m)
    244 }
    245 
    246 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    247 	m := &newSessionTicketMsg{}
    248 	m.ticket = randomBytes(rand.Intn(4), rand)
    249 	return reflect.ValueOf(m)
    250 }
    251 
    252 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
    253 	s := &sessionState{}
    254 	s.vers = uint16(rand.Intn(10000))
    255 	s.cipherSuite = uint16(rand.Intn(10000))
    256 	s.masterSecret = randomBytes(rand.Intn(100), rand)
    257 	numCerts := rand.Intn(20)
    258 	s.certificates = make([][]byte, numCerts)
    259 	for i := 0; i < numCerts; i++ {
    260 		s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
    261 	}
    262 	return reflect.ValueOf(s)
    263 }
    264 
    265 func TestRejectEmptySCTList(t *testing.T) {
    266 	// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
    267 	// empty SCT lists are invalid.
    268 
    269 	var random [32]byte
    270 	sct := []byte{0x42, 0x42, 0x42, 0x42}
    271 	serverHello := serverHelloMsg{
    272 		vers:   VersionTLS12,
    273 		random: random[:],
    274 		scts:   [][]byte{sct},
    275 	}
    276 	serverHelloBytes := serverHello.marshal()
    277 
    278 	var serverHelloCopy serverHelloMsg
    279 	if !serverHelloCopy.unmarshal(serverHelloBytes) {
    280 		t.Fatal("Failed to unmarshal initial message")
    281 	}
    282 
    283 	// Change serverHelloBytes so that the SCT list is empty
    284 	i := bytes.Index(serverHelloBytes, sct)
    285 	if i < 0 {
    286 		t.Fatal("Cannot find SCT in ServerHello")
    287 	}
    288 
    289 	var serverHelloEmptySCT []byte
    290 	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
    291 	// Append the extension length and SCT list length for an empty list.
    292 	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
    293 	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
    294 
    295 	// Update the handshake message length.
    296 	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
    297 	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
    298 	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
    299 
    300 	// Update the extensions length
    301 	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
    302 	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
    303 
    304 	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
    305 		t.Fatal("Unmarshaled ServerHello with empty SCT list")
    306 	}
    307 }
    308 
    309 func TestRejectEmptySCT(t *testing.T) {
    310 	// Not only must the SCT list be non-empty, but the SCT elements must
    311 	// not be zero length.
    312 
    313 	var random [32]byte
    314 	serverHello := serverHelloMsg{
    315 		vers:   VersionTLS12,
    316 		random: random[:],
    317 		scts:   [][]byte{nil},
    318 	}
    319 	serverHelloBytes := serverHello.marshal()
    320 
    321 	var serverHelloCopy serverHelloMsg
    322 	if serverHelloCopy.unmarshal(serverHelloBytes) {
    323 		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
    324 	}
    325 }
    326