Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket.h"
      6 
      7 #include "base/memory/scoped_ptr.h"
      8 #include "net/base/address_list.h"
      9 #include "net/base/net_log.h"
     10 #include "net/base/net_log_unittest.h"
     11 #include "net/base/test_completion_callback.h"
     12 #include "net/base/winsock_init.h"
     13 #include "net/dns/host_resolver.h"
     14 #include "net/dns/mock_host_resolver.h"
     15 #include "net/socket/client_socket_factory.h"
     16 #include "net/socket/socket_test_util.h"
     17 #include "net/socket/tcp_client_socket.h"
     18 #include "testing/gtest/include/gtest/gtest.h"
     19 #include "testing/platform_test.h"
     20 
     21 //-----------------------------------------------------------------------------
     22 
     23 namespace net {
     24 
     25 const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
     26 const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
     27 
     28 class SOCKSClientSocketTest : public PlatformTest {
     29  public:
     30   SOCKSClientSocketTest();
     31   // Create a SOCKSClientSocket on top of a MockSocket.
     32   scoped_ptr<SOCKSClientSocket> BuildMockSocket(
     33       MockRead reads[], size_t reads_count,
     34       MockWrite writes[], size_t writes_count,
     35       HostResolver* host_resolver,
     36       const std::string& hostname, int port,
     37       NetLog* net_log);
     38   virtual void SetUp();
     39 
     40  protected:
     41   scoped_ptr<SOCKSClientSocket> user_sock_;
     42   AddressList address_list_;
     43   // Filled in by BuildMockSocket() and owned by its return value
     44   // (which |user_sock| is set to).
     45   StreamSocket* tcp_sock_;
     46   TestCompletionCallback callback_;
     47   scoped_ptr<MockHostResolver> host_resolver_;
     48   scoped_ptr<SocketDataProvider> data_;
     49 };
     50 
     51 SOCKSClientSocketTest::SOCKSClientSocketTest()
     52   : host_resolver_(new MockHostResolver) {
     53 }
     54 
     55 // Set up platform before every test case
     56 void SOCKSClientSocketTest::SetUp() {
     57   PlatformTest::SetUp();
     58 }
     59 
     60 scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
     61     MockRead reads[],
     62     size_t reads_count,
     63     MockWrite writes[],
     64     size_t writes_count,
     65     HostResolver* host_resolver,
     66     const std::string& hostname,
     67     int port,
     68     NetLog* net_log) {
     69 
     70   TestCompletionCallback callback;
     71   data_.reset(new StaticSocketDataProvider(reads, reads_count,
     72                                            writes, writes_count));
     73   tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
     74 
     75   int rv = tcp_sock_->Connect(callback.callback());
     76   EXPECT_EQ(ERR_IO_PENDING, rv);
     77   rv = callback.WaitForResult();
     78   EXPECT_EQ(OK, rv);
     79   EXPECT_TRUE(tcp_sock_->IsConnected());
     80 
     81   scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
     82   // |connection| takes ownership of |tcp_sock_|, but keep a
     83   // non-owning pointer to it.
     84   connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
     85   return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
     86       connection.Pass(),
     87       HostResolver::RequestInfo(HostPortPair(hostname, port)),
     88       DEFAULT_PRIORITY,
     89       host_resolver));
     90 }
     91 
     92 // Implementation of HostResolver that never completes its resolve request.
     93 // We use this in the test "DisconnectWhileHostResolveInProgress" to make
     94 // sure that the outstanding resolve request gets cancelled.
     95 class HangingHostResolverWithCancel : public HostResolver {
     96  public:
     97   HangingHostResolverWithCancel() : outstanding_request_(NULL) {}
     98 
     99   virtual int Resolve(const RequestInfo& info,
    100                       RequestPriority priority,
    101                       AddressList* addresses,
    102                       const CompletionCallback& callback,
    103                       RequestHandle* out_req,
    104                       const BoundNetLog& net_log) OVERRIDE {
    105     DCHECK(addresses);
    106     DCHECK_EQ(false, callback.is_null());
    107     EXPECT_FALSE(HasOutstandingRequest());
    108     outstanding_request_ = reinterpret_cast<RequestHandle>(1);
    109     *out_req = outstanding_request_;
    110     return ERR_IO_PENDING;
    111   }
    112 
    113   virtual int ResolveFromCache(const RequestInfo& info,
    114                                AddressList* addresses,
    115                                const BoundNetLog& net_log) OVERRIDE {
    116     NOTIMPLEMENTED();
    117     return ERR_UNEXPECTED;
    118   }
    119 
    120   virtual void CancelRequest(RequestHandle req) OVERRIDE {
    121     EXPECT_TRUE(HasOutstandingRequest());
    122     EXPECT_EQ(outstanding_request_, req);
    123     outstanding_request_ = NULL;
    124   }
    125 
    126   bool HasOutstandingRequest() {
    127     return outstanding_request_ != NULL;
    128   }
    129 
    130  private:
    131   RequestHandle outstanding_request_;
    132 
    133   DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel);
    134 };
    135 
    136 // Tests a complete handshake and the disconnection.
    137 TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
    138   const std::string payload_write = "random data";
    139   const std::string payload_read = "moar random data";
    140 
    141   MockWrite data_writes[] = {
    142       MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)),
    143       MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
    144   MockRead data_reads[] = {
    145       MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
    146       MockRead(ASYNC, payload_read.data(), payload_read.size()) };
    147   CapturingNetLog log;
    148 
    149   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    150                                data_writes, arraysize(data_writes),
    151                                host_resolver_.get(),
    152                                "localhost", 80,
    153                                &log);
    154 
    155   // At this state the TCP connection is completed but not the SOCKS handshake.
    156   EXPECT_TRUE(tcp_sock_->IsConnected());
    157   EXPECT_FALSE(user_sock_->IsConnected());
    158 
    159   int rv = user_sock_->Connect(callback_.callback());
    160   EXPECT_EQ(ERR_IO_PENDING, rv);
    161 
    162   CapturingNetLog::CapturedEntryList entries;
    163   log.GetEntries(&entries);
    164   EXPECT_TRUE(
    165       LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    166   EXPECT_FALSE(user_sock_->IsConnected());
    167 
    168   rv = callback_.WaitForResult();
    169   EXPECT_EQ(OK, rv);
    170   EXPECT_TRUE(user_sock_->IsConnected());
    171   log.GetEntries(&entries);
    172   EXPECT_TRUE(LogContainsEndEvent(
    173       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    174 
    175   scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
    176   memcpy(buffer->data(), payload_write.data(), payload_write.size());
    177   rv = user_sock_->Write(
    178       buffer.get(), payload_write.size(), callback_.callback());
    179   EXPECT_EQ(ERR_IO_PENDING, rv);
    180   rv = callback_.WaitForResult();
    181   EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
    182 
    183   buffer = new IOBuffer(payload_read.size());
    184   rv =
    185       user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
    186   EXPECT_EQ(ERR_IO_PENDING, rv);
    187   rv = callback_.WaitForResult();
    188   EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
    189   EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
    190 
    191   user_sock_->Disconnect();
    192   EXPECT_FALSE(tcp_sock_->IsConnected());
    193   EXPECT_FALSE(user_sock_->IsConnected());
    194 }
    195 
    196 // List of responses from the socks server and the errors they should
    197 // throw up are tested here.
    198 TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
    199   const struct {
    200     const char fail_reply[8];
    201     Error fail_code;
    202   } tests[] = {
    203     // Failure of the server response code
    204     {
    205       { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
    206       ERR_SOCKS_CONNECTION_FAILED,
    207     },
    208     // Failure of the null byte
    209     {
    210       { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
    211       ERR_SOCKS_CONNECTION_FAILED,
    212     },
    213   };
    214 
    215   //---------------------------------------
    216 
    217   for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
    218     MockWrite data_writes[] = {
    219         MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
    220     MockRead data_reads[] = {
    221         MockRead(SYNCHRONOUS, tests[i].fail_reply,
    222                  arraysize(tests[i].fail_reply)) };
    223     CapturingNetLog log;
    224 
    225     user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    226                                  data_writes, arraysize(data_writes),
    227                                  host_resolver_.get(),
    228                                  "localhost", 80,
    229                                  &log);
    230 
    231     int rv = user_sock_->Connect(callback_.callback());
    232     EXPECT_EQ(ERR_IO_PENDING, rv);
    233 
    234     CapturingNetLog::CapturedEntryList entries;
    235     log.GetEntries(&entries);
    236     EXPECT_TRUE(LogContainsBeginEvent(
    237         entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    238 
    239     rv = callback_.WaitForResult();
    240     EXPECT_EQ(tests[i].fail_code, rv);
    241     EXPECT_FALSE(user_sock_->IsConnected());
    242     EXPECT_TRUE(tcp_sock_->IsConnected());
    243     log.GetEntries(&entries);
    244     EXPECT_TRUE(LogContainsEndEvent(
    245         entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    246   }
    247 }
    248 
    249 // Tests scenario when the server sends the handshake response in
    250 // more than one packet.
    251 TEST_F(SOCKSClientSocketTest, PartialServerReads) {
    252   const char kSOCKSPartialReply1[] = { 0x00 };
    253   const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
    254 
    255   MockWrite data_writes[] = {
    256       MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
    257   MockRead data_reads[] = {
    258       MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
    259       MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
    260   CapturingNetLog log;
    261 
    262   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    263                                data_writes, arraysize(data_writes),
    264                                host_resolver_.get(),
    265                                "localhost", 80,
    266                                &log);
    267 
    268   int rv = user_sock_->Connect(callback_.callback());
    269   EXPECT_EQ(ERR_IO_PENDING, rv);
    270   CapturingNetLog::CapturedEntryList entries;
    271   log.GetEntries(&entries);
    272   EXPECT_TRUE(LogContainsBeginEvent(
    273       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    274 
    275   rv = callback_.WaitForResult();
    276   EXPECT_EQ(OK, rv);
    277   EXPECT_TRUE(user_sock_->IsConnected());
    278   log.GetEntries(&entries);
    279   EXPECT_TRUE(LogContainsEndEvent(
    280       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    281 }
    282 
    283 // Tests scenario when the client sends the handshake request in
    284 // more than one packet.
    285 TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
    286   const char kSOCKSPartialRequest1[] = { 0x04, 0x01 };
    287   const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
    288 
    289   MockWrite data_writes[] = {
    290       MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)),
    291       // simulate some empty writes
    292       MockWrite(ASYNC, 0),
    293       MockWrite(ASYNC, 0),
    294       MockWrite(ASYNC, kSOCKSPartialRequest2,
    295                 arraysize(kSOCKSPartialRequest2)) };
    296   MockRead data_reads[] = {
    297       MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
    298   CapturingNetLog log;
    299 
    300   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    301                                data_writes, arraysize(data_writes),
    302                                host_resolver_.get(),
    303                                "localhost", 80,
    304                                &log);
    305 
    306   int rv = user_sock_->Connect(callback_.callback());
    307   EXPECT_EQ(ERR_IO_PENDING, rv);
    308   CapturingNetLog::CapturedEntryList entries;
    309   log.GetEntries(&entries);
    310   EXPECT_TRUE(LogContainsBeginEvent(
    311       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    312 
    313   rv = callback_.WaitForResult();
    314   EXPECT_EQ(OK, rv);
    315   EXPECT_TRUE(user_sock_->IsConnected());
    316   log.GetEntries(&entries);
    317   EXPECT_TRUE(LogContainsEndEvent(
    318       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    319 }
    320 
    321 // Tests the case when the server sends a smaller sized handshake data
    322 // and closes the connection.
    323 TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
    324   MockWrite data_writes[] = {
    325       MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
    326   MockRead data_reads[] = {
    327       MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2),
    328       // close connection unexpectedly
    329       MockRead(SYNCHRONOUS, 0) };
    330   CapturingNetLog log;
    331 
    332   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    333                                data_writes, arraysize(data_writes),
    334                                host_resolver_.get(),
    335                                "localhost", 80,
    336                                &log);
    337 
    338   int rv = user_sock_->Connect(callback_.callback());
    339   EXPECT_EQ(ERR_IO_PENDING, rv);
    340   CapturingNetLog::CapturedEntryList entries;
    341   log.GetEntries(&entries);
    342   EXPECT_TRUE(LogContainsBeginEvent(
    343       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    344 
    345   rv = callback_.WaitForResult();
    346   EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
    347   EXPECT_FALSE(user_sock_->IsConnected());
    348   log.GetEntries(&entries);
    349   EXPECT_TRUE(LogContainsEndEvent(
    350       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    351 }
    352 
    353 // Tries to connect to an unknown hostname. Should fail rather than
    354 // falling back to SOCKS4a.
    355 TEST_F(SOCKSClientSocketTest, FailedDNS) {
    356   const char hostname[] = "unresolved.ipv4.address";
    357 
    358   host_resolver_->rules()->AddSimulatedFailure(hostname);
    359 
    360   CapturingNetLog log;
    361 
    362   user_sock_ = BuildMockSocket(NULL, 0,
    363                                NULL, 0,
    364                                host_resolver_.get(),
    365                                hostname, 80,
    366                                &log);
    367 
    368   int rv = user_sock_->Connect(callback_.callback());
    369   EXPECT_EQ(ERR_IO_PENDING, rv);
    370   CapturingNetLog::CapturedEntryList entries;
    371   log.GetEntries(&entries);
    372   EXPECT_TRUE(LogContainsBeginEvent(
    373       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    374 
    375   rv = callback_.WaitForResult();
    376   EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
    377   EXPECT_FALSE(user_sock_->IsConnected());
    378   log.GetEntries(&entries);
    379   EXPECT_TRUE(LogContainsEndEvent(
    380       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    381 }
    382 
    383 // Calls Disconnect() while a host resolve is in progress. The outstanding host
    384 // resolve should be cancelled.
    385 TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
    386   scoped_ptr<HangingHostResolverWithCancel> hanging_resolver(
    387     new HangingHostResolverWithCancel());
    388 
    389   // Doesn't matter what the socket data is, we will never use it -- garbage.
    390   MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) };
    391   MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) };
    392 
    393   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    394                                data_writes, arraysize(data_writes),
    395                                hanging_resolver.get(),
    396                                "foo", 80,
    397                                NULL);
    398 
    399   // Start connecting (will get stuck waiting for the host to resolve).
    400   int rv = user_sock_->Connect(callback_.callback());
    401   EXPECT_EQ(ERR_IO_PENDING, rv);
    402 
    403   EXPECT_FALSE(user_sock_->IsConnected());
    404   EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
    405 
    406   // The host resolver should have received the resolve request.
    407   EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
    408 
    409   // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
    410   user_sock_->Disconnect();
    411 
    412   EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
    413 
    414   EXPECT_FALSE(user_sock_->IsConnected());
    415   EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
    416 }
    417 
    418 // Tries to connect to an IPv6 IP.  Should fail, as SOCKS4 does not support
    419 // IPv6.
    420 TEST_F(SOCKSClientSocketTest, NoIPv6) {
    421   const char kHostName[] = "::1";
    422 
    423   user_sock_ = BuildMockSocket(NULL, 0,
    424                                NULL, 0,
    425                                host_resolver_.get(),
    426                                kHostName, 80,
    427                                NULL);
    428 
    429   EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
    430             callback_.GetResult(user_sock_->Connect(callback_.callback())));
    431 }
    432 
    433 // Same as above, but with a real resolver, to protect against regressions.
    434 TEST_F(SOCKSClientSocketTest, NoIPv6RealResolver) {
    435   const char kHostName[] = "::1";
    436 
    437   scoped_ptr<HostResolver> host_resolver(
    438       HostResolver::CreateSystemResolver(HostResolver::Options(), NULL));
    439 
    440   user_sock_ = BuildMockSocket(NULL, 0,
    441                                NULL, 0,
    442                                host_resolver.get(),
    443                                kHostName, 80,
    444                                NULL);
    445 
    446   EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
    447             callback_.GetResult(user_sock_->Connect(callback_.callback())));
    448 }
    449 
    450 }  // namespace net
    451