Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks5_client_socket.h"
      6 
      7 #include <algorithm>
      8 #include <iterator>
      9 #include <map>
     10 
     11 #include "base/sys_byteorder.h"
     12 #include "net/base/address_list.h"
     13 #include "net/base/net_log.h"
     14 #include "net/base/net_log_unittest.h"
     15 #include "net/base/test_completion_callback.h"
     16 #include "net/base/winsock_init.h"
     17 #include "net/dns/mock_host_resolver.h"
     18 #include "net/socket/client_socket_factory.h"
     19 #include "net/socket/socket_test_util.h"
     20 #include "net/socket/tcp_client_socket.h"
     21 #include "testing/gtest/include/gtest/gtest.h"
     22 #include "testing/platform_test.h"
     23 
     24 //-----------------------------------------------------------------------------
     25 
     26 namespace net {
     27 
     28 namespace {
     29 
     30 // Base class to test SOCKS5ClientSocket
     31 class SOCKS5ClientSocketTest : public PlatformTest {
     32  public:
     33   SOCKS5ClientSocketTest();
     34   // Create a SOCKSClientSocket on top of a MockSocket.
     35   scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[],
     36                                                  size_t reads_count,
     37                                                  MockWrite writes[],
     38                                                  size_t writes_count,
     39                                                  const std::string& hostname,
     40                                                  int port,
     41                                                  NetLog* net_log);
     42 
     43   virtual void SetUp();
     44 
     45  protected:
     46   const uint16 kNwPort;
     47   CapturingNetLog net_log_;
     48   scoped_ptr<SOCKS5ClientSocket> user_sock_;
     49   AddressList address_list_;
     50   // Filled in by BuildMockSocket() and owned by its return value
     51   // (which |user_sock| is set to).
     52   StreamSocket* tcp_sock_;
     53   TestCompletionCallback callback_;
     54   scoped_ptr<MockHostResolver> host_resolver_;
     55   scoped_ptr<SocketDataProvider> data_;
     56 
     57  private:
     58   DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest);
     59 };
     60 
     61 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
     62   : kNwPort(base::HostToNet16(80)),
     63     host_resolver_(new MockHostResolver) {
     64 }
     65 
     66 // Set up platform before every test case
     67 void SOCKS5ClientSocketTest::SetUp() {
     68   PlatformTest::SetUp();
     69 
     70   // Resolve the "localhost" AddressList used by the TCP connection to connect.
     71   HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080));
     72   TestCompletionCallback callback;
     73   int rv = host_resolver_->Resolve(info,
     74                                    DEFAULT_PRIORITY,
     75                                    &address_list_,
     76                                    callback.callback(),
     77                                    NULL,
     78                                    BoundNetLog());
     79   ASSERT_EQ(ERR_IO_PENDING, rv);
     80   rv = callback.WaitForResult();
     81   ASSERT_EQ(OK, rv);
     82 }
     83 
     84 scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
     85     MockRead reads[],
     86     size_t reads_count,
     87     MockWrite writes[],
     88     size_t writes_count,
     89     const std::string& hostname,
     90     int port,
     91     NetLog* net_log) {
     92   TestCompletionCallback callback;
     93   data_.reset(new StaticSocketDataProvider(reads, reads_count,
     94                                            writes, writes_count));
     95   tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
     96 
     97   int rv = tcp_sock_->Connect(callback.callback());
     98   EXPECT_EQ(ERR_IO_PENDING, rv);
     99   rv = callback.WaitForResult();
    100   EXPECT_EQ(OK, rv);
    101   EXPECT_TRUE(tcp_sock_->IsConnected());
    102 
    103   scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
    104   // |connection| takes ownership of |tcp_sock_|, but keep a
    105   // non-owning pointer to it.
    106   connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
    107   return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket(
    108       connection.Pass(),
    109       HostResolver::RequestInfo(HostPortPair(hostname, port))));
    110 }
    111 
    112 // Tests a complete SOCKS5 handshake and the disconnection.
    113 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
    114   const std::string payload_write = "random data";
    115   const std::string payload_read = "moar random data";
    116 
    117   const char kOkRequest[] = {
    118     0x05,  // Version
    119     0x01,  // Command (CONNECT)
    120     0x00,  // Reserved.
    121     0x03,  // Address type (DOMAINNAME).
    122     0x09,  // Length of domain (9)
    123     // Domain string:
    124     'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
    125     0x00, 0x50,  // 16-bit port (80)
    126   };
    127 
    128   MockWrite data_writes[] = {
    129       MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    130       MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)),
    131       MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
    132   MockRead data_reads[] = {
    133       MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    134       MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
    135       MockRead(ASYNC, payload_read.data(), payload_read.size()) };
    136 
    137   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    138                                data_writes, arraysize(data_writes),
    139                                "localhost", 80, &net_log_);
    140 
    141   // At this state the TCP connection is completed but not the SOCKS handshake.
    142   EXPECT_TRUE(tcp_sock_->IsConnected());
    143   EXPECT_FALSE(user_sock_->IsConnected());
    144 
    145   int rv = user_sock_->Connect(callback_.callback());
    146   EXPECT_EQ(ERR_IO_PENDING, rv);
    147   EXPECT_FALSE(user_sock_->IsConnected());
    148 
    149   CapturingNetLog::CapturedEntryList net_log_entries;
    150   net_log_.GetEntries(&net_log_entries);
    151   EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    152                                     NetLog::TYPE_SOCKS5_CONNECT));
    153 
    154   rv = callback_.WaitForResult();
    155 
    156   EXPECT_EQ(OK, rv);
    157   EXPECT_TRUE(user_sock_->IsConnected());
    158 
    159   net_log_.GetEntries(&net_log_entries);
    160   EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    161                                   NetLog::TYPE_SOCKS5_CONNECT));
    162 
    163   scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
    164   memcpy(buffer->data(), payload_write.data(), payload_write.size());
    165   rv = user_sock_->Write(
    166       buffer.get(), payload_write.size(), callback_.callback());
    167   EXPECT_EQ(ERR_IO_PENDING, rv);
    168   rv = callback_.WaitForResult();
    169   EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
    170 
    171   buffer = new IOBuffer(payload_read.size());
    172   rv =
    173       user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
    174   EXPECT_EQ(ERR_IO_PENDING, rv);
    175   rv = callback_.WaitForResult();
    176   EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
    177   EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
    178 
    179   user_sock_->Disconnect();
    180   EXPECT_FALSE(tcp_sock_->IsConnected());
    181   EXPECT_FALSE(user_sock_->IsConnected());
    182 }
    183 
    184 // Test that you can call Connect() again after having called Disconnect().
    185 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
    186   const std::string hostname = "my-host-name";
    187   const char kSOCKS5DomainRequest[] = {
    188       0x05,  // VER
    189       0x01,  // CMD
    190       0x00,  // RSV
    191       0x03,  // ATYPE
    192   };
    193 
    194   std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest));
    195   request.push_back(hostname.size());
    196   request.append(hostname);
    197   request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
    198 
    199   for (int i = 0; i < 2; ++i) {
    200     MockWrite data_writes[] = {
    201         MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    202         MockWrite(SYNCHRONOUS, request.data(), request.size())
    203     };
    204     MockRead data_reads[] = {
    205         MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    206         MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
    207     };
    208 
    209     user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    210                                  data_writes, arraysize(data_writes),
    211                                  hostname, 80, NULL);
    212 
    213     int rv = user_sock_->Connect(callback_.callback());
    214     EXPECT_EQ(OK, rv);
    215     EXPECT_TRUE(user_sock_->IsConnected());
    216 
    217     user_sock_->Disconnect();
    218     EXPECT_FALSE(user_sock_->IsConnected());
    219   }
    220 }
    221 
    222 // Test that we fail trying to connect to a hosname longer than 255 bytes.
    223 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
    224   // Create a string of length 256, where each character is 'x'.
    225   std::string large_host_name;
    226   std::fill_n(std::back_inserter(large_host_name), 256, 'x');
    227 
    228   // Create a SOCKS socket, with mock transport socket.
    229   MockWrite data_writes[] = {MockWrite()};
    230   MockRead data_reads[] = {MockRead()};
    231   user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    232                                data_writes, arraysize(data_writes),
    233                                large_host_name, 80, NULL);
    234 
    235   // Try to connect -- should fail (without having read/written anything to
    236   // the transport socket first) because the hostname is too long.
    237   TestCompletionCallback callback;
    238   int rv = user_sock_->Connect(callback.callback());
    239   EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
    240 }
    241 
    242 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
    243   const std::string hostname = "www.google.com";
    244 
    245   const char kOkRequest[] = {
    246     0x05,  // Version
    247     0x01,  // Command (CONNECT)
    248     0x00,  // Reserved.
    249     0x03,  // Address type (DOMAINNAME).
    250     0x0E,  // Length of domain (14)
    251     // Domain string:
    252     'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
    253     0x00, 0x50,  // 16-bit port (80)
    254   };
    255 
    256   // Test for partial greet request write
    257   {
    258     const char partial1[] = { 0x05, 0x01 };
    259     const char partial2[] = { 0x00 };
    260     MockWrite data_writes[] = {
    261         MockWrite(ASYNC, arraysize(partial1)),
    262         MockWrite(ASYNC, partial2, arraysize(partial2)),
    263         MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) };
    264     MockRead data_reads[] = {
    265         MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    266         MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
    267     user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    268                                  data_writes, arraysize(data_writes),
    269                                  hostname, 80, &net_log_);
    270     int rv = user_sock_->Connect(callback_.callback());
    271     EXPECT_EQ(ERR_IO_PENDING, rv);
    272 
    273     CapturingNetLog::CapturedEntryList net_log_entries;
    274     net_log_.GetEntries(&net_log_entries);
    275     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    276                 NetLog::TYPE_SOCKS5_CONNECT));
    277 
    278     rv = callback_.WaitForResult();
    279     EXPECT_EQ(OK, rv);
    280     EXPECT_TRUE(user_sock_->IsConnected());
    281 
    282     net_log_.GetEntries(&net_log_entries);
    283     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    284                 NetLog::TYPE_SOCKS5_CONNECT));
    285   }
    286 
    287   // Test for partial greet response read
    288   {
    289     const char partial1[] = { 0x05 };
    290     const char partial2[] = { 0x00 };
    291     MockWrite data_writes[] = {
    292         MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    293         MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) };
    294     MockRead data_reads[] = {
    295         MockRead(ASYNC, partial1, arraysize(partial1)),
    296         MockRead(ASYNC, partial2, arraysize(partial2)),
    297         MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
    298     user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    299                                  data_writes, arraysize(data_writes),
    300                                  hostname, 80, &net_log_);
    301     int rv = user_sock_->Connect(callback_.callback());
    302     EXPECT_EQ(ERR_IO_PENDING, rv);
    303 
    304     CapturingNetLog::CapturedEntryList net_log_entries;
    305     net_log_.GetEntries(&net_log_entries);
    306     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    307                                       NetLog::TYPE_SOCKS5_CONNECT));
    308     rv = callback_.WaitForResult();
    309     EXPECT_EQ(OK, rv);
    310     EXPECT_TRUE(user_sock_->IsConnected());
    311     net_log_.GetEntries(&net_log_entries);
    312     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    313                                     NetLog::TYPE_SOCKS5_CONNECT));
    314   }
    315 
    316   // Test for partial handshake request write.
    317   {
    318     const int kSplitPoint = 3;  // Break handshake write into two parts.
    319     MockWrite data_writes[] = {
    320         MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    321         MockWrite(ASYNC, kOkRequest, kSplitPoint),
    322         MockWrite(ASYNC, kOkRequest + kSplitPoint,
    323                   arraysize(kOkRequest) - kSplitPoint)
    324     };
    325     MockRead data_reads[] = {
    326         MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    327         MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
    328     user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    329                                  data_writes, arraysize(data_writes),
    330                                  hostname, 80, &net_log_);
    331     int rv = user_sock_->Connect(callback_.callback());
    332     EXPECT_EQ(ERR_IO_PENDING, rv);
    333     CapturingNetLog::CapturedEntryList net_log_entries;
    334     net_log_.GetEntries(&net_log_entries);
    335     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    336                                       NetLog::TYPE_SOCKS5_CONNECT));
    337     rv = callback_.WaitForResult();
    338     EXPECT_EQ(OK, rv);
    339     EXPECT_TRUE(user_sock_->IsConnected());
    340     net_log_.GetEntries(&net_log_entries);
    341     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    342                                     NetLog::TYPE_SOCKS5_CONNECT));
    343   }
    344 
    345   // Test for partial handshake response read
    346   {
    347     const int kSplitPoint = 6;  // Break the handshake read into two parts.
    348     MockWrite data_writes[] = {
    349         MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    350         MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest))
    351     };
    352     MockRead data_reads[] = {
    353         MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    354         MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint),
    355         MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint,
    356                  kSOCKS5OkResponseLength - kSplitPoint)
    357     };
    358 
    359     user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
    360                                  data_writes, arraysize(data_writes),
    361                                  hostname, 80, &net_log_);
    362     int rv = user_sock_->Connect(callback_.callback());
    363     EXPECT_EQ(ERR_IO_PENDING, rv);
    364     CapturingNetLog::CapturedEntryList net_log_entries;
    365     net_log_.GetEntries(&net_log_entries);
    366     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    367                                       NetLog::TYPE_SOCKS5_CONNECT));
    368     rv = callback_.WaitForResult();
    369     EXPECT_EQ(OK, rv);
    370     EXPECT_TRUE(user_sock_->IsConnected());
    371     net_log_.GetEntries(&net_log_entries);
    372     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    373                                     NetLog::TYPE_SOCKS5_CONNECT));
    374   }
    375 }
    376 
    377 }  // namespace
    378 
    379 }  // namespace net
    380