1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/transport_client_socket_pool_test_util.h" 6 7 #include <string> 8 9 #include "base/logging.h" 10 #include "base/memory/weak_ptr.h" 11 #include "base/run_loop.h" 12 #include "net/base/ip_endpoint.h" 13 #include "net/base/load_timing_info.h" 14 #include "net/base/load_timing_info_test_util.h" 15 #include "net/base/net_util.h" 16 #include "net/socket/client_socket_handle.h" 17 #include "net/socket/ssl_client_socket.h" 18 #include "net/udp/datagram_client_socket.h" 19 #include "testing/gtest/include/gtest/gtest.h" 20 21 namespace net { 22 23 namespace { 24 25 IPAddressNumber ParseIP(const std::string& ip) { 26 IPAddressNumber number; 27 CHECK(ParseIPLiteralToNumber(ip, &number)); 28 return number; 29 } 30 31 // A StreamSocket which connects synchronously and successfully. 32 class MockConnectClientSocket : public StreamSocket { 33 public: 34 MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log) 35 : connected_(false), 36 addrlist_(addrlist), 37 net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), 38 use_tcp_fastopen_(false) {} 39 40 // StreamSocket implementation. 41 virtual int Connect(const CompletionCallback& callback) OVERRIDE { 42 connected_ = true; 43 return OK; 44 } 45 virtual void Disconnect() OVERRIDE { connected_ = false; } 46 virtual bool IsConnected() const OVERRIDE { return connected_; } 47 virtual bool IsConnectedAndIdle() const OVERRIDE { return connected_; } 48 49 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { 50 *address = addrlist_.front(); 51 return OK; 52 } 53 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { 54 if (!connected_) 55 return ERR_SOCKET_NOT_CONNECTED; 56 if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) 57 SetIPv4Address(address); 58 else 59 SetIPv6Address(address); 60 return OK; 61 } 62 virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } 63 64 virtual void SetSubresourceSpeculation() OVERRIDE {} 65 virtual void SetOmniboxSpeculation() OVERRIDE {} 66 virtual bool WasEverUsed() const OVERRIDE { return false; } 67 virtual void EnableTCPFastOpenIfSupported() OVERRIDE { 68 use_tcp_fastopen_ = true; 69 } 70 virtual bool UsingTCPFastOpen() const OVERRIDE { return use_tcp_fastopen_; } 71 virtual bool WasNpnNegotiated() const OVERRIDE { return false; } 72 virtual NextProto GetNegotiatedProtocol() const OVERRIDE { 73 return kProtoUnknown; 74 } 75 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } 76 77 // Socket implementation. 78 virtual int Read(IOBuffer* buf, 79 int buf_len, 80 const CompletionCallback& callback) OVERRIDE { 81 return ERR_FAILED; 82 } 83 virtual int Write(IOBuffer* buf, 84 int buf_len, 85 const CompletionCallback& callback) OVERRIDE { 86 return ERR_FAILED; 87 } 88 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } 89 virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } 90 91 private: 92 bool connected_; 93 const AddressList addrlist_; 94 BoundNetLog net_log_; 95 bool use_tcp_fastopen_; 96 97 DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket); 98 }; 99 100 class MockFailingClientSocket : public StreamSocket { 101 public: 102 MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log) 103 : addrlist_(addrlist), 104 net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), 105 use_tcp_fastopen_(false) {} 106 107 // StreamSocket implementation. 108 virtual int Connect(const CompletionCallback& callback) OVERRIDE { 109 return ERR_CONNECTION_FAILED; 110 } 111 112 virtual void Disconnect() OVERRIDE {} 113 114 virtual bool IsConnected() const OVERRIDE { return false; } 115 virtual bool IsConnectedAndIdle() const OVERRIDE { return false; } 116 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { 117 return ERR_UNEXPECTED; 118 } 119 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { 120 return ERR_UNEXPECTED; 121 } 122 virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } 123 124 virtual void SetSubresourceSpeculation() OVERRIDE {} 125 virtual void SetOmniboxSpeculation() OVERRIDE {} 126 virtual bool WasEverUsed() const OVERRIDE { return false; } 127 virtual void EnableTCPFastOpenIfSupported() OVERRIDE { 128 use_tcp_fastopen_ = true; 129 } 130 virtual bool UsingTCPFastOpen() const OVERRIDE { return use_tcp_fastopen_; } 131 virtual bool WasNpnNegotiated() const OVERRIDE { return false; } 132 virtual NextProto GetNegotiatedProtocol() const OVERRIDE { 133 return kProtoUnknown; 134 } 135 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } 136 137 // Socket implementation. 138 virtual int Read(IOBuffer* buf, 139 int buf_len, 140 const CompletionCallback& callback) OVERRIDE { 141 return ERR_FAILED; 142 } 143 144 virtual int Write(IOBuffer* buf, 145 int buf_len, 146 const CompletionCallback& callback) OVERRIDE { 147 return ERR_FAILED; 148 } 149 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } 150 virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } 151 152 private: 153 const AddressList addrlist_; 154 BoundNetLog net_log_; 155 bool use_tcp_fastopen_; 156 157 DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket); 158 }; 159 160 class MockTriggerableClientSocket : public StreamSocket { 161 public: 162 // |should_connect| indicates whether the socket should successfully complete 163 // or fail. 164 MockTriggerableClientSocket(const AddressList& addrlist, 165 bool should_connect, 166 net::NetLog* net_log) 167 : should_connect_(should_connect), 168 is_connected_(false), 169 addrlist_(addrlist), 170 net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), 171 use_tcp_fastopen_(false), 172 weak_factory_(this) {} 173 174 // Call this method to get a closure which will trigger the connect callback 175 // when called. The closure can be called even after the socket is deleted; it 176 // will safely do nothing. 177 base::Closure GetConnectCallback() { 178 return base::Bind(&MockTriggerableClientSocket::DoCallback, 179 weak_factory_.GetWeakPtr()); 180 } 181 182 static scoped_ptr<StreamSocket> MakeMockPendingClientSocket( 183 const AddressList& addrlist, 184 bool should_connect, 185 net::NetLog* net_log) { 186 scoped_ptr<MockTriggerableClientSocket> socket( 187 new MockTriggerableClientSocket(addrlist, should_connect, net_log)); 188 base::MessageLoop::current()->PostTask(FROM_HERE, 189 socket->GetConnectCallback()); 190 return socket.PassAs<StreamSocket>(); 191 } 192 193 static scoped_ptr<StreamSocket> MakeMockDelayedClientSocket( 194 const AddressList& addrlist, 195 bool should_connect, 196 const base::TimeDelta& delay, 197 net::NetLog* net_log) { 198 scoped_ptr<MockTriggerableClientSocket> socket( 199 new MockTriggerableClientSocket(addrlist, should_connect, net_log)); 200 base::MessageLoop::current()->PostDelayedTask( 201 FROM_HERE, socket->GetConnectCallback(), delay); 202 return socket.PassAs<StreamSocket>(); 203 } 204 205 static scoped_ptr<StreamSocket> MakeMockStalledClientSocket( 206 const AddressList& addrlist, 207 net::NetLog* net_log) { 208 scoped_ptr<MockTriggerableClientSocket> socket( 209 new MockTriggerableClientSocket(addrlist, true, net_log)); 210 return socket.PassAs<StreamSocket>(); 211 } 212 213 // StreamSocket implementation. 214 virtual int Connect(const CompletionCallback& callback) OVERRIDE { 215 DCHECK(callback_.is_null()); 216 callback_ = callback; 217 return ERR_IO_PENDING; 218 } 219 220 virtual void Disconnect() OVERRIDE {} 221 222 virtual bool IsConnected() const OVERRIDE { return is_connected_; } 223 virtual bool IsConnectedAndIdle() const OVERRIDE { return is_connected_; } 224 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { 225 *address = addrlist_.front(); 226 return OK; 227 } 228 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { 229 if (!is_connected_) 230 return ERR_SOCKET_NOT_CONNECTED; 231 if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) 232 SetIPv4Address(address); 233 else 234 SetIPv6Address(address); 235 return OK; 236 } 237 virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } 238 239 virtual void SetSubresourceSpeculation() OVERRIDE {} 240 virtual void SetOmniboxSpeculation() OVERRIDE {} 241 virtual bool WasEverUsed() const OVERRIDE { return false; } 242 virtual void EnableTCPFastOpenIfSupported() OVERRIDE { 243 use_tcp_fastopen_ = true; 244 } 245 virtual bool UsingTCPFastOpen() const OVERRIDE { return use_tcp_fastopen_; } 246 virtual bool WasNpnNegotiated() const OVERRIDE { return false; } 247 virtual NextProto GetNegotiatedProtocol() const OVERRIDE { 248 return kProtoUnknown; 249 } 250 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } 251 252 // Socket implementation. 253 virtual int Read(IOBuffer* buf, 254 int buf_len, 255 const CompletionCallback& callback) OVERRIDE { 256 return ERR_FAILED; 257 } 258 259 virtual int Write(IOBuffer* buf, 260 int buf_len, 261 const CompletionCallback& callback) OVERRIDE { 262 return ERR_FAILED; 263 } 264 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } 265 virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } 266 267 private: 268 void DoCallback() { 269 is_connected_ = should_connect_; 270 callback_.Run(is_connected_ ? OK : ERR_CONNECTION_FAILED); 271 } 272 273 bool should_connect_; 274 bool is_connected_; 275 const AddressList addrlist_; 276 BoundNetLog net_log_; 277 CompletionCallback callback_; 278 bool use_tcp_fastopen_; 279 280 base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_; 281 282 DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket); 283 }; 284 285 } // namespace 286 287 void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) { 288 LoadTimingInfo load_timing_info; 289 // Only pass true in as |is_reused|, as in general, HttpStream types should 290 // have stricter concepts of reuse than socket pools. 291 EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info)); 292 293 EXPECT_TRUE(load_timing_info.socket_reused); 294 EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); 295 296 ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing); 297 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); 298 } 299 300 void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) { 301 EXPECT_FALSE(handle.is_reused()); 302 303 LoadTimingInfo load_timing_info; 304 EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); 305 306 EXPECT_FALSE(load_timing_info.socket_reused); 307 EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); 308 309 ExpectConnectTimingHasTimes(load_timing_info.connect_timing, 310 CONNECT_TIMING_HAS_DNS_TIMES); 311 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); 312 313 TestLoadTimingInfoConnectedReused(handle); 314 } 315 316 void SetIPv4Address(IPEndPoint* address) { 317 *address = IPEndPoint(ParseIP("1.1.1.1"), 80); 318 } 319 320 void SetIPv6Address(IPEndPoint* address) { 321 *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80); 322 } 323 324 MockTransportClientSocketFactory::MockTransportClientSocketFactory( 325 NetLog* net_log) 326 : net_log_(net_log), 327 allocation_count_(0), 328 client_socket_type_(MOCK_CLIENT_SOCKET), 329 client_socket_types_(NULL), 330 client_socket_index_(0), 331 client_socket_index_max_(0), 332 delay_(base::TimeDelta::FromMilliseconds( 333 ClientSocketPool::kMaxConnectRetryIntervalMs)) {} 334 335 MockTransportClientSocketFactory::~MockTransportClientSocketFactory() {} 336 337 scoped_ptr<DatagramClientSocket> 338 MockTransportClientSocketFactory::CreateDatagramClientSocket( 339 DatagramSocket::BindType bind_type, 340 const RandIntCallback& rand_int_cb, 341 NetLog* net_log, 342 const NetLog::Source& source) { 343 NOTREACHED(); 344 return scoped_ptr<DatagramClientSocket>(); 345 } 346 347 scoped_ptr<StreamSocket> 348 MockTransportClientSocketFactory::CreateTransportClientSocket( 349 const AddressList& addresses, 350 NetLog* /* net_log */, 351 const NetLog::Source& /* source */) { 352 allocation_count_++; 353 354 ClientSocketType type = client_socket_type_; 355 if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) { 356 type = client_socket_types_[client_socket_index_++]; 357 } 358 359 switch (type) { 360 case MOCK_CLIENT_SOCKET: 361 return scoped_ptr<StreamSocket>( 362 new MockConnectClientSocket(addresses, net_log_)); 363 case MOCK_FAILING_CLIENT_SOCKET: 364 return scoped_ptr<StreamSocket>( 365 new MockFailingClientSocket(addresses, net_log_)); 366 case MOCK_PENDING_CLIENT_SOCKET: 367 return MockTriggerableClientSocket::MakeMockPendingClientSocket( 368 addresses, true, net_log_); 369 case MOCK_PENDING_FAILING_CLIENT_SOCKET: 370 return MockTriggerableClientSocket::MakeMockPendingClientSocket( 371 addresses, false, net_log_); 372 case MOCK_DELAYED_CLIENT_SOCKET: 373 return MockTriggerableClientSocket::MakeMockDelayedClientSocket( 374 addresses, true, delay_, net_log_); 375 case MOCK_DELAYED_FAILING_CLIENT_SOCKET: 376 return MockTriggerableClientSocket::MakeMockDelayedClientSocket( 377 addresses, false, delay_, net_log_); 378 case MOCK_STALLED_CLIENT_SOCKET: 379 return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses, 380 net_log_); 381 case MOCK_TRIGGERABLE_CLIENT_SOCKET: { 382 scoped_ptr<MockTriggerableClientSocket> rv( 383 new MockTriggerableClientSocket(addresses, true, net_log_)); 384 triggerable_sockets_.push(rv->GetConnectCallback()); 385 // run_loop_quit_closure_ behaves like a condition variable. It will 386 // wake up WaitForTriggerableSocketCreation() if it is sleeping. We 387 // don't need to worry about atomicity because this code is 388 // single-threaded. 389 if (!run_loop_quit_closure_.is_null()) 390 run_loop_quit_closure_.Run(); 391 return rv.PassAs<StreamSocket>(); 392 } 393 default: 394 NOTREACHED(); 395 return scoped_ptr<StreamSocket>( 396 new MockConnectClientSocket(addresses, net_log_)); 397 } 398 } 399 400 scoped_ptr<SSLClientSocket> 401 MockTransportClientSocketFactory::CreateSSLClientSocket( 402 scoped_ptr<ClientSocketHandle> transport_socket, 403 const HostPortPair& host_and_port, 404 const SSLConfig& ssl_config, 405 const SSLClientSocketContext& context) { 406 NOTIMPLEMENTED(); 407 return scoped_ptr<SSLClientSocket>(); 408 } 409 410 void MockTransportClientSocketFactory::ClearSSLSessionCache() { 411 NOTIMPLEMENTED(); 412 } 413 414 void MockTransportClientSocketFactory::set_client_socket_types( 415 ClientSocketType* type_list, 416 int num_types) { 417 DCHECK_GT(num_types, 0); 418 client_socket_types_ = type_list; 419 client_socket_index_ = 0; 420 client_socket_index_max_ = num_types; 421 } 422 423 base::Closure 424 MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() { 425 while (triggerable_sockets_.empty()) { 426 base::RunLoop run_loop; 427 run_loop_quit_closure_ = run_loop.QuitClosure(); 428 run_loop.Run(); 429 run_loop_quit_closure_.Reset(); 430 } 431 base::Closure trigger = triggerable_sockets_.front(); 432 triggerable_sockets_.pop(); 433 return trigger; 434 } 435 436 } // namespace net 437