1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import "bytes" 8 9 type clientHelloMsg struct { 10 raw []byte 11 vers uint16 12 random []byte 13 sessionId []byte 14 cipherSuites []uint16 15 compressionMethods []uint8 16 nextProtoNeg bool 17 serverName string 18 ocspStapling bool 19 scts bool 20 supportedCurves []CurveID 21 supportedPoints []uint8 22 ticketSupported bool 23 sessionTicket []uint8 24 signatureAndHashes []signatureAndHash 25 secureRenegotiation bool 26 alpnProtocols []string 27 } 28 29 func (m *clientHelloMsg) equal(i interface{}) bool { 30 m1, ok := i.(*clientHelloMsg) 31 if !ok { 32 return false 33 } 34 35 return bytes.Equal(m.raw, m1.raw) && 36 m.vers == m1.vers && 37 bytes.Equal(m.random, m1.random) && 38 bytes.Equal(m.sessionId, m1.sessionId) && 39 eqUint16s(m.cipherSuites, m1.cipherSuites) && 40 bytes.Equal(m.compressionMethods, m1.compressionMethods) && 41 m.nextProtoNeg == m1.nextProtoNeg && 42 m.serverName == m1.serverName && 43 m.ocspStapling == m1.ocspStapling && 44 m.scts == m1.scts && 45 eqCurveIDs(m.supportedCurves, m1.supportedCurves) && 46 bytes.Equal(m.supportedPoints, m1.supportedPoints) && 47 m.ticketSupported == m1.ticketSupported && 48 bytes.Equal(m.sessionTicket, m1.sessionTicket) && 49 eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) && 50 m.secureRenegotiation == m1.secureRenegotiation && 51 eqStrings(m.alpnProtocols, m1.alpnProtocols) 52 } 53 54 func (m *clientHelloMsg) marshal() []byte { 55 if m.raw != nil { 56 return m.raw 57 } 58 59 length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) 60 numExtensions := 0 61 extensionsLength := 0 62 if m.nextProtoNeg { 63 numExtensions++ 64 } 65 if m.ocspStapling { 66 extensionsLength += 1 + 2 + 2 67 numExtensions++ 68 } 69 if len(m.serverName) > 0 { 70 extensionsLength += 5 + len(m.serverName) 71 numExtensions++ 72 } 73 if len(m.supportedCurves) > 0 { 74 extensionsLength += 2 + 2*len(m.supportedCurves) 75 numExtensions++ 76 } 77 if len(m.supportedPoints) > 0 { 78 extensionsLength += 1 + len(m.supportedPoints) 79 numExtensions++ 80 } 81 if m.ticketSupported { 82 extensionsLength += len(m.sessionTicket) 83 numExtensions++ 84 } 85 if len(m.signatureAndHashes) > 0 { 86 extensionsLength += 2 + 2*len(m.signatureAndHashes) 87 numExtensions++ 88 } 89 if m.secureRenegotiation { 90 extensionsLength += 1 91 numExtensions++ 92 } 93 if len(m.alpnProtocols) > 0 { 94 extensionsLength += 2 95 for _, s := range m.alpnProtocols { 96 if l := len(s); l == 0 || l > 255 { 97 panic("invalid ALPN protocol") 98 } 99 extensionsLength++ 100 extensionsLength += len(s) 101 } 102 numExtensions++ 103 } 104 if m.scts { 105 numExtensions++ 106 } 107 if numExtensions > 0 { 108 extensionsLength += 4 * numExtensions 109 length += 2 + extensionsLength 110 } 111 112 x := make([]byte, 4+length) 113 x[0] = typeClientHello 114 x[1] = uint8(length >> 16) 115 x[2] = uint8(length >> 8) 116 x[3] = uint8(length) 117 x[4] = uint8(m.vers >> 8) 118 x[5] = uint8(m.vers) 119 copy(x[6:38], m.random) 120 x[38] = uint8(len(m.sessionId)) 121 copy(x[39:39+len(m.sessionId)], m.sessionId) 122 y := x[39+len(m.sessionId):] 123 y[0] = uint8(len(m.cipherSuites) >> 7) 124 y[1] = uint8(len(m.cipherSuites) << 1) 125 for i, suite := range m.cipherSuites { 126 y[2+i*2] = uint8(suite >> 8) 127 y[3+i*2] = uint8(suite) 128 } 129 z := y[2+len(m.cipherSuites)*2:] 130 z[0] = uint8(len(m.compressionMethods)) 131 copy(z[1:], m.compressionMethods) 132 133 z = z[1+len(m.compressionMethods):] 134 if numExtensions > 0 { 135 z[0] = byte(extensionsLength >> 8) 136 z[1] = byte(extensionsLength) 137 z = z[2:] 138 } 139 if m.nextProtoNeg { 140 z[0] = byte(extensionNextProtoNeg >> 8) 141 z[1] = byte(extensionNextProtoNeg & 0xff) 142 // The length is always 0 143 z = z[4:] 144 } 145 if len(m.serverName) > 0 { 146 z[0] = byte(extensionServerName >> 8) 147 z[1] = byte(extensionServerName & 0xff) 148 l := len(m.serverName) + 5 149 z[2] = byte(l >> 8) 150 z[3] = byte(l) 151 z = z[4:] 152 153 // RFC 3546, section 3.1 154 // 155 // struct { 156 // NameType name_type; 157 // select (name_type) { 158 // case host_name: HostName; 159 // } name; 160 // } ServerName; 161 // 162 // enum { 163 // host_name(0), (255) 164 // } NameType; 165 // 166 // opaque HostName<1..2^16-1>; 167 // 168 // struct { 169 // ServerName server_name_list<1..2^16-1> 170 // } ServerNameList; 171 172 z[0] = byte((len(m.serverName) + 3) >> 8) 173 z[1] = byte(len(m.serverName) + 3) 174 z[3] = byte(len(m.serverName) >> 8) 175 z[4] = byte(len(m.serverName)) 176 copy(z[5:], []byte(m.serverName)) 177 z = z[l:] 178 } 179 if m.ocspStapling { 180 // RFC 4366, section 3.6 181 z[0] = byte(extensionStatusRequest >> 8) 182 z[1] = byte(extensionStatusRequest) 183 z[2] = 0 184 z[3] = 5 185 z[4] = 1 // OCSP type 186 // Two zero valued uint16s for the two lengths. 187 z = z[9:] 188 } 189 if len(m.supportedCurves) > 0 { 190 // http://tools.ietf.org/html/rfc4492#section-5.5.1 191 z[0] = byte(extensionSupportedCurves >> 8) 192 z[1] = byte(extensionSupportedCurves) 193 l := 2 + 2*len(m.supportedCurves) 194 z[2] = byte(l >> 8) 195 z[3] = byte(l) 196 l -= 2 197 z[4] = byte(l >> 8) 198 z[5] = byte(l) 199 z = z[6:] 200 for _, curve := range m.supportedCurves { 201 z[0] = byte(curve >> 8) 202 z[1] = byte(curve) 203 z = z[2:] 204 } 205 } 206 if len(m.supportedPoints) > 0 { 207 // http://tools.ietf.org/html/rfc4492#section-5.5.2 208 z[0] = byte(extensionSupportedPoints >> 8) 209 z[1] = byte(extensionSupportedPoints) 210 l := 1 + len(m.supportedPoints) 211 z[2] = byte(l >> 8) 212 z[3] = byte(l) 213 l-- 214 z[4] = byte(l) 215 z = z[5:] 216 for _, pointFormat := range m.supportedPoints { 217 z[0] = byte(pointFormat) 218 z = z[1:] 219 } 220 } 221 if m.ticketSupported { 222 // http://tools.ietf.org/html/rfc5077#section-3.2 223 z[0] = byte(extensionSessionTicket >> 8) 224 z[1] = byte(extensionSessionTicket) 225 l := len(m.sessionTicket) 226 z[2] = byte(l >> 8) 227 z[3] = byte(l) 228 z = z[4:] 229 copy(z, m.sessionTicket) 230 z = z[len(m.sessionTicket):] 231 } 232 if len(m.signatureAndHashes) > 0 { 233 // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 234 z[0] = byte(extensionSignatureAlgorithms >> 8) 235 z[1] = byte(extensionSignatureAlgorithms) 236 l := 2 + 2*len(m.signatureAndHashes) 237 z[2] = byte(l >> 8) 238 z[3] = byte(l) 239 z = z[4:] 240 241 l -= 2 242 z[0] = byte(l >> 8) 243 z[1] = byte(l) 244 z = z[2:] 245 for _, sigAndHash := range m.signatureAndHashes { 246 z[0] = sigAndHash.hash 247 z[1] = sigAndHash.signature 248 z = z[2:] 249 } 250 } 251 if m.secureRenegotiation { 252 z[0] = byte(extensionRenegotiationInfo >> 8) 253 z[1] = byte(extensionRenegotiationInfo & 0xff) 254 z[2] = 0 255 z[3] = 1 256 z = z[5:] 257 } 258 if len(m.alpnProtocols) > 0 { 259 z[0] = byte(extensionALPN >> 8) 260 z[1] = byte(extensionALPN & 0xff) 261 lengths := z[2:] 262 z = z[6:] 263 264 stringsLength := 0 265 for _, s := range m.alpnProtocols { 266 l := len(s) 267 z[0] = byte(l) 268 copy(z[1:], s) 269 z = z[1+l:] 270 stringsLength += 1 + l 271 } 272 273 lengths[2] = byte(stringsLength >> 8) 274 lengths[3] = byte(stringsLength) 275 stringsLength += 2 276 lengths[0] = byte(stringsLength >> 8) 277 lengths[1] = byte(stringsLength) 278 } 279 if m.scts { 280 // https://tools.ietf.org/html/rfc6962#section-3.3.1 281 z[0] = byte(extensionSCT >> 8) 282 z[1] = byte(extensionSCT) 283 // zero uint16 for the zero-length extension_data 284 z = z[4:] 285 } 286 287 m.raw = x 288 289 return x 290 } 291 292 func (m *clientHelloMsg) unmarshal(data []byte) bool { 293 if len(data) < 42 { 294 return false 295 } 296 m.raw = data 297 m.vers = uint16(data[4])<<8 | uint16(data[5]) 298 m.random = data[6:38] 299 sessionIdLen := int(data[38]) 300 if sessionIdLen > 32 || len(data) < 39+sessionIdLen { 301 return false 302 } 303 m.sessionId = data[39 : 39+sessionIdLen] 304 data = data[39+sessionIdLen:] 305 if len(data) < 2 { 306 return false 307 } 308 // cipherSuiteLen is the number of bytes of cipher suite numbers. Since 309 // they are uint16s, the number must be even. 310 cipherSuiteLen := int(data[0])<<8 | int(data[1]) 311 if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { 312 return false 313 } 314 numCipherSuites := cipherSuiteLen / 2 315 m.cipherSuites = make([]uint16, numCipherSuites) 316 for i := 0; i < numCipherSuites; i++ { 317 m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) 318 if m.cipherSuites[i] == scsvRenegotiation { 319 m.secureRenegotiation = true 320 } 321 } 322 data = data[2+cipherSuiteLen:] 323 if len(data) < 1 { 324 return false 325 } 326 compressionMethodsLen := int(data[0]) 327 if len(data) < 1+compressionMethodsLen { 328 return false 329 } 330 m.compressionMethods = data[1 : 1+compressionMethodsLen] 331 332 data = data[1+compressionMethodsLen:] 333 334 m.nextProtoNeg = false 335 m.serverName = "" 336 m.ocspStapling = false 337 m.ticketSupported = false 338 m.sessionTicket = nil 339 m.signatureAndHashes = nil 340 m.alpnProtocols = nil 341 m.scts = false 342 343 if len(data) == 0 { 344 // ClientHello is optionally followed by extension data 345 return true 346 } 347 if len(data) < 2 { 348 return false 349 } 350 351 extensionsLength := int(data[0])<<8 | int(data[1]) 352 data = data[2:] 353 if extensionsLength != len(data) { 354 return false 355 } 356 357 for len(data) != 0 { 358 if len(data) < 4 { 359 return false 360 } 361 extension := uint16(data[0])<<8 | uint16(data[1]) 362 length := int(data[2])<<8 | int(data[3]) 363 data = data[4:] 364 if len(data) < length { 365 return false 366 } 367 368 switch extension { 369 case extensionServerName: 370 d := data[:length] 371 if len(d) < 2 { 372 return false 373 } 374 namesLen := int(d[0])<<8 | int(d[1]) 375 d = d[2:] 376 if len(d) != namesLen { 377 return false 378 } 379 for len(d) > 0 { 380 if len(d) < 3 { 381 return false 382 } 383 nameType := d[0] 384 nameLen := int(d[1])<<8 | int(d[2]) 385 d = d[3:] 386 if len(d) < nameLen { 387 return false 388 } 389 if nameType == 0 { 390 m.serverName = string(d[:nameLen]) 391 break 392 } 393 d = d[nameLen:] 394 } 395 case extensionNextProtoNeg: 396 if length > 0 { 397 return false 398 } 399 m.nextProtoNeg = true 400 case extensionStatusRequest: 401 m.ocspStapling = length > 0 && data[0] == statusTypeOCSP 402 case extensionSupportedCurves: 403 // http://tools.ietf.org/html/rfc4492#section-5.5.1 404 if length < 2 { 405 return false 406 } 407 l := int(data[0])<<8 | int(data[1]) 408 if l%2 == 1 || length != l+2 { 409 return false 410 } 411 numCurves := l / 2 412 m.supportedCurves = make([]CurveID, numCurves) 413 d := data[2:] 414 for i := 0; i < numCurves; i++ { 415 m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1]) 416 d = d[2:] 417 } 418 case extensionSupportedPoints: 419 // http://tools.ietf.org/html/rfc4492#section-5.5.2 420 if length < 1 { 421 return false 422 } 423 l := int(data[0]) 424 if length != l+1 { 425 return false 426 } 427 m.supportedPoints = make([]uint8, l) 428 copy(m.supportedPoints, data[1:]) 429 case extensionSessionTicket: 430 // http://tools.ietf.org/html/rfc5077#section-3.2 431 m.ticketSupported = true 432 m.sessionTicket = data[:length] 433 case extensionSignatureAlgorithms: 434 // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 435 if length < 2 || length&1 != 0 { 436 return false 437 } 438 l := int(data[0])<<8 | int(data[1]) 439 if l != length-2 { 440 return false 441 } 442 n := l / 2 443 d := data[2:] 444 m.signatureAndHashes = make([]signatureAndHash, n) 445 for i := range m.signatureAndHashes { 446 m.signatureAndHashes[i].hash = d[0] 447 m.signatureAndHashes[i].signature = d[1] 448 d = d[2:] 449 } 450 case extensionRenegotiationInfo: 451 if length != 1 || data[0] != 0 { 452 return false 453 } 454 m.secureRenegotiation = true 455 case extensionALPN: 456 if length < 2 { 457 return false 458 } 459 l := int(data[0])<<8 | int(data[1]) 460 if l != length-2 { 461 return false 462 } 463 d := data[2:length] 464 for len(d) != 0 { 465 stringLen := int(d[0]) 466 d = d[1:] 467 if stringLen == 0 || stringLen > len(d) { 468 return false 469 } 470 m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) 471 d = d[stringLen:] 472 } 473 case extensionSCT: 474 m.scts = true 475 if length != 0 { 476 return false 477 } 478 } 479 data = data[length:] 480 } 481 482 return true 483 } 484 485 type serverHelloMsg struct { 486 raw []byte 487 vers uint16 488 random []byte 489 sessionId []byte 490 cipherSuite uint16 491 compressionMethod uint8 492 nextProtoNeg bool 493 nextProtos []string 494 ocspStapling bool 495 scts [][]byte 496 ticketSupported bool 497 secureRenegotiation bool 498 alpnProtocol string 499 } 500 501 func (m *serverHelloMsg) equal(i interface{}) bool { 502 m1, ok := i.(*serverHelloMsg) 503 if !ok { 504 return false 505 } 506 507 if len(m.scts) != len(m1.scts) { 508 return false 509 } 510 for i, sct := range m.scts { 511 if !bytes.Equal(sct, m1.scts[i]) { 512 return false 513 } 514 } 515 516 return bytes.Equal(m.raw, m1.raw) && 517 m.vers == m1.vers && 518 bytes.Equal(m.random, m1.random) && 519 bytes.Equal(m.sessionId, m1.sessionId) && 520 m.cipherSuite == m1.cipherSuite && 521 m.compressionMethod == m1.compressionMethod && 522 m.nextProtoNeg == m1.nextProtoNeg && 523 eqStrings(m.nextProtos, m1.nextProtos) && 524 m.ocspStapling == m1.ocspStapling && 525 m.ticketSupported == m1.ticketSupported && 526 m.secureRenegotiation == m1.secureRenegotiation && 527 m.alpnProtocol == m1.alpnProtocol 528 } 529 530 func (m *serverHelloMsg) marshal() []byte { 531 if m.raw != nil { 532 return m.raw 533 } 534 535 length := 38 + len(m.sessionId) 536 numExtensions := 0 537 extensionsLength := 0 538 539 nextProtoLen := 0 540 if m.nextProtoNeg { 541 numExtensions++ 542 for _, v := range m.nextProtos { 543 nextProtoLen += len(v) 544 } 545 nextProtoLen += len(m.nextProtos) 546 extensionsLength += nextProtoLen 547 } 548 if m.ocspStapling { 549 numExtensions++ 550 } 551 if m.ticketSupported { 552 numExtensions++ 553 } 554 if m.secureRenegotiation { 555 extensionsLength += 1 556 numExtensions++ 557 } 558 if alpnLen := len(m.alpnProtocol); alpnLen > 0 { 559 if alpnLen >= 256 { 560 panic("invalid ALPN protocol") 561 } 562 extensionsLength += 2 + 1 + alpnLen 563 numExtensions++ 564 } 565 sctLen := 0 566 if len(m.scts) > 0 { 567 for _, sct := range m.scts { 568 sctLen += len(sct) + 2 569 } 570 extensionsLength += 2 + sctLen 571 numExtensions++ 572 } 573 574 if numExtensions > 0 { 575 extensionsLength += 4 * numExtensions 576 length += 2 + extensionsLength 577 } 578 579 x := make([]byte, 4+length) 580 x[0] = typeServerHello 581 x[1] = uint8(length >> 16) 582 x[2] = uint8(length >> 8) 583 x[3] = uint8(length) 584 x[4] = uint8(m.vers >> 8) 585 x[5] = uint8(m.vers) 586 copy(x[6:38], m.random) 587 x[38] = uint8(len(m.sessionId)) 588 copy(x[39:39+len(m.sessionId)], m.sessionId) 589 z := x[39+len(m.sessionId):] 590 z[0] = uint8(m.cipherSuite >> 8) 591 z[1] = uint8(m.cipherSuite) 592 z[2] = uint8(m.compressionMethod) 593 594 z = z[3:] 595 if numExtensions > 0 { 596 z[0] = byte(extensionsLength >> 8) 597 z[1] = byte(extensionsLength) 598 z = z[2:] 599 } 600 if m.nextProtoNeg { 601 z[0] = byte(extensionNextProtoNeg >> 8) 602 z[1] = byte(extensionNextProtoNeg & 0xff) 603 z[2] = byte(nextProtoLen >> 8) 604 z[3] = byte(nextProtoLen) 605 z = z[4:] 606 607 for _, v := range m.nextProtos { 608 l := len(v) 609 if l > 255 { 610 l = 255 611 } 612 z[0] = byte(l) 613 copy(z[1:], []byte(v[0:l])) 614 z = z[1+l:] 615 } 616 } 617 if m.ocspStapling { 618 z[0] = byte(extensionStatusRequest >> 8) 619 z[1] = byte(extensionStatusRequest) 620 z = z[4:] 621 } 622 if m.ticketSupported { 623 z[0] = byte(extensionSessionTicket >> 8) 624 z[1] = byte(extensionSessionTicket) 625 z = z[4:] 626 } 627 if m.secureRenegotiation { 628 z[0] = byte(extensionRenegotiationInfo >> 8) 629 z[1] = byte(extensionRenegotiationInfo & 0xff) 630 z[2] = 0 631 z[3] = 1 632 z = z[5:] 633 } 634 if alpnLen := len(m.alpnProtocol); alpnLen > 0 { 635 z[0] = byte(extensionALPN >> 8) 636 z[1] = byte(extensionALPN & 0xff) 637 l := 2 + 1 + alpnLen 638 z[2] = byte(l >> 8) 639 z[3] = byte(l) 640 l -= 2 641 z[4] = byte(l >> 8) 642 z[5] = byte(l) 643 l -= 1 644 z[6] = byte(l) 645 copy(z[7:], []byte(m.alpnProtocol)) 646 z = z[7+alpnLen:] 647 } 648 if sctLen > 0 { 649 z[0] = byte(extensionSCT >> 8) 650 z[1] = byte(extensionSCT) 651 l := sctLen + 2 652 z[2] = byte(l >> 8) 653 z[3] = byte(l) 654 z[4] = byte(sctLen >> 8) 655 z[5] = byte(sctLen) 656 657 z = z[6:] 658 for _, sct := range m.scts { 659 z[0] = byte(len(sct) >> 8) 660 z[1] = byte(len(sct)) 661 copy(z[2:], sct) 662 z = z[len(sct)+2:] 663 } 664 } 665 666 m.raw = x 667 668 return x 669 } 670 671 func (m *serverHelloMsg) unmarshal(data []byte) bool { 672 if len(data) < 42 { 673 return false 674 } 675 m.raw = data 676 m.vers = uint16(data[4])<<8 | uint16(data[5]) 677 m.random = data[6:38] 678 sessionIdLen := int(data[38]) 679 if sessionIdLen > 32 || len(data) < 39+sessionIdLen { 680 return false 681 } 682 m.sessionId = data[39 : 39+sessionIdLen] 683 data = data[39+sessionIdLen:] 684 if len(data) < 3 { 685 return false 686 } 687 m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) 688 m.compressionMethod = data[2] 689 data = data[3:] 690 691 m.nextProtoNeg = false 692 m.nextProtos = nil 693 m.ocspStapling = false 694 m.scts = nil 695 m.ticketSupported = false 696 m.alpnProtocol = "" 697 698 if len(data) == 0 { 699 // ServerHello is optionally followed by extension data 700 return true 701 } 702 if len(data) < 2 { 703 return false 704 } 705 706 extensionsLength := int(data[0])<<8 | int(data[1]) 707 data = data[2:] 708 if len(data) != extensionsLength { 709 return false 710 } 711 712 for len(data) != 0 { 713 if len(data) < 4 { 714 return false 715 } 716 extension := uint16(data[0])<<8 | uint16(data[1]) 717 length := int(data[2])<<8 | int(data[3]) 718 data = data[4:] 719 if len(data) < length { 720 return false 721 } 722 723 switch extension { 724 case extensionNextProtoNeg: 725 m.nextProtoNeg = true 726 d := data[:length] 727 for len(d) > 0 { 728 l := int(d[0]) 729 d = d[1:] 730 if l == 0 || l > len(d) { 731 return false 732 } 733 m.nextProtos = append(m.nextProtos, string(d[:l])) 734 d = d[l:] 735 } 736 case extensionStatusRequest: 737 if length > 0 { 738 return false 739 } 740 m.ocspStapling = true 741 case extensionSessionTicket: 742 if length > 0 { 743 return false 744 } 745 m.ticketSupported = true 746 case extensionRenegotiationInfo: 747 if length != 1 || data[0] != 0 { 748 return false 749 } 750 m.secureRenegotiation = true 751 case extensionALPN: 752 d := data[:length] 753 if len(d) < 3 { 754 return false 755 } 756 l := int(d[0])<<8 | int(d[1]) 757 if l != len(d)-2 { 758 return false 759 } 760 d = d[2:] 761 l = int(d[0]) 762 if l != len(d)-1 { 763 return false 764 } 765 d = d[1:] 766 m.alpnProtocol = string(d) 767 case extensionSCT: 768 d := data[:length] 769 770 if len(d) < 2 { 771 return false 772 } 773 l := int(d[0])<<8 | int(d[1]) 774 d = d[2:] 775 if len(d) != l { 776 return false 777 } 778 if l == 0 { 779 continue 780 } 781 782 m.scts = make([][]byte, 0, 3) 783 for len(d) != 0 { 784 if len(d) < 2 { 785 return false 786 } 787 sctLen := int(d[0])<<8 | int(d[1]) 788 d = d[2:] 789 if len(d) < sctLen { 790 return false 791 } 792 m.scts = append(m.scts, d[:sctLen]) 793 d = d[sctLen:] 794 } 795 } 796 data = data[length:] 797 } 798 799 return true 800 } 801 802 type certificateMsg struct { 803 raw []byte 804 certificates [][]byte 805 } 806 807 func (m *certificateMsg) equal(i interface{}) bool { 808 m1, ok := i.(*certificateMsg) 809 if !ok { 810 return false 811 } 812 813 return bytes.Equal(m.raw, m1.raw) && 814 eqByteSlices(m.certificates, m1.certificates) 815 } 816 817 func (m *certificateMsg) marshal() (x []byte) { 818 if m.raw != nil { 819 return m.raw 820 } 821 822 var i int 823 for _, slice := range m.certificates { 824 i += len(slice) 825 } 826 827 length := 3 + 3*len(m.certificates) + i 828 x = make([]byte, 4+length) 829 x[0] = typeCertificate 830 x[1] = uint8(length >> 16) 831 x[2] = uint8(length >> 8) 832 x[3] = uint8(length) 833 834 certificateOctets := length - 3 835 x[4] = uint8(certificateOctets >> 16) 836 x[5] = uint8(certificateOctets >> 8) 837 x[6] = uint8(certificateOctets) 838 839 y := x[7:] 840 for _, slice := range m.certificates { 841 y[0] = uint8(len(slice) >> 16) 842 y[1] = uint8(len(slice) >> 8) 843 y[2] = uint8(len(slice)) 844 copy(y[3:], slice) 845 y = y[3+len(slice):] 846 } 847 848 m.raw = x 849 return 850 } 851 852 func (m *certificateMsg) unmarshal(data []byte) bool { 853 if len(data) < 7 { 854 return false 855 } 856 857 m.raw = data 858 certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) 859 if uint32(len(data)) != certsLen+7 { 860 return false 861 } 862 863 numCerts := 0 864 d := data[7:] 865 for certsLen > 0 { 866 if len(d) < 4 { 867 return false 868 } 869 certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) 870 if uint32(len(d)) < 3+certLen { 871 return false 872 } 873 d = d[3+certLen:] 874 certsLen -= 3 + certLen 875 numCerts++ 876 } 877 878 m.certificates = make([][]byte, numCerts) 879 d = data[7:] 880 for i := 0; i < numCerts; i++ { 881 certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) 882 m.certificates[i] = d[3 : 3+certLen] 883 d = d[3+certLen:] 884 } 885 886 return true 887 } 888 889 type serverKeyExchangeMsg struct { 890 raw []byte 891 key []byte 892 } 893 894 func (m *serverKeyExchangeMsg) equal(i interface{}) bool { 895 m1, ok := i.(*serverKeyExchangeMsg) 896 if !ok { 897 return false 898 } 899 900 return bytes.Equal(m.raw, m1.raw) && 901 bytes.Equal(m.key, m1.key) 902 } 903 904 func (m *serverKeyExchangeMsg) marshal() []byte { 905 if m.raw != nil { 906 return m.raw 907 } 908 length := len(m.key) 909 x := make([]byte, length+4) 910 x[0] = typeServerKeyExchange 911 x[1] = uint8(length >> 16) 912 x[2] = uint8(length >> 8) 913 x[3] = uint8(length) 914 copy(x[4:], m.key) 915 916 m.raw = x 917 return x 918 } 919 920 func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { 921 m.raw = data 922 if len(data) < 4 { 923 return false 924 } 925 m.key = data[4:] 926 return true 927 } 928 929 type certificateStatusMsg struct { 930 raw []byte 931 statusType uint8 932 response []byte 933 } 934 935 func (m *certificateStatusMsg) equal(i interface{}) bool { 936 m1, ok := i.(*certificateStatusMsg) 937 if !ok { 938 return false 939 } 940 941 return bytes.Equal(m.raw, m1.raw) && 942 m.statusType == m1.statusType && 943 bytes.Equal(m.response, m1.response) 944 } 945 946 func (m *certificateStatusMsg) marshal() []byte { 947 if m.raw != nil { 948 return m.raw 949 } 950 951 var x []byte 952 if m.statusType == statusTypeOCSP { 953 x = make([]byte, 4+4+len(m.response)) 954 x[0] = typeCertificateStatus 955 l := len(m.response) + 4 956 x[1] = byte(l >> 16) 957 x[2] = byte(l >> 8) 958 x[3] = byte(l) 959 x[4] = statusTypeOCSP 960 961 l -= 4 962 x[5] = byte(l >> 16) 963 x[6] = byte(l >> 8) 964 x[7] = byte(l) 965 copy(x[8:], m.response) 966 } else { 967 x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType} 968 } 969 970 m.raw = x 971 return x 972 } 973 974 func (m *certificateStatusMsg) unmarshal(data []byte) bool { 975 m.raw = data 976 if len(data) < 5 { 977 return false 978 } 979 m.statusType = data[4] 980 981 m.response = nil 982 if m.statusType == statusTypeOCSP { 983 if len(data) < 8 { 984 return false 985 } 986 respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) 987 if uint32(len(data)) != 4+4+respLen { 988 return false 989 } 990 m.response = data[8:] 991 } 992 return true 993 } 994 995 type serverHelloDoneMsg struct{} 996 997 func (m *serverHelloDoneMsg) equal(i interface{}) bool { 998 _, ok := i.(*serverHelloDoneMsg) 999 return ok 1000 } 1001 1002 func (m *serverHelloDoneMsg) marshal() []byte { 1003 x := make([]byte, 4) 1004 x[0] = typeServerHelloDone 1005 return x 1006 } 1007 1008 func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { 1009 return len(data) == 4 1010 } 1011 1012 type clientKeyExchangeMsg struct { 1013 raw []byte 1014 ciphertext []byte 1015 } 1016 1017 func (m *clientKeyExchangeMsg) equal(i interface{}) bool { 1018 m1, ok := i.(*clientKeyExchangeMsg) 1019 if !ok { 1020 return false 1021 } 1022 1023 return bytes.Equal(m.raw, m1.raw) && 1024 bytes.Equal(m.ciphertext, m1.ciphertext) 1025 } 1026 1027 func (m *clientKeyExchangeMsg) marshal() []byte { 1028 if m.raw != nil { 1029 return m.raw 1030 } 1031 length := len(m.ciphertext) 1032 x := make([]byte, length+4) 1033 x[0] = typeClientKeyExchange 1034 x[1] = uint8(length >> 16) 1035 x[2] = uint8(length >> 8) 1036 x[3] = uint8(length) 1037 copy(x[4:], m.ciphertext) 1038 1039 m.raw = x 1040 return x 1041 } 1042 1043 func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { 1044 m.raw = data 1045 if len(data) < 4 { 1046 return false 1047 } 1048 l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) 1049 if l != len(data)-4 { 1050 return false 1051 } 1052 m.ciphertext = data[4:] 1053 return true 1054 } 1055 1056 type finishedMsg struct { 1057 raw []byte 1058 verifyData []byte 1059 } 1060 1061 func (m *finishedMsg) equal(i interface{}) bool { 1062 m1, ok := i.(*finishedMsg) 1063 if !ok { 1064 return false 1065 } 1066 1067 return bytes.Equal(m.raw, m1.raw) && 1068 bytes.Equal(m.verifyData, m1.verifyData) 1069 } 1070 1071 func (m *finishedMsg) marshal() (x []byte) { 1072 if m.raw != nil { 1073 return m.raw 1074 } 1075 1076 x = make([]byte, 4+len(m.verifyData)) 1077 x[0] = typeFinished 1078 x[3] = byte(len(m.verifyData)) 1079 copy(x[4:], m.verifyData) 1080 m.raw = x 1081 return 1082 } 1083 1084 func (m *finishedMsg) unmarshal(data []byte) bool { 1085 m.raw = data 1086 if len(data) < 4 { 1087 return false 1088 } 1089 m.verifyData = data[4:] 1090 return true 1091 } 1092 1093 type nextProtoMsg struct { 1094 raw []byte 1095 proto string 1096 } 1097 1098 func (m *nextProtoMsg) equal(i interface{}) bool { 1099 m1, ok := i.(*nextProtoMsg) 1100 if !ok { 1101 return false 1102 } 1103 1104 return bytes.Equal(m.raw, m1.raw) && 1105 m.proto == m1.proto 1106 } 1107 1108 func (m *nextProtoMsg) marshal() []byte { 1109 if m.raw != nil { 1110 return m.raw 1111 } 1112 l := len(m.proto) 1113 if l > 255 { 1114 l = 255 1115 } 1116 1117 padding := 32 - (l+2)%32 1118 length := l + padding + 2 1119 x := make([]byte, length+4) 1120 x[0] = typeNextProtocol 1121 x[1] = uint8(length >> 16) 1122 x[2] = uint8(length >> 8) 1123 x[3] = uint8(length) 1124 1125 y := x[4:] 1126 y[0] = byte(l) 1127 copy(y[1:], []byte(m.proto[0:l])) 1128 y = y[1+l:] 1129 y[0] = byte(padding) 1130 1131 m.raw = x 1132 1133 return x 1134 } 1135 1136 func (m *nextProtoMsg) unmarshal(data []byte) bool { 1137 m.raw = data 1138 1139 if len(data) < 5 { 1140 return false 1141 } 1142 data = data[4:] 1143 protoLen := int(data[0]) 1144 data = data[1:] 1145 if len(data) < protoLen { 1146 return false 1147 } 1148 m.proto = string(data[0:protoLen]) 1149 data = data[protoLen:] 1150 1151 if len(data) < 1 { 1152 return false 1153 } 1154 paddingLen := int(data[0]) 1155 data = data[1:] 1156 if len(data) != paddingLen { 1157 return false 1158 } 1159 1160 return true 1161 } 1162 1163 type certificateRequestMsg struct { 1164 raw []byte 1165 // hasSignatureAndHash indicates whether this message includes a list 1166 // of signature and hash functions. This change was introduced with TLS 1167 // 1.2. 1168 hasSignatureAndHash bool 1169 1170 certificateTypes []byte 1171 signatureAndHashes []signatureAndHash 1172 certificateAuthorities [][]byte 1173 } 1174 1175 func (m *certificateRequestMsg) equal(i interface{}) bool { 1176 m1, ok := i.(*certificateRequestMsg) 1177 if !ok { 1178 return false 1179 } 1180 1181 return bytes.Equal(m.raw, m1.raw) && 1182 bytes.Equal(m.certificateTypes, m1.certificateTypes) && 1183 eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && 1184 eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) 1185 } 1186 1187 func (m *certificateRequestMsg) marshal() (x []byte) { 1188 if m.raw != nil { 1189 return m.raw 1190 } 1191 1192 // See http://tools.ietf.org/html/rfc4346#section-7.4.4 1193 length := 1 + len(m.certificateTypes) + 2 1194 casLength := 0 1195 for _, ca := range m.certificateAuthorities { 1196 casLength += 2 + len(ca) 1197 } 1198 length += casLength 1199 1200 if m.hasSignatureAndHash { 1201 length += 2 + 2*len(m.signatureAndHashes) 1202 } 1203 1204 x = make([]byte, 4+length) 1205 x[0] = typeCertificateRequest 1206 x[1] = uint8(length >> 16) 1207 x[2] = uint8(length >> 8) 1208 x[3] = uint8(length) 1209 1210 x[4] = uint8(len(m.certificateTypes)) 1211 1212 copy(x[5:], m.certificateTypes) 1213 y := x[5+len(m.certificateTypes):] 1214 1215 if m.hasSignatureAndHash { 1216 n := len(m.signatureAndHashes) * 2 1217 y[0] = uint8(n >> 8) 1218 y[1] = uint8(n) 1219 y = y[2:] 1220 for _, sigAndHash := range m.signatureAndHashes { 1221 y[0] = sigAndHash.hash 1222 y[1] = sigAndHash.signature 1223 y = y[2:] 1224 } 1225 } 1226 1227 y[0] = uint8(casLength >> 8) 1228 y[1] = uint8(casLength) 1229 y = y[2:] 1230 for _, ca := range m.certificateAuthorities { 1231 y[0] = uint8(len(ca) >> 8) 1232 y[1] = uint8(len(ca)) 1233 y = y[2:] 1234 copy(y, ca) 1235 y = y[len(ca):] 1236 } 1237 1238 m.raw = x 1239 return 1240 } 1241 1242 func (m *certificateRequestMsg) unmarshal(data []byte) bool { 1243 m.raw = data 1244 1245 if len(data) < 5 { 1246 return false 1247 } 1248 1249 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) 1250 if uint32(len(data))-4 != length { 1251 return false 1252 } 1253 1254 numCertTypes := int(data[4]) 1255 data = data[5:] 1256 if numCertTypes == 0 || len(data) <= numCertTypes { 1257 return false 1258 } 1259 1260 m.certificateTypes = make([]byte, numCertTypes) 1261 if copy(m.certificateTypes, data) != numCertTypes { 1262 return false 1263 } 1264 1265 data = data[numCertTypes:] 1266 1267 if m.hasSignatureAndHash { 1268 if len(data) < 2 { 1269 return false 1270 } 1271 sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) 1272 data = data[2:] 1273 if sigAndHashLen&1 != 0 { 1274 return false 1275 } 1276 if len(data) < int(sigAndHashLen) { 1277 return false 1278 } 1279 numSigAndHash := sigAndHashLen / 2 1280 m.signatureAndHashes = make([]signatureAndHash, numSigAndHash) 1281 for i := range m.signatureAndHashes { 1282 m.signatureAndHashes[i].hash = data[0] 1283 m.signatureAndHashes[i].signature = data[1] 1284 data = data[2:] 1285 } 1286 } 1287 1288 if len(data) < 2 { 1289 return false 1290 } 1291 casLength := uint16(data[0])<<8 | uint16(data[1]) 1292 data = data[2:] 1293 if len(data) < int(casLength) { 1294 return false 1295 } 1296 cas := make([]byte, casLength) 1297 copy(cas, data) 1298 data = data[casLength:] 1299 1300 m.certificateAuthorities = nil 1301 for len(cas) > 0 { 1302 if len(cas) < 2 { 1303 return false 1304 } 1305 caLen := uint16(cas[0])<<8 | uint16(cas[1]) 1306 cas = cas[2:] 1307 1308 if len(cas) < int(caLen) { 1309 return false 1310 } 1311 1312 m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) 1313 cas = cas[caLen:] 1314 } 1315 if len(data) > 0 { 1316 return false 1317 } 1318 1319 return true 1320 } 1321 1322 type certificateVerifyMsg struct { 1323 raw []byte 1324 hasSignatureAndHash bool 1325 signatureAndHash signatureAndHash 1326 signature []byte 1327 } 1328 1329 func (m *certificateVerifyMsg) equal(i interface{}) bool { 1330 m1, ok := i.(*certificateVerifyMsg) 1331 if !ok { 1332 return false 1333 } 1334 1335 return bytes.Equal(m.raw, m1.raw) && 1336 m.hasSignatureAndHash == m1.hasSignatureAndHash && 1337 m.signatureAndHash.hash == m1.signatureAndHash.hash && 1338 m.signatureAndHash.signature == m1.signatureAndHash.signature && 1339 bytes.Equal(m.signature, m1.signature) 1340 } 1341 1342 func (m *certificateVerifyMsg) marshal() (x []byte) { 1343 if m.raw != nil { 1344 return m.raw 1345 } 1346 1347 // See http://tools.ietf.org/html/rfc4346#section-7.4.8 1348 siglength := len(m.signature) 1349 length := 2 + siglength 1350 if m.hasSignatureAndHash { 1351 length += 2 1352 } 1353 x = make([]byte, 4+length) 1354 x[0] = typeCertificateVerify 1355 x[1] = uint8(length >> 16) 1356 x[2] = uint8(length >> 8) 1357 x[3] = uint8(length) 1358 y := x[4:] 1359 if m.hasSignatureAndHash { 1360 y[0] = m.signatureAndHash.hash 1361 y[1] = m.signatureAndHash.signature 1362 y = y[2:] 1363 } 1364 y[0] = uint8(siglength >> 8) 1365 y[1] = uint8(siglength) 1366 copy(y[2:], m.signature) 1367 1368 m.raw = x 1369 1370 return 1371 } 1372 1373 func (m *certificateVerifyMsg) unmarshal(data []byte) bool { 1374 m.raw = data 1375 1376 if len(data) < 6 { 1377 return false 1378 } 1379 1380 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) 1381 if uint32(len(data))-4 != length { 1382 return false 1383 } 1384 1385 data = data[4:] 1386 if m.hasSignatureAndHash { 1387 m.signatureAndHash.hash = data[0] 1388 m.signatureAndHash.signature = data[1] 1389 data = data[2:] 1390 } 1391 1392 if len(data) < 2 { 1393 return false 1394 } 1395 siglength := int(data[0])<<8 + int(data[1]) 1396 data = data[2:] 1397 if len(data) != siglength { 1398 return false 1399 } 1400 1401 m.signature = data 1402 1403 return true 1404 } 1405 1406 type newSessionTicketMsg struct { 1407 raw []byte 1408 ticket []byte 1409 } 1410 1411 func (m *newSessionTicketMsg) equal(i interface{}) bool { 1412 m1, ok := i.(*newSessionTicketMsg) 1413 if !ok { 1414 return false 1415 } 1416 1417 return bytes.Equal(m.raw, m1.raw) && 1418 bytes.Equal(m.ticket, m1.ticket) 1419 } 1420 1421 func (m *newSessionTicketMsg) marshal() (x []byte) { 1422 if m.raw != nil { 1423 return m.raw 1424 } 1425 1426 // See http://tools.ietf.org/html/rfc5077#section-3.3 1427 ticketLen := len(m.ticket) 1428 length := 2 + 4 + ticketLen 1429 x = make([]byte, 4+length) 1430 x[0] = typeNewSessionTicket 1431 x[1] = uint8(length >> 16) 1432 x[2] = uint8(length >> 8) 1433 x[3] = uint8(length) 1434 x[8] = uint8(ticketLen >> 8) 1435 x[9] = uint8(ticketLen) 1436 copy(x[10:], m.ticket) 1437 1438 m.raw = x 1439 1440 return 1441 } 1442 1443 func (m *newSessionTicketMsg) unmarshal(data []byte) bool { 1444 m.raw = data 1445 1446 if len(data) < 10 { 1447 return false 1448 } 1449 1450 length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) 1451 if uint32(len(data))-4 != length { 1452 return false 1453 } 1454 1455 ticketLen := int(data[8])<<8 + int(data[9]) 1456 if len(data)-10 != ticketLen { 1457 return false 1458 } 1459 1460 m.ticket = data[10:] 1461 1462 return true 1463 } 1464 1465 func eqUint16s(x, y []uint16) bool { 1466 if len(x) != len(y) { 1467 return false 1468 } 1469 for i, v := range x { 1470 if y[i] != v { 1471 return false 1472 } 1473 } 1474 return true 1475 } 1476 1477 func eqCurveIDs(x, y []CurveID) bool { 1478 if len(x) != len(y) { 1479 return false 1480 } 1481 for i, v := range x { 1482 if y[i] != v { 1483 return false 1484 } 1485 } 1486 return true 1487 } 1488 1489 func eqStrings(x, y []string) bool { 1490 if len(x) != len(y) { 1491 return false 1492 } 1493 for i, v := range x { 1494 if y[i] != v { 1495 return false 1496 } 1497 } 1498 return true 1499 } 1500 1501 func eqByteSlices(x, y [][]byte) bool { 1502 if len(x) != len(y) { 1503 return false 1504 } 1505 for i, v := range x { 1506 if !bytes.Equal(v, y[i]) { 1507 return false 1508 } 1509 } 1510 return true 1511 } 1512 1513 func eqSignatureAndHashes(x, y []signatureAndHash) bool { 1514 if len(x) != len(y) { 1515 return false 1516 } 1517 for i, v := range x { 1518 v2 := y[i] 1519 if v.hash != v2.hash || v.signature != v2.signature { 1520 return false 1521 } 1522 } 1523 return true 1524 } 1525