Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket.h"
      6 
      7 #include "net/base/address_list.h"
      8 #include "net/base/net_log.h"
      9 #include "net/base/net_log_unittest.h"
     10 #include "net/base/mock_host_resolver.h"
     11 #include "net/base/test_completion_callback.h"
     12 #include "net/base/winsock_init.h"
     13 #include "net/socket/client_socket_factory.h"
     14 #include "net/socket/tcp_client_socket.h"
     15 #include "net/socket/socket_test_util.h"
     16 #include "testing/gtest/include/gtest/gtest.h"
     17 #include "testing/platform_test.h"
     18 
     19 //-----------------------------------------------------------------------------
     20 
     21 namespace net {
     22 
     23 const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
     24 const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
     25 
     26 class SOCKSClientSocketTest : public PlatformTest {
     27  public:
     28   SOCKSClientSocketTest();
     29   // Create a SOCKSClientSocket on top of a MockSocket.
     30   SOCKSClientSocket* BuildMockSocket(MockRead reads[], size_t reads_count,
     31                                      MockWrite writes[], size_t writes_count,
     32                                      HostResolver* host_resolver,
     33                                      const std::string& hostname, int port,
     34                                      NetLog* net_log);
     35   virtual void SetUp();
     36 
     37  protected:
     38   scoped_ptr<SOCKSClientSocket> user_sock_;
     39   AddressList address_list_;
     40   ClientSocket* tcp_sock_;
     41   TestCompletionCallback callback_;
     42   scoped_ptr<MockHostResolver> host_resolver_;
     43   scoped_ptr<SocketDataProvider> data_;
     44 };
     45 
     46 SOCKSClientSocketTest::SOCKSClientSocketTest()
     47   : host_resolver_(new MockHostResolver) {
     48 }
     49 
     50 // Set up platform before every test case
     51 void SOCKSClientSocketTest::SetUp() {
     52   PlatformTest::SetUp();
     53 }
     54 
     55 SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket(
     56     MockRead reads[],
     57     size_t reads_count,
     58     MockWrite writes[],
     59     size_t writes_count,
     60     HostResolver* host_resolver,
     61     const std::string& hostname,
     62     int port,
     63     NetLog* net_log) {
     64 
     65   TestCompletionCallback callback;
     66   data_.reset(new StaticSocketDataProvider(reads, reads_count,
     67                                            writes, writes_count));
     68   tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
     69 
     70   int rv = tcp_sock_->Connect(&callback);
     71   EXPECT_EQ(ERR_IO_PENDING, rv);
     72   rv = callback.WaitForResult();
     73   EXPECT_EQ(OK, rv);
     74   EXPECT_TRUE(tcp_sock_->IsConnected());
     75 
     76   return new SOCKSClientSocket(tcp_sock_,
     77       HostResolver::RequestInfo(HostPortPair(hostname, port)),
     78       host_resolver);
     79 }
     80 
     81 // Implementation of HostResolver that never completes its resolve request.
     82 // We use this in the test "DisconnectWhileHostResolveInProgress" to make
     83 // sure that the outstanding resolve request gets cancelled.
     84 class HangingHostResolver : public HostResolver {
     85  public:
     86   HangingHostResolver() : outstanding_request_(NULL) {}
     87 
     88   virtual int Resolve(const RequestInfo& info,
     89                       AddressList* addresses,
     90                       CompletionCallback* callback,
     91                       RequestHandle* out_req,
     92                       const BoundNetLog& net_log) {
     93     EXPECT_FALSE(HasOutstandingRequest());
     94     outstanding_request_ = reinterpret_cast<RequestHandle>(1);
     95     *out_req = outstanding_request_;
     96     return ERR_IO_PENDING;
     97   }
     98 
     99   virtual void CancelRequest(RequestHandle req) {
    100     EXPECT_TRUE(HasOutstandingRequest());
    101     EXPECT_EQ(outstanding_request_, req);
    102     outstanding_request_ = NULL;
    103   }
    104 
    105   virtual void AddObserver(Observer* observer) {}
    106   virtual void RemoveObserver(Observer* observer) {}
    107 
    108   bool HasOutstandingRequest() {
    109     return outstanding_request_ != NULL;
    110   }
    111 
    112  private:
    113   RequestHandle outstanding_request_;
    114 
    115   DISALLOW_COPY_AND_ASSIGN(HangingHostResolver);
    116 };
    117 
    118 // Tests a complete handshake and the disconnection.
    119 TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
    120   const std::string payload_write = "random data";
    121   const std::string payload_read = "moar random data";
    122 
    123   MockWrite data_writes[] = {
    124       MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)),
    125       MockWrite(true, payload_write.data(), payload_write.size()) };
    126   MockRead data_reads[] = {
    127       MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
    128       MockRead(true, payload_read.data(), payload_read.size()) };
    129   CapturingNetLog log(CapturingNetLog::kUnbounded);
    130 
    131   user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    132                                    data_writes, arraysize(data_writes),
    133                                    host_resolver_.get(),
    134                                    "localhost", 80,
    135                                    &log));
    136 
    137   // At this state the TCP connection is completed but not the SOCKS handshake.
    138   EXPECT_TRUE(tcp_sock_->IsConnected());
    139   EXPECT_FALSE(user_sock_->IsConnected());
    140 
    141   int rv = user_sock_->Connect(&callback_);
    142   EXPECT_EQ(ERR_IO_PENDING, rv);
    143 
    144   net::CapturingNetLog::EntryList entries;
    145   log.GetEntries(&entries);
    146   EXPECT_TRUE(
    147       LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    148   EXPECT_FALSE(user_sock_->IsConnected());
    149 
    150   rv = callback_.WaitForResult();
    151   EXPECT_EQ(OK, rv);
    152   EXPECT_TRUE(user_sock_->IsConnected());
    153   log.GetEntries(&entries);
    154   EXPECT_TRUE(LogContainsEndEvent(
    155       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    156 
    157   scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
    158   memcpy(buffer->data(), payload_write.data(), payload_write.size());
    159   rv = user_sock_->Write(buffer, payload_write.size(), &callback_);
    160   EXPECT_EQ(ERR_IO_PENDING, rv);
    161   rv = callback_.WaitForResult();
    162   EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
    163 
    164   buffer = new IOBuffer(payload_read.size());
    165   rv = user_sock_->Read(buffer, payload_read.size(), &callback_);
    166   EXPECT_EQ(ERR_IO_PENDING, rv);
    167   rv = callback_.WaitForResult();
    168   EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
    169   EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
    170 
    171   user_sock_->Disconnect();
    172   EXPECT_FALSE(tcp_sock_->IsConnected());
    173   EXPECT_FALSE(user_sock_->IsConnected());
    174 }
    175 
    176 // List of responses from the socks server and the errors they should
    177 // throw up are tested here.
    178 TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
    179   const struct {
    180     const char fail_reply[8];
    181     Error fail_code;
    182   } tests[] = {
    183     // Failure of the server response code
    184     {
    185       { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
    186       ERR_SOCKS_CONNECTION_FAILED,
    187     },
    188     // Failure of the null byte
    189     {
    190       { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
    191       ERR_SOCKS_CONNECTION_FAILED,
    192     },
    193   };
    194 
    195   //---------------------------------------
    196 
    197   for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
    198     MockWrite data_writes[] = {
    199         MockWrite(false, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
    200     MockRead data_reads[] = {
    201         MockRead(false, tests[i].fail_reply, arraysize(tests[i].fail_reply)) };
    202     CapturingNetLog log(CapturingNetLog::kUnbounded);
    203 
    204     user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    205                                      data_writes, arraysize(data_writes),
    206                                      host_resolver_.get(),
    207                                      "localhost", 80,
    208                                      &log));
    209 
    210     int rv = user_sock_->Connect(&callback_);
    211     EXPECT_EQ(ERR_IO_PENDING, rv);
    212 
    213     net::CapturingNetLog::EntryList entries;
    214     log.GetEntries(&entries);
    215     EXPECT_TRUE(LogContainsBeginEvent(
    216         entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    217 
    218     rv = callback_.WaitForResult();
    219     EXPECT_EQ(tests[i].fail_code, rv);
    220     EXPECT_FALSE(user_sock_->IsConnected());
    221     EXPECT_TRUE(tcp_sock_->IsConnected());
    222     log.GetEntries(&entries);
    223     EXPECT_TRUE(LogContainsEndEvent(
    224         entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    225   }
    226 }
    227 
    228 // Tests scenario when the server sends the handshake response in
    229 // more than one packet.
    230 TEST_F(SOCKSClientSocketTest, PartialServerReads) {
    231   const char kSOCKSPartialReply1[] = { 0x00 };
    232   const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
    233 
    234   MockWrite data_writes[] = {
    235       MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
    236   MockRead data_reads[] = {
    237       MockRead(true, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
    238       MockRead(true, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
    239   CapturingNetLog log(CapturingNetLog::kUnbounded);
    240 
    241   user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    242                                    data_writes, arraysize(data_writes),
    243                                    host_resolver_.get(),
    244                                    "localhost", 80,
    245                                    &log));
    246 
    247   int rv = user_sock_->Connect(&callback_);
    248   EXPECT_EQ(ERR_IO_PENDING, rv);
    249   net::CapturingNetLog::EntryList entries;
    250   log.GetEntries(&entries);
    251   EXPECT_TRUE(LogContainsBeginEvent(
    252       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    253 
    254   rv = callback_.WaitForResult();
    255   EXPECT_EQ(OK, rv);
    256   EXPECT_TRUE(user_sock_->IsConnected());
    257   log.GetEntries(&entries);
    258   EXPECT_TRUE(LogContainsEndEvent(
    259       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    260 }
    261 
    262 // Tests scenario when the client sends the handshake request in
    263 // more than one packet.
    264 TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
    265   const char kSOCKSPartialRequest1[] = { 0x04, 0x01 };
    266   const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
    267 
    268   MockWrite data_writes[] = {
    269       MockWrite(true, arraysize(kSOCKSPartialRequest1)),
    270       // simulate some empty writes
    271       MockWrite(true, 0),
    272       MockWrite(true, 0),
    273       MockWrite(true, kSOCKSPartialRequest2,
    274                 arraysize(kSOCKSPartialRequest2)) };
    275   MockRead data_reads[] = {
    276       MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
    277   CapturingNetLog log(CapturingNetLog::kUnbounded);
    278 
    279   user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    280                                    data_writes, arraysize(data_writes),
    281                                    host_resolver_.get(),
    282                                    "localhost", 80,
    283                                    &log));
    284 
    285   int rv = user_sock_->Connect(&callback_);
    286   EXPECT_EQ(ERR_IO_PENDING, rv);
    287   net::CapturingNetLog::EntryList entries;
    288   log.GetEntries(&entries);
    289   EXPECT_TRUE(LogContainsBeginEvent(
    290       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    291 
    292   rv = callback_.WaitForResult();
    293   EXPECT_EQ(OK, rv);
    294   EXPECT_TRUE(user_sock_->IsConnected());
    295   log.GetEntries(&entries);
    296   EXPECT_TRUE(LogContainsEndEvent(
    297       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    298 }
    299 
    300 // Tests the case when the server sends a smaller sized handshake data
    301 // and closes the connection.
    302 TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
    303   MockWrite data_writes[] = {
    304       MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
    305   MockRead data_reads[] = {
    306       MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2),
    307       // close connection unexpectedly
    308       MockRead(false, 0) };
    309   CapturingNetLog log(CapturingNetLog::kUnbounded);
    310 
    311   user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    312                                    data_writes, arraysize(data_writes),
    313                                    host_resolver_.get(),
    314                                    "localhost", 80,
    315                                    &log));
    316 
    317   int rv = user_sock_->Connect(&callback_);
    318   EXPECT_EQ(ERR_IO_PENDING, rv);
    319   net::CapturingNetLog::EntryList entries;
    320   log.GetEntries(&entries);
    321   EXPECT_TRUE(LogContainsBeginEvent(
    322       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    323 
    324   rv = callback_.WaitForResult();
    325   EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
    326   EXPECT_FALSE(user_sock_->IsConnected());
    327   log.GetEntries(&entries);
    328   EXPECT_TRUE(LogContainsEndEvent(
    329       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    330 }
    331 
    332 // Tries to connect to an unknown hostname. Should fail rather than
    333 // falling back to SOCKS4a.
    334 TEST_F(SOCKSClientSocketTest, FailedDNS) {
    335   const char hostname[] = "unresolved.ipv4.address";
    336 
    337   host_resolver_->rules()->AddSimulatedFailure(hostname);
    338 
    339   CapturingNetLog log(CapturingNetLog::kUnbounded);
    340 
    341   user_sock_.reset(BuildMockSocket(NULL, 0,
    342                                    NULL, 0,
    343                                    host_resolver_.get(),
    344                                    hostname, 80,
    345                                    &log));
    346 
    347   int rv = user_sock_->Connect(&callback_);
    348   EXPECT_EQ(ERR_IO_PENDING, rv);
    349   net::CapturingNetLog::EntryList entries;
    350   log.GetEntries(&entries);
    351   EXPECT_TRUE(LogContainsBeginEvent(
    352       entries, 0, NetLog::TYPE_SOCKS_CONNECT));
    353 
    354   rv = callback_.WaitForResult();
    355   EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
    356   EXPECT_FALSE(user_sock_->IsConnected());
    357   log.GetEntries(&entries);
    358   EXPECT_TRUE(LogContainsEndEvent(
    359       entries, -1, NetLog::TYPE_SOCKS_CONNECT));
    360 }
    361 
    362 // Calls Disconnect() while a host resolve is in progress. The outstanding host
    363 // resolve should be cancelled.
    364 TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
    365   scoped_ptr<HangingHostResolver> hanging_resolver(new HangingHostResolver());
    366 
    367   // Doesn't matter what the socket data is, we will never use it -- garbage.
    368   MockWrite data_writes[] = { MockWrite(false, "", 0) };
    369   MockRead data_reads[] = { MockRead(false, "", 0) };
    370 
    371   user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    372                                    data_writes, arraysize(data_writes),
    373                                    hanging_resolver.get(),
    374                                    "foo", 80,
    375                                    NULL));
    376 
    377   // Start connecting (will get stuck waiting for the host to resolve).
    378   int rv = user_sock_->Connect(&callback_);
    379   EXPECT_EQ(ERR_IO_PENDING, rv);
    380 
    381   EXPECT_FALSE(user_sock_->IsConnected());
    382   EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
    383 
    384   // The host resolver should have received the resolve request.
    385   EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
    386 
    387   // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
    388   user_sock_->Disconnect();
    389 
    390   EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
    391 
    392   EXPECT_FALSE(user_sock_->IsConnected());
    393   EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
    394 }
    395 
    396 }  // namespace net
    397