1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 // This test suite uses SSLClientSocket to test the implementation of 6 // SSLServerSocket. In order to establish connections between the sockets 7 // we need two additional classes: 8 // 1. FakeSocket 9 // Connects SSL socket to FakeDataChannel. This class is just a stub. 10 // 11 // 2. FakeDataChannel 12 // Implements the actual exchange of data between two FakeSockets. 13 // 14 // Implementations of these two classes are included in this file. 15 16 #include "net/socket/ssl_server_socket.h" 17 18 #include <stdlib.h> 19 20 #include <queue> 21 22 #include "base/compiler_specific.h" 23 #include "base/file_util.h" 24 #include "base/files/file_path.h" 25 #include "base/message_loop/message_loop.h" 26 #include "base/path_service.h" 27 #include "crypto/nss_util.h" 28 #include "crypto/rsa_private_key.h" 29 #include "net/base/address_list.h" 30 #include "net/base/completion_callback.h" 31 #include "net/base/host_port_pair.h" 32 #include "net/base/io_buffer.h" 33 #include "net/base/ip_endpoint.h" 34 #include "net/base/net_errors.h" 35 #include "net/base/net_log.h" 36 #include "net/base/test_data_directory.h" 37 #include "net/cert/cert_status_flags.h" 38 #include "net/cert/mock_cert_verifier.h" 39 #include "net/cert/x509_certificate.h" 40 #include "net/http/transport_security_state.h" 41 #include "net/socket/client_socket_factory.h" 42 #include "net/socket/socket_test_util.h" 43 #include "net/socket/ssl_client_socket.h" 44 #include "net/socket/stream_socket.h" 45 #include "net/ssl/ssl_config_service.h" 46 #include "net/ssl/ssl_info.h" 47 #include "net/test/cert_test_util.h" 48 #include "testing/gtest/include/gtest/gtest.h" 49 #include "testing/platform_test.h" 50 51 namespace net { 52 53 namespace { 54 55 class FakeDataChannel { 56 public: 57 FakeDataChannel() 58 : read_buf_len_(0), 59 weak_factory_(this), 60 closed_(false), 61 write_called_after_close_(false) { 62 } 63 64 int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 65 if (closed_) 66 return 0; 67 if (data_.empty()) { 68 read_callback_ = callback; 69 read_buf_ = buf; 70 read_buf_len_ = buf_len; 71 return net::ERR_IO_PENDING; 72 } 73 return PropogateData(buf, buf_len); 74 } 75 76 int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 77 if (closed_) { 78 if (write_called_after_close_) 79 return net::ERR_CONNECTION_RESET; 80 write_called_after_close_ = true; 81 write_callback_ = callback; 82 base::MessageLoop::current()->PostTask( 83 FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback, 84 weak_factory_.GetWeakPtr())); 85 return net::ERR_IO_PENDING; 86 } 87 data_.push(new net::DrainableIOBuffer(buf, buf_len)); 88 base::MessageLoop::current()->PostTask( 89 FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, 90 weak_factory_.GetWeakPtr())); 91 return buf_len; 92 } 93 94 // Closes the FakeDataChannel. After Close() is called, Read() returns 0, 95 // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that 96 // after the FakeDataChannel is closed, the first Write() call completes 97 // asynchronously, which is necessary to reproduce bug 127822. 98 void Close() { 99 closed_ = true; 100 } 101 102 private: 103 void DoReadCallback() { 104 if (read_callback_.is_null() || data_.empty()) 105 return; 106 107 int copied = PropogateData(read_buf_, read_buf_len_); 108 CompletionCallback callback = read_callback_; 109 read_callback_.Reset(); 110 read_buf_ = NULL; 111 read_buf_len_ = 0; 112 callback.Run(copied); 113 } 114 115 void DoWriteCallback() { 116 if (write_callback_.is_null()) 117 return; 118 119 CompletionCallback callback = write_callback_; 120 write_callback_.Reset(); 121 callback.Run(net::ERR_CONNECTION_RESET); 122 } 123 124 int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { 125 scoped_refptr<net::DrainableIOBuffer> buf = data_.front(); 126 int copied = std::min(buf->BytesRemaining(), read_buf_len); 127 memcpy(read_buf->data(), buf->data(), copied); 128 buf->DidConsume(copied); 129 130 if (!buf->BytesRemaining()) 131 data_.pop(); 132 return copied; 133 } 134 135 CompletionCallback read_callback_; 136 scoped_refptr<net::IOBuffer> read_buf_; 137 int read_buf_len_; 138 139 CompletionCallback write_callback_; 140 141 std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; 142 143 base::WeakPtrFactory<FakeDataChannel> weak_factory_; 144 145 // True if Close() has been called. 146 bool closed_; 147 148 // Controls the completion of Write() after the FakeDataChannel is closed. 149 // After the FakeDataChannel is closed, the first Write() call completes 150 // asynchronously. 151 bool write_called_after_close_; 152 153 DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); 154 }; 155 156 class FakeSocket : public StreamSocket { 157 public: 158 FakeSocket(FakeDataChannel* incoming_channel, 159 FakeDataChannel* outgoing_channel) 160 : incoming_(incoming_channel), 161 outgoing_(outgoing_channel) { 162 } 163 164 virtual ~FakeSocket() { 165 } 166 167 virtual int Read(IOBuffer* buf, int buf_len, 168 const CompletionCallback& callback) OVERRIDE { 169 // Read random number of bytes. 170 buf_len = rand() % buf_len + 1; 171 return incoming_->Read(buf, buf_len, callback); 172 } 173 174 virtual int Write(IOBuffer* buf, int buf_len, 175 const CompletionCallback& callback) OVERRIDE { 176 // Write random number of bytes. 177 buf_len = rand() % buf_len + 1; 178 return outgoing_->Write(buf, buf_len, callback); 179 } 180 181 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { 182 return true; 183 } 184 185 virtual bool SetSendBufferSize(int32 size) OVERRIDE { 186 return true; 187 } 188 189 virtual int Connect(const CompletionCallback& callback) OVERRIDE { 190 return net::OK; 191 } 192 193 virtual void Disconnect() OVERRIDE { 194 incoming_->Close(); 195 outgoing_->Close(); 196 } 197 198 virtual bool IsConnected() const OVERRIDE { 199 return true; 200 } 201 202 virtual bool IsConnectedAndIdle() const OVERRIDE { 203 return true; 204 } 205 206 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { 207 net::IPAddressNumber ip_address(net::kIPv4AddressSize); 208 *address = net::IPEndPoint(ip_address, 0 /*port*/); 209 return net::OK; 210 } 211 212 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { 213 net::IPAddressNumber ip_address(4); 214 *address = net::IPEndPoint(ip_address, 0); 215 return net::OK; 216 } 217 218 virtual const BoundNetLog& NetLog() const OVERRIDE { 219 return net_log_; 220 } 221 222 virtual void SetSubresourceSpeculation() OVERRIDE {} 223 virtual void SetOmniboxSpeculation() OVERRIDE {} 224 225 virtual bool WasEverUsed() const OVERRIDE { 226 return true; 227 } 228 229 virtual bool UsingTCPFastOpen() const OVERRIDE { 230 return false; 231 } 232 233 234 virtual bool WasNpnNegotiated() const OVERRIDE { 235 return false; 236 } 237 238 virtual NextProto GetNegotiatedProtocol() const OVERRIDE { 239 return kProtoUnknown; 240 } 241 242 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { 243 return false; 244 } 245 246 private: 247 net::BoundNetLog net_log_; 248 FakeDataChannel* incoming_; 249 FakeDataChannel* outgoing_; 250 251 DISALLOW_COPY_AND_ASSIGN(FakeSocket); 252 }; 253 254 } // namespace 255 256 // Verify the correctness of the test helper classes first. 257 TEST(FakeSocketTest, DataTransfer) { 258 // Establish channels between two sockets. 259 FakeDataChannel channel_1; 260 FakeDataChannel channel_2; 261 FakeSocket client(&channel_1, &channel_2); 262 FakeSocket server(&channel_2, &channel_1); 263 264 const char kTestData[] = "testing123"; 265 const int kTestDataSize = strlen(kTestData); 266 const int kReadBufSize = 1024; 267 scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData); 268 scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); 269 270 // Write then read. 271 int written = 272 server.Write(write_buf.get(), kTestDataSize, CompletionCallback()); 273 EXPECT_GT(written, 0); 274 EXPECT_LE(written, kTestDataSize); 275 276 int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback()); 277 EXPECT_GT(read, 0); 278 EXPECT_LE(read, written); 279 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); 280 281 // Read then write. 282 TestCompletionCallback callback; 283 EXPECT_EQ(net::ERR_IO_PENDING, 284 server.Read(read_buf.get(), kReadBufSize, callback.callback())); 285 286 written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback()); 287 EXPECT_GT(written, 0); 288 EXPECT_LE(written, kTestDataSize); 289 290 read = callback.WaitForResult(); 291 EXPECT_GT(read, 0); 292 EXPECT_LE(read, written); 293 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); 294 } 295 296 class SSLServerSocketTest : public PlatformTest { 297 public: 298 SSLServerSocketTest() 299 : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), 300 cert_verifier_(new MockCertVerifier()), 301 transport_security_state_(new TransportSecurityState) { 302 cert_verifier_->set_default_result(net::CERT_STATUS_AUTHORITY_INVALID); 303 } 304 305 protected: 306 void Initialize() { 307 FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_); 308 FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_); 309 310 base::FilePath certs_dir(GetTestCertsDirectory()); 311 312 base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); 313 std::string cert_der; 314 ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der)); 315 316 scoped_refptr<net::X509Certificate> cert = 317 X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); 318 319 base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); 320 std::string key_string; 321 ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string)); 322 std::vector<uint8> key_vector( 323 reinterpret_cast<const uint8*>(key_string.data()), 324 reinterpret_cast<const uint8*>(key_string.data() + 325 key_string.length())); 326 327 scoped_ptr<crypto::RSAPrivateKey> private_key( 328 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); 329 330 net::SSLConfig ssl_config; 331 ssl_config.cached_info_enabled = false; 332 ssl_config.false_start_enabled = false; 333 ssl_config.channel_id_enabled = false; 334 ssl_config.version_min = SSL_PROTOCOL_VERSION_SSL3; 335 ssl_config.version_max = SSL_PROTOCOL_VERSION_TLS1_1; 336 337 // Certificate provided by the host doesn't need authority. 338 net::SSLConfig::CertAndStatus cert_and_status; 339 cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; 340 cert_and_status.der_cert = cert_der; 341 ssl_config.allowed_bad_certs.push_back(cert_and_status); 342 343 net::HostPortPair host_and_pair("unittest", 0); 344 net::SSLClientSocketContext context; 345 context.cert_verifier = cert_verifier_.get(); 346 context.transport_security_state = transport_security_state_.get(); 347 client_socket_.reset( 348 socket_factory_->CreateSSLClientSocket( 349 fake_client_socket, host_and_pair, ssl_config, context)); 350 server_socket_.reset(net::CreateSSLServerSocket( 351 fake_server_socket, cert.get(), private_key.get(), net::SSLConfig())); 352 } 353 354 FakeDataChannel channel_1_; 355 FakeDataChannel channel_2_; 356 scoped_ptr<net::SSLClientSocket> client_socket_; 357 scoped_ptr<net::SSLServerSocket> server_socket_; 358 net::ClientSocketFactory* socket_factory_; 359 scoped_ptr<net::MockCertVerifier> cert_verifier_; 360 scoped_ptr<net::TransportSecurityState> transport_security_state_; 361 }; 362 363 // SSLServerSocket is only implemented using NSS. 364 #if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX) 365 366 // This test only executes creation of client and server sockets. This is to 367 // test that creation of sockets doesn't crash and have minimal code to run 368 // under valgrind in order to help debugging memory problems. 369 TEST_F(SSLServerSocketTest, Initialize) { 370 Initialize(); 371 } 372 373 // This test executes Connect() on SSLClientSocket and Handshake() on 374 // SSLServerSocket to make sure handshaking between the two sockets is 375 // completed successfully. 376 TEST_F(SSLServerSocketTest, Handshake) { 377 Initialize(); 378 379 TestCompletionCallback connect_callback; 380 TestCompletionCallback handshake_callback; 381 382 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 383 EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 384 385 int client_ret = client_socket_->Connect(connect_callback.callback()); 386 EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 387 388 if (client_ret == net::ERR_IO_PENDING) { 389 EXPECT_EQ(net::OK, connect_callback.WaitForResult()); 390 } 391 if (server_ret == net::ERR_IO_PENDING) { 392 EXPECT_EQ(net::OK, handshake_callback.WaitForResult()); 393 } 394 395 // Make sure the cert status is expected. 396 SSLInfo ssl_info; 397 client_socket_->GetSSLInfo(&ssl_info); 398 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status); 399 } 400 401 TEST_F(SSLServerSocketTest, DataTransfer) { 402 Initialize(); 403 404 TestCompletionCallback connect_callback; 405 TestCompletionCallback handshake_callback; 406 407 // Establish connection. 408 int client_ret = client_socket_->Connect(connect_callback.callback()); 409 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 410 411 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 412 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 413 414 client_ret = connect_callback.GetResult(client_ret); 415 ASSERT_EQ(net::OK, client_ret); 416 server_ret = handshake_callback.GetResult(server_ret); 417 ASSERT_EQ(net::OK, server_ret); 418 419 const int kReadBufSize = 1024; 420 scoped_refptr<net::StringIOBuffer> write_buf = 421 new net::StringIOBuffer("testing123"); 422 scoped_refptr<net::DrainableIOBuffer> read_buf = 423 new net::DrainableIOBuffer(new net::IOBuffer(kReadBufSize), 424 kReadBufSize); 425 426 // Write then read. 427 TestCompletionCallback write_callback; 428 TestCompletionCallback read_callback; 429 server_ret = server_socket_->Write( 430 write_buf.get(), write_buf->size(), write_callback.callback()); 431 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 432 client_ret = client_socket_->Read( 433 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 434 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 435 436 server_ret = write_callback.GetResult(server_ret); 437 EXPECT_GT(server_ret, 0); 438 client_ret = read_callback.GetResult(client_ret); 439 ASSERT_GT(client_ret, 0); 440 441 read_buf->DidConsume(client_ret); 442 while (read_buf->BytesConsumed() < write_buf->size()) { 443 client_ret = client_socket_->Read( 444 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 445 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 446 client_ret = read_callback.GetResult(client_ret); 447 ASSERT_GT(client_ret, 0); 448 read_buf->DidConsume(client_ret); 449 } 450 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); 451 read_buf->SetOffset(0); 452 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 453 454 // Read then write. 455 write_buf = new net::StringIOBuffer("hello123"); 456 server_ret = server_socket_->Read( 457 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 458 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 459 client_ret = client_socket_->Write( 460 write_buf.get(), write_buf->size(), write_callback.callback()); 461 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 462 463 server_ret = read_callback.GetResult(server_ret); 464 ASSERT_GT(server_ret, 0); 465 client_ret = write_callback.GetResult(client_ret); 466 EXPECT_GT(client_ret, 0); 467 468 read_buf->DidConsume(server_ret); 469 while (read_buf->BytesConsumed() < write_buf->size()) { 470 server_ret = server_socket_->Read( 471 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 472 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 473 server_ret = read_callback.GetResult(server_ret); 474 ASSERT_GT(server_ret, 0); 475 read_buf->DidConsume(server_ret); 476 } 477 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); 478 read_buf->SetOffset(0); 479 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 480 } 481 482 // A regression test for bug 127822 (http://crbug.com/127822). 483 // If the server closes the connection after the handshake is finished, 484 // the client's Write() call should not cause an infinite loop. 485 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket. 486 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { 487 Initialize(); 488 489 TestCompletionCallback connect_callback; 490 TestCompletionCallback handshake_callback; 491 492 // Establish connection. 493 int client_ret = client_socket_->Connect(connect_callback.callback()); 494 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 495 496 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 497 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 498 499 client_ret = connect_callback.GetResult(client_ret); 500 ASSERT_EQ(net::OK, client_ret); 501 server_ret = handshake_callback.GetResult(server_ret); 502 ASSERT_EQ(net::OK, server_ret); 503 504 scoped_refptr<net::StringIOBuffer> write_buf = 505 new net::StringIOBuffer("testing123"); 506 507 // The server closes the connection. The server needs to write some 508 // data first so that the client's Read() calls from the transport 509 // socket won't return ERR_IO_PENDING. This ensures that the client 510 // will call Read() on the transport socket again. 511 TestCompletionCallback write_callback; 512 513 server_ret = server_socket_->Write( 514 write_buf.get(), write_buf->size(), write_callback.callback()); 515 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 516 517 server_ret = write_callback.GetResult(server_ret); 518 EXPECT_GT(server_ret, 0); 519 520 server_socket_->Disconnect(); 521 522 // The client writes some data. This should not cause an infinite loop. 523 client_ret = client_socket_->Write( 524 write_buf.get(), write_buf->size(), write_callback.callback()); 525 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 526 527 client_ret = write_callback.GetResult(client_ret); 528 EXPECT_GT(client_ret, 0); 529 530 base::MessageLoop::current()->PostDelayedTask( 531 FROM_HERE, base::MessageLoop::QuitClosure(), 532 base::TimeDelta::FromMilliseconds(10)); 533 base::MessageLoop::current()->Run(); 534 } 535 536 // This test executes ExportKeyingMaterial() on the client and server sockets, 537 // after connecting them, and verifies that the results match. 538 // This test will fail if False Start is enabled (see crbug.com/90208). 539 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { 540 Initialize(); 541 542 TestCompletionCallback connect_callback; 543 TestCompletionCallback handshake_callback; 544 545 int client_ret = client_socket_->Connect(connect_callback.callback()); 546 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 547 548 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 549 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 550 551 if (client_ret == net::ERR_IO_PENDING) { 552 ASSERT_EQ(net::OK, connect_callback.WaitForResult()); 553 } 554 if (server_ret == net::ERR_IO_PENDING) { 555 ASSERT_EQ(net::OK, handshake_callback.WaitForResult()); 556 } 557 558 const int kKeyingMaterialSize = 32; 559 const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test"; 560 const char* kKeyingContext = ""; 561 unsigned char server_out[kKeyingMaterialSize]; 562 int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, 563 false, kKeyingContext, 564 server_out, sizeof(server_out)); 565 ASSERT_EQ(net::OK, rv); 566 567 unsigned char client_out[kKeyingMaterialSize]; 568 rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, 569 false, kKeyingContext, 570 client_out, sizeof(client_out)); 571 ASSERT_EQ(net::OK, rv); 572 EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out))); 573 574 const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad"; 575 unsigned char client_bad[kKeyingMaterialSize]; 576 rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, 577 false, kKeyingContext, 578 client_bad, sizeof(client_bad)); 579 ASSERT_EQ(rv, net::OK); 580 EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out))); 581 } 582 #endif 583 584 } // namespace net 585