1 // Copyright (c) 2010 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks5_client_socket.h" 6 7 #include <algorithm> 8 #include <map> 9 10 #include "net/base/address_list.h" 11 #include "net/base/net_log.h" 12 #include "net/base/net_log_unittest.h" 13 #include "net/base/mock_host_resolver.h" 14 #include "net/base/sys_addrinfo.h" 15 #include "net/base/test_completion_callback.h" 16 #include "net/base/winsock_init.h" 17 #include "net/socket/client_socket_factory.h" 18 #include "net/socket/socket_test_util.h" 19 #include "net/socket/tcp_client_socket.h" 20 #include "testing/gtest/include/gtest/gtest.h" 21 #include "testing/platform_test.h" 22 23 //----------------------------------------------------------------------------- 24 25 namespace net { 26 27 namespace { 28 29 // Base class to test SOCKS5ClientSocket 30 class SOCKS5ClientSocketTest : public PlatformTest { 31 public: 32 SOCKS5ClientSocketTest(); 33 // Create a SOCKSClientSocket on top of a MockSocket. 34 SOCKS5ClientSocket* BuildMockSocket(MockRead reads[], 35 size_t reads_count, 36 MockWrite writes[], 37 size_t writes_count, 38 const std::string& hostname, 39 int port, 40 NetLog* net_log); 41 42 virtual void SetUp(); 43 44 protected: 45 const uint16 kNwPort; 46 CapturingNetLog net_log_; 47 scoped_ptr<SOCKS5ClientSocket> user_sock_; 48 AddressList address_list_; 49 ClientSocket* tcp_sock_; 50 TestCompletionCallback callback_; 51 scoped_ptr<MockHostResolver> host_resolver_; 52 scoped_ptr<SocketDataProvider> data_; 53 54 private: 55 DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest); 56 }; 57 58 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest() 59 : kNwPort(htons(80)), 60 net_log_(CapturingNetLog::kUnbounded), 61 host_resolver_(new MockHostResolver) { 62 } 63 64 // Set up platform before every test case 65 void SOCKS5ClientSocketTest::SetUp() { 66 PlatformTest::SetUp(); 67 68 // Resolve the "localhost" AddressList used by the TCP connection to connect. 69 HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080)); 70 int rv = host_resolver_->Resolve(info, &address_list_, NULL, NULL, 71 BoundNetLog()); 72 ASSERT_EQ(OK, rv); 73 } 74 75 SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( 76 MockRead reads[], 77 size_t reads_count, 78 MockWrite writes[], 79 size_t writes_count, 80 const std::string& hostname, 81 int port, 82 NetLog* net_log) { 83 TestCompletionCallback callback; 84 data_.reset(new StaticSocketDataProvider(reads, reads_count, 85 writes, writes_count)); 86 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); 87 88 int rv = tcp_sock_->Connect(&callback); 89 EXPECT_EQ(ERR_IO_PENDING, rv); 90 rv = callback.WaitForResult(); 91 EXPECT_EQ(OK, rv); 92 EXPECT_TRUE(tcp_sock_->IsConnected()); 93 94 return new SOCKS5ClientSocket(tcp_sock_, 95 HostResolver::RequestInfo(HostPortPair(hostname, port))); 96 } 97 98 // Tests a complete SOCKS5 handshake and the disconnection. 99 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { 100 const std::string payload_write = "random data"; 101 const std::string payload_read = "moar random data"; 102 103 const char kOkRequest[] = { 104 0x05, // Version 105 0x01, // Command (CONNECT) 106 0x00, // Reserved. 107 0x03, // Address type (DOMAINNAME). 108 0x09, // Length of domain (9) 109 // Domain string: 110 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 111 0x00, 0x50, // 16-bit port (80) 112 }; 113 114 MockWrite data_writes[] = { 115 MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 116 MockWrite(true, kOkRequest, arraysize(kOkRequest)), 117 MockWrite(true, payload_write.data(), payload_write.size()) }; 118 MockRead data_reads[] = { 119 MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 120 MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength), 121 MockRead(true, payload_read.data(), payload_read.size()) }; 122 123 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 124 data_writes, arraysize(data_writes), 125 "localhost", 80, &net_log_)); 126 127 // At this state the TCP connection is completed but not the SOCKS handshake. 128 EXPECT_TRUE(tcp_sock_->IsConnected()); 129 EXPECT_FALSE(user_sock_->IsConnected()); 130 131 int rv = user_sock_->Connect(&callback_); 132 EXPECT_EQ(ERR_IO_PENDING, rv); 133 EXPECT_FALSE(user_sock_->IsConnected()); 134 135 net::CapturingNetLog::EntryList net_log_entries; 136 net_log_.GetEntries(&net_log_entries); 137 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 138 NetLog::TYPE_SOCKS5_CONNECT)); 139 140 rv = callback_.WaitForResult(); 141 142 EXPECT_EQ(OK, rv); 143 EXPECT_TRUE(user_sock_->IsConnected()); 144 145 net_log_.GetEntries(&net_log_entries); 146 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 147 NetLog::TYPE_SOCKS5_CONNECT)); 148 149 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); 150 memcpy(buffer->data(), payload_write.data(), payload_write.size()); 151 rv = user_sock_->Write(buffer, payload_write.size(), &callback_); 152 EXPECT_EQ(ERR_IO_PENDING, rv); 153 rv = callback_.WaitForResult(); 154 EXPECT_EQ(static_cast<int>(payload_write.size()), rv); 155 156 buffer = new IOBuffer(payload_read.size()); 157 rv = user_sock_->Read(buffer, payload_read.size(), &callback_); 158 EXPECT_EQ(ERR_IO_PENDING, rv); 159 rv = callback_.WaitForResult(); 160 EXPECT_EQ(static_cast<int>(payload_read.size()), rv); 161 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); 162 163 user_sock_->Disconnect(); 164 EXPECT_FALSE(tcp_sock_->IsConnected()); 165 EXPECT_FALSE(user_sock_->IsConnected()); 166 } 167 168 // Test that you can call Connect() again after having called Disconnect(). 169 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) { 170 const std::string hostname = "my-host-name"; 171 const char kSOCKS5DomainRequest[] = { 172 0x05, // VER 173 0x01, // CMD 174 0x00, // RSV 175 0x03, // ATYPE 176 }; 177 178 std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest)); 179 request.push_back(hostname.size()); 180 request.append(hostname); 181 request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort)); 182 183 for (int i = 0; i < 2; ++i) { 184 MockWrite data_writes[] = { 185 MockWrite(false, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 186 MockWrite(false, request.data(), request.size()) 187 }; 188 MockRead data_reads[] = { 189 MockRead(false, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 190 MockRead(false, kSOCKS5OkResponse, kSOCKS5OkResponseLength) 191 }; 192 193 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 194 data_writes, arraysize(data_writes), 195 hostname, 80, NULL)); 196 197 int rv = user_sock_->Connect(&callback_); 198 EXPECT_EQ(OK, rv); 199 EXPECT_TRUE(user_sock_->IsConnected()); 200 201 user_sock_->Disconnect(); 202 EXPECT_FALSE(user_sock_->IsConnected()); 203 } 204 } 205 206 // Test that we fail trying to connect to a hosname longer than 255 bytes. 207 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) { 208 // Create a string of length 256, where each character is 'x'. 209 std::string large_host_name; 210 std::fill_n(std::back_inserter(large_host_name), 256, 'x'); 211 212 // Create a SOCKS socket, with mock transport socket. 213 MockWrite data_writes[] = {MockWrite()}; 214 MockRead data_reads[] = {MockRead()}; 215 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 216 data_writes, arraysize(data_writes), 217 large_host_name, 80, NULL)); 218 219 // Try to connect -- should fail (without having read/written anything to 220 // the transport socket first) because the hostname is too long. 221 TestCompletionCallback callback; 222 int rv = user_sock_->Connect(&callback); 223 EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); 224 } 225 226 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { 227 const std::string hostname = "www.google.com"; 228 229 const char kOkRequest[] = { 230 0x05, // Version 231 0x01, // Command (CONNECT) 232 0x00, // Reserved. 233 0x03, // Address type (DOMAINNAME). 234 0x0E, // Length of domain (14) 235 // Domain string: 236 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm', 237 0x00, 0x50, // 16-bit port (80) 238 }; 239 240 // Test for partial greet request write 241 { 242 const char partial1[] = { 0x05, 0x01 }; 243 const char partial2[] = { 0x00 }; 244 MockWrite data_writes[] = { 245 MockWrite(true, arraysize(partial1)), 246 MockWrite(true, partial2, arraysize(partial2)), 247 MockWrite(true, kOkRequest, arraysize(kOkRequest)) }; 248 MockRead data_reads[] = { 249 MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 250 MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; 251 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 252 data_writes, arraysize(data_writes), 253 hostname, 80, &net_log_)); 254 int rv = user_sock_->Connect(&callback_); 255 EXPECT_EQ(ERR_IO_PENDING, rv); 256 257 net::CapturingNetLog::EntryList net_log_entries; 258 net_log_.GetEntries(&net_log_entries); 259 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 260 NetLog::TYPE_SOCKS5_CONNECT)); 261 262 rv = callback_.WaitForResult(); 263 EXPECT_EQ(OK, rv); 264 EXPECT_TRUE(user_sock_->IsConnected()); 265 266 net_log_.GetEntries(&net_log_entries); 267 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 268 NetLog::TYPE_SOCKS5_CONNECT)); 269 } 270 271 // Test for partial greet response read 272 { 273 const char partial1[] = { 0x05 }; 274 const char partial2[] = { 0x00 }; 275 MockWrite data_writes[] = { 276 MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 277 MockWrite(true, kOkRequest, arraysize(kOkRequest)) }; 278 MockRead data_reads[] = { 279 MockRead(true, partial1, arraysize(partial1)), 280 MockRead(true, partial2, arraysize(partial2)), 281 MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; 282 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 283 data_writes, arraysize(data_writes), 284 hostname, 80, &net_log_)); 285 int rv = user_sock_->Connect(&callback_); 286 EXPECT_EQ(ERR_IO_PENDING, rv); 287 288 net::CapturingNetLog::EntryList net_log_entries; 289 net_log_.GetEntries(&net_log_entries); 290 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 291 NetLog::TYPE_SOCKS5_CONNECT)); 292 rv = callback_.WaitForResult(); 293 EXPECT_EQ(OK, rv); 294 EXPECT_TRUE(user_sock_->IsConnected()); 295 net_log_.GetEntries(&net_log_entries); 296 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 297 NetLog::TYPE_SOCKS5_CONNECT)); 298 } 299 300 // Test for partial handshake request write. 301 { 302 const int kSplitPoint = 3; // Break handshake write into two parts. 303 MockWrite data_writes[] = { 304 MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 305 MockWrite(true, kOkRequest, kSplitPoint), 306 MockWrite(true, kOkRequest + kSplitPoint, 307 arraysize(kOkRequest) - kSplitPoint) 308 }; 309 MockRead data_reads[] = { 310 MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 311 MockRead(true, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; 312 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 313 data_writes, arraysize(data_writes), 314 hostname, 80, &net_log_)); 315 int rv = user_sock_->Connect(&callback_); 316 EXPECT_EQ(ERR_IO_PENDING, rv); 317 net::CapturingNetLog::EntryList net_log_entries; 318 net_log_.GetEntries(&net_log_entries); 319 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 320 NetLog::TYPE_SOCKS5_CONNECT)); 321 rv = callback_.WaitForResult(); 322 EXPECT_EQ(OK, rv); 323 EXPECT_TRUE(user_sock_->IsConnected()); 324 net_log_.GetEntries(&net_log_entries); 325 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 326 NetLog::TYPE_SOCKS5_CONNECT)); 327 } 328 329 // Test for partial handshake response read 330 { 331 const int kSplitPoint = 6; // Break the handshake read into two parts. 332 MockWrite data_writes[] = { 333 MockWrite(true, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 334 MockWrite(true, kOkRequest, arraysize(kOkRequest)) 335 }; 336 MockRead data_reads[] = { 337 MockRead(true, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 338 MockRead(true, kSOCKS5OkResponse, kSplitPoint), 339 MockRead(true, kSOCKS5OkResponse + kSplitPoint, 340 kSOCKS5OkResponseLength - kSplitPoint) 341 }; 342 343 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 344 data_writes, arraysize(data_writes), 345 hostname, 80, &net_log_)); 346 int rv = user_sock_->Connect(&callback_); 347 EXPECT_EQ(ERR_IO_PENDING, rv); 348 net::CapturingNetLog::EntryList net_log_entries; 349 net_log_.GetEntries(&net_log_entries); 350 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 351 NetLog::TYPE_SOCKS5_CONNECT)); 352 rv = callback_.WaitForResult(); 353 EXPECT_EQ(OK, rv); 354 EXPECT_TRUE(user_sock_->IsConnected()); 355 net_log_.GetEntries(&net_log_entries); 356 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 357 NetLog::TYPE_SOCKS5_CONNECT)); 358 } 359 } 360 361 } // namespace 362 363 } // namespace net 364