Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks5_client_socket.h"
      6 
      7 #include <algorithm>
      8 #include <map>
      9 
     10 #include "net/base/address_list.h"
     11 #include "net/base/net_log.h"
     12 #include "net/base/net_log_unittest.h"
     13 #include "net/base/mock_host_resolver.h"
     14 #include "net/base/sys_addrinfo.h"
     15 #include "net/base/test_completion_callback.h"
     16 #include "net/base/winsock_init.h"
     17 #include "net/socket/client_socket_factory.h"
     18 #include "net/socket/socket_test_util.h"
     19 #include "net/socket/tcp_client_socket.h"
     20 #include "testing/gtest/include/gtest/gtest.h"
     21 #include "testing/platform_test.h"
     22 
     23 //-----------------------------------------------------------------------------
     24 
     25 namespace net {
     26 
     27 namespace {
     28 
     29 // Base class to test SOCKS5ClientSocket
     30 class SOCKS5ClientSocketTest : public PlatformTest {
     31  public:
     32   SOCKS5ClientSocketTest();
     33   // Create a SOCKSClientSocket on top of a MockSocket.
     34   SOCKS5ClientSocket* BuildMockSocket(MockRead reads[],
     35                                       size_t reads_count,
     36                                       MockWrite writes[],
     37                                       size_t writes_count,
     38                                       const std::string& hostname,
     39                                       int port,
     40                                       NetLog* net_log);
     41 
     42   virtual void SetUp();
     43 
     44  protected:
     45   const uint16 kNwPort;
     46   CapturingNetLog net_log_;
     47   scoped_ptr<SOCKS5ClientSocket> user_sock_;
     48   AddressList address_list_;
     49   ClientSocket* tcp_sock_;
     50   TestCompletionCallback callback_;
     51   scoped_ptr<MockHostResolver> host_resolver_;
     52   scoped_ptr<SocketDataProvider> data_;
     53 
     54  private:
     55   DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest);
     56 };
     57 
     58 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
     59   : kNwPort(htons(80)),
     60     net_log_(CapturingNetLog::kUnbounded),
     61     host_resolver_(new MockHostResolver) {
     62 }
     63 
     64 // Set up platform before every test case
     65 void SOCKS5ClientSocketTest::SetUp() {
     66   PlatformTest::SetUp();
     67 
     68   // Resolve the "localhost" AddressList used by the TCP connection to connect.
     69   HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080));
     70   int rv = host_resolver_->Resolve(info, &address_list_, NULL, NULL,
     71                                    BoundNetLog());
     72   ASSERT_EQ(OK, rv);
     73 }
     74 
     75 SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket(
     76     MockRead reads[],
     77     size_t reads_count,
     78     MockWrite writes[],
     79     size_t writes_count,
     80     const std::string& hostname,
     81     int port,
     82     NetLog* net_log) {
     83   TestCompletionCallback callback;
     84   data_.reset(new StaticSocketDataProvider(reads, reads_count,
     85                                            writes, writes_count));
     86   tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
     87 
     88   int rv = tcp_sock_->Connect(&callback);
     89   EXPECT_EQ(ERR_IO_PENDING, rv);
     90   rv = callback.WaitForResult();
     91   EXPECT_EQ(OK, rv);
     92   EXPECT_TRUE(tcp_sock_->IsConnected());
     93 
     94   return new SOCKS5ClientSocket(tcp_sock_,
     95       HostResolver::RequestInfo(HostPortPair(hostname, port)));
     96 }
     97 
     98 // Tests a complete SOCKS5 handshake and the disconnection.
     99 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
    100   const std::string payload_write = "random data";
    101   const std::string payload_read = "moar random data";
    102 
    103   const char kOkRequest[] = {
    104     0x05,  // Version
    105     0x01,  // Command (CONNECT)
    106     0x00,  // Reserved.
    107     0x03,  // Address type (DOMAINNAME).
    108     0x09,  // Length of domain (9)
    109     // Domain string:
    110     'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
    111     0x00, 0x50,  // 16-bit port (80)
    112   };
    113 
    114   MockWrite data_writes[] = {
    115       MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    116       MockWrite(true, kOkRequest, arraysize(kOkRequest)),
    117       MockWrite(true, payload_write.data(), payload_write.size()) };
    118   MockRead data_reads[] = {
    119       MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    120       MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
    121       MockRead(true, payload_read.data(), payload_read.size()) };
    122 
    123   user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    124                                    data_writes, arraysize(data_writes),
    125                                    "localhost", 80, &net_log_));
    126 
    127   // At this state the TCP connection is completed but not the SOCKS handshake.
    128   EXPECT_TRUE(tcp_sock_->IsConnected());
    129   EXPECT_FALSE(user_sock_->IsConnected());
    130 
    131   int rv = user_sock_->Connect(&callback_);
    132   EXPECT_EQ(ERR_IO_PENDING, rv);
    133   EXPECT_FALSE(user_sock_->IsConnected());
    134 
    135   net::CapturingNetLog::EntryList net_log_entries;
    136   net_log_.GetEntries(&net_log_entries);
    137   EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    138                                     NetLog::TYPE_SOCKS5_CONNECT));
    139 
    140   rv = callback_.WaitForResult();
    141 
    142   EXPECT_EQ(OK, rv);
    143   EXPECT_TRUE(user_sock_->IsConnected());
    144 
    145   net_log_.GetEntries(&net_log_entries);
    146   EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    147                                   NetLog::TYPE_SOCKS5_CONNECT));
    148 
    149   scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
    150   memcpy(buffer->data(), payload_write.data(), payload_write.size());
    151   rv = user_sock_->Write(buffer, payload_write.size(), &callback_);
    152   EXPECT_EQ(ERR_IO_PENDING, rv);
    153   rv = callback_.WaitForResult();
    154   EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
    155 
    156   buffer = new IOBuffer(payload_read.size());
    157   rv = user_sock_->Read(buffer, payload_read.size(), &callback_);
    158   EXPECT_EQ(ERR_IO_PENDING, rv);
    159   rv = callback_.WaitForResult();
    160   EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
    161   EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
    162 
    163   user_sock_->Disconnect();
    164   EXPECT_FALSE(tcp_sock_->IsConnected());
    165   EXPECT_FALSE(user_sock_->IsConnected());
    166 }
    167 
    168 // Test that you can call Connect() again after having called Disconnect().
    169 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
    170   const std::string hostname = "my-host-name";
    171   const char kSOCKS5DomainRequest[] = {
    172       0x05,  // VER
    173       0x01,  // CMD
    174       0x00,  // RSV
    175       0x03,  // ATYPE
    176   };
    177 
    178   std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest));
    179   request.push_back(hostname.size());
    180   request.append(hostname);
    181   request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
    182 
    183   for (int i = 0; i < 2; ++i) {
    184     MockWrite data_writes[] = {
    185         MockWrite(false, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    186         MockWrite(false, request.data(), request.size())
    187     };
    188     MockRead data_reads[] = {
    189         MockRead(false, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    190         MockRead(false, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
    191     };
    192 
    193     user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    194                                      data_writes, arraysize(data_writes),
    195                                      hostname, 80, NULL));
    196 
    197     int rv = user_sock_->Connect(&callback_);
    198     EXPECT_EQ(OK, rv);
    199     EXPECT_TRUE(user_sock_->IsConnected());
    200 
    201     user_sock_->Disconnect();
    202     EXPECT_FALSE(user_sock_->IsConnected());
    203   }
    204 }
    205 
    206 // Test that we fail trying to connect to a hosname longer than 255 bytes.
    207 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
    208   // Create a string of length 256, where each character is 'x'.
    209   std::string large_host_name;
    210   std::fill_n(std::back_inserter(large_host_name), 256, 'x');
    211 
    212   // Create a SOCKS socket, with mock transport socket.
    213   MockWrite data_writes[] = {MockWrite()};
    214   MockRead data_reads[] = {MockRead()};
    215   user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    216                                    data_writes, arraysize(data_writes),
    217                                    large_host_name, 80, NULL));
    218 
    219   // Try to connect -- should fail (without having read/written anything to
    220   // the transport socket first) because the hostname is too long.
    221   TestCompletionCallback callback;
    222   int rv = user_sock_->Connect(&callback);
    223   EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
    224 }
    225 
    226 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
    227   const std::string hostname = "www.google.com";
    228 
    229   const char kOkRequest[] = {
    230     0x05,  // Version
    231     0x01,  // Command (CONNECT)
    232     0x00,  // Reserved.
    233     0x03,  // Address type (DOMAINNAME).
    234     0x0E,  // Length of domain (14)
    235     // Domain string:
    236     'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
    237     0x00, 0x50,  // 16-bit port (80)
    238   };
    239 
    240   // Test for partial greet request write
    241   {
    242     const char partial1[] = { 0x05, 0x01 };
    243     const char partial2[] = { 0x00 };
    244     MockWrite data_writes[] = {
    245         MockWrite(true, arraysize(partial1)),
    246         MockWrite(true, partial2, arraysize(partial2)),
    247         MockWrite(true, kOkRequest, arraysize(kOkRequest)) };
    248     MockRead data_reads[] = {
    249         MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    250         MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
    251     user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    252                                      data_writes, arraysize(data_writes),
    253                                      hostname, 80, &net_log_));
    254     int rv = user_sock_->Connect(&callback_);
    255     EXPECT_EQ(ERR_IO_PENDING, rv);
    256 
    257     net::CapturingNetLog::EntryList net_log_entries;
    258     net_log_.GetEntries(&net_log_entries);
    259     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    260                 NetLog::TYPE_SOCKS5_CONNECT));
    261 
    262     rv = callback_.WaitForResult();
    263     EXPECT_EQ(OK, rv);
    264     EXPECT_TRUE(user_sock_->IsConnected());
    265 
    266     net_log_.GetEntries(&net_log_entries);
    267     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    268                 NetLog::TYPE_SOCKS5_CONNECT));
    269   }
    270 
    271   // Test for partial greet response read
    272   {
    273     const char partial1[] = { 0x05 };
    274     const char partial2[] = { 0x00 };
    275     MockWrite data_writes[] = {
    276         MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    277         MockWrite(true, kOkRequest, arraysize(kOkRequest)) };
    278     MockRead data_reads[] = {
    279         MockRead(true, partial1, arraysize(partial1)),
    280         MockRead(true, partial2, arraysize(partial2)),
    281         MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
    282     user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    283                                      data_writes, arraysize(data_writes),
    284                                      hostname, 80, &net_log_));
    285     int rv = user_sock_->Connect(&callback_);
    286     EXPECT_EQ(ERR_IO_PENDING, rv);
    287 
    288     net::CapturingNetLog::EntryList net_log_entries;
    289     net_log_.GetEntries(&net_log_entries);
    290     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    291                                       NetLog::TYPE_SOCKS5_CONNECT));
    292     rv = callback_.WaitForResult();
    293     EXPECT_EQ(OK, rv);
    294     EXPECT_TRUE(user_sock_->IsConnected());
    295     net_log_.GetEntries(&net_log_entries);
    296     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    297                                     NetLog::TYPE_SOCKS5_CONNECT));
    298   }
    299 
    300   // Test for partial handshake request write.
    301   {
    302     const int kSplitPoint = 3;  // Break handshake write into two parts.
    303     MockWrite data_writes[] = {
    304         MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    305         MockWrite(true, kOkRequest, kSplitPoint),
    306         MockWrite(true, kOkRequest + kSplitPoint,
    307                   arraysize(kOkRequest) - kSplitPoint)
    308     };
    309     MockRead data_reads[] = {
    310         MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    311         MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
    312     user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    313                                      data_writes, arraysize(data_writes),
    314                                      hostname, 80, &net_log_));
    315     int rv = user_sock_->Connect(&callback_);
    316     EXPECT_EQ(ERR_IO_PENDING, rv);
    317     net::CapturingNetLog::EntryList net_log_entries;
    318     net_log_.GetEntries(&net_log_entries);
    319     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    320                                       NetLog::TYPE_SOCKS5_CONNECT));
    321     rv = callback_.WaitForResult();
    322     EXPECT_EQ(OK, rv);
    323     EXPECT_TRUE(user_sock_->IsConnected());
    324     net_log_.GetEntries(&net_log_entries);
    325     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    326                                     NetLog::TYPE_SOCKS5_CONNECT));
    327   }
    328 
    329   // Test for partial handshake response read
    330   {
    331     const int kSplitPoint = 6;  // Break the handshake read into two parts.
    332     MockWrite data_writes[] = {
    333         MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
    334         MockWrite(true, kOkRequest, arraysize(kOkRequest))
    335     };
    336     MockRead data_reads[] = {
    337         MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
    338         MockRead(true, kSOCKS5OkResponse, kSplitPoint),
    339         MockRead(true, kSOCKS5OkResponse + kSplitPoint,
    340                  kSOCKS5OkResponseLength - kSplitPoint)
    341     };
    342 
    343     user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
    344                                      data_writes, arraysize(data_writes),
    345                                      hostname, 80, &net_log_));
    346     int rv = user_sock_->Connect(&callback_);
    347     EXPECT_EQ(ERR_IO_PENDING, rv);
    348     net::CapturingNetLog::EntryList net_log_entries;
    349     net_log_.GetEntries(&net_log_entries);
    350     EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
    351                                       NetLog::TYPE_SOCKS5_CONNECT));
    352     rv = callback_.WaitForResult();
    353     EXPECT_EQ(OK, rv);
    354     EXPECT_TRUE(user_sock_->IsConnected());
    355     net_log_.GetEntries(&net_log_entries);
    356     EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
    357                                     NetLog::TYPE_SOCKS5_CONNECT));
    358   }
    359 }
    360 
    361 }  // namespace
    362 
    363 }  // namespace net
    364