1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ 6 #define NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ 7 #pragma once 8 9 #include <string> 10 11 #include "base/basictypes.h" 12 #include "base/memory/ref_counted.h" 13 #include "base/memory/scoped_ptr.h" 14 #include "base/time.h" 15 #include "base/timer.h" 16 #include "net/base/host_port_pair.h" 17 #include "net/base/host_resolver.h" 18 #include "net/socket/client_socket_pool_base.h" 19 #include "net/socket/client_socket_pool_histograms.h" 20 #include "net/socket/client_socket_pool.h" 21 22 namespace net { 23 24 class ClientSocketFactory; 25 26 class TransportSocketParams : public base::RefCounted<TransportSocketParams> { 27 public: 28 TransportSocketParams(const HostPortPair& host_port_pair, 29 RequestPriority priority, 30 const GURL& referrer, 31 bool disable_resolver_cache, 32 bool ignore_limits); 33 34 const HostResolver::RequestInfo& destination() const { return destination_; } 35 bool ignore_limits() const { return ignore_limits_; } 36 #ifdef ANDROID 37 // Gets the UID of the calling process 38 bool getUID(uid_t *uid) const; 39 void setUID(uid_t uid); 40 #endif 41 42 private: 43 friend class base::RefCounted<TransportSocketParams>; 44 ~TransportSocketParams(); 45 46 void Initialize(RequestPriority priority, const GURL& referrer, 47 bool disable_resolver_cache); 48 49 HostResolver::RequestInfo destination_; 50 bool ignore_limits_; 51 #ifdef ANDROID 52 // Gets the UID of the calling process 53 bool valid_uid_; 54 int calling_uid_; 55 #endif 56 DISALLOW_COPY_AND_ASSIGN(TransportSocketParams); 57 }; 58 59 // TransportConnectJob handles the host resolution necessary for socket creation 60 // and the transport (likely TCP) connect. TransportConnectJob also has fallback 61 // logic for IPv6 connect() timeouts (which may happen due to networks / routers 62 // with broken IPv6 support). Those timeouts take 20s, so rather than make the 63 // user wait 20s for the timeout to fire, we use a fallback timer 64 // (kIPv6FallbackTimerInMs) and start a connect() to a IPv4 address if the timer 65 // fires. Then we race the IPv4 connect() against the IPv6 connect() (which has 66 // a headstart) and return the one that completes first to the socket pool. 67 class TransportConnectJob : public ConnectJob { 68 public: 69 TransportConnectJob(const std::string& group_name, 70 const scoped_refptr<TransportSocketParams>& params, 71 base::TimeDelta timeout_duration, 72 ClientSocketFactory* client_socket_factory, 73 HostResolver* host_resolver, 74 Delegate* delegate, 75 NetLog* net_log); 76 virtual ~TransportConnectJob(); 77 78 // ConnectJob methods. 79 virtual LoadState GetLoadState() const; 80 81 // Makes |addrlist| start with an IPv4 address if |addrlist| contains any 82 // IPv4 address. 83 // 84 // WARNING: this method should only be used to implement the prefer-IPv4 85 // hack. It is a public method for the unit tests. 86 static void MakeAddrListStartWithIPv4(AddressList* addrlist); 87 88 static const int kIPv6FallbackTimerInMs; 89 90 private: 91 enum State { 92 STATE_RESOLVE_HOST, 93 STATE_RESOLVE_HOST_COMPLETE, 94 STATE_TRANSPORT_CONNECT, 95 STATE_TRANSPORT_CONNECT_COMPLETE, 96 STATE_NONE, 97 }; 98 99 void OnIOComplete(int result); 100 101 // Runs the state transition loop. 102 int DoLoop(int result); 103 104 int DoResolveHost(); 105 int DoResolveHostComplete(int result); 106 int DoTransportConnect(); 107 int DoTransportConnectComplete(int result); 108 109 // Not part of the state machine. 110 void DoIPv6FallbackTransportConnect(); 111 void DoIPv6FallbackTransportConnectComplete(int result); 112 113 // Begins the host resolution and the TCP connect. Returns OK on success 114 // and ERR_IO_PENDING if it cannot immediately service the request. 115 // Otherwise, it returns a net error code. 116 virtual int ConnectInternal(); 117 118 scoped_refptr<TransportSocketParams> params_; 119 ClientSocketFactory* const client_socket_factory_; 120 CompletionCallbackImpl<TransportConnectJob> callback_; 121 SingleRequestHostResolver resolver_; 122 AddressList addresses_; 123 State next_state_; 124 125 // The time Connect() was called. 126 base::TimeTicks start_time_; 127 128 // The time the connect was started (after DNS finished). 129 base::TimeTicks connect_start_time_; 130 131 scoped_ptr<ClientSocket> transport_socket_; 132 133 scoped_ptr<ClientSocket> fallback_transport_socket_; 134 scoped_ptr<AddressList> fallback_addresses_; 135 CompletionCallbackImpl<TransportConnectJob> fallback_callback_; 136 base::TimeTicks fallback_connect_start_time_; 137 base::OneShotTimer<TransportConnectJob> fallback_timer_; 138 139 DISALLOW_COPY_AND_ASSIGN(TransportConnectJob); 140 }; 141 142 class TransportClientSocketPool : public ClientSocketPool { 143 public: 144 TransportClientSocketPool( 145 int max_sockets, 146 int max_sockets_per_group, 147 ClientSocketPoolHistograms* histograms, 148 HostResolver* host_resolver, 149 ClientSocketFactory* client_socket_factory, 150 NetLog* net_log); 151 152 virtual ~TransportClientSocketPool(); 153 154 // ClientSocketPool methods: 155 156 virtual int RequestSocket(const std::string& group_name, 157 const void* resolve_info, 158 RequestPriority priority, 159 ClientSocketHandle* handle, 160 CompletionCallback* callback, 161 const BoundNetLog& net_log); 162 163 virtual void RequestSockets(const std::string& group_name, 164 const void* params, 165 int num_sockets, 166 const BoundNetLog& net_log); 167 168 virtual void CancelRequest(const std::string& group_name, 169 ClientSocketHandle* handle); 170 171 virtual void ReleaseSocket(const std::string& group_name, 172 ClientSocket* socket, 173 int id); 174 175 virtual void Flush(); 176 177 virtual void CloseIdleSockets(); 178 179 virtual int IdleSocketCount() const; 180 181 virtual int IdleSocketCountInGroup(const std::string& group_name) const; 182 183 virtual LoadState GetLoadState(const std::string& group_name, 184 const ClientSocketHandle* handle) const; 185 186 virtual DictionaryValue* GetInfoAsValue(const std::string& name, 187 const std::string& type, 188 bool include_nested_pools) const; 189 190 virtual base::TimeDelta ConnectionTimeout() const; 191 192 virtual ClientSocketPoolHistograms* histograms() const; 193 194 private: 195 typedef ClientSocketPoolBase<TransportSocketParams> PoolBase; 196 197 class TransportConnectJobFactory 198 : public PoolBase::ConnectJobFactory { 199 public: 200 TransportConnectJobFactory(ClientSocketFactory* client_socket_factory, 201 HostResolver* host_resolver, 202 NetLog* net_log) 203 : client_socket_factory_(client_socket_factory), 204 host_resolver_(host_resolver), 205 net_log_(net_log) {} 206 207 virtual ~TransportConnectJobFactory() {} 208 209 // ClientSocketPoolBase::ConnectJobFactory methods. 210 211 virtual ConnectJob* NewConnectJob( 212 const std::string& group_name, 213 const PoolBase::Request& request, 214 ConnectJob::Delegate* delegate) const; 215 216 virtual base::TimeDelta ConnectionTimeout() const; 217 218 private: 219 ClientSocketFactory* const client_socket_factory_; 220 HostResolver* const host_resolver_; 221 NetLog* net_log_; 222 223 DISALLOW_COPY_AND_ASSIGN(TransportConnectJobFactory); 224 }; 225 226 PoolBase base_; 227 228 DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool); 229 }; 230 231 REGISTER_SOCKET_PARAMS_FOR_POOL(TransportClientSocketPool, 232 TransportSocketParams); 233 234 } // namespace net 235 236 #endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ 237