Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 // This test suite uses SSLClientSocket to test the implementation of
      6 // SSLServerSocket. In order to establish connections between the sockets
      7 // we need two additional classes:
      8 // 1. FakeSocket
      9 //    Connects SSL socket to FakeDataChannel. This class is just a stub.
     10 //
     11 // 2. FakeDataChannel
     12 //    Implements the actual exchange of data between two FakeSockets.
     13 //
     14 // Implementations of these two classes are included in this file.
     15 
     16 #include "net/socket/ssl_server_socket.h"
     17 
     18 #include <queue>
     19 
     20 #include "base/file_path.h"
     21 #include "base/file_util.h"
     22 #include "base/path_service.h"
     23 #include "crypto/nss_util.h"
     24 #include "crypto/rsa_private_key.h"
     25 #include "net/base/address_list.h"
     26 #include "net/base/cert_status_flags.h"
     27 #include "net/base/cert_verifier.h"
     28 #include "net/base/host_port_pair.h"
     29 #include "net/base/io_buffer.h"
     30 #include "net/base/ip_endpoint.h"
     31 #include "net/base/net_errors.h"
     32 #include "net/base/net_log.h"
     33 #include "net/base/ssl_config_service.h"
     34 #include "net/base/x509_certificate.h"
     35 #include "net/socket/client_socket.h"
     36 #include "net/socket/client_socket_factory.h"
     37 #include "net/socket/socket_test_util.h"
     38 #include "net/socket/ssl_client_socket.h"
     39 #include "testing/gtest/include/gtest/gtest.h"
     40 #include "testing/platform_test.h"
     41 
     42 namespace net {
     43 
     44 namespace {
     45 
     46 class FakeDataChannel {
     47  public:
     48   FakeDataChannel() : read_callback_(NULL), read_buf_len_(0) {
     49   }
     50 
     51   virtual int Read(IOBuffer* buf, int buf_len,
     52                    CompletionCallback* callback) {
     53     if (data_.empty()) {
     54       read_callback_ = callback;
     55       read_buf_ = buf;
     56       read_buf_len_ = buf_len;
     57       return net::ERR_IO_PENDING;
     58     }
     59     return PropogateData(buf, buf_len);
     60   }
     61 
     62   virtual int Write(IOBuffer* buf, int buf_len,
     63                     CompletionCallback* callback) {
     64     data_.push(new net::DrainableIOBuffer(buf, buf_len));
     65     DoReadCallback();
     66     return buf_len;
     67   }
     68 
     69  private:
     70   void DoReadCallback() {
     71     if (!read_callback_)
     72       return;
     73 
     74     int copied = PropogateData(read_buf_, read_buf_len_);
     75     net::CompletionCallback* callback = read_callback_;
     76     read_callback_ = NULL;
     77     read_buf_ = NULL;
     78     read_buf_len_ = 0;
     79     callback->Run(copied);
     80   }
     81 
     82   int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) {
     83     scoped_refptr<net::DrainableIOBuffer> buf = data_.front();
     84     int copied = std::min(buf->BytesRemaining(), read_buf_len);
     85     memcpy(read_buf->data(), buf->data(), copied);
     86     buf->DidConsume(copied);
     87 
     88     if (!buf->BytesRemaining())
     89       data_.pop();
     90     return copied;
     91   }
     92 
     93   net::CompletionCallback* read_callback_;
     94   scoped_refptr<net::IOBuffer> read_buf_;
     95   int read_buf_len_;
     96 
     97   std::queue<scoped_refptr<net::DrainableIOBuffer> > data_;
     98 
     99   DISALLOW_COPY_AND_ASSIGN(FakeDataChannel);
    100 };
    101 
    102 class FakeSocket : public ClientSocket {
    103  public:
    104   FakeSocket(FakeDataChannel* incoming_channel,
    105              FakeDataChannel* outgoing_channel)
    106       : incoming_(incoming_channel),
    107         outgoing_(outgoing_channel) {
    108   }
    109 
    110   virtual ~FakeSocket() {
    111 
    112   }
    113 
    114   virtual int Read(IOBuffer* buf, int buf_len,
    115                    CompletionCallback* callback) {
    116     return incoming_->Read(buf, buf_len, callback);
    117   }
    118 
    119   virtual int Write(IOBuffer* buf, int buf_len,
    120                     CompletionCallback* callback) {
    121     return outgoing_->Write(buf, buf_len, callback);
    122   }
    123 
    124   virtual bool SetReceiveBufferSize(int32 size) {
    125     return true;
    126   }
    127 
    128   virtual bool SetSendBufferSize(int32 size) {
    129     return true;
    130   }
    131 
    132   virtual int Connect(CompletionCallback* callback) {
    133     return net::OK;
    134   }
    135 
    136   virtual void Disconnect() {}
    137 
    138   virtual bool IsConnected() const {
    139     return true;
    140   }
    141 
    142   virtual bool IsConnectedAndIdle() const {
    143     return true;
    144   }
    145 
    146   virtual int GetPeerAddress(AddressList* address) const {
    147     net::IPAddressNumber ip_address(4);
    148     *address = net::AddressList(ip_address, 0, false);
    149     return net::OK;
    150   }
    151 
    152   virtual int GetLocalAddress(IPEndPoint* address) const {
    153     net::IPAddressNumber ip_address(4);
    154     *address = net::IPEndPoint(ip_address, 0);
    155     return net::OK;
    156   }
    157 
    158   virtual const BoundNetLog& NetLog() const {
    159     return net_log_;
    160   }
    161 
    162   virtual void SetSubresourceSpeculation() {}
    163   virtual void SetOmniboxSpeculation() {}
    164 
    165   virtual bool WasEverUsed() const {
    166     return true;
    167   }
    168 
    169   virtual bool UsingTCPFastOpen() const {
    170     return false;
    171   }
    172 
    173  private:
    174   net::BoundNetLog net_log_;
    175   FakeDataChannel* incoming_;
    176   FakeDataChannel* outgoing_;
    177 
    178   DISALLOW_COPY_AND_ASSIGN(FakeSocket);
    179 };
    180 
    181 }  // namespace
    182 
    183 // Verify the correctness of the test helper classes first.
    184 TEST(FakeSocketTest, DataTransfer) {
    185   // Establish channels between two sockets.
    186   FakeDataChannel channel_1;
    187   FakeDataChannel channel_2;
    188   FakeSocket client(&channel_1, &channel_2);
    189   FakeSocket server(&channel_2, &channel_1);
    190 
    191   const char kTestData[] = "testing123";
    192   const int kTestDataSize = strlen(kTestData);
    193   const int kReadBufSize = 1024;
    194   scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData);
    195   scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
    196 
    197   // Write then read.
    198   EXPECT_EQ(kTestDataSize, server.Write(write_buf, kTestDataSize, NULL));
    199   EXPECT_EQ(kTestDataSize, client.Read(read_buf, kReadBufSize, NULL));
    200   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize));
    201 
    202   // Read then write.
    203   TestCompletionCallback callback;
    204   EXPECT_EQ(net::ERR_IO_PENDING,
    205             server.Read(read_buf, kReadBufSize, &callback));
    206   EXPECT_EQ(kTestDataSize, client.Write(write_buf, kTestDataSize, NULL));
    207   EXPECT_EQ(kTestDataSize, callback.WaitForResult());
    208   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize));
    209 }
    210 
    211 class SSLServerSocketTest : public PlatformTest {
    212  public:
    213   SSLServerSocketTest()
    214       : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) {
    215   }
    216 
    217  protected:
    218   void Initialize() {
    219     FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_);
    220     FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_);
    221 
    222     FilePath certs_dir;
    223     PathService::Get(base::DIR_SOURCE_ROOT, &certs_dir);
    224     certs_dir = certs_dir.AppendASCII("net");
    225     certs_dir = certs_dir.AppendASCII("data");
    226     certs_dir = certs_dir.AppendASCII("ssl");
    227     certs_dir = certs_dir.AppendASCII("certificates");
    228 
    229     FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
    230     std::string cert_der;
    231     ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der));
    232 
    233     scoped_refptr<net::X509Certificate> cert =
    234         X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
    235 
    236     FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
    237     std::string key_string;
    238     ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string));
    239     std::vector<uint8> key_vector(
    240         reinterpret_cast<const uint8*>(key_string.data()),
    241         reinterpret_cast<const uint8*>(key_string.data() +
    242                                        key_string.length()));
    243 
    244     scoped_ptr<crypto::RSAPrivateKey> private_key(
    245         crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
    246 
    247     net::SSLConfig ssl_config;
    248     ssl_config.false_start_enabled = false;
    249     ssl_config.ssl3_enabled = true;
    250     ssl_config.tls1_enabled = true;
    251 
    252     // Certificate provided by the host doesn't need authority.
    253     net::SSLConfig::CertAndStatus cert_and_status;
    254     cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
    255     cert_and_status.cert = cert;
    256     ssl_config.allowed_bad_certs.push_back(cert_and_status);
    257 
    258     net::HostPortPair host_and_pair("unittest", 0);
    259     client_socket_.reset(
    260         socket_factory_->CreateSSLClientSocket(
    261             fake_client_socket, host_and_pair, ssl_config, NULL,
    262             &cert_verifier_));
    263     server_socket_.reset(net::CreateSSLServerSocket(fake_server_socket,
    264                                                     cert, private_key.get(),
    265                                                     net::SSLConfig()));
    266   }
    267 
    268   FakeDataChannel channel_1_;
    269   FakeDataChannel channel_2_;
    270   scoped_ptr<net::SSLClientSocket> client_socket_;
    271   scoped_ptr<net::SSLServerSocket> server_socket_;
    272   net::ClientSocketFactory* socket_factory_;
    273   net::CertVerifier cert_verifier_;
    274 };
    275 
    276 // SSLServerSocket is only implemented using NSS.
    277 #if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX)
    278 
    279 // This test only executes creation of client and server sockets. This is to
    280 // test that creation of sockets doesn't crash and have minimal code to run
    281 // under valgrind in order to help debugging memory problems.
    282 TEST_F(SSLServerSocketTest, Initialize) {
    283   Initialize();
    284 }
    285 
    286 // This test executes Connect() of SSLClientSocket and Accept() of
    287 // SSLServerSocket to make sure handshaking between the two sockets are
    288 // completed successfully.
    289 TEST_F(SSLServerSocketTest, Handshake) {
    290   Initialize();
    291 
    292   TestCompletionCallback connect_callback;
    293   TestCompletionCallback accept_callback;
    294 
    295   int server_ret = server_socket_->Accept(&accept_callback);
    296   EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
    297 
    298   int client_ret = client_socket_->Connect(&connect_callback);
    299   EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
    300 
    301   if (client_ret == net::ERR_IO_PENDING) {
    302     EXPECT_EQ(net::OK, connect_callback.WaitForResult());
    303   }
    304   if (server_ret == net::ERR_IO_PENDING) {
    305     EXPECT_EQ(net::OK, accept_callback.WaitForResult());
    306   }
    307 }
    308 
    309 TEST_F(SSLServerSocketTest, DataTransfer) {
    310   Initialize();
    311 
    312   TestCompletionCallback connect_callback;
    313   TestCompletionCallback accept_callback;
    314 
    315   // Establish connection.
    316   int client_ret = client_socket_->Connect(&connect_callback);
    317   ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
    318 
    319   int server_ret = server_socket_->Accept(&accept_callback);
    320   ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
    321 
    322   if (client_ret == net::ERR_IO_PENDING) {
    323     ASSERT_EQ(net::OK, connect_callback.WaitForResult());
    324   }
    325   if (server_ret == net::ERR_IO_PENDING) {
    326     ASSERT_EQ(net::OK, accept_callback.WaitForResult());
    327   }
    328 
    329   const int kReadBufSize = 1024;
    330   scoped_refptr<net::StringIOBuffer> write_buf =
    331       new net::StringIOBuffer("testing123");
    332   scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
    333 
    334   // Write then read.
    335   TestCompletionCallback write_callback;
    336   TestCompletionCallback read_callback;
    337   server_ret = server_socket_->Write(write_buf, write_buf->size(),
    338                                      &write_callback);
    339   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
    340   client_ret = client_socket_->Read(read_buf, kReadBufSize, &read_callback);
    341   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
    342 
    343   if (server_ret == net::ERR_IO_PENDING) {
    344     EXPECT_GT(write_callback.WaitForResult(), 0);
    345   }
    346   if (client_ret == net::ERR_IO_PENDING) {
    347     EXPECT_GT(read_callback.WaitForResult(), 0);
    348   }
    349   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
    350 
    351   // Read then write.
    352   write_buf = new net::StringIOBuffer("hello123");
    353   server_ret = server_socket_->Read(read_buf, kReadBufSize, &read_callback);
    354   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
    355   client_ret = client_socket_->Write(write_buf, write_buf->size(),
    356                                      &write_callback);
    357   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
    358 
    359   if (server_ret == net::ERR_IO_PENDING) {
    360     EXPECT_GT(read_callback.WaitForResult(), 0);
    361   }
    362   if (client_ret == net::ERR_IO_PENDING) {
    363     EXPECT_GT(write_callback.WaitForResult(), 0);
    364   }
    365   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
    366 }
    367 #endif
    368 
    369 }  // namespace net
    370