Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket.h"
      6 
      7 #include "base/memory/scoped_ptr.h"
      8 #include "net/base/address_list.h"
      9 #include "net/base/net_log.h"
     10 #include "net/base/net_log_unittest.h"
     11 #include "net/base/test_completion_callback.h"
     12 #include "net/base/winsock_init.h"
     13 #include "net/dns/mock_host_resolver.h"
     14 #include "net/socket/client_socket_factory.h"
     15 #include "net/socket/socket_test_util.h"
     16 #include "net/socket/tcp_client_socket.h"
     17 #include "testing/gtest/include/gtest/gtest.h"
     18 #include "testing/platform_test.h"
     19 
     20 //-----------------------------------------------------------------------------
     21 
     22 namespace net {
     23 
     24 const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
     25 const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
     26 
     27 class SOCKSClientSocketTest : public PlatformTest {
     28  public:
     29   SOCKSClientSocketTest();
     30   // Create a SOCKSClientSocket on top of a MockSocket.
     31   scoped_ptr<SOCKSClientSocket> BuildMockSocket(
     32       MockRead reads[], size_t reads_count,
     33       MockWrite writes[], size_t writes_count,
     34       HostResolver* host_resolver,
     35       const std::string& hostname, int port,
     36       NetLog* net_log);
     37   virtual void SetUp();
     38 
     39  protected:
     40   scoped_ptr<SOCKSClientSocket> user_sock_;
     41   AddressList address_list_;
     42   // Filled in by BuildMockSocket() and owned by its return value
     43   // (which |user_sock| is set to).
     44   StreamSocket* tcp_sock_;
     45   TestCompletionCallback callback_;
     46   scoped_ptr<MockHostResolver> host_resolver_;
     47   scoped_ptr<SocketDataProvider> data_;
     48 };
     49 
     50 SOCKSClientSocketTest::SOCKSClientSocketTest()
     51   : host_resolver_(new MockHostResolver) {
     52 }
     53 
     54 // Set up platform before every test case
     55 void SOCKSClientSocketTest::SetUp() {
     56   PlatformTest::SetUp();
     57 }
     58 
     59 scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
     60     MockRead reads[],
     61     size_t reads_count,
     62     MockWrite writes[],
     63     size_t writes_count,
     64     HostResolver* host_resolver,
     65     const std::string& hostname,
     66     int port,
     67     NetLog* net_log) {
     68 
     69   TestCompletionCallback callback;
     70   data_.reset(new StaticSocketDataProvider(reads, reads_count,
     71                                            writes, writes_count));
     72   tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
     73 
     74   int rv = tcp_sock_->Connect(callback.callback());
     75   EXPECT_EQ(ERR_IO_PENDING, rv);
     76   rv = callback.WaitForResult();
     77   EXPECT_EQ(OK, rv);
     78   EXPECT_TRUE(tcp_sock_->IsConnected());
     79 
     80   scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
     81   // |connection| takes ownership of |tcp_sock_|, but keep a
     82   // non-owning pointer to it.
     83   connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
     84   return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
     85       connection.Pass(),
     86       HostResolver::RequestInfo(HostPortPair(hostname, port)),
     87       DEFAULT_PRIORITY,
     88       host_resolver));
     89 }
     90 
     91 // Implementation of HostResolver that never completes its resolve request.
     92 // We use this in the test "DisconnectWhileHostResolveInProgress" to make
     93 // sure that the outstanding resolve request gets cancelled.
     94 class HangingHostResolverWithCancel : public HostResolver {
     95  public:
     96   HangingHostResolverWithCancel() : outstanding_request_(NULL) {}
     97 
     98   virtual int Resolve(const RequestInfo& info,
     99                       RequestPriority priority,
    100                       AddressList* addresses,
    101                       const CompletionCallback& callback,
    102                       RequestHandle* out_req,
    103                       const BoundNetLog& net_log) OVERRIDE {
    104     DCHECK(addresses);
    105     DCHECK_EQ(false, callback.is_null());
    106     EXPECT_FALSE(HasOutstandingRequest());
    107     outstanding_request_ = reinterpret_cast<RequestHandle>(1);
    108     *out_req = outstanding_request_;
    109     return ERR_IO_PENDING;
    110   }
    111 
    112   virtual int ResolveFromCache(const RequestInfo& info,
    113                                AddressList* addresses,
    114                                const BoundNetLog& net_log) OVERRIDE {
    115     NOTIMPLEMENTED();
    116     return ERR_UNEXPECTED;
    117   }
    118 
    119   virtual void CancelRequest(RequestHandle req) OVERRIDE {
    120     EXPECT_TRUE(HasOutstandingRequest());
    121     EXPECT_EQ(outstanding_request_, req);
    122     outstanding_request_ = NULL;
    123   }
    124 
    125   bool HasOutstandingRequest() {
    126     return outstanding_request_ != NULL;
    127   }
    128 
    129  private:
    130   RequestHandle outstanding_request_;
    131 
    132   DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel);
    133 };
    134 
    135 // Tests a complete handshake and the disconnection.
    136 TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
    137   const std::string payload_write = "random data";
    138   const std::string payload_read = "moar random data";
    139 
    140   MockWrite data_writes[] = {
    141       MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)),
    142       MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
    143   MockRead data_reads[] = {
    144       MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
    145       MockRead(ASYNC, payload_read.data(), payload_read.size()) };
    146   CapturingNetLog log;
    147 
    148   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    149                                data_writes, arraysize(data_writes),
    150                                host_resolver_.get(),
    151                                "localhost", 80,
    152                                &log);
    153 
    154   // At this state the TCP connection is completed but not the SOCKS handshake.
    155   EXPECT_TRUE(tcp_sock_->IsConnected());
    156   EXPECT_FALSE(user_sock_->IsConnected());
    157 
    158   int rv = user_sock_->Connect(callback_.callback());
    159   EXPECT_EQ(ERR_IO_PENDING, rv);
    160 
    161   CapturingNetLog::CapturedEntryList entries;
    162   log.GetEntries(&entries);
    163   EXPECT_TRUE(
    164       LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    165   EXPECT_FALSE(user_sock_->IsConnected());
    166 
    167   rv = callback_.WaitForResult();
    168   EXPECT_EQ(OK, rv);
    169   EXPECT_TRUE(user_sock_->IsConnected());
    170   log.GetEntries(&entries);
    171   EXPECT_TRUE(LogContainsEndEvent(
    172       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    173 
    174   scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
    175   memcpy(buffer->data(), payload_write.data(), payload_write.size());
    176   rv = user_sock_->Write(
    177       buffer.get(), payload_write.size(), callback_.callback());
    178   EXPECT_EQ(ERR_IO_PENDING, rv);
    179   rv = callback_.WaitForResult();
    180   EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
    181 
    182   buffer = new IOBuffer(payload_read.size());
    183   rv =
    184       user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
    185   EXPECT_EQ(ERR_IO_PENDING, rv);
    186   rv = callback_.WaitForResult();
    187   EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
    188   EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
    189 
    190   user_sock_->Disconnect();
    191   EXPECT_FALSE(tcp_sock_->IsConnected());
    192   EXPECT_FALSE(user_sock_->IsConnected());
    193 }
    194 
    195 // List of responses from the socks server and the errors they should
    196 // throw up are tested here.
    197 TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
    198   const struct {
    199     const char fail_reply[8];
    200     Error fail_code;
    201   } tests[] = {
    202     // Failure of the server response code
    203     {
    204       { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
    205       ERR_SOCKS_CONNECTION_FAILED,
    206     },
    207     // Failure of the null byte
    208     {
    209       { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
    210       ERR_SOCKS_CONNECTION_FAILED,
    211     },
    212   };
    213 
    214   //---------------------------------------
    215 
    216   for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
    217     MockWrite data_writes[] = {
    218         MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
    219     MockRead data_reads[] = {
    220         MockRead(SYNCHRONOUS, tests[i].fail_reply,
    221                  arraysize(tests[i].fail_reply)) };
    222     CapturingNetLog log;
    223 
    224     user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    225                                  data_writes, arraysize(data_writes),
    226                                  host_resolver_.get(),
    227                                  "localhost", 80,
    228                                  &log);
    229 
    230     int rv = user_sock_->Connect(callback_.callback());
    231     EXPECT_EQ(ERR_IO_PENDING, rv);
    232 
    233     CapturingNetLog::CapturedEntryList entries;
    234     log.GetEntries(&entries);
    235     EXPECT_TRUE(LogContainsBeginEvent(
    236         entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    237 
    238     rv = callback_.WaitForResult();
    239     EXPECT_EQ(tests[i].fail_code, rv);
    240     EXPECT_FALSE(user_sock_->IsConnected());
    241     EXPECT_TRUE(tcp_sock_->IsConnected());
    242     log.GetEntries(&entries);
    243     EXPECT_TRUE(LogContainsEndEvent(
    244         entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    245   }
    246 }
    247 
    248 // Tests scenario when the server sends the handshake response in
    249 // more than one packet.
    250 TEST_F(SOCKSClientSocketTest, PartialServerReads) {
    251   const char kSOCKSPartialReply1[] = { 0x00 };
    252   const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
    253 
    254   MockWrite data_writes[] = {
    255       MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
    256   MockRead data_reads[] = {
    257       MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
    258       MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
    259   CapturingNetLog log;
    260 
    261   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    262                                data_writes, arraysize(data_writes),
    263                                host_resolver_.get(),
    264                                "localhost", 80,
    265                                &log);
    266 
    267   int rv = user_sock_->Connect(callback_.callback());
    268   EXPECT_EQ(ERR_IO_PENDING, rv);
    269   CapturingNetLog::CapturedEntryList entries;
    270   log.GetEntries(&entries);
    271   EXPECT_TRUE(LogContainsBeginEvent(
    272       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    273 
    274   rv = callback_.WaitForResult();
    275   EXPECT_EQ(OK, rv);
    276   EXPECT_TRUE(user_sock_->IsConnected());
    277   log.GetEntries(&entries);
    278   EXPECT_TRUE(LogContainsEndEvent(
    279       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    280 }
    281 
    282 // Tests scenario when the client sends the handshake request in
    283 // more than one packet.
    284 TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
    285   const char kSOCKSPartialRequest1[] = { 0x04, 0x01 };
    286   const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
    287 
    288   MockWrite data_writes[] = {
    289       MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)),
    290       // simulate some empty writes
    291       MockWrite(ASYNC, 0),
    292       MockWrite(ASYNC, 0),
    293       MockWrite(ASYNC, kSOCKSPartialRequest2,
    294                 arraysize(kSOCKSPartialRequest2)) };
    295   MockRead data_reads[] = {
    296       MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
    297   CapturingNetLog log;
    298 
    299   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    300                                data_writes, arraysize(data_writes),
    301                                host_resolver_.get(),
    302                                "localhost", 80,
    303                                &log);
    304 
    305   int rv = user_sock_->Connect(callback_.callback());
    306   EXPECT_EQ(ERR_IO_PENDING, rv);
    307   CapturingNetLog::CapturedEntryList entries;
    308   log.GetEntries(&entries);
    309   EXPECT_TRUE(LogContainsBeginEvent(
    310       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    311 
    312   rv = callback_.WaitForResult();
    313   EXPECT_EQ(OK, rv);
    314   EXPECT_TRUE(user_sock_->IsConnected());
    315   log.GetEntries(&entries);
    316   EXPECT_TRUE(LogContainsEndEvent(
    317       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    318 }
    319 
    320 // Tests the case when the server sends a smaller sized handshake data
    321 // and closes the connection.
    322 TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
    323   MockWrite data_writes[] = {
    324       MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
    325   MockRead data_reads[] = {
    326       MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2),
    327       // close connection unexpectedly
    328       MockRead(SYNCHRONOUS, 0) };
    329   CapturingNetLog log;
    330 
    331   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    332                                data_writes, arraysize(data_writes),
    333                                host_resolver_.get(),
    334                                "localhost", 80,
    335                                &log);
    336 
    337   int rv = user_sock_->Connect(callback_.callback());
    338   EXPECT_EQ(ERR_IO_PENDING, rv);
    339   CapturingNetLog::CapturedEntryList entries;
    340   log.GetEntries(&entries);
    341   EXPECT_TRUE(LogContainsBeginEvent(
    342       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    343 
    344   rv = callback_.WaitForResult();
    345   EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
    346   EXPECT_FALSE(user_sock_->IsConnected());
    347   log.GetEntries(&entries);
    348   EXPECT_TRUE(LogContainsEndEvent(
    349       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    350 }
    351 
    352 // Tries to connect to an unknown hostname. Should fail rather than
    353 // falling back to SOCKS4a.
    354 TEST_F(SOCKSClientSocketTest, FailedDNS) {
    355   const char hostname[] = "unresolved.ipv4.address";
    356 
    357   host_resolver_->rules()->AddSimulatedFailure(hostname);
    358 
    359   CapturingNetLog log;
    360 
    361   user_sock_ = BuildMockSocket(NULL, 0,
    362                                NULL, 0,
    363                                host_resolver_.get(),
    364                                hostname, 80,
    365                                &log);
    366 
    367   int rv = user_sock_->Connect(callback_.callback());
    368   EXPECT_EQ(ERR_IO_PENDING, rv);
    369   CapturingNetLog::CapturedEntryList entries;
    370   log.GetEntries(&entries);
    371   EXPECT_TRUE(LogContainsBeginEvent(
    372       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    373 
    374   rv = callback_.WaitForResult();
    375   EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
    376   EXPECT_FALSE(user_sock_->IsConnected());
    377   log.GetEntries(&entries);
    378   EXPECT_TRUE(LogContainsEndEvent(
    379       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    380 }
    381 
    382 // Calls Disconnect() while a host resolve is in progress. The outstanding host
    383 // resolve should be cancelled.
    384 TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
    385   scoped_ptr<HangingHostResolverWithCancel> hanging_resolver(
    386     new HangingHostResolverWithCancel());
    387 
    388   // Doesn't matter what the socket data is, we will never use it -- garbage.
    389   MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) };
    390   MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) };
    391 
    392   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    393                                data_writes, arraysize(data_writes),
    394                                hanging_resolver.get(),
    395                                "foo", 80,
    396                                NULL);
    397 
    398   // Start connecting (will get stuck waiting for the host to resolve).
    399   int rv = user_sock_->Connect(callback_.callback());
    400   EXPECT_EQ(ERR_IO_PENDING, rv);
    401 
    402   EXPECT_FALSE(user_sock_->IsConnected());
    403   EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
    404 
    405   // The host resolver should have received the resolve request.
    406   EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
    407 
    408   // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
    409   user_sock_->Disconnect();
    410 
    411   EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
    412 
    413   EXPECT_FALSE(user_sock_->IsConnected());
    414   EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
    415 }
    416 
    417 }  // namespace net
    418