1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package tls 6 7 import ( 8 "bytes" 9 "crypto/ecdsa" 10 "crypto/rsa" 11 "crypto/x509" 12 "encoding/base64" 13 "encoding/binary" 14 "encoding/pem" 15 "fmt" 16 "io" 17 "net" 18 "os" 19 "os/exec" 20 "path/filepath" 21 "strconv" 22 "testing" 23 "time" 24 ) 25 26 // Note: see comment in handshake_test.go for details of how the reference 27 // tests work. 28 29 // blockingSource is an io.Reader that blocks a Read call until it's closed. 30 type blockingSource chan bool 31 32 func (b blockingSource) Read([]byte) (n int, err error) { 33 <-b 34 return 0, io.EOF 35 } 36 37 // clientTest represents a test of the TLS client handshake against a reference 38 // implementation. 39 type clientTest struct { 40 // name is a freeform string identifying the test and the file in which 41 // the expected results will be stored. 42 name string 43 // command, if not empty, contains a series of arguments for the 44 // command to run for the reference server. 45 command []string 46 // config, if not nil, contains a custom Config to use for this test. 47 config *Config 48 // cert, if not empty, contains a DER-encoded certificate for the 49 // reference server. 50 cert []byte 51 // key, if not nil, contains either a *rsa.PrivateKey or 52 // *ecdsa.PrivateKey which is the private key for the reference server. 53 key interface{} 54 // extensions, if not nil, contains a list of extension data to be returned 55 // from the ServerHello. The data should be in standard TLS format with 56 // a 2-byte uint16 type, 2-byte data length, followed by the extension data. 57 extensions [][]byte 58 // validate, if not nil, is a function that will be called with the 59 // ConnectionState of the resulting connection. It returns a non-nil 60 // error if the ConnectionState is unacceptable. 61 validate func(ConnectionState) error 62 } 63 64 var defaultServerCommand = []string{"openssl", "s_server"} 65 66 // connFromCommand starts the reference server process, connects to it and 67 // returns a recordingConn for the connection. The stdin return value is a 68 // blockingSource for the stdin of the child process. It must be closed before 69 // Waiting for child. 70 func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin blockingSource, err error) { 71 cert := testRSACertificate 72 if len(test.cert) > 0 { 73 cert = test.cert 74 } 75 certPath := tempFile(string(cert)) 76 defer os.Remove(certPath) 77 78 var key interface{} = testRSAPrivateKey 79 if test.key != nil { 80 key = test.key 81 } 82 var pemType string 83 var derBytes []byte 84 switch key := key.(type) { 85 case *rsa.PrivateKey: 86 pemType = "RSA" 87 derBytes = x509.MarshalPKCS1PrivateKey(key) 88 case *ecdsa.PrivateKey: 89 pemType = "EC" 90 var err error 91 derBytes, err = x509.MarshalECPrivateKey(key) 92 if err != nil { 93 panic(err) 94 } 95 default: 96 panic("unknown key type") 97 } 98 99 var pemOut bytes.Buffer 100 pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes}) 101 102 keyPath := tempFile(string(pemOut.Bytes())) 103 defer os.Remove(keyPath) 104 105 var command []string 106 if len(test.command) > 0 { 107 command = append(command, test.command...) 108 } else { 109 command = append(command, defaultServerCommand...) 110 } 111 command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) 112 // serverPort contains the port that OpenSSL will listen on. OpenSSL 113 // can't take "0" as an argument here so we have to pick a number and 114 // hope that it's not in use on the machine. Since this only occurs 115 // when -update is given and thus when there's a human watching the 116 // test, this isn't too bad. 117 const serverPort = 24323 118 command = append(command, "-accept", strconv.Itoa(serverPort)) 119 120 if len(test.extensions) > 0 { 121 var serverInfo bytes.Buffer 122 for _, ext := range test.extensions { 123 pem.Encode(&serverInfo, &pem.Block{ 124 Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)), 125 Bytes: ext, 126 }) 127 } 128 serverInfoPath := tempFile(serverInfo.String()) 129 defer os.Remove(serverInfoPath) 130 command = append(command, "-serverinfo", serverInfoPath) 131 } 132 133 cmd := exec.Command(command[0], command[1:]...) 134 stdin = blockingSource(make(chan bool)) 135 cmd.Stdin = stdin 136 var out bytes.Buffer 137 cmd.Stdout = &out 138 cmd.Stderr = &out 139 if err := cmd.Start(); err != nil { 140 return nil, nil, nil, err 141 } 142 143 // OpenSSL does print an "ACCEPT" banner, but it does so *before* 144 // opening the listening socket, so we can't use that to wait until it 145 // has started listening. Thus we are forced to poll until we get a 146 // connection. 147 var tcpConn net.Conn 148 for i := uint(0); i < 5; i++ { 149 tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ 150 IP: net.IPv4(127, 0, 0, 1), 151 Port: serverPort, 152 }) 153 if err == nil { 154 break 155 } 156 time.Sleep((1 << i) * 5 * time.Millisecond) 157 } 158 if err != nil { 159 close(stdin) 160 out.WriteTo(os.Stdout) 161 cmd.Process.Kill() 162 return nil, nil, nil, cmd.Wait() 163 } 164 165 record := &recordingConn{ 166 Conn: tcpConn, 167 } 168 169 return record, cmd, stdin, nil 170 } 171 172 func (test *clientTest) dataPath() string { 173 return filepath.Join("testdata", "Client-"+test.name) 174 } 175 176 func (test *clientTest) loadData() (flows [][]byte, err error) { 177 in, err := os.Open(test.dataPath()) 178 if err != nil { 179 return nil, err 180 } 181 defer in.Close() 182 return parseTestData(in) 183 } 184 185 func (test *clientTest) run(t *testing.T, write bool) { 186 var clientConn, serverConn net.Conn 187 var recordingConn *recordingConn 188 var childProcess *exec.Cmd 189 var stdin blockingSource 190 191 if write { 192 var err error 193 recordingConn, childProcess, stdin, err = test.connFromCommand() 194 if err != nil { 195 t.Fatalf("Failed to start subcommand: %s", err) 196 } 197 clientConn = recordingConn 198 } else { 199 clientConn, serverConn = net.Pipe() 200 } 201 202 config := test.config 203 if config == nil { 204 config = testConfig 205 } 206 client := Client(clientConn, config) 207 208 doneChan := make(chan bool) 209 go func() { 210 if _, err := client.Write([]byte("hello\n")); err != nil { 211 t.Errorf("Client.Write failed: %s", err) 212 } 213 if test.validate != nil { 214 if err := test.validate(client.ConnectionState()); err != nil { 215 t.Errorf("validate callback returned error: %s", err) 216 } 217 } 218 client.Close() 219 clientConn.Close() 220 doneChan <- true 221 }() 222 223 if !write { 224 flows, err := test.loadData() 225 if err != nil { 226 t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err) 227 } 228 for i, b := range flows { 229 if i%2 == 1 { 230 serverConn.Write(b) 231 continue 232 } 233 bb := make([]byte, len(b)) 234 _, err := io.ReadFull(serverConn, bb) 235 if err != nil { 236 t.Fatalf("%s #%d: %s", test.name, i, err) 237 } 238 if !bytes.Equal(b, bb) { 239 t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b) 240 } 241 } 242 serverConn.Close() 243 } 244 245 <-doneChan 246 247 if write { 248 path := test.dataPath() 249 out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 250 if err != nil { 251 t.Fatalf("Failed to create output file: %s", err) 252 } 253 defer out.Close() 254 recordingConn.Close() 255 close(stdin) 256 childProcess.Process.Kill() 257 childProcess.Wait() 258 if len(recordingConn.flows) < 3 { 259 childProcess.Stdout.(*bytes.Buffer).WriteTo(os.Stdout) 260 t.Fatalf("Client connection didn't work") 261 } 262 recordingConn.WriteTo(out) 263 fmt.Printf("Wrote %s\n", path) 264 } 265 } 266 267 func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) { 268 test := *template 269 test.name = prefix + test.name 270 if len(test.command) == 0 { 271 test.command = defaultClientCommand 272 } 273 test.command = append([]string(nil), test.command...) 274 test.command = append(test.command, option) 275 test.run(t, *update) 276 } 277 278 func runClientTestTLS10(t *testing.T, template *clientTest) { 279 runClientTestForVersion(t, template, "TLSv10-", "-tls1") 280 } 281 282 func runClientTestTLS11(t *testing.T, template *clientTest) { 283 runClientTestForVersion(t, template, "TLSv11-", "-tls1_1") 284 } 285 286 func runClientTestTLS12(t *testing.T, template *clientTest) { 287 runClientTestForVersion(t, template, "TLSv12-", "-tls1_2") 288 } 289 290 func TestHandshakeClientRSARC4(t *testing.T) { 291 test := &clientTest{ 292 name: "RSA-RC4", 293 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"}, 294 } 295 runClientTestTLS10(t, test) 296 runClientTestTLS11(t, test) 297 runClientTestTLS12(t, test) 298 } 299 300 func TestHandshakeClientECDHERSAAES(t *testing.T) { 301 test := &clientTest{ 302 name: "ECDHE-RSA-AES", 303 command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"}, 304 } 305 runClientTestTLS10(t, test) 306 runClientTestTLS11(t, test) 307 runClientTestTLS12(t, test) 308 } 309 310 func TestHandshakeClientECDHEECDSAAES(t *testing.T) { 311 test := &clientTest{ 312 name: "ECDHE-ECDSA-AES", 313 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"}, 314 cert: testECDSACertificate, 315 key: testECDSAPrivateKey, 316 } 317 runClientTestTLS10(t, test) 318 runClientTestTLS11(t, test) 319 runClientTestTLS12(t, test) 320 } 321 322 func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { 323 test := &clientTest{ 324 name: "ECDHE-ECDSA-AES-GCM", 325 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, 326 cert: testECDSACertificate, 327 key: testECDSAPrivateKey, 328 } 329 runClientTestTLS12(t, test) 330 } 331 332 func TestHandshakeClientAES256GCMSHA384(t *testing.T) { 333 test := &clientTest{ 334 name: "ECDHE-ECDSA-AES256-GCM-SHA384", 335 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"}, 336 cert: testECDSACertificate, 337 key: testECDSAPrivateKey, 338 } 339 runClientTestTLS12(t, test) 340 } 341 342 func TestHandshakeClientCertRSA(t *testing.T) { 343 config := *testConfig 344 cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) 345 config.Certificates = []Certificate{cert} 346 347 test := &clientTest{ 348 name: "ClientCert-RSA-RSA", 349 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 350 config: &config, 351 } 352 353 runClientTestTLS10(t, test) 354 runClientTestTLS12(t, test) 355 356 test = &clientTest{ 357 name: "ClientCert-RSA-ECDSA", 358 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 359 config: &config, 360 cert: testECDSACertificate, 361 key: testECDSAPrivateKey, 362 } 363 364 runClientTestTLS10(t, test) 365 runClientTestTLS12(t, test) 366 367 test = &clientTest{ 368 name: "ClientCert-RSA-AES256-GCM-SHA384", 369 command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-verify", "1"}, 370 config: &config, 371 cert: testRSACertificate, 372 key: testRSAPrivateKey, 373 } 374 375 runClientTestTLS12(t, test) 376 } 377 378 func TestHandshakeClientCertECDSA(t *testing.T) { 379 config := *testConfig 380 cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) 381 config.Certificates = []Certificate{cert} 382 383 test := &clientTest{ 384 name: "ClientCert-ECDSA-RSA", 385 command: []string{"openssl", "s_server", "-cipher", "RC4-SHA", "-verify", "1"}, 386 config: &config, 387 } 388 389 runClientTestTLS10(t, test) 390 runClientTestTLS12(t, test) 391 392 test = &clientTest{ 393 name: "ClientCert-ECDSA-ECDSA", 394 command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"}, 395 config: &config, 396 cert: testECDSACertificate, 397 key: testECDSAPrivateKey, 398 } 399 400 runClientTestTLS10(t, test) 401 runClientTestTLS12(t, test) 402 } 403 404 func TestClientResumption(t *testing.T) { 405 serverConfig := &Config{ 406 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, 407 Certificates: testConfig.Certificates, 408 } 409 410 issuer, err := x509.ParseCertificate(testRSACertificateIssuer) 411 if err != nil { 412 panic(err) 413 } 414 415 rootCAs := x509.NewCertPool() 416 rootCAs.AddCert(issuer) 417 418 clientConfig := &Config{ 419 CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, 420 ClientSessionCache: NewLRUClientSessionCache(32), 421 RootCAs: rootCAs, 422 ServerName: "example.golang", 423 } 424 425 testResumeState := func(test string, didResume bool) { 426 _, hs, err := testHandshake(clientConfig, serverConfig) 427 if err != nil { 428 t.Fatalf("%s: handshake failed: %s", test, err) 429 } 430 if hs.DidResume != didResume { 431 t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) 432 } 433 if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { 434 t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) 435 } 436 } 437 438 getTicket := func() []byte { 439 return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket 440 } 441 randomKey := func() [32]byte { 442 var k [32]byte 443 if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil { 444 t.Fatalf("Failed to read new SessionTicketKey: %s", err) 445 } 446 return k 447 } 448 449 testResumeState("Handshake", false) 450 ticket := getTicket() 451 testResumeState("Resume", true) 452 if !bytes.Equal(ticket, getTicket()) { 453 t.Fatal("first ticket doesn't match ticket after resumption") 454 } 455 456 key2 := randomKey() 457 serverConfig.SetSessionTicketKeys([][32]byte{key2}) 458 459 testResumeState("InvalidSessionTicketKey", false) 460 testResumeState("ResumeAfterInvalidSessionTicketKey", true) 461 462 serverConfig.SetSessionTicketKeys([][32]byte{randomKey(), key2}) 463 ticket = getTicket() 464 testResumeState("KeyChange", true) 465 if bytes.Equal(ticket, getTicket()) { 466 t.Fatal("new ticket wasn't included while resuming") 467 } 468 testResumeState("KeyChangeFinish", true) 469 470 clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} 471 testResumeState("DifferentCipherSuite", false) 472 testResumeState("DifferentCipherSuiteRecovers", true) 473 474 clientConfig.ClientSessionCache = nil 475 testResumeState("WithoutSessionCache", false) 476 } 477 478 func TestLRUClientSessionCache(t *testing.T) { 479 // Initialize cache of capacity 4. 480 cache := NewLRUClientSessionCache(4) 481 cs := make([]ClientSessionState, 6) 482 keys := []string{"0", "1", "2", "3", "4", "5", "6"} 483 484 // Add 4 entries to the cache and look them up. 485 for i := 0; i < 4; i++ { 486 cache.Put(keys[i], &cs[i]) 487 } 488 for i := 0; i < 4; i++ { 489 if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { 490 t.Fatalf("session cache failed lookup for added key: %s", keys[i]) 491 } 492 } 493 494 // Add 2 more entries to the cache. First 2 should be evicted. 495 for i := 4; i < 6; i++ { 496 cache.Put(keys[i], &cs[i]) 497 } 498 for i := 0; i < 2; i++ { 499 if s, ok := cache.Get(keys[i]); ok || s != nil { 500 t.Fatalf("session cache should have evicted key: %s", keys[i]) 501 } 502 } 503 504 // Touch entry 2. LRU should evict 3 next. 505 cache.Get(keys[2]) 506 cache.Put(keys[0], &cs[0]) 507 if s, ok := cache.Get(keys[3]); ok || s != nil { 508 t.Fatalf("session cache should have evicted key 3") 509 } 510 511 // Update entry 0 in place. 512 cache.Put(keys[0], &cs[3]) 513 if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { 514 t.Fatalf("session cache failed update for key 0") 515 } 516 517 // Adding a nil entry is valid. 518 cache.Put(keys[0], nil) 519 if s, ok := cache.Get(keys[0]); !ok || s != nil { 520 t.Fatalf("failed to add nil entry to cache") 521 } 522 } 523 524 func TestHandshakeClientALPNMatch(t *testing.T) { 525 config := *testConfig 526 config.NextProtos = []string{"proto2", "proto1"} 527 528 test := &clientTest{ 529 name: "ALPN", 530 // Note that this needs OpenSSL 1.0.2 because that is the first 531 // version that supports the -alpn flag. 532 command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"}, 533 config: &config, 534 validate: func(state ConnectionState) error { 535 // The server's preferences should override the client. 536 if state.NegotiatedProtocol != "proto1" { 537 return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol) 538 } 539 return nil 540 }, 541 } 542 runClientTestTLS12(t, test) 543 } 544 545 func TestHandshakeClientALPNNoMatch(t *testing.T) { 546 config := *testConfig 547 config.NextProtos = []string{"proto3"} 548 549 test := &clientTest{ 550 name: "ALPN-NoMatch", 551 // Note that this needs OpenSSL 1.0.2 because that is the first 552 // version that supports the -alpn flag. 553 command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"}, 554 config: &config, 555 validate: func(state ConnectionState) error { 556 // There's no overlap so OpenSSL will not select a protocol. 557 if state.NegotiatedProtocol != "" { 558 return fmt.Errorf("Got protocol %q, wanted ''", state.NegotiatedProtocol) 559 } 560 return nil 561 }, 562 } 563 runClientTestTLS12(t, test) 564 } 565 566 // sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443` 567 const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0=" 568 569 func TestHandshakClientSCTs(t *testing.T) { 570 config := *testConfig 571 572 scts, err := base64.StdEncoding.DecodeString(sctsBase64) 573 if err != nil { 574 t.Fatal(err) 575 } 576 577 test := &clientTest{ 578 name: "SCT", 579 // Note that this needs OpenSSL 1.0.2 because that is the first 580 // version that supports the -serverinfo flag. 581 command: []string{"openssl", "s_server"}, 582 config: &config, 583 extensions: [][]byte{scts}, 584 validate: func(state ConnectionState) error { 585 expectedSCTs := [][]byte{ 586 scts[8:125], 587 scts[127:245], 588 scts[247:], 589 } 590 if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) { 591 return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs)) 592 } 593 for i, expected := range expectedSCTs { 594 if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) { 595 return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected) 596 } 597 } 598 return nil 599 }, 600 } 601 runClientTestTLS12(t, test) 602 } 603