Home | History | Annotate | Download | only in glue
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "jingle/glue/fake_ssl_client_socket.h"
      6 
      7 #include <algorithm>
      8 #include <vector>
      9 
     10 #include "base/basictypes.h"
     11 #include "base/memory/ref_counted.h"
     12 #include "base/memory/scoped_ptr.h"
     13 #include "base/message_loop/message_loop.h"
     14 #include "net/base/io_buffer.h"
     15 #include "net/base/net_log.h"
     16 #include "net/base/test_completion_callback.h"
     17 #include "net/socket/socket_test_util.h"
     18 #include "net/socket/stream_socket.h"
     19 #include "testing/gmock/include/gmock/gmock.h"
     20 #include "testing/gtest/include/gtest/gtest.h"
     21 
     22 namespace jingle_glue {
     23 
     24 namespace {
     25 
     26 using ::testing::Return;
     27 using ::testing::ReturnRef;
     28 
     29 // Used by RunUnsuccessfulHandshakeTestHelper.  Represents where in
     30 // the handshake step an error should be inserted.
     31 enum HandshakeErrorLocation {
     32   CONNECT_ERROR,
     33   SEND_CLIENT_HELLO_ERROR,
     34   VERIFY_SERVER_HELLO_ERROR,
     35 };
     36 
     37 // Private error codes appended to the net::Error set.
     38 enum {
     39   // An error representing a server hello that has been corrupted in
     40   // transit.
     41   ERR_MALFORMED_SERVER_HELLO = -15000,
     42 };
     43 
     44 // Used by PassThroughMethods test.
     45 class MockClientSocket : public net::StreamSocket {
     46  public:
     47   virtual ~MockClientSocket() {}
     48 
     49   MOCK_METHOD3(Read, int(net::IOBuffer*, int,
     50                          const net::CompletionCallback&));
     51   MOCK_METHOD3(Write, int(net::IOBuffer*, int,
     52                           const net::CompletionCallback&));
     53   MOCK_METHOD1(SetReceiveBufferSize, bool(int32));
     54   MOCK_METHOD1(SetSendBufferSize, bool(int32));
     55   MOCK_METHOD1(Connect, int(const net::CompletionCallback&));
     56   MOCK_METHOD0(Disconnect, void());
     57   MOCK_CONST_METHOD0(IsConnected, bool());
     58   MOCK_CONST_METHOD0(IsConnectedAndIdle, bool());
     59   MOCK_CONST_METHOD1(GetPeerAddress, int(net::IPEndPoint*));
     60   MOCK_CONST_METHOD1(GetLocalAddress, int(net::IPEndPoint*));
     61   MOCK_CONST_METHOD0(NetLog, const net::BoundNetLog&());
     62   MOCK_METHOD0(SetSubresourceSpeculation, void());
     63   MOCK_METHOD0(SetOmniboxSpeculation, void());
     64   MOCK_CONST_METHOD0(WasEverUsed, bool());
     65   MOCK_CONST_METHOD0(UsingTCPFastOpen, bool());
     66   MOCK_CONST_METHOD0(NumBytesRead, int64());
     67   MOCK_CONST_METHOD0(GetConnectTimeMicros, base::TimeDelta());
     68   MOCK_CONST_METHOD0(WasNpnNegotiated, bool());
     69   MOCK_CONST_METHOD0(GetNegotiatedProtocol, net::NextProto());
     70   MOCK_METHOD1(GetSSLInfo, bool(net::SSLInfo*));
     71 };
     72 
     73 // Break up |data| into a bunch of chunked MockReads/Writes and push
     74 // them onto |ops|.
     75 template <net::MockReadWriteType type>
     76 void AddChunkedOps(base::StringPiece data, size_t chunk_size, net::IoMode mode,
     77                    std::vector<net::MockReadWrite<type> >* ops) {
     78   DCHECK_GT(chunk_size, 0U);
     79   size_t offset = 0;
     80   while (offset < data.size()) {
     81     size_t bounded_chunk_size = std::min(data.size() - offset, chunk_size);
     82     ops->push_back(net::MockReadWrite<type>(mode, data.data() + offset,
     83                                             bounded_chunk_size));
     84     offset += bounded_chunk_size;
     85   }
     86 }
     87 
     88 class FakeSSLClientSocketTest : public testing::Test {
     89  protected:
     90   FakeSSLClientSocketTest() {}
     91 
     92   virtual ~FakeSSLClientSocketTest() {}
     93 
     94   scoped_ptr<net::StreamSocket> MakeClientSocket() {
     95     return mock_client_socket_factory_.CreateTransportClientSocket(
     96         net::AddressList(), NULL, net::NetLog::Source());
     97   }
     98 
     99   void SetData(const net::MockConnect& mock_connect,
    100                std::vector<net::MockRead>* reads,
    101                std::vector<net::MockWrite>* writes) {
    102     static_socket_data_provider_.reset(
    103         new net::StaticSocketDataProvider(
    104             reads->empty() ? NULL : &*reads->begin(), reads->size(),
    105             writes->empty() ? NULL : &*writes->begin(), writes->size()));
    106     static_socket_data_provider_->set_connect_data(mock_connect);
    107     mock_client_socket_factory_.AddSocketDataProvider(
    108         static_socket_data_provider_.get());
    109   }
    110 
    111   void ExpectStatus(
    112       net::IoMode mode, int expected_status, int immediate_status,
    113       net::TestCompletionCallback* test_completion_callback) {
    114     if (mode == net::ASYNC) {
    115       EXPECT_EQ(net::ERR_IO_PENDING, immediate_status);
    116       int status = test_completion_callback->WaitForResult();
    117       EXPECT_EQ(expected_status, status);
    118     } else {
    119       EXPECT_EQ(expected_status, immediate_status);
    120     }
    121   }
    122 
    123   // Sets up the mock socket to generate a successful handshake
    124   // (sliced up according to the parameters) and makes sure the
    125   // FakeSSLClientSocket behaves as expected.
    126   void RunSuccessfulHandshakeTest(
    127       net::IoMode mode, size_t read_chunk_size, size_t write_chunk_size,
    128       int num_resets) {
    129     base::StringPiece ssl_client_hello =
    130         FakeSSLClientSocket::GetSslClientHello();
    131     base::StringPiece ssl_server_hello =
    132         FakeSSLClientSocket::GetSslServerHello();
    133 
    134     net::MockConnect mock_connect(mode, net::OK);
    135     std::vector<net::MockRead> reads;
    136     std::vector<net::MockWrite> writes;
    137     static const char kReadTestData[] = "read test data";
    138     static const char kWriteTestData[] = "write test data";
    139     for (int i = 0; i < num_resets + 1; ++i) {
    140       SCOPED_TRACE(i);
    141       AddChunkedOps(ssl_server_hello, read_chunk_size, mode, &reads);
    142       AddChunkedOps(ssl_client_hello, write_chunk_size, mode, &writes);
    143       reads.push_back(
    144           net::MockRead(mode, kReadTestData, arraysize(kReadTestData)));
    145       writes.push_back(
    146           net::MockWrite(mode, kWriteTestData, arraysize(kWriteTestData)));
    147     }
    148     SetData(mock_connect, &reads, &writes);
    149 
    150     FakeSSLClientSocket fake_ssl_client_socket(MakeClientSocket());
    151 
    152     for (int i = 0; i < num_resets + 1; ++i) {
    153       SCOPED_TRACE(i);
    154       net::TestCompletionCallback test_completion_callback;
    155       int status = fake_ssl_client_socket.Connect(
    156           test_completion_callback.callback());
    157       if (mode == net::ASYNC) {
    158         EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
    159       }
    160       ExpectStatus(mode, net::OK, status, &test_completion_callback);
    161       if (fake_ssl_client_socket.IsConnected()) {
    162         int read_len = arraysize(kReadTestData);
    163         int read_buf_len = 2 * read_len;
    164         scoped_refptr<net::IOBuffer> read_buf(
    165             new net::IOBuffer(read_buf_len));
    166         int read_status = fake_ssl_client_socket.Read(
    167             read_buf.get(), read_buf_len, test_completion_callback.callback());
    168         ExpectStatus(mode, read_len, read_status, &test_completion_callback);
    169 
    170         scoped_refptr<net::IOBuffer> write_buf(
    171             new net::StringIOBuffer(kWriteTestData));
    172         int write_status =
    173             fake_ssl_client_socket.Write(write_buf.get(),
    174                                          arraysize(kWriteTestData),
    175                                          test_completion_callback.callback());
    176         ExpectStatus(mode, arraysize(kWriteTestData), write_status,
    177                      &test_completion_callback);
    178       } else {
    179         ADD_FAILURE();
    180       }
    181       fake_ssl_client_socket.Disconnect();
    182       EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
    183     }
    184   }
    185 
    186   // Sets up the mock socket to generate an unsuccessful handshake
    187   // FakeSSLClientSocket fails as expected.
    188   void RunUnsuccessfulHandshakeTestHelper(
    189       net::IoMode mode, int error, HandshakeErrorLocation location) {
    190     DCHECK_NE(error, net::OK);
    191     base::StringPiece ssl_client_hello =
    192         FakeSSLClientSocket::GetSslClientHello();
    193     base::StringPiece ssl_server_hello =
    194         FakeSSLClientSocket::GetSslServerHello();
    195 
    196     net::MockConnect mock_connect(mode, net::OK);
    197     std::vector<net::MockRead> reads;
    198     std::vector<net::MockWrite> writes;
    199     const size_t kChunkSize = 1;
    200     AddChunkedOps(ssl_server_hello, kChunkSize, mode, &reads);
    201     AddChunkedOps(ssl_client_hello, kChunkSize, mode, &writes);
    202     switch (location) {
    203       case CONNECT_ERROR:
    204         mock_connect.result = error;
    205         writes.clear();
    206         reads.clear();
    207         break;
    208       case SEND_CLIENT_HELLO_ERROR: {
    209         // Use a fixed index for repeatability.
    210         size_t index = 100 % writes.size();
    211         writes[index].result = error;
    212         writes[index].data = NULL;
    213         writes[index].data_len = 0;
    214         writes.resize(index + 1);
    215         reads.clear();
    216         break;
    217       }
    218       case VERIFY_SERVER_HELLO_ERROR: {
    219         // Use a fixed index for repeatability.
    220         size_t index = 50 % reads.size();
    221         if (error == ERR_MALFORMED_SERVER_HELLO) {
    222           static const char kBadData[] = "BAD_DATA";
    223           reads[index].data = kBadData;
    224           reads[index].data_len = arraysize(kBadData);
    225         } else {
    226           reads[index].result = error;
    227           reads[index].data = NULL;
    228           reads[index].data_len = 0;
    229         }
    230         reads.resize(index + 1);
    231         if (error ==
    232             net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
    233           static const char kDummyData[] = "DUMMY";
    234           reads.push_back(net::MockRead(mode, kDummyData));
    235         }
    236         break;
    237       }
    238     }
    239     SetData(mock_connect, &reads, &writes);
    240 
    241     FakeSSLClientSocket fake_ssl_client_socket(MakeClientSocket());
    242 
    243     // The two errors below are interpreted by FakeSSLClientSocket as
    244     // an unexpected event.
    245     int expected_status =
    246         ((error == net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) ||
    247          (error == ERR_MALFORMED_SERVER_HELLO)) ?
    248         net::ERR_UNEXPECTED : error;
    249 
    250     net::TestCompletionCallback test_completion_callback;
    251     int status = fake_ssl_client_socket.Connect(
    252         test_completion_callback.callback());
    253     EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
    254     ExpectStatus(mode, expected_status, status, &test_completion_callback);
    255     EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
    256   }
    257 
    258   void RunUnsuccessfulHandshakeTest(
    259       int error, HandshakeErrorLocation location) {
    260     RunUnsuccessfulHandshakeTestHelper(net::SYNCHRONOUS, error, location);
    261     RunUnsuccessfulHandshakeTestHelper(net::ASYNC, error, location);
    262   }
    263 
    264   // MockTCPClientSocket needs a message loop.
    265   base::MessageLoop message_loop_;
    266 
    267   net::MockClientSocketFactory mock_client_socket_factory_;
    268   scoped_ptr<net::StaticSocketDataProvider> static_socket_data_provider_;
    269 };
    270 
    271 TEST_F(FakeSSLClientSocketTest, PassThroughMethods) {
    272   scoped_ptr<MockClientSocket> mock_client_socket(new MockClientSocket());
    273   const int kReceiveBufferSize = 10;
    274   const int kSendBufferSize = 20;
    275   net::IPEndPoint ip_endpoint(net::IPAddressNumber(net::kIPv4AddressSize), 80);
    276   const int kPeerAddress = 30;
    277   net::BoundNetLog net_log;
    278   EXPECT_CALL(*mock_client_socket, SetReceiveBufferSize(kReceiveBufferSize));
    279   EXPECT_CALL(*mock_client_socket, SetSendBufferSize(kSendBufferSize));
    280   EXPECT_CALL(*mock_client_socket, GetPeerAddress(&ip_endpoint)).
    281       WillOnce(Return(kPeerAddress));
    282   EXPECT_CALL(*mock_client_socket, NetLog()).WillOnce(ReturnRef(net_log));
    283   EXPECT_CALL(*mock_client_socket, SetSubresourceSpeculation());
    284   EXPECT_CALL(*mock_client_socket, SetOmniboxSpeculation());
    285 
    286   // Takes ownership of |mock_client_socket|.
    287   FakeSSLClientSocket fake_ssl_client_socket(
    288       mock_client_socket.PassAs<net::StreamSocket>());
    289   fake_ssl_client_socket.SetReceiveBufferSize(kReceiveBufferSize);
    290   fake_ssl_client_socket.SetSendBufferSize(kSendBufferSize);
    291   EXPECT_EQ(kPeerAddress,
    292             fake_ssl_client_socket.GetPeerAddress(&ip_endpoint));
    293   EXPECT_EQ(&net_log, &fake_ssl_client_socket.NetLog());
    294   fake_ssl_client_socket.SetSubresourceSpeculation();
    295   fake_ssl_client_socket.SetOmniboxSpeculation();
    296 }
    297 
    298 TEST_F(FakeSSLClientSocketTest, SuccessfulHandshakeSync) {
    299   for (size_t i = 1; i < 100; i += 3) {
    300     SCOPED_TRACE(i);
    301     for (size_t j = 1; j < 100; j += 5) {
    302       SCOPED_TRACE(j);
    303       RunSuccessfulHandshakeTest(net::SYNCHRONOUS, i, j, 0);
    304     }
    305   }
    306 }
    307 
    308 TEST_F(FakeSSLClientSocketTest, SuccessfulHandshakeAsync) {
    309   for (size_t i = 1; i < 100; i += 7) {
    310     SCOPED_TRACE(i);
    311     for (size_t j = 1; j < 100; j += 9) {
    312       SCOPED_TRACE(j);
    313       RunSuccessfulHandshakeTest(net::ASYNC, i, j, 0);
    314     }
    315   }
    316 }
    317 
    318 TEST_F(FakeSSLClientSocketTest, ResetSocket) {
    319   RunSuccessfulHandshakeTest(net::ASYNC, 1, 2, 3);
    320 }
    321 
    322 TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeConnectError) {
    323   RunUnsuccessfulHandshakeTest(net::ERR_ACCESS_DENIED, CONNECT_ERROR);
    324 }
    325 
    326 TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeWriteError) {
    327   RunUnsuccessfulHandshakeTest(net::ERR_OUT_OF_MEMORY,
    328                                SEND_CLIENT_HELLO_ERROR);
    329 }
    330 
    331 TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeReadError) {
    332   RunUnsuccessfulHandshakeTest(net::ERR_CONNECTION_CLOSED,
    333                                VERIFY_SERVER_HELLO_ERROR);
    334 }
    335 
    336 TEST_F(FakeSSLClientSocketTest, PeerClosedDuringHandshake) {
    337   RunUnsuccessfulHandshakeTest(
    338       net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ,
    339       VERIFY_SERVER_HELLO_ERROR);
    340 }
    341 
    342 TEST_F(FakeSSLClientSocketTest, MalformedServerHello) {
    343   RunUnsuccessfulHandshakeTest(ERR_MALFORMED_SERVER_HELLO,
    344                                VERIFY_SERVER_HELLO_ERROR);
    345 }
    346 
    347 }  // namespace
    348 
    349 }  // namespace jingle_glue
    350