Home | History | Annotate | Download | only in tls
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package tls
      6 
      7 import (
      8 	"bytes"
      9 	"math/rand"
     10 	"reflect"
     11 	"strings"
     12 	"testing"
     13 	"testing/quick"
     14 )
     15 
     16 var tests = []interface{}{
     17 	&clientHelloMsg{},
     18 	&serverHelloMsg{},
     19 	&finishedMsg{},
     20 
     21 	&certificateMsg{},
     22 	&certificateRequestMsg{},
     23 	&certificateVerifyMsg{},
     24 	&certificateStatusMsg{},
     25 	&clientKeyExchangeMsg{},
     26 	&nextProtoMsg{},
     27 	&newSessionTicketMsg{},
     28 	&sessionState{},
     29 }
     30 
     31 type testMessage interface {
     32 	marshal() []byte
     33 	unmarshal([]byte) bool
     34 	equal(interface{}) bool
     35 }
     36 
     37 func TestMarshalUnmarshal(t *testing.T) {
     38 	rand := rand.New(rand.NewSource(0))
     39 
     40 	for i, iface := range tests {
     41 		ty := reflect.ValueOf(iface).Type()
     42 
     43 		n := 100
     44 		if testing.Short() {
     45 			n = 5
     46 		}
     47 		for j := 0; j < n; j++ {
     48 			v, ok := quick.Value(ty, rand)
     49 			if !ok {
     50 				t.Errorf("#%d: failed to create value", i)
     51 				break
     52 			}
     53 
     54 			m1 := v.Interface().(testMessage)
     55 			marshaled := m1.marshal()
     56 			m2 := iface.(testMessage)
     57 			if !m2.unmarshal(marshaled) {
     58 				t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
     59 				break
     60 			}
     61 			m2.marshal() // to fill any marshal cache in the message
     62 
     63 			if !m1.equal(m2) {
     64 				t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
     65 				break
     66 			}
     67 
     68 			if i >= 3 {
     69 				// The first three message types (ClientHello,
     70 				// ServerHello and Finished) are allowed to
     71 				// have parsable prefixes because the extension
     72 				// data is optional and the length of the
     73 				// Finished varies across versions.
     74 				for j := 0; j < len(marshaled); j++ {
     75 					if m2.unmarshal(marshaled[0:j]) {
     76 						t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
     77 						break
     78 					}
     79 				}
     80 			}
     81 		}
     82 	}
     83 }
     84 
     85 func TestFuzz(t *testing.T) {
     86 	rand := rand.New(rand.NewSource(0))
     87 	for _, iface := range tests {
     88 		m := iface.(testMessage)
     89 
     90 		for j := 0; j < 1000; j++ {
     91 			len := rand.Intn(100)
     92 			bytes := randomBytes(len, rand)
     93 			// This just looks for crashes due to bounds errors etc.
     94 			m.unmarshal(bytes)
     95 		}
     96 	}
     97 }
     98 
     99 func randomBytes(n int, rand *rand.Rand) []byte {
    100 	r := make([]byte, n)
    101 	if _, err := rand.Read(r); err != nil {
    102 		panic("rand.Read failed: " + err.Error())
    103 	}
    104 	return r
    105 }
    106 
    107 func randomString(n int, rand *rand.Rand) string {
    108 	b := randomBytes(n, rand)
    109 	return string(b)
    110 }
    111 
    112 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    113 	m := &clientHelloMsg{}
    114 	m.vers = uint16(rand.Intn(65536))
    115 	m.random = randomBytes(32, rand)
    116 	m.sessionId = randomBytes(rand.Intn(32), rand)
    117 	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
    118 	for i := 0; i < len(m.cipherSuites); i++ {
    119 		cs := uint16(rand.Int31())
    120 		if cs == scsvRenegotiation {
    121 			cs += 1
    122 		}
    123 		m.cipherSuites[i] = cs
    124 	}
    125 	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
    126 	if rand.Intn(10) > 5 {
    127 		m.nextProtoNeg = true
    128 	}
    129 	if rand.Intn(10) > 5 {
    130 		m.serverName = randomString(rand.Intn(255), rand)
    131 		for strings.HasSuffix(m.serverName, ".") {
    132 			m.serverName = m.serverName[:len(m.serverName)-1]
    133 		}
    134 	}
    135 	m.ocspStapling = rand.Intn(10) > 5
    136 	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
    137 	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
    138 	for i := range m.supportedCurves {
    139 		m.supportedCurves[i] = CurveID(rand.Intn(30000))
    140 	}
    141 	if rand.Intn(10) > 5 {
    142 		m.ticketSupported = true
    143 		if rand.Intn(10) > 5 {
    144 			m.sessionTicket = randomBytes(rand.Intn(300), rand)
    145 		}
    146 	}
    147 	if rand.Intn(10) > 5 {
    148 		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
    149 	}
    150 	m.alpnProtocols = make([]string, rand.Intn(5))
    151 	for i := range m.alpnProtocols {
    152 		m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
    153 	}
    154 	if rand.Intn(10) > 5 {
    155 		m.scts = true
    156 	}
    157 
    158 	return reflect.ValueOf(m)
    159 }
    160 
    161 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    162 	m := &serverHelloMsg{}
    163 	m.vers = uint16(rand.Intn(65536))
    164 	m.random = randomBytes(32, rand)
    165 	m.sessionId = randomBytes(rand.Intn(32), rand)
    166 	m.cipherSuite = uint16(rand.Int31())
    167 	m.compressionMethod = uint8(rand.Intn(256))
    168 
    169 	if rand.Intn(10) > 5 {
    170 		m.nextProtoNeg = true
    171 
    172 		n := rand.Intn(10)
    173 		m.nextProtos = make([]string, n)
    174 		for i := 0; i < n; i++ {
    175 			m.nextProtos[i] = randomString(20, rand)
    176 		}
    177 	}
    178 
    179 	if rand.Intn(10) > 5 {
    180 		m.ocspStapling = true
    181 	}
    182 	if rand.Intn(10) > 5 {
    183 		m.ticketSupported = true
    184 	}
    185 	m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
    186 
    187 	if rand.Intn(10) > 5 {
    188 		numSCTs := rand.Intn(4)
    189 		m.scts = make([][]byte, numSCTs)
    190 		for i := range m.scts {
    191 			m.scts[i] = randomBytes(rand.Intn(500), rand)
    192 		}
    193 	}
    194 
    195 	return reflect.ValueOf(m)
    196 }
    197 
    198 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    199 	m := &certificateMsg{}
    200 	numCerts := rand.Intn(20)
    201 	m.certificates = make([][]byte, numCerts)
    202 	for i := 0; i < numCerts; i++ {
    203 		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
    204 	}
    205 	return reflect.ValueOf(m)
    206 }
    207 
    208 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    209 	m := &certificateRequestMsg{}
    210 	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
    211 	numCAs := rand.Intn(100)
    212 	m.certificateAuthorities = make([][]byte, numCAs)
    213 	for i := 0; i < numCAs; i++ {
    214 		m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
    215 	}
    216 	return reflect.ValueOf(m)
    217 }
    218 
    219 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    220 	m := &certificateVerifyMsg{}
    221 	m.signature = randomBytes(rand.Intn(15)+1, rand)
    222 	return reflect.ValueOf(m)
    223 }
    224 
    225 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    226 	m := &certificateStatusMsg{}
    227 	if rand.Intn(10) > 5 {
    228 		m.statusType = statusTypeOCSP
    229 		m.response = randomBytes(rand.Intn(10)+1, rand)
    230 	} else {
    231 		m.statusType = 42
    232 	}
    233 	return reflect.ValueOf(m)
    234 }
    235 
    236 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    237 	m := &clientKeyExchangeMsg{}
    238 	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
    239 	return reflect.ValueOf(m)
    240 }
    241 
    242 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    243 	m := &finishedMsg{}
    244 	m.verifyData = randomBytes(12, rand)
    245 	return reflect.ValueOf(m)
    246 }
    247 
    248 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    249 	m := &nextProtoMsg{}
    250 	m.proto = randomString(rand.Intn(255), rand)
    251 	return reflect.ValueOf(m)
    252 }
    253 
    254 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
    255 	m := &newSessionTicketMsg{}
    256 	m.ticket = randomBytes(rand.Intn(4), rand)
    257 	return reflect.ValueOf(m)
    258 }
    259 
    260 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
    261 	s := &sessionState{}
    262 	s.vers = uint16(rand.Intn(10000))
    263 	s.cipherSuite = uint16(rand.Intn(10000))
    264 	s.masterSecret = randomBytes(rand.Intn(100), rand)
    265 	numCerts := rand.Intn(20)
    266 	s.certificates = make([][]byte, numCerts)
    267 	for i := 0; i < numCerts; i++ {
    268 		s.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
    269 	}
    270 	return reflect.ValueOf(s)
    271 }
    272 
    273 func TestRejectEmptySCTList(t *testing.T) {
    274 	// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that
    275 	// empty SCT lists are invalid.
    276 
    277 	var random [32]byte
    278 	sct := []byte{0x42, 0x42, 0x42, 0x42}
    279 	serverHello := serverHelloMsg{
    280 		vers:   VersionTLS12,
    281 		random: random[:],
    282 		scts:   [][]byte{sct},
    283 	}
    284 	serverHelloBytes := serverHello.marshal()
    285 
    286 	var serverHelloCopy serverHelloMsg
    287 	if !serverHelloCopy.unmarshal(serverHelloBytes) {
    288 		t.Fatal("Failed to unmarshal initial message")
    289 	}
    290 
    291 	// Change serverHelloBytes so that the SCT list is empty
    292 	i := bytes.Index(serverHelloBytes, sct)
    293 	if i < 0 {
    294 		t.Fatal("Cannot find SCT in ServerHello")
    295 	}
    296 
    297 	var serverHelloEmptySCT []byte
    298 	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
    299 	// Append the extension length and SCT list length for an empty list.
    300 	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
    301 	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
    302 
    303 	// Update the handshake message length.
    304 	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
    305 	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
    306 	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
    307 
    308 	// Update the extensions length
    309 	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
    310 	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
    311 
    312 	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
    313 		t.Fatal("Unmarshaled ServerHello with empty SCT list")
    314 	}
    315 }
    316 
    317 func TestRejectEmptySCT(t *testing.T) {
    318 	// Not only must the SCT list be non-empty, but the SCT elements must
    319 	// not be zero length.
    320 
    321 	var random [32]byte
    322 	serverHello := serverHelloMsg{
    323 		vers:   VersionTLS12,
    324 		random: random[:],
    325 		scts:   [][]byte{nil},
    326 	}
    327 	serverHelloBytes := serverHello.marshal()
    328 
    329 	var serverHelloCopy serverHelloMsg
    330 	if serverHelloCopy.unmarshal(serverHelloBytes) {
    331 		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
    332 	}
    333 }
    334