1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "math/rand" 9 "reflect" 10 "testing" 11 "testing/quick" 12 ) 13 14 var tests = []interface{}{ 15 &clientHelloMsg{}, 16 &serverHelloMsg{}, 17 &finishedMsg{}, 18 19 &certificateMsg{}, 20 &certificateRequestMsg{}, 21 &certificateVerifyMsg{}, 22 &certificateStatusMsg{}, 23 &clientKeyExchangeMsg{}, 24 &nextProtoMsg{}, 25 &newSessionTicketMsg{}, 26 &sessionState{}, 27 } 28 29 type testMessage interface { 30 marshal() []byte 31 unmarshal([]byte) bool 32 equal(interface{}) bool 33 } 34 35 func TestMarshalUnmarshal(t *testing.T) { 36 rand := rand.New(rand.NewSource(0)) 37 38 for i, iface := range tests { 39 ty := reflect.ValueOf(iface).Type() 40 41 n := 100 42 if testing.Short() { 43 n = 5 44 } 45 for j := 0; j < n; j++ { 46 v, ok := quick.Value(ty, rand) 47 if !ok { 48 t.Errorf("#%d: failed to create value", i) 49 break 50 } 51 52 m1 := v.Interface().(testMessage) 53 marshaled := m1.marshal() 54 m2 := iface.(testMessage) 55 if !m2.unmarshal(marshaled) { 56 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 57 break 58 } 59 m2.marshal() // to fill any marshal cache in the message 60 61 if !m1.equal(m2) { 62 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) 63 break 64 } 65 66 if i >= 3 { 67 // The first three message types (ClientHello, 68 // ServerHello and Finished) are allowed to 69 // have parsable prefixes because the extension 70 // data is optional and the length of the 71 // Finished varies across versions. 72 for j := 0; j < len(marshaled); j++ { 73 if m2.unmarshal(marshaled[0:j]) { 74 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 75 break 76 } 77 } 78 } 79 } 80 } 81 } 82 83 func TestFuzz(t *testing.T) { 84 rand := rand.New(rand.NewSource(0)) 85 for _, iface := range tests { 86 m := iface.(testMessage) 87 88 for j := 0; j < 1000; j++ { 89 len := rand.Intn(100) 90 bytes := randomBytes(len, rand) 91 // This just looks for crashes due to bounds errors etc. 92 m.unmarshal(bytes) 93 } 94 } 95 } 96 97 func randomBytes(n int, rand *rand.Rand) []byte { 98 r := make([]byte, n) 99 for i := 0; i < n; i++ { 100 r[i] = byte(rand.Int31()) 101 } 102 return r 103 } 104 105 func randomString(n int, rand *rand.Rand) string { 106 b := randomBytes(n, rand) 107 return string(b) 108 } 109 110 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 111 m := &clientHelloMsg{} 112 m.vers = uint16(rand.Intn(65536)) 113 m.random = randomBytes(32, rand) 114 m.sessionId = randomBytes(rand.Intn(32), rand) 115 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 116 for i := 0; i < len(m.cipherSuites); i++ { 117 m.cipherSuites[i] = uint16(rand.Int31()) 118 } 119 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 120 if rand.Intn(10) > 5 { 121 m.nextProtoNeg = true 122 } 123 if rand.Intn(10) > 5 { 124 m.serverName = randomString(rand.Intn(255), rand) 125 } 126 m.ocspStapling = rand.Intn(10) > 5 127 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 128 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 129 for i := range m.supportedCurves { 130 m.supportedCurves[i] = CurveID(rand.Intn(30000)) 131 } 132 if rand.Intn(10) > 5 { 133 m.ticketSupported = true 134 if rand.Intn(10) > 5 { 135 m.sessionTicket = randomBytes(rand.Intn(300), rand) 136 } 137 } 138 if rand.Intn(10) > 5 { 139 m.signatureAndHashes = supportedSignatureAlgorithms 140 } 141 m.alpnProtocols = make([]string, rand.Intn(5)) 142 for i := range m.alpnProtocols { 143 m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand) 144 } 145 if rand.Intn(10) > 5 { 146 m.scts = true 147 } 148 149 return reflect.ValueOf(m) 150 } 151 152 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 153 m := &serverHelloMsg{} 154 m.vers = uint16(rand.Intn(65536)) 155 m.random = randomBytes(32, rand) 156 m.sessionId = randomBytes(rand.Intn(32), rand) 157 m.cipherSuite = uint16(rand.Int31()) 158 m.compressionMethod = uint8(rand.Intn(256)) 159 160 if rand.Intn(10) > 5 { 161 m.nextProtoNeg = true 162 163 n := rand.Intn(10) 164 m.nextProtos = make([]string, n) 165 for i := 0; i < n; i++ { 166 m.nextProtos[i] = randomString(20, rand) 167 } 168 } 169 170 if rand.Intn(10) > 5 { 171 m.ocspStapling = true 172 } 173 if rand.Intn(10) > 5 { 174 m.ticketSupported = true 175 } 176 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 177 178 if rand.Intn(10) > 5 { 179 numSCTs := rand.Intn(4) 180 m.scts = make([][]byte, numSCTs) 181 for i := range m.scts { 182 m.scts[i] = randomBytes(rand.Intn(500), rand) 183 } 184 } 185 186 return reflect.ValueOf(m) 187 } 188 189 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 190 m := &certificateMsg{} 191 numCerts := rand.Intn(20) 192 m.certificates = make([][]byte, numCerts) 193 for i := 0; i < numCerts; i++ { 194 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 195 } 196 return reflect.ValueOf(m) 197 } 198 199 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 200 m := &certificateRequestMsg{} 201 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 202 numCAs := rand.Intn(100) 203 m.certificateAuthorities = make([][]byte, numCAs) 204 for i := 0; i < numCAs; i++ { 205 m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) 206 } 207 return reflect.ValueOf(m) 208 } 209 210 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 211 m := &certificateVerifyMsg{} 212 m.signature = randomBytes(rand.Intn(15)+1, rand) 213 return reflect.ValueOf(m) 214 } 215 216 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 217 m := &certificateStatusMsg{} 218 if rand.Intn(10) > 5 { 219 m.statusType = statusTypeOCSP 220 m.response = randomBytes(rand.Intn(10)+1, rand) 221 } else { 222 m.statusType = 42 223 } 224 return reflect.ValueOf(m) 225 } 226 227 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 228 m := &clientKeyExchangeMsg{} 229 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 230 return reflect.ValueOf(m) 231 } 232 233 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 234 m := &finishedMsg{} 235 m.verifyData = randomBytes(12, rand) 236 return reflect.ValueOf(m) 237 } 238 239 func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value { 240 m := &nextProtoMsg{} 241 m.proto = randomString(rand.Intn(255), rand) 242 return reflect.ValueOf(m) 243 } 244 245 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 246 m := &newSessionTicketMsg{} 247 m.ticket = randomBytes(rand.Intn(4), rand) 248 return reflect.ValueOf(m) 249 } 250 251 func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 252 s := &sessionState{} 253 s.vers = uint16(rand.Intn(10000)) 254 s.cipherSuite = uint16(rand.Intn(10000)) 255 s.masterSecret = randomBytes(rand.Intn(100), rand) 256 numCerts := rand.Intn(20) 257 s.certificates = make([][]byte, numCerts) 258 for i := 0; i < numCerts; i++ { 259 s.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 260 } 261 return reflect.ValueOf(s) 262 } 263