1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks5_client_socket.h" 6 7 #include <algorithm> 8 #include <iterator> 9 #include <map> 10 11 #include "base/sys_byteorder.h" 12 #include "net/base/address_list.h" 13 #include "net/base/net_log.h" 14 #include "net/base/net_log_unittest.h" 15 #include "net/base/test_completion_callback.h" 16 #include "net/base/winsock_init.h" 17 #include "net/dns/mock_host_resolver.h" 18 #include "net/socket/client_socket_factory.h" 19 #include "net/socket/socket_test_util.h" 20 #include "net/socket/tcp_client_socket.h" 21 #include "testing/gtest/include/gtest/gtest.h" 22 #include "testing/platform_test.h" 23 24 //----------------------------------------------------------------------------- 25 26 namespace net { 27 28 namespace { 29 30 // Base class to test SOCKS5ClientSocket 31 class SOCKS5ClientSocketTest : public PlatformTest { 32 public: 33 SOCKS5ClientSocketTest(); 34 // Create a SOCKSClientSocket on top of a MockSocket. 35 scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[], 36 size_t reads_count, 37 MockWrite writes[], 38 size_t writes_count, 39 const std::string& hostname, 40 int port, 41 NetLog* net_log); 42 43 virtual void SetUp(); 44 45 protected: 46 const uint16 kNwPort; 47 CapturingNetLog net_log_; 48 scoped_ptr<SOCKS5ClientSocket> user_sock_; 49 AddressList address_list_; 50 // Filled in by BuildMockSocket() and owned by its return value 51 // (which |user_sock| is set to). 52 StreamSocket* tcp_sock_; 53 TestCompletionCallback callback_; 54 scoped_ptr<MockHostResolver> host_resolver_; 55 scoped_ptr<SocketDataProvider> data_; 56 57 private: 58 DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest); 59 }; 60 61 SOCKS5ClientSocketTest::SOCKS5ClientSocketTest() 62 : kNwPort(base::HostToNet16(80)), 63 host_resolver_(new MockHostResolver) { 64 } 65 66 // Set up platform before every test case 67 void SOCKS5ClientSocketTest::SetUp() { 68 PlatformTest::SetUp(); 69 70 // Resolve the "localhost" AddressList used by the TCP connection to connect. 71 HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080)); 72 TestCompletionCallback callback; 73 int rv = host_resolver_->Resolve(info, 74 DEFAULT_PRIORITY, 75 &address_list_, 76 callback.callback(), 77 NULL, 78 BoundNetLog()); 79 ASSERT_EQ(ERR_IO_PENDING, rv); 80 rv = callback.WaitForResult(); 81 ASSERT_EQ(OK, rv); 82 } 83 84 scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket( 85 MockRead reads[], 86 size_t reads_count, 87 MockWrite writes[], 88 size_t writes_count, 89 const std::string& hostname, 90 int port, 91 NetLog* net_log) { 92 TestCompletionCallback callback; 93 data_.reset(new StaticSocketDataProvider(reads, reads_count, 94 writes, writes_count)); 95 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); 96 97 int rv = tcp_sock_->Connect(callback.callback()); 98 EXPECT_EQ(ERR_IO_PENDING, rv); 99 rv = callback.WaitForResult(); 100 EXPECT_EQ(OK, rv); 101 EXPECT_TRUE(tcp_sock_->IsConnected()); 102 103 scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); 104 // |connection| takes ownership of |tcp_sock_|, but keep a 105 // non-owning pointer to it. 106 connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); 107 return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket( 108 connection.Pass(), 109 HostResolver::RequestInfo(HostPortPair(hostname, port)))); 110 } 111 112 // Tests a complete SOCKS5 handshake and the disconnection. 113 TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { 114 const std::string payload_write = "random data"; 115 const std::string payload_read = "moar random data"; 116 117 const char kOkRequest[] = { 118 0x05, // Version 119 0x01, // Command (CONNECT) 120 0x00, // Reserved. 121 0x03, // Address type (DOMAINNAME). 122 0x09, // Length of domain (9) 123 // Domain string: 124 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 125 0x00, 0x50, // 16-bit port (80) 126 }; 127 128 MockWrite data_writes[] = { 129 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 130 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)), 131 MockWrite(ASYNC, payload_write.data(), payload_write.size()) }; 132 MockRead data_reads[] = { 133 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 134 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength), 135 MockRead(ASYNC, payload_read.data(), payload_read.size()) }; 136 137 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 138 data_writes, arraysize(data_writes), 139 "localhost", 80, &net_log_); 140 141 // At this state the TCP connection is completed but not the SOCKS handshake. 142 EXPECT_TRUE(tcp_sock_->IsConnected()); 143 EXPECT_FALSE(user_sock_->IsConnected()); 144 145 int rv = user_sock_->Connect(callback_.callback()); 146 EXPECT_EQ(ERR_IO_PENDING, rv); 147 EXPECT_FALSE(user_sock_->IsConnected()); 148 149 CapturingNetLog::CapturedEntryList net_log_entries; 150 net_log_.GetEntries(&net_log_entries); 151 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 152 NetLog::TYPE_SOCKS5_CONNECT)); 153 154 rv = callback_.WaitForResult(); 155 156 EXPECT_EQ(OK, rv); 157 EXPECT_TRUE(user_sock_->IsConnected()); 158 159 net_log_.GetEntries(&net_log_entries); 160 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 161 NetLog::TYPE_SOCKS5_CONNECT)); 162 163 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); 164 memcpy(buffer->data(), payload_write.data(), payload_write.size()); 165 rv = user_sock_->Write( 166 buffer.get(), payload_write.size(), callback_.callback()); 167 EXPECT_EQ(ERR_IO_PENDING, rv); 168 rv = callback_.WaitForResult(); 169 EXPECT_EQ(static_cast<int>(payload_write.size()), rv); 170 171 buffer = new IOBuffer(payload_read.size()); 172 rv = 173 user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback()); 174 EXPECT_EQ(ERR_IO_PENDING, rv); 175 rv = callback_.WaitForResult(); 176 EXPECT_EQ(static_cast<int>(payload_read.size()), rv); 177 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); 178 179 user_sock_->Disconnect(); 180 EXPECT_FALSE(tcp_sock_->IsConnected()); 181 EXPECT_FALSE(user_sock_->IsConnected()); 182 } 183 184 // Test that you can call Connect() again after having called Disconnect(). 185 TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) { 186 const std::string hostname = "my-host-name"; 187 const char kSOCKS5DomainRequest[] = { 188 0x05, // VER 189 0x01, // CMD 190 0x00, // RSV 191 0x03, // ATYPE 192 }; 193 194 std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest)); 195 request.push_back(hostname.size()); 196 request.append(hostname); 197 request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort)); 198 199 for (int i = 0; i < 2; ++i) { 200 MockWrite data_writes[] = { 201 MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 202 MockWrite(SYNCHRONOUS, request.data(), request.size()) 203 }; 204 MockRead data_reads[] = { 205 MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 206 MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength) 207 }; 208 209 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 210 data_writes, arraysize(data_writes), 211 hostname, 80, NULL); 212 213 int rv = user_sock_->Connect(callback_.callback()); 214 EXPECT_EQ(OK, rv); 215 EXPECT_TRUE(user_sock_->IsConnected()); 216 217 user_sock_->Disconnect(); 218 EXPECT_FALSE(user_sock_->IsConnected()); 219 } 220 } 221 222 // Test that we fail trying to connect to a hosname longer than 255 bytes. 223 TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) { 224 // Create a string of length 256, where each character is 'x'. 225 std::string large_host_name; 226 std::fill_n(std::back_inserter(large_host_name), 256, 'x'); 227 228 // Create a SOCKS socket, with mock transport socket. 229 MockWrite data_writes[] = {MockWrite()}; 230 MockRead data_reads[] = {MockRead()}; 231 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 232 data_writes, arraysize(data_writes), 233 large_host_name, 80, NULL); 234 235 // Try to connect -- should fail (without having read/written anything to 236 // the transport socket first) because the hostname is too long. 237 TestCompletionCallback callback; 238 int rv = user_sock_->Connect(callback.callback()); 239 EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); 240 } 241 242 TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { 243 const std::string hostname = "www.google.com"; 244 245 const char kOkRequest[] = { 246 0x05, // Version 247 0x01, // Command (CONNECT) 248 0x00, // Reserved. 249 0x03, // Address type (DOMAINNAME). 250 0x0E, // Length of domain (14) 251 // Domain string: 252 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm', 253 0x00, 0x50, // 16-bit port (80) 254 }; 255 256 // Test for partial greet request write 257 { 258 const char partial1[] = { 0x05, 0x01 }; 259 const char partial2[] = { 0x00 }; 260 MockWrite data_writes[] = { 261 MockWrite(ASYNC, arraysize(partial1)), 262 MockWrite(ASYNC, partial2, arraysize(partial2)), 263 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) }; 264 MockRead data_reads[] = { 265 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 266 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; 267 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 268 data_writes, arraysize(data_writes), 269 hostname, 80, &net_log_); 270 int rv = user_sock_->Connect(callback_.callback()); 271 EXPECT_EQ(ERR_IO_PENDING, rv); 272 273 CapturingNetLog::CapturedEntryList net_log_entries; 274 net_log_.GetEntries(&net_log_entries); 275 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 276 NetLog::TYPE_SOCKS5_CONNECT)); 277 278 rv = callback_.WaitForResult(); 279 EXPECT_EQ(OK, rv); 280 EXPECT_TRUE(user_sock_->IsConnected()); 281 282 net_log_.GetEntries(&net_log_entries); 283 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 284 NetLog::TYPE_SOCKS5_CONNECT)); 285 } 286 287 // Test for partial greet response read 288 { 289 const char partial1[] = { 0x05 }; 290 const char partial2[] = { 0x00 }; 291 MockWrite data_writes[] = { 292 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 293 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) }; 294 MockRead data_reads[] = { 295 MockRead(ASYNC, partial1, arraysize(partial1)), 296 MockRead(ASYNC, partial2, arraysize(partial2)), 297 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; 298 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 299 data_writes, arraysize(data_writes), 300 hostname, 80, &net_log_); 301 int rv = user_sock_->Connect(callback_.callback()); 302 EXPECT_EQ(ERR_IO_PENDING, rv); 303 304 CapturingNetLog::CapturedEntryList net_log_entries; 305 net_log_.GetEntries(&net_log_entries); 306 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 307 NetLog::TYPE_SOCKS5_CONNECT)); 308 rv = callback_.WaitForResult(); 309 EXPECT_EQ(OK, rv); 310 EXPECT_TRUE(user_sock_->IsConnected()); 311 net_log_.GetEntries(&net_log_entries); 312 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 313 NetLog::TYPE_SOCKS5_CONNECT)); 314 } 315 316 // Test for partial handshake request write. 317 { 318 const int kSplitPoint = 3; // Break handshake write into two parts. 319 MockWrite data_writes[] = { 320 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 321 MockWrite(ASYNC, kOkRequest, kSplitPoint), 322 MockWrite(ASYNC, kOkRequest + kSplitPoint, 323 arraysize(kOkRequest) - kSplitPoint) 324 }; 325 MockRead data_reads[] = { 326 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 327 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; 328 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 329 data_writes, arraysize(data_writes), 330 hostname, 80, &net_log_); 331 int rv = user_sock_->Connect(callback_.callback()); 332 EXPECT_EQ(ERR_IO_PENDING, rv); 333 CapturingNetLog::CapturedEntryList net_log_entries; 334 net_log_.GetEntries(&net_log_entries); 335 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 336 NetLog::TYPE_SOCKS5_CONNECT)); 337 rv = callback_.WaitForResult(); 338 EXPECT_EQ(OK, rv); 339 EXPECT_TRUE(user_sock_->IsConnected()); 340 net_log_.GetEntries(&net_log_entries); 341 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 342 NetLog::TYPE_SOCKS5_CONNECT)); 343 } 344 345 // Test for partial handshake response read 346 { 347 const int kSplitPoint = 6; // Break the handshake read into two parts. 348 MockWrite data_writes[] = { 349 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 350 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) 351 }; 352 MockRead data_reads[] = { 353 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 354 MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint), 355 MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint, 356 kSOCKS5OkResponseLength - kSplitPoint) 357 }; 358 359 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 360 data_writes, arraysize(data_writes), 361 hostname, 80, &net_log_); 362 int rv = user_sock_->Connect(callback_.callback()); 363 EXPECT_EQ(ERR_IO_PENDING, rv); 364 CapturingNetLog::CapturedEntryList net_log_entries; 365 net_log_.GetEntries(&net_log_entries); 366 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 367 NetLog::TYPE_SOCKS5_CONNECT)); 368 rv = callback_.WaitForResult(); 369 EXPECT_EQ(OK, rv); 370 EXPECT_TRUE(user_sock_->IsConnected()); 371 net_log_.GetEntries(&net_log_entries); 372 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 373 NetLog::TYPE_SOCKS5_CONNECT)); 374 } 375 } 376 377 } // namespace 378 379 } // namespace net 380