1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks_client_socket.h" 6 7 #include "base/memory/scoped_ptr.h" 8 #include "net/base/address_list.h" 9 #include "net/base/net_log.h" 10 #include "net/base/net_log_unittest.h" 11 #include "net/base/test_completion_callback.h" 12 #include "net/base/winsock_init.h" 13 #include "net/dns/mock_host_resolver.h" 14 #include "net/socket/client_socket_factory.h" 15 #include "net/socket/socket_test_util.h" 16 #include "net/socket/tcp_client_socket.h" 17 #include "testing/gtest/include/gtest/gtest.h" 18 #include "testing/platform_test.h" 19 20 //----------------------------------------------------------------------------- 21 22 namespace net { 23 24 const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 }; 25 const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; 26 27 class SOCKSClientSocketTest : public PlatformTest { 28 public: 29 SOCKSClientSocketTest(); 30 // Create a SOCKSClientSocket on top of a MockSocket. 31 scoped_ptr<SOCKSClientSocket> BuildMockSocket( 32 MockRead reads[], size_t reads_count, 33 MockWrite writes[], size_t writes_count, 34 HostResolver* host_resolver, 35 const std::string& hostname, int port, 36 NetLog* net_log); 37 virtual void SetUp(); 38 39 protected: 40 scoped_ptr<SOCKSClientSocket> user_sock_; 41 AddressList address_list_; 42 // Filled in by BuildMockSocket() and owned by its return value 43 // (which |user_sock| is set to). 44 StreamSocket* tcp_sock_; 45 TestCompletionCallback callback_; 46 scoped_ptr<MockHostResolver> host_resolver_; 47 scoped_ptr<SocketDataProvider> data_; 48 }; 49 50 SOCKSClientSocketTest::SOCKSClientSocketTest() 51 : host_resolver_(new MockHostResolver) { 52 } 53 54 // Set up platform before every test case 55 void SOCKSClientSocketTest::SetUp() { 56 PlatformTest::SetUp(); 57 } 58 59 scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket( 60 MockRead reads[], 61 size_t reads_count, 62 MockWrite writes[], 63 size_t writes_count, 64 HostResolver* host_resolver, 65 const std::string& hostname, 66 int port, 67 NetLog* net_log) { 68 69 TestCompletionCallback callback; 70 data_.reset(new StaticSocketDataProvider(reads, reads_count, 71 writes, writes_count)); 72 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); 73 74 int rv = tcp_sock_->Connect(callback.callback()); 75 EXPECT_EQ(ERR_IO_PENDING, rv); 76 rv = callback.WaitForResult(); 77 EXPECT_EQ(OK, rv); 78 EXPECT_TRUE(tcp_sock_->IsConnected()); 79 80 scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); 81 // |connection| takes ownership of |tcp_sock_|, but keep a 82 // non-owning pointer to it. 83 connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); 84 return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket( 85 connection.Pass(), 86 HostResolver::RequestInfo(HostPortPair(hostname, port)), 87 DEFAULT_PRIORITY, 88 host_resolver)); 89 } 90 91 // Implementation of HostResolver that never completes its resolve request. 92 // We use this in the test "DisconnectWhileHostResolveInProgress" to make 93 // sure that the outstanding resolve request gets cancelled. 94 class HangingHostResolverWithCancel : public HostResolver { 95 public: 96 HangingHostResolverWithCancel() : outstanding_request_(NULL) {} 97 98 virtual int Resolve(const RequestInfo& info, 99 RequestPriority priority, 100 AddressList* addresses, 101 const CompletionCallback& callback, 102 RequestHandle* out_req, 103 const BoundNetLog& net_log) OVERRIDE { 104 DCHECK(addresses); 105 DCHECK_EQ(false, callback.is_null()); 106 EXPECT_FALSE(HasOutstandingRequest()); 107 outstanding_request_ = reinterpret_cast<RequestHandle>(1); 108 *out_req = outstanding_request_; 109 return ERR_IO_PENDING; 110 } 111 112 virtual int ResolveFromCache(const RequestInfo& info, 113 AddressList* addresses, 114 const BoundNetLog& net_log) OVERRIDE { 115 NOTIMPLEMENTED(); 116 return ERR_UNEXPECTED; 117 } 118 119 virtual void CancelRequest(RequestHandle req) OVERRIDE { 120 EXPECT_TRUE(HasOutstandingRequest()); 121 EXPECT_EQ(outstanding_request_, req); 122 outstanding_request_ = NULL; 123 } 124 125 bool HasOutstandingRequest() { 126 return outstanding_request_ != NULL; 127 } 128 129 private: 130 RequestHandle outstanding_request_; 131 132 DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel); 133 }; 134 135 // Tests a complete handshake and the disconnection. 136 TEST_F(SOCKSClientSocketTest, CompleteHandshake) { 137 const std::string payload_write = "random data"; 138 const std::string payload_read = "moar random data"; 139 140 MockWrite data_writes[] = { 141 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)), 142 MockWrite(ASYNC, payload_write.data(), payload_write.size()) }; 143 MockRead data_reads[] = { 144 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)), 145 MockRead(ASYNC, payload_read.data(), payload_read.size()) }; 146 CapturingNetLog log; 147 148 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 149 data_writes, arraysize(data_writes), 150 host_resolver_.get(), 151 "localhost", 80, 152 &log); 153 154 // At this state the TCP connection is completed but not the SOCKS handshake. 155 EXPECT_TRUE(tcp_sock_->IsConnected()); 156 EXPECT_FALSE(user_sock_->IsConnected()); 157 158 int rv = user_sock_->Connect(callback_.callback()); 159 EXPECT_EQ(ERR_IO_PENDING, rv); 160 161 CapturingNetLog::CapturedEntryList entries; 162 log.GetEntries(&entries); 163 EXPECT_TRUE( 164 LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 165 EXPECT_FALSE(user_sock_->IsConnected()); 166 167 rv = callback_.WaitForResult(); 168 EXPECT_EQ(OK, rv); 169 EXPECT_TRUE(user_sock_->IsConnected()); 170 log.GetEntries(&entries); 171 EXPECT_TRUE(LogContainsEndEvent( 172 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 173 174 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); 175 memcpy(buffer->data(), payload_write.data(), payload_write.size()); 176 rv = user_sock_->Write( 177 buffer.get(), payload_write.size(), callback_.callback()); 178 EXPECT_EQ(ERR_IO_PENDING, rv); 179 rv = callback_.WaitForResult(); 180 EXPECT_EQ(static_cast<int>(payload_write.size()), rv); 181 182 buffer = new IOBuffer(payload_read.size()); 183 rv = 184 user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback()); 185 EXPECT_EQ(ERR_IO_PENDING, rv); 186 rv = callback_.WaitForResult(); 187 EXPECT_EQ(static_cast<int>(payload_read.size()), rv); 188 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); 189 190 user_sock_->Disconnect(); 191 EXPECT_FALSE(tcp_sock_->IsConnected()); 192 EXPECT_FALSE(user_sock_->IsConnected()); 193 } 194 195 // List of responses from the socks server and the errors they should 196 // throw up are tested here. 197 TEST_F(SOCKSClientSocketTest, HandshakeFailures) { 198 const struct { 199 const char fail_reply[8]; 200 Error fail_code; 201 } tests[] = { 202 // Failure of the server response code 203 { 204 { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }, 205 ERR_SOCKS_CONNECTION_FAILED, 206 }, 207 // Failure of the null byte 208 { 209 { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 }, 210 ERR_SOCKS_CONNECTION_FAILED, 211 }, 212 }; 213 214 //--------------------------------------- 215 216 for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) { 217 MockWrite data_writes[] = { 218 MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; 219 MockRead data_reads[] = { 220 MockRead(SYNCHRONOUS, tests[i].fail_reply, 221 arraysize(tests[i].fail_reply)) }; 222 CapturingNetLog log; 223 224 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 225 data_writes, arraysize(data_writes), 226 host_resolver_.get(), 227 "localhost", 80, 228 &log); 229 230 int rv = user_sock_->Connect(callback_.callback()); 231 EXPECT_EQ(ERR_IO_PENDING, rv); 232 233 CapturingNetLog::CapturedEntryList entries; 234 log.GetEntries(&entries); 235 EXPECT_TRUE(LogContainsBeginEvent( 236 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 237 238 rv = callback_.WaitForResult(); 239 EXPECT_EQ(tests[i].fail_code, rv); 240 EXPECT_FALSE(user_sock_->IsConnected()); 241 EXPECT_TRUE(tcp_sock_->IsConnected()); 242 log.GetEntries(&entries); 243 EXPECT_TRUE(LogContainsEndEvent( 244 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 245 } 246 } 247 248 // Tests scenario when the server sends the handshake response in 249 // more than one packet. 250 TEST_F(SOCKSClientSocketTest, PartialServerReads) { 251 const char kSOCKSPartialReply1[] = { 0x00 }; 252 const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; 253 254 MockWrite data_writes[] = { 255 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; 256 MockRead data_reads[] = { 257 MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)), 258 MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; 259 CapturingNetLog log; 260 261 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 262 data_writes, arraysize(data_writes), 263 host_resolver_.get(), 264 "localhost", 80, 265 &log); 266 267 int rv = user_sock_->Connect(callback_.callback()); 268 EXPECT_EQ(ERR_IO_PENDING, rv); 269 CapturingNetLog::CapturedEntryList entries; 270 log.GetEntries(&entries); 271 EXPECT_TRUE(LogContainsBeginEvent( 272 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 273 274 rv = callback_.WaitForResult(); 275 EXPECT_EQ(OK, rv); 276 EXPECT_TRUE(user_sock_->IsConnected()); 277 log.GetEntries(&entries); 278 EXPECT_TRUE(LogContainsEndEvent( 279 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 280 } 281 282 // Tests scenario when the client sends the handshake request in 283 // more than one packet. 284 TEST_F(SOCKSClientSocketTest, PartialClientWrites) { 285 const char kSOCKSPartialRequest1[] = { 0x04, 0x01 }; 286 const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 }; 287 288 MockWrite data_writes[] = { 289 MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)), 290 // simulate some empty writes 291 MockWrite(ASYNC, 0), 292 MockWrite(ASYNC, 0), 293 MockWrite(ASYNC, kSOCKSPartialRequest2, 294 arraysize(kSOCKSPartialRequest2)) }; 295 MockRead data_reads[] = { 296 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; 297 CapturingNetLog log; 298 299 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 300 data_writes, arraysize(data_writes), 301 host_resolver_.get(), 302 "localhost", 80, 303 &log); 304 305 int rv = user_sock_->Connect(callback_.callback()); 306 EXPECT_EQ(ERR_IO_PENDING, rv); 307 CapturingNetLog::CapturedEntryList entries; 308 log.GetEntries(&entries); 309 EXPECT_TRUE(LogContainsBeginEvent( 310 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 311 312 rv = callback_.WaitForResult(); 313 EXPECT_EQ(OK, rv); 314 EXPECT_TRUE(user_sock_->IsConnected()); 315 log.GetEntries(&entries); 316 EXPECT_TRUE(LogContainsEndEvent( 317 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 318 } 319 320 // Tests the case when the server sends a smaller sized handshake data 321 // and closes the connection. 322 TEST_F(SOCKSClientSocketTest, FailedSocketRead) { 323 MockWrite data_writes[] = { 324 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; 325 MockRead data_reads[] = { 326 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2), 327 // close connection unexpectedly 328 MockRead(SYNCHRONOUS, 0) }; 329 CapturingNetLog log; 330 331 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 332 data_writes, arraysize(data_writes), 333 host_resolver_.get(), 334 "localhost", 80, 335 &log); 336 337 int rv = user_sock_->Connect(callback_.callback()); 338 EXPECT_EQ(ERR_IO_PENDING, rv); 339 CapturingNetLog::CapturedEntryList entries; 340 log.GetEntries(&entries); 341 EXPECT_TRUE(LogContainsBeginEvent( 342 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 343 344 rv = callback_.WaitForResult(); 345 EXPECT_EQ(ERR_CONNECTION_CLOSED, rv); 346 EXPECT_FALSE(user_sock_->IsConnected()); 347 log.GetEntries(&entries); 348 EXPECT_TRUE(LogContainsEndEvent( 349 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 350 } 351 352 // Tries to connect to an unknown hostname. Should fail rather than 353 // falling back to SOCKS4a. 354 TEST_F(SOCKSClientSocketTest, FailedDNS) { 355 const char hostname[] = "unresolved.ipv4.address"; 356 357 host_resolver_->rules()->AddSimulatedFailure(hostname); 358 359 CapturingNetLog log; 360 361 user_sock_ = BuildMockSocket(NULL, 0, 362 NULL, 0, 363 host_resolver_.get(), 364 hostname, 80, 365 &log); 366 367 int rv = user_sock_->Connect(callback_.callback()); 368 EXPECT_EQ(ERR_IO_PENDING, rv); 369 CapturingNetLog::CapturedEntryList entries; 370 log.GetEntries(&entries); 371 EXPECT_TRUE(LogContainsBeginEvent( 372 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 373 374 rv = callback_.WaitForResult(); 375 EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv); 376 EXPECT_FALSE(user_sock_->IsConnected()); 377 log.GetEntries(&entries); 378 EXPECT_TRUE(LogContainsEndEvent( 379 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 380 } 381 382 // Calls Disconnect() while a host resolve is in progress. The outstanding host 383 // resolve should be cancelled. 384 TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { 385 scoped_ptr<HangingHostResolverWithCancel> hanging_resolver( 386 new HangingHostResolverWithCancel()); 387 388 // Doesn't matter what the socket data is, we will never use it -- garbage. 389 MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) }; 390 MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) }; 391 392 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 393 data_writes, arraysize(data_writes), 394 hanging_resolver.get(), 395 "foo", 80, 396 NULL); 397 398 // Start connecting (will get stuck waiting for the host to resolve). 399 int rv = user_sock_->Connect(callback_.callback()); 400 EXPECT_EQ(ERR_IO_PENDING, rv); 401 402 EXPECT_FALSE(user_sock_->IsConnected()); 403 EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); 404 405 // The host resolver should have received the resolve request. 406 EXPECT_TRUE(hanging_resolver->HasOutstandingRequest()); 407 408 // Disconnect the SOCKS socket -- this should cancel the outstanding resolve. 409 user_sock_->Disconnect(); 410 411 EXPECT_FALSE(hanging_resolver->HasOutstandingRequest()); 412 413 EXPECT_FALSE(user_sock_->IsConnected()); 414 EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); 415 } 416 417 } // namespace net 418