1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "math/rand" 10 "reflect" 11 "testing" 12 "testing/quick" 13 ) 14 15 var tests = []interface{}{ 16 &clientHelloMsg{}, 17 &serverHelloMsg{}, 18 &finishedMsg{}, 19 20 &certificateMsg{}, 21 &certificateRequestMsg{}, 22 &certificateVerifyMsg{}, 23 &certificateStatusMsg{}, 24 &clientKeyExchangeMsg{}, 25 &nextProtoMsg{}, 26 &newSessionTicketMsg{}, 27 &sessionState{}, 28 } 29 30 type testMessage interface { 31 marshal() []byte 32 unmarshal([]byte) bool 33 equal(interface{}) bool 34 } 35 36 func TestMarshalUnmarshal(t *testing.T) { 37 rand := rand.New(rand.NewSource(0)) 38 39 for i, iface := range tests { 40 ty := reflect.ValueOf(iface).Type() 41 42 n := 100 43 if testing.Short() { 44 n = 5 45 } 46 for j := 0; j < n; j++ { 47 v, ok := quick.Value(ty, rand) 48 if !ok { 49 t.Errorf("#%d: failed to create value", i) 50 break 51 } 52 53 m1 := v.Interface().(testMessage) 54 marshaled := m1.marshal() 55 m2 := iface.(testMessage) 56 if !m2.unmarshal(marshaled) { 57 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 58 break 59 } 60 m2.marshal() // to fill any marshal cache in the message 61 62 if !m1.equal(m2) { 63 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) 64 break 65 } 66 67 if i >= 3 { 68 // The first three message types (ClientHello, 69 // ServerHello and Finished) are allowed to 70 // have parsable prefixes because the extension 71 // data is optional and the length of the 72 // Finished varies across versions. 73 for j := 0; j < len(marshaled); j++ { 74 if m2.unmarshal(marshaled[0:j]) { 75 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 76 break 77 } 78 } 79 } 80 } 81 } 82 } 83 84 func TestFuzz(t *testing.T) { 85 rand := rand.New(rand.NewSource(0)) 86 for _, iface := range tests { 87 m := iface.(testMessage) 88 89 for j := 0; j < 1000; j++ { 90 len := rand.Intn(100) 91 bytes := randomBytes(len, rand) 92 // This just looks for crashes due to bounds errors etc. 93 m.unmarshal(bytes) 94 } 95 } 96 } 97 98 func randomBytes(n int, rand *rand.Rand) []byte { 99 r := make([]byte, n) 100 for i := 0; i < n; i++ { 101 r[i] = byte(rand.Int31()) 102 } 103 return r 104 } 105 106 func randomString(n int, rand *rand.Rand) string { 107 b := randomBytes(n, rand) 108 return string(b) 109 } 110 111 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 112 m := &clientHelloMsg{} 113 m.vers = uint16(rand.Intn(65536)) 114 m.random = randomBytes(32, rand) 115 m.sessionId = randomBytes(rand.Intn(32), rand) 116 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 117 for i := 0; i < len(m.cipherSuites); i++ { 118 m.cipherSuites[i] = uint16(rand.Int31()) 119 } 120 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 121 if rand.Intn(10) > 5 { 122 m.nextProtoNeg = true 123 } 124 if rand.Intn(10) > 5 { 125 m.serverName = randomString(rand.Intn(255), rand) 126 } 127 m.ocspStapling = rand.Intn(10) > 5 128 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 129 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 130 for i := range m.supportedCurves { 131 m.supportedCurves[i] = CurveID(rand.Intn(30000)) 132 } 133 if rand.Intn(10) > 5 { 134 m.ticketSupported = true 135 if rand.Intn(10) > 5 { 136 m.sessionTicket = randomBytes(rand.Intn(300), rand) 137 } 138 } 139 if rand.Intn(10) > 5 { 140 m.signatureAndHashes = supportedSignatureAlgorithms 141 } 142 m.alpnProtocols = make([]string, rand.Intn(5)) 143 for i := range m.alpnProtocols { 144 m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand) 145 } 146 if rand.Intn(10) > 5 { 147 m.scts = true 148 } 149 150 return reflect.ValueOf(m) 151 } 152 153 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 154 m := &serverHelloMsg{} 155 m.vers = uint16(rand.Intn(65536)) 156 m.random = randomBytes(32, rand) 157 m.sessionId = randomBytes(rand.Intn(32), rand) 158 m.cipherSuite = uint16(rand.Int31()) 159 m.compressionMethod = uint8(rand.Intn(256)) 160 161 if rand.Intn(10) > 5 { 162 m.nextProtoNeg = true 163 164 n := rand.Intn(10) 165 m.nextProtos = make([]string, n) 166 for i := 0; i < n; i++ { 167 m.nextProtos[i] = randomString(20, rand) 168 } 169 } 170 171 if rand.Intn(10) > 5 { 172 m.ocspStapling = true 173 } 174 if rand.Intn(10) > 5 { 175 m.ticketSupported = true 176 } 177 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 178 179 if rand.Intn(10) > 5 { 180 numSCTs := rand.Intn(4) 181 m.scts = make([][]byte, numSCTs) 182 for i := range m.scts { 183 m.scts[i] = randomBytes(rand.Intn(500), rand) 184 } 185 } 186 187 return reflect.ValueOf(m) 188 } 189 190 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 191 m := &certificateMsg{} 192 numCerts := rand.Intn(20) 193 m.certificates = make([][]byte, numCerts) 194 for i := 0; i < numCerts; i++ { 195 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 196 } 197 return reflect.ValueOf(m) 198 } 199 200 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 201 m := &certificateRequestMsg{} 202 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 203 numCAs := rand.Intn(100) 204 m.certificateAuthorities = make([][]byte, numCAs) 205 for i := 0; i < numCAs; i++ { 206 m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) 207 } 208 return reflect.ValueOf(m) 209 } 210 211 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 212 m := &certificateVerifyMsg{} 213 m.signature = randomBytes(rand.Intn(15)+1, rand) 214 return reflect.ValueOf(m) 215 } 216 217 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 218 m := &certificateStatusMsg{} 219 if rand.Intn(10) > 5 { 220 m.statusType = statusTypeOCSP 221 m.response = randomBytes(rand.Intn(10)+1, rand) 222 } else { 223 m.statusType = 42 224 } 225 return reflect.ValueOf(m) 226 } 227 228 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 229 m := &clientKeyExchangeMsg{} 230 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 231 return reflect.ValueOf(m) 232 } 233 234 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 235 m := &finishedMsg{} 236 m.verifyData = randomBytes(12, rand) 237 return reflect.ValueOf(m) 238 } 239 240 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value { 241 m := &nextProtoMsg{} 242 m.proto = randomString(rand.Intn(255), rand) 243 return reflect.ValueOf(m) 244 } 245 246 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 247 m := &newSessionTicketMsg{} 248 m.ticket = randomBytes(rand.Intn(4), rand) 249 return reflect.ValueOf(m) 250 } 251 252 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 253 s := &sessionState{} 254 s.vers = uint16(rand.Intn(10000)) 255 s.cipherSuite = uint16(rand.Intn(10000)) 256 s.masterSecret = randomBytes(rand.Intn(100), rand) 257 numCerts := rand.Intn(20) 258 s.certificates = make([][]byte, numCerts) 259 for i := 0; i < numCerts; i++ { 260 s.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 261 } 262 return reflect.ValueOf(s) 263 } 264 265 func TestRejectEmptySCTList(t *testing.T) { 266 // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that 267 // empty SCT lists are invalid. 268 269 var random [32]byte 270 sct := []byte{0x42, 0x42, 0x42, 0x42} 271 serverHello := serverHelloMsg{ 272 vers: VersionTLS12, 273 random: random[:], 274 scts: [][]byte{sct}, 275 } 276 serverHelloBytes := serverHello.marshal() 277 278 var serverHelloCopy serverHelloMsg 279 if !serverHelloCopy.unmarshal(serverHelloBytes) { 280 t.Fatal("Failed to unmarshal initial message") 281 } 282 283 // Change serverHelloBytes so that the SCT list is empty 284 i := bytes.Index(serverHelloBytes, sct) 285 if i < 0 { 286 t.Fatal("Cannot find SCT in ServerHello") 287 } 288 289 var serverHelloEmptySCT []byte 290 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 291 // Append the extension length and SCT list length for an empty list. 292 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 293 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 294 295 // Update the handshake message length. 296 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 297 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 298 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 299 300 // Update the extensions length 301 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 302 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) 303 304 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 305 t.Fatal("Unmarshaled ServerHello with empty SCT list") 306 } 307 } 308 309 func TestRejectEmptySCT(t *testing.T) { 310 // Not only must the SCT list be non-empty, but the SCT elements must 311 // not be zero length. 312 313 var random [32]byte 314 serverHello := serverHelloMsg{ 315 vers: VersionTLS12, 316 random: random[:], 317 scts: [][]byte{nil}, 318 } 319 serverHelloBytes := serverHello.marshal() 320 321 var serverHelloCopy serverHelloMsg 322 if serverHelloCopy.unmarshal(serverHelloBytes) { 323 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 324 } 325 } 326