1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks_client_socket.h" 6 7 #include "net/base/address_list.h" 8 #include "net/base/net_log.h" 9 #include "net/base/net_log_unittest.h" 10 #include "net/base/mock_host_resolver.h" 11 #include "net/base/test_completion_callback.h" 12 #include "net/base/winsock_init.h" 13 #include "net/socket/client_socket_factory.h" 14 #include "net/socket/tcp_client_socket.h" 15 #include "net/socket/socket_test_util.h" 16 #include "testing/gtest/include/gtest/gtest.h" 17 #include "testing/platform_test.h" 18 19 //----------------------------------------------------------------------------- 20 21 namespace net { 22 23 const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 }; 24 const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; 25 26 class SOCKSClientSocketTest : public PlatformTest { 27 public: 28 SOCKSClientSocketTest(); 29 // Create a SOCKSClientSocket on top of a MockSocket. 30 SOCKSClientSocket* BuildMockSocket(MockRead reads[], size_t reads_count, 31 MockWrite writes[], size_t writes_count, 32 HostResolver* host_resolver, 33 const std::string& hostname, int port, 34 NetLog* net_log); 35 virtual void SetUp(); 36 37 protected: 38 scoped_ptr<SOCKSClientSocket> user_sock_; 39 AddressList address_list_; 40 ClientSocket* tcp_sock_; 41 TestCompletionCallback callback_; 42 scoped_ptr<MockHostResolver> host_resolver_; 43 scoped_ptr<SocketDataProvider> data_; 44 }; 45 46 SOCKSClientSocketTest::SOCKSClientSocketTest() 47 : host_resolver_(new MockHostResolver) { 48 } 49 50 // Set up platform before every test case 51 void SOCKSClientSocketTest::SetUp() { 52 PlatformTest::SetUp(); 53 } 54 55 SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( 56 MockRead reads[], 57 size_t reads_count, 58 MockWrite writes[], 59 size_t writes_count, 60 HostResolver* host_resolver, 61 const std::string& hostname, 62 int port, 63 NetLog* net_log) { 64 65 TestCompletionCallback callback; 66 data_.reset(new StaticSocketDataProvider(reads, reads_count, 67 writes, writes_count)); 68 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); 69 70 int rv = tcp_sock_->Connect(&callback); 71 EXPECT_EQ(ERR_IO_PENDING, rv); 72 rv = callback.WaitForResult(); 73 EXPECT_EQ(OK, rv); 74 EXPECT_TRUE(tcp_sock_->IsConnected()); 75 76 return new SOCKSClientSocket(tcp_sock_, 77 HostResolver::RequestInfo(HostPortPair(hostname, port)), 78 host_resolver); 79 } 80 81 // Implementation of HostResolver that never completes its resolve request. 82 // We use this in the test "DisconnectWhileHostResolveInProgress" to make 83 // sure that the outstanding resolve request gets cancelled. 84 class HangingHostResolver : public HostResolver { 85 public: 86 HangingHostResolver() : outstanding_request_(NULL) {} 87 88 virtual int Resolve(const RequestInfo& info, 89 AddressList* addresses, 90 CompletionCallback* callback, 91 RequestHandle* out_req, 92 const BoundNetLog& net_log) { 93 EXPECT_FALSE(HasOutstandingRequest()); 94 outstanding_request_ = reinterpret_cast<RequestHandle>(1); 95 *out_req = outstanding_request_; 96 return ERR_IO_PENDING; 97 } 98 99 virtual void CancelRequest(RequestHandle req) { 100 EXPECT_TRUE(HasOutstandingRequest()); 101 EXPECT_EQ(outstanding_request_, req); 102 outstanding_request_ = NULL; 103 } 104 105 virtual void AddObserver(Observer* observer) {} 106 virtual void RemoveObserver(Observer* observer) {} 107 108 bool HasOutstandingRequest() { 109 return outstanding_request_ != NULL; 110 } 111 112 private: 113 RequestHandle outstanding_request_; 114 115 DISALLOW_COPY_AND_ASSIGN(HangingHostResolver); 116 }; 117 118 // Tests a complete handshake and the disconnection. 119 TEST_F(SOCKSClientSocketTest, CompleteHandshake) { 120 const std::string payload_write = "random data"; 121 const std::string payload_read = "moar random data"; 122 123 MockWrite data_writes[] = { 124 MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)), 125 MockWrite(true, payload_write.data(), payload_write.size()) }; 126 MockRead data_reads[] = { 127 MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)), 128 MockRead(true, payload_read.data(), payload_read.size()) }; 129 CapturingNetLog log(CapturingNetLog::kUnbounded); 130 131 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 132 data_writes, arraysize(data_writes), 133 host_resolver_.get(), 134 "localhost", 80, 135 &log)); 136 137 // At this state the TCP connection is completed but not the SOCKS handshake. 138 EXPECT_TRUE(tcp_sock_->IsConnected()); 139 EXPECT_FALSE(user_sock_->IsConnected()); 140 141 int rv = user_sock_->Connect(&callback_); 142 EXPECT_EQ(ERR_IO_PENDING, rv); 143 144 net::CapturingNetLog::EntryList entries; 145 log.GetEntries(&entries); 146 EXPECT_TRUE( 147 LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 148 EXPECT_FALSE(user_sock_->IsConnected()); 149 150 rv = callback_.WaitForResult(); 151 EXPECT_EQ(OK, rv); 152 EXPECT_TRUE(user_sock_->IsConnected()); 153 log.GetEntries(&entries); 154 EXPECT_TRUE(LogContainsEndEvent( 155 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 156 157 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); 158 memcpy(buffer->data(), payload_write.data(), payload_write.size()); 159 rv = user_sock_->Write(buffer, payload_write.size(), &callback_); 160 EXPECT_EQ(ERR_IO_PENDING, rv); 161 rv = callback_.WaitForResult(); 162 EXPECT_EQ(static_cast<int>(payload_write.size()), rv); 163 164 buffer = new IOBuffer(payload_read.size()); 165 rv = user_sock_->Read(buffer, payload_read.size(), &callback_); 166 EXPECT_EQ(ERR_IO_PENDING, rv); 167 rv = callback_.WaitForResult(); 168 EXPECT_EQ(static_cast<int>(payload_read.size()), rv); 169 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); 170 171 user_sock_->Disconnect(); 172 EXPECT_FALSE(tcp_sock_->IsConnected()); 173 EXPECT_FALSE(user_sock_->IsConnected()); 174 } 175 176 // List of responses from the socks server and the errors they should 177 // throw up are tested here. 178 TEST_F(SOCKSClientSocketTest, HandshakeFailures) { 179 const struct { 180 const char fail_reply[8]; 181 Error fail_code; 182 } tests[] = { 183 // Failure of the server response code 184 { 185 { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }, 186 ERR_SOCKS_CONNECTION_FAILED, 187 }, 188 // Failure of the null byte 189 { 190 { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 }, 191 ERR_SOCKS_CONNECTION_FAILED, 192 }, 193 }; 194 195 //--------------------------------------- 196 197 for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) { 198 MockWrite data_writes[] = { 199 MockWrite(false, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; 200 MockRead data_reads[] = { 201 MockRead(false, tests[i].fail_reply, arraysize(tests[i].fail_reply)) }; 202 CapturingNetLog log(CapturingNetLog::kUnbounded); 203 204 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 205 data_writes, arraysize(data_writes), 206 host_resolver_.get(), 207 "localhost", 80, 208 &log)); 209 210 int rv = user_sock_->Connect(&callback_); 211 EXPECT_EQ(ERR_IO_PENDING, rv); 212 213 net::CapturingNetLog::EntryList entries; 214 log.GetEntries(&entries); 215 EXPECT_TRUE(LogContainsBeginEvent( 216 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 217 218 rv = callback_.WaitForResult(); 219 EXPECT_EQ(tests[i].fail_code, rv); 220 EXPECT_FALSE(user_sock_->IsConnected()); 221 EXPECT_TRUE(tcp_sock_->IsConnected()); 222 log.GetEntries(&entries); 223 EXPECT_TRUE(LogContainsEndEvent( 224 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 225 } 226 } 227 228 // Tests scenario when the server sends the handshake response in 229 // more than one packet. 230 TEST_F(SOCKSClientSocketTest, PartialServerReads) { 231 const char kSOCKSPartialReply1[] = { 0x00 }; 232 const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; 233 234 MockWrite data_writes[] = { 235 MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; 236 MockRead data_reads[] = { 237 MockRead(true, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)), 238 MockRead(true, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; 239 CapturingNetLog log(CapturingNetLog::kUnbounded); 240 241 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 242 data_writes, arraysize(data_writes), 243 host_resolver_.get(), 244 "localhost", 80, 245 &log)); 246 247 int rv = user_sock_->Connect(&callback_); 248 EXPECT_EQ(ERR_IO_PENDING, rv); 249 net::CapturingNetLog::EntryList entries; 250 log.GetEntries(&entries); 251 EXPECT_TRUE(LogContainsBeginEvent( 252 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 253 254 rv = callback_.WaitForResult(); 255 EXPECT_EQ(OK, rv); 256 EXPECT_TRUE(user_sock_->IsConnected()); 257 log.GetEntries(&entries); 258 EXPECT_TRUE(LogContainsEndEvent( 259 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 260 } 261 262 // Tests scenario when the client sends the handshake request in 263 // more than one packet. 264 TEST_F(SOCKSClientSocketTest, PartialClientWrites) { 265 const char kSOCKSPartialRequest1[] = { 0x04, 0x01 }; 266 const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 }; 267 268 MockWrite data_writes[] = { 269 MockWrite(true, arraysize(kSOCKSPartialRequest1)), 270 // simulate some empty writes 271 MockWrite(true, 0), 272 MockWrite(true, 0), 273 MockWrite(true, kSOCKSPartialRequest2, 274 arraysize(kSOCKSPartialRequest2)) }; 275 MockRead data_reads[] = { 276 MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; 277 CapturingNetLog log(CapturingNetLog::kUnbounded); 278 279 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 280 data_writes, arraysize(data_writes), 281 host_resolver_.get(), 282 "localhost", 80, 283 &log)); 284 285 int rv = user_sock_->Connect(&callback_); 286 EXPECT_EQ(ERR_IO_PENDING, rv); 287 net::CapturingNetLog::EntryList entries; 288 log.GetEntries(&entries); 289 EXPECT_TRUE(LogContainsBeginEvent( 290 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 291 292 rv = callback_.WaitForResult(); 293 EXPECT_EQ(OK, rv); 294 EXPECT_TRUE(user_sock_->IsConnected()); 295 log.GetEntries(&entries); 296 EXPECT_TRUE(LogContainsEndEvent( 297 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 298 } 299 300 // Tests the case when the server sends a smaller sized handshake data 301 // and closes the connection. 302 TEST_F(SOCKSClientSocketTest, FailedSocketRead) { 303 MockWrite data_writes[] = { 304 MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; 305 MockRead data_reads[] = { 306 MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2), 307 // close connection unexpectedly 308 MockRead(false, 0) }; 309 CapturingNetLog log(CapturingNetLog::kUnbounded); 310 311 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 312 data_writes, arraysize(data_writes), 313 host_resolver_.get(), 314 "localhost", 80, 315 &log)); 316 317 int rv = user_sock_->Connect(&callback_); 318 EXPECT_EQ(ERR_IO_PENDING, rv); 319 net::CapturingNetLog::EntryList entries; 320 log.GetEntries(&entries); 321 EXPECT_TRUE(LogContainsBeginEvent( 322 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 323 324 rv = callback_.WaitForResult(); 325 EXPECT_EQ(ERR_CONNECTION_CLOSED, rv); 326 EXPECT_FALSE(user_sock_->IsConnected()); 327 log.GetEntries(&entries); 328 EXPECT_TRUE(LogContainsEndEvent( 329 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 330 } 331 332 // Tries to connect to an unknown hostname. Should fail rather than 333 // falling back to SOCKS4a. 334 TEST_F(SOCKSClientSocketTest, FailedDNS) { 335 const char hostname[] = "unresolved.ipv4.address"; 336 337 host_resolver_->rules()->AddSimulatedFailure(hostname); 338 339 CapturingNetLog log(CapturingNetLog::kUnbounded); 340 341 user_sock_.reset(BuildMockSocket(NULL, 0, 342 NULL, 0, 343 host_resolver_.get(), 344 hostname, 80, 345 &log)); 346 347 int rv = user_sock_->Connect(&callback_); 348 EXPECT_EQ(ERR_IO_PENDING, rv); 349 net::CapturingNetLog::EntryList entries; 350 log.GetEntries(&entries); 351 EXPECT_TRUE(LogContainsBeginEvent( 352 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 353 354 rv = callback_.WaitForResult(); 355 EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv); 356 EXPECT_FALSE(user_sock_->IsConnected()); 357 log.GetEntries(&entries); 358 EXPECT_TRUE(LogContainsEndEvent( 359 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 360 } 361 362 // Calls Disconnect() while a host resolve is in progress. The outstanding host 363 // resolve should be cancelled. 364 TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { 365 scoped_ptr<HangingHostResolver> hanging_resolver(new HangingHostResolver()); 366 367 // Doesn't matter what the socket data is, we will never use it -- garbage. 368 MockWrite data_writes[] = { MockWrite(false, "", 0) }; 369 MockRead data_reads[] = { MockRead(false, "", 0) }; 370 371 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 372 data_writes, arraysize(data_writes), 373 hanging_resolver.get(), 374 "foo", 80, 375 NULL)); 376 377 // Start connecting (will get stuck waiting for the host to resolve). 378 int rv = user_sock_->Connect(&callback_); 379 EXPECT_EQ(ERR_IO_PENDING, rv); 380 381 EXPECT_FALSE(user_sock_->IsConnected()); 382 EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); 383 384 // The host resolver should have received the resolve request. 385 EXPECT_TRUE(hanging_resolver->HasOutstandingRequest()); 386 387 // Disconnect the SOCKS socket -- this should cancel the outstanding resolve. 388 user_sock_->Disconnect(); 389 390 EXPECT_FALSE(hanging_resolver->HasOutstandingRequest()); 391 392 EXPECT_FALSE(user_sock_->IsConnected()); 393 EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); 394 } 395 396 } // namespace net 397