Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #ifndef NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
      6 #define NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
      7 
      8 #include <string>
      9 
     10 #include "base/basictypes.h"
     11 #include "base/memory/ref_counted.h"
     12 #include "base/memory/scoped_ptr.h"
     13 #include "base/time/time.h"
     14 #include "base/timer/timer.h"
     15 #include "net/base/host_port_pair.h"
     16 #include "net/dns/host_resolver.h"
     17 #include "net/dns/single_request_host_resolver.h"
     18 #include "net/socket/client_socket_pool.h"
     19 #include "net/socket/client_socket_pool_base.h"
     20 #include "net/socket/client_socket_pool_histograms.h"
     21 
     22 namespace net {
     23 
     24 class ClientSocketFactory;
     25 
     26 typedef base::Callback<int(const AddressList&, const BoundNetLog& net_log)>
     27 OnHostResolutionCallback;
     28 
     29 class NET_EXPORT_PRIVATE TransportSocketParams
     30     : public base::RefCounted<TransportSocketParams> {
     31  public:
     32   // |host_resolution_callback| will be invoked after the the hostname is
     33   // resolved.  If |host_resolution_callback| does not return OK, then the
     34   // connection will be aborted with that value.
     35   TransportSocketParams(
     36       const HostPortPair& host_port_pair,
     37       RequestPriority priority,
     38       bool disable_resolver_cache,
     39       bool ignore_limits,
     40       const OnHostResolutionCallback& host_resolution_callback);
     41 
     42   const HostResolver::RequestInfo& destination() const { return destination_; }
     43   bool ignore_limits() const { return ignore_limits_; }
     44   const OnHostResolutionCallback& host_resolution_callback() const {
     45     return host_resolution_callback_;
     46   }
     47 
     48  private:
     49   friend class base::RefCounted<TransportSocketParams>;
     50   ~TransportSocketParams();
     51 
     52   void Initialize(RequestPriority priority, bool disable_resolver_cache);
     53 
     54   HostResolver::RequestInfo destination_;
     55   bool ignore_limits_;
     56   const OnHostResolutionCallback host_resolution_callback_;
     57 
     58   DISALLOW_COPY_AND_ASSIGN(TransportSocketParams);
     59 };
     60 
     61 // TransportConnectJob handles the host resolution necessary for socket creation
     62 // and the transport (likely TCP) connect. TransportConnectJob also has fallback
     63 // logic for IPv6 connect() timeouts (which may happen due to networks / routers
     64 // with broken IPv6 support). Those timeouts take 20s, so rather than make the
     65 // user wait 20s for the timeout to fire, we use a fallback timer
     66 // (kIPv6FallbackTimerInMs) and start a connect() to a IPv4 address if the timer
     67 // fires. Then we race the IPv4 connect() against the IPv6 connect() (which has
     68 // a headstart) and return the one that completes first to the socket pool.
     69 class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob {
     70  public:
     71   TransportConnectJob(const std::string& group_name,
     72                       const scoped_refptr<TransportSocketParams>& params,
     73                       base::TimeDelta timeout_duration,
     74                       ClientSocketFactory* client_socket_factory,
     75                       HostResolver* host_resolver,
     76                       Delegate* delegate,
     77                       NetLog* net_log);
     78   virtual ~TransportConnectJob();
     79 
     80   // ConnectJob methods.
     81   virtual LoadState GetLoadState() const OVERRIDE;
     82 
     83   // Rolls |addrlist| forward until the first IPv4 address, if any.
     84   // WARNING: this method should only be used to implement the prefer-IPv4 hack.
     85   static void MakeAddressListStartWithIPv4(AddressList* addrlist);
     86 
     87   static const int kIPv6FallbackTimerInMs;
     88 
     89  private:
     90   enum State {
     91     STATE_RESOLVE_HOST,
     92     STATE_RESOLVE_HOST_COMPLETE,
     93     STATE_TRANSPORT_CONNECT,
     94     STATE_TRANSPORT_CONNECT_COMPLETE,
     95     STATE_NONE,
     96   };
     97 
     98   void OnIOComplete(int result);
     99 
    100   // Runs the state transition loop.
    101   int DoLoop(int result);
    102 
    103   int DoResolveHost();
    104   int DoResolveHostComplete(int result);
    105   int DoTransportConnect();
    106   int DoTransportConnectComplete(int result);
    107 
    108   // Not part of the state machine.
    109   void DoIPv6FallbackTransportConnect();
    110   void DoIPv6FallbackTransportConnectComplete(int result);
    111 
    112   // Begins the host resolution and the TCP connect.  Returns OK on success
    113   // and ERR_IO_PENDING if it cannot immediately service the request.
    114   // Otherwise, it returns a net error code.
    115   virtual int ConnectInternal() OVERRIDE;
    116 
    117   scoped_refptr<TransportSocketParams> params_;
    118   ClientSocketFactory* const client_socket_factory_;
    119   SingleRequestHostResolver resolver_;
    120   AddressList addresses_;
    121   State next_state_;
    122 
    123   scoped_ptr<StreamSocket> transport_socket_;
    124 
    125   scoped_ptr<StreamSocket> fallback_transport_socket_;
    126   scoped_ptr<AddressList> fallback_addresses_;
    127   base::TimeTicks fallback_connect_start_time_;
    128   base::OneShotTimer<TransportConnectJob> fallback_timer_;
    129 
    130   DISALLOW_COPY_AND_ASSIGN(TransportConnectJob);
    131 };
    132 
    133 class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
    134  public:
    135   TransportClientSocketPool(
    136       int max_sockets,
    137       int max_sockets_per_group,
    138       ClientSocketPoolHistograms* histograms,
    139       HostResolver* host_resolver,
    140       ClientSocketFactory* client_socket_factory,
    141       NetLog* net_log);
    142 
    143   virtual ~TransportClientSocketPool();
    144 
    145   // ClientSocketPool implementation.
    146   virtual int RequestSocket(const std::string& group_name,
    147                             const void* resolve_info,
    148                             RequestPriority priority,
    149                             ClientSocketHandle* handle,
    150                             const CompletionCallback& callback,
    151                             const BoundNetLog& net_log) OVERRIDE;
    152   virtual void RequestSockets(const std::string& group_name,
    153                               const void* params,
    154                               int num_sockets,
    155                               const BoundNetLog& net_log) OVERRIDE;
    156   virtual void CancelRequest(const std::string& group_name,
    157                              ClientSocketHandle* handle) OVERRIDE;
    158   virtual void ReleaseSocket(const std::string& group_name,
    159                              StreamSocket* socket,
    160                              int id) OVERRIDE;
    161   virtual void FlushWithError(int error) OVERRIDE;
    162   virtual bool IsStalled() const OVERRIDE;
    163   virtual void CloseIdleSockets() OVERRIDE;
    164   virtual int IdleSocketCount() const OVERRIDE;
    165   virtual int IdleSocketCountInGroup(
    166       const std::string& group_name) const OVERRIDE;
    167   virtual LoadState GetLoadState(
    168       const std::string& group_name,
    169       const ClientSocketHandle* handle) const OVERRIDE;
    170   virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE;
    171   virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE;
    172   virtual base::DictionaryValue* GetInfoAsValue(
    173       const std::string& name,
    174       const std::string& type,
    175       bool include_nested_pools) const OVERRIDE;
    176   virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
    177   virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
    178 
    179  private:
    180   typedef ClientSocketPoolBase<TransportSocketParams> PoolBase;
    181 
    182   class TransportConnectJobFactory
    183       : public PoolBase::ConnectJobFactory {
    184    public:
    185     TransportConnectJobFactory(ClientSocketFactory* client_socket_factory,
    186                          HostResolver* host_resolver,
    187                          NetLog* net_log)
    188         : client_socket_factory_(client_socket_factory),
    189           host_resolver_(host_resolver),
    190           net_log_(net_log) {}
    191 
    192     virtual ~TransportConnectJobFactory() {}
    193 
    194     // ClientSocketPoolBase::ConnectJobFactory methods.
    195 
    196     virtual ConnectJob* NewConnectJob(
    197         const std::string& group_name,
    198         const PoolBase::Request& request,
    199         ConnectJob::Delegate* delegate) const OVERRIDE;
    200 
    201     virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
    202 
    203    private:
    204     ClientSocketFactory* const client_socket_factory_;
    205     HostResolver* const host_resolver_;
    206     NetLog* net_log_;
    207 
    208     DISALLOW_COPY_AND_ASSIGN(TransportConnectJobFactory);
    209   };
    210 
    211   PoolBase base_;
    212 
    213   DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool);
    214 };
    215 
    216 REGISTER_SOCKET_PARAMS_FOR_POOL(TransportClientSocketPool,
    217                                 TransportSocketParams);
    218 
    219 }  // namespace net
    220 
    221 #endif  // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
    222