1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 // This test suite uses SSLClientSocket to test the implementation of 6 // SSLServerSocket. In order to establish connections between the sockets 7 // we need two additional classes: 8 // 1. FakeSocket 9 // Connects SSL socket to FakeDataChannel. This class is just a stub. 10 // 11 // 2. FakeDataChannel 12 // Implements the actual exchange of data between two FakeSockets. 13 // 14 // Implementations of these two classes are included in this file. 15 16 #include "net/socket/ssl_server_socket.h" 17 18 #include <queue> 19 20 #include "base/file_path.h" 21 #include "base/file_util.h" 22 #include "base/path_service.h" 23 #include "crypto/nss_util.h" 24 #include "crypto/rsa_private_key.h" 25 #include "net/base/address_list.h" 26 #include "net/base/cert_status_flags.h" 27 #include "net/base/cert_verifier.h" 28 #include "net/base/host_port_pair.h" 29 #include "net/base/io_buffer.h" 30 #include "net/base/ip_endpoint.h" 31 #include "net/base/net_errors.h" 32 #include "net/base/net_log.h" 33 #include "net/base/ssl_config_service.h" 34 #include "net/base/x509_certificate.h" 35 #include "net/socket/client_socket.h" 36 #include "net/socket/client_socket_factory.h" 37 #include "net/socket/socket_test_util.h" 38 #include "net/socket/ssl_client_socket.h" 39 #include "testing/gtest/include/gtest/gtest.h" 40 #include "testing/platform_test.h" 41 42 namespace net { 43 44 namespace { 45 46 class FakeDataChannel { 47 public: 48 FakeDataChannel() : read_callback_(NULL), read_buf_len_(0) { 49 } 50 51 virtual int Read(IOBuffer* buf, int buf_len, 52 CompletionCallback* callback) { 53 if (data_.empty()) { 54 read_callback_ = callback; 55 read_buf_ = buf; 56 read_buf_len_ = buf_len; 57 return net::ERR_IO_PENDING; 58 } 59 return PropogateData(buf, buf_len); 60 } 61 62 virtual int Write(IOBuffer* buf, int buf_len, 63 CompletionCallback* callback) { 64 data_.push(new net::DrainableIOBuffer(buf, buf_len)); 65 DoReadCallback(); 66 return buf_len; 67 } 68 69 private: 70 void DoReadCallback() { 71 if (!read_callback_) 72 return; 73 74 int copied = PropogateData(read_buf_, read_buf_len_); 75 net::CompletionCallback* callback = read_callback_; 76 read_callback_ = NULL; 77 read_buf_ = NULL; 78 read_buf_len_ = 0; 79 callback->Run(copied); 80 } 81 82 int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { 83 scoped_refptr<net::DrainableIOBuffer> buf = data_.front(); 84 int copied = std::min(buf->BytesRemaining(), read_buf_len); 85 memcpy(read_buf->data(), buf->data(), copied); 86 buf->DidConsume(copied); 87 88 if (!buf->BytesRemaining()) 89 data_.pop(); 90 return copied; 91 } 92 93 net::CompletionCallback* read_callback_; 94 scoped_refptr<net::IOBuffer> read_buf_; 95 int read_buf_len_; 96 97 std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; 98 99 DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); 100 }; 101 102 class FakeSocket : public ClientSocket { 103 public: 104 FakeSocket(FakeDataChannel* incoming_channel, 105 FakeDataChannel* outgoing_channel) 106 : incoming_(incoming_channel), 107 outgoing_(outgoing_channel) { 108 } 109 110 virtual ~FakeSocket() { 111 112 } 113 114 virtual int Read(IOBuffer* buf, int buf_len, 115 CompletionCallback* callback) { 116 return incoming_->Read(buf, buf_len, callback); 117 } 118 119 virtual int Write(IOBuffer* buf, int buf_len, 120 CompletionCallback* callback) { 121 return outgoing_->Write(buf, buf_len, callback); 122 } 123 124 virtual bool SetReceiveBufferSize(int32 size) { 125 return true; 126 } 127 128 virtual bool SetSendBufferSize(int32 size) { 129 return true; 130 } 131 132 virtual int Connect(CompletionCallback* callback) { 133 return net::OK; 134 } 135 136 virtual void Disconnect() {} 137 138 virtual bool IsConnected() const { 139 return true; 140 } 141 142 virtual bool IsConnectedAndIdle() const { 143 return true; 144 } 145 146 virtual int GetPeerAddress(AddressList* address) const { 147 net::IPAddressNumber ip_address(4); 148 *address = net::AddressList(ip_address, 0, false); 149 return net::OK; 150 } 151 152 virtual int GetLocalAddress(IPEndPoint* address) const { 153 net::IPAddressNumber ip_address(4); 154 *address = net::IPEndPoint(ip_address, 0); 155 return net::OK; 156 } 157 158 virtual const BoundNetLog& NetLog() const { 159 return net_log_; 160 } 161 162 virtual void SetSubresourceSpeculation() {} 163 virtual void SetOmniboxSpeculation() {} 164 165 virtual bool WasEverUsed() const { 166 return true; 167 } 168 169 virtual bool UsingTCPFastOpen() const { 170 return false; 171 } 172 173 private: 174 net::BoundNetLog net_log_; 175 FakeDataChannel* incoming_; 176 FakeDataChannel* outgoing_; 177 178 DISALLOW_COPY_AND_ASSIGN(FakeSocket); 179 }; 180 181 } // namespace 182 183 // Verify the correctness of the test helper classes first. 184 TEST(FakeSocketTest, DataTransfer) { 185 // Establish channels between two sockets. 186 FakeDataChannel channel_1; 187 FakeDataChannel channel_2; 188 FakeSocket client(&channel_1, &channel_2); 189 FakeSocket server(&channel_2, &channel_1); 190 191 const char kTestData[] = "testing123"; 192 const int kTestDataSize = strlen(kTestData); 193 const int kReadBufSize = 1024; 194 scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData); 195 scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); 196 197 // Write then read. 198 EXPECT_EQ(kTestDataSize, server.Write(write_buf, kTestDataSize, NULL)); 199 EXPECT_EQ(kTestDataSize, client.Read(read_buf, kReadBufSize, NULL)); 200 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize)); 201 202 // Read then write. 203 TestCompletionCallback callback; 204 EXPECT_EQ(net::ERR_IO_PENDING, 205 server.Read(read_buf, kReadBufSize, &callback)); 206 EXPECT_EQ(kTestDataSize, client.Write(write_buf, kTestDataSize, NULL)); 207 EXPECT_EQ(kTestDataSize, callback.WaitForResult()); 208 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize)); 209 } 210 211 class SSLServerSocketTest : public PlatformTest { 212 public: 213 SSLServerSocketTest() 214 : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) { 215 } 216 217 protected: 218 void Initialize() { 219 FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_); 220 FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_); 221 222 FilePath certs_dir; 223 PathService::Get(base::DIR_SOURCE_ROOT, &certs_dir); 224 certs_dir = certs_dir.AppendASCII("net"); 225 certs_dir = certs_dir.AppendASCII("data"); 226 certs_dir = certs_dir.AppendASCII("ssl"); 227 certs_dir = certs_dir.AppendASCII("certificates"); 228 229 FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); 230 std::string cert_der; 231 ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der)); 232 233 scoped_refptr<net::X509Certificate> cert = 234 X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); 235 236 FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); 237 std::string key_string; 238 ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string)); 239 std::vector<uint8> key_vector( 240 reinterpret_cast<const uint8*>(key_string.data()), 241 reinterpret_cast<const uint8*>(key_string.data() + 242 key_string.length())); 243 244 scoped_ptr<crypto::RSAPrivateKey> private_key( 245 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); 246 247 net::SSLConfig ssl_config; 248 ssl_config.false_start_enabled = false; 249 ssl_config.ssl3_enabled = true; 250 ssl_config.tls1_enabled = true; 251 252 // Certificate provided by the host doesn't need authority. 253 net::SSLConfig::CertAndStatus cert_and_status; 254 cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID; 255 cert_and_status.cert = cert; 256 ssl_config.allowed_bad_certs.push_back(cert_and_status); 257 258 net::HostPortPair host_and_pair("unittest", 0); 259 client_socket_.reset( 260 socket_factory_->CreateSSLClientSocket( 261 fake_client_socket, host_and_pair, ssl_config, NULL, 262 &cert_verifier_)); 263 server_socket_.reset(net::CreateSSLServerSocket(fake_server_socket, 264 cert, private_key.get(), 265 net::SSLConfig())); 266 } 267 268 FakeDataChannel channel_1_; 269 FakeDataChannel channel_2_; 270 scoped_ptr<net::SSLClientSocket> client_socket_; 271 scoped_ptr<net::SSLServerSocket> server_socket_; 272 net::ClientSocketFactory* socket_factory_; 273 net::CertVerifier cert_verifier_; 274 }; 275 276 // SSLServerSocket is only implemented using NSS. 277 #if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX) 278 279 // This test only executes creation of client and server sockets. This is to 280 // test that creation of sockets doesn't crash and have minimal code to run 281 // under valgrind in order to help debugging memory problems. 282 TEST_F(SSLServerSocketTest, Initialize) { 283 Initialize(); 284 } 285 286 // This test executes Connect() of SSLClientSocket and Accept() of 287 // SSLServerSocket to make sure handshaking between the two sockets are 288 // completed successfully. 289 TEST_F(SSLServerSocketTest, Handshake) { 290 Initialize(); 291 292 TestCompletionCallback connect_callback; 293 TestCompletionCallback accept_callback; 294 295 int server_ret = server_socket_->Accept(&accept_callback); 296 EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 297 298 int client_ret = client_socket_->Connect(&connect_callback); 299 EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 300 301 if (client_ret == net::ERR_IO_PENDING) { 302 EXPECT_EQ(net::OK, connect_callback.WaitForResult()); 303 } 304 if (server_ret == net::ERR_IO_PENDING) { 305 EXPECT_EQ(net::OK, accept_callback.WaitForResult()); 306 } 307 } 308 309 TEST_F(SSLServerSocketTest, DataTransfer) { 310 Initialize(); 311 312 TestCompletionCallback connect_callback; 313 TestCompletionCallback accept_callback; 314 315 // Establish connection. 316 int client_ret = client_socket_->Connect(&connect_callback); 317 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 318 319 int server_ret = server_socket_->Accept(&accept_callback); 320 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 321 322 if (client_ret == net::ERR_IO_PENDING) { 323 ASSERT_EQ(net::OK, connect_callback.WaitForResult()); 324 } 325 if (server_ret == net::ERR_IO_PENDING) { 326 ASSERT_EQ(net::OK, accept_callback.WaitForResult()); 327 } 328 329 const int kReadBufSize = 1024; 330 scoped_refptr<net::StringIOBuffer> write_buf = 331 new net::StringIOBuffer("testing123"); 332 scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); 333 334 // Write then read. 335 TestCompletionCallback write_callback; 336 TestCompletionCallback read_callback; 337 server_ret = server_socket_->Write(write_buf, write_buf->size(), 338 &write_callback); 339 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 340 client_ret = client_socket_->Read(read_buf, kReadBufSize, &read_callback); 341 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 342 343 if (server_ret == net::ERR_IO_PENDING) { 344 EXPECT_GT(write_callback.WaitForResult(), 0); 345 } 346 if (client_ret == net::ERR_IO_PENDING) { 347 EXPECT_GT(read_callback.WaitForResult(), 0); 348 } 349 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 350 351 // Read then write. 352 write_buf = new net::StringIOBuffer("hello123"); 353 server_ret = server_socket_->Read(read_buf, kReadBufSize, &read_callback); 354 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 355 client_ret = client_socket_->Write(write_buf, write_buf->size(), 356 &write_callback); 357 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 358 359 if (server_ret == net::ERR_IO_PENDING) { 360 EXPECT_GT(read_callback.WaitForResult(), 0); 361 } 362 if (client_ret == net::ERR_IO_PENDING) { 363 EXPECT_GT(write_callback.WaitForResult(), 0); 364 } 365 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 366 } 367 #endif 368 369 } // namespace net 370