Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket_pool.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/bind_helpers.h"
      9 #include "base/time/time.h"
     10 #include "base/values.h"
     11 #include "net/base/net_errors.h"
     12 #include "net/socket/client_socket_factory.h"
     13 #include "net/socket/client_socket_handle.h"
     14 #include "net/socket/client_socket_pool_base.h"
     15 #include "net/socket/socks5_client_socket.h"
     16 #include "net/socket/socks_client_socket.h"
     17 #include "net/socket/transport_client_socket_pool.h"
     18 
     19 namespace net {
     20 
     21 SOCKSSocketParams::SOCKSSocketParams(
     22     const scoped_refptr<TransportSocketParams>& proxy_server,
     23     bool socks_v5,
     24     const HostPortPair& host_port_pair,
     25     RequestPriority priority)
     26     : transport_params_(proxy_server),
     27       destination_(host_port_pair),
     28       socks_v5_(socks_v5) {
     29   if (transport_params_.get())
     30     ignore_limits_ = transport_params_->ignore_limits();
     31   else
     32     ignore_limits_ = false;
     33   destination_.set_priority(priority);
     34 }
     35 
     36 SOCKSSocketParams::~SOCKSSocketParams() {}
     37 
     38 // SOCKSConnectJobs will time out after this many seconds.  Note this is on
     39 // top of the timeout for the transport socket.
     40 static const int kSOCKSConnectJobTimeoutInSeconds = 30;
     41 
     42 SOCKSConnectJob::SOCKSConnectJob(
     43     const std::string& group_name,
     44     const scoped_refptr<SOCKSSocketParams>& socks_params,
     45     const base::TimeDelta& timeout_duration,
     46     TransportClientSocketPool* transport_pool,
     47     HostResolver* host_resolver,
     48     Delegate* delegate,
     49     NetLog* net_log)
     50     : ConnectJob(group_name, timeout_duration, delegate,
     51                  BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
     52       socks_params_(socks_params),
     53       transport_pool_(transport_pool),
     54       resolver_(host_resolver),
     55       callback_(base::Bind(&SOCKSConnectJob::OnIOComplete,
     56                            base::Unretained(this))) {
     57 }
     58 
     59 SOCKSConnectJob::~SOCKSConnectJob() {
     60   // We don't worry about cancelling the tcp socket since the destructor in
     61   // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of
     62   // it.
     63 }
     64 
     65 LoadState SOCKSConnectJob::GetLoadState() const {
     66   switch (next_state_) {
     67     case STATE_TRANSPORT_CONNECT:
     68     case STATE_TRANSPORT_CONNECT_COMPLETE:
     69       return transport_socket_handle_->GetLoadState();
     70     case STATE_SOCKS_CONNECT:
     71     case STATE_SOCKS_CONNECT_COMPLETE:
     72       return LOAD_STATE_CONNECTING;
     73     default:
     74       NOTREACHED();
     75       return LOAD_STATE_IDLE;
     76   }
     77 }
     78 
     79 void SOCKSConnectJob::OnIOComplete(int result) {
     80   int rv = DoLoop(result);
     81   if (rv != ERR_IO_PENDING)
     82     NotifyDelegateOfCompletion(rv);  // Deletes |this|
     83 }
     84 
     85 int SOCKSConnectJob::DoLoop(int result) {
     86   DCHECK_NE(next_state_, STATE_NONE);
     87 
     88   int rv = result;
     89   do {
     90     State state = next_state_;
     91     next_state_ = STATE_NONE;
     92     switch (state) {
     93       case STATE_TRANSPORT_CONNECT:
     94         DCHECK_EQ(OK, rv);
     95         rv = DoTransportConnect();
     96         break;
     97       case STATE_TRANSPORT_CONNECT_COMPLETE:
     98         rv = DoTransportConnectComplete(rv);
     99         break;
    100       case STATE_SOCKS_CONNECT:
    101         DCHECK_EQ(OK, rv);
    102         rv = DoSOCKSConnect();
    103         break;
    104       case STATE_SOCKS_CONNECT_COMPLETE:
    105         rv = DoSOCKSConnectComplete(rv);
    106         break;
    107       default:
    108         NOTREACHED() << "bad state";
    109         rv = ERR_FAILED;
    110         break;
    111     }
    112   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    113 
    114   return rv;
    115 }
    116 
    117 int SOCKSConnectJob::DoTransportConnect() {
    118   next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
    119   transport_socket_handle_.reset(new ClientSocketHandle());
    120   return transport_socket_handle_->Init(
    121       group_name(), socks_params_->transport_params(),
    122       socks_params_->destination().priority(), callback_, transport_pool_,
    123       net_log());
    124 }
    125 
    126 int SOCKSConnectJob::DoTransportConnectComplete(int result) {
    127   if (result != OK)
    128     return ERR_PROXY_CONNECTION_FAILED;
    129 
    130   // Reset the timer to just the length of time allowed for SOCKS handshake
    131   // so that a fast TCP connection plus a slow SOCKS failure doesn't take
    132   // longer to timeout than it should.
    133   ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds));
    134   next_state_ = STATE_SOCKS_CONNECT;
    135   return result;
    136 }
    137 
    138 int SOCKSConnectJob::DoSOCKSConnect() {
    139   next_state_ = STATE_SOCKS_CONNECT_COMPLETE;
    140 
    141   // Add a SOCKS connection on top of the tcp socket.
    142   if (socks_params_->is_socks_v5()) {
    143     socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(),
    144                                          socks_params_->destination()));
    145   } else {
    146     socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(),
    147                                         socks_params_->destination(),
    148                                         resolver_));
    149   }
    150   return socket_->Connect(
    151       base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this)));
    152 }
    153 
    154 int SOCKSConnectJob::DoSOCKSConnectComplete(int result) {
    155   if (result != OK) {
    156     socket_->Disconnect();
    157     return result;
    158   }
    159 
    160   set_socket(socket_.release());
    161   return result;
    162 }
    163 
    164 int SOCKSConnectJob::ConnectInternal() {
    165   next_state_ = STATE_TRANSPORT_CONNECT;
    166   return DoLoop(OK);
    167 }
    168 
    169 ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
    170     const std::string& group_name,
    171     const PoolBase::Request& request,
    172     ConnectJob::Delegate* delegate) const {
    173   return new SOCKSConnectJob(group_name,
    174                              request.params(),
    175                              ConnectionTimeout(),
    176                              transport_pool_,
    177                              host_resolver_,
    178                              delegate,
    179                              net_log_);
    180 }
    181 
    182 base::TimeDelta
    183 SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const {
    184   return transport_pool_->ConnectionTimeout() +
    185       base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds);
    186 }
    187 
    188 SOCKSClientSocketPool::SOCKSClientSocketPool(
    189     int max_sockets,
    190     int max_sockets_per_group,
    191     ClientSocketPoolHistograms* histograms,
    192     HostResolver* host_resolver,
    193     TransportClientSocketPool* transport_pool,
    194     NetLog* net_log)
    195     : transport_pool_(transport_pool),
    196       base_(max_sockets, max_sockets_per_group, histograms,
    197             ClientSocketPool::unused_idle_socket_timeout(),
    198             ClientSocketPool::used_idle_socket_timeout(),
    199             new SOCKSConnectJobFactory(transport_pool,
    200                                        host_resolver,
    201                                        net_log)) {
    202   // We should always have a |transport_pool_| except in unit tests.
    203   if (transport_pool_)
    204     transport_pool_->AddLayeredPool(this);
    205 }
    206 
    207 SOCKSClientSocketPool::~SOCKSClientSocketPool() {
    208   // We should always have a |transport_pool_| except in unit tests.
    209   if (transport_pool_)
    210     transport_pool_->RemoveLayeredPool(this);
    211 }
    212 
    213 int SOCKSClientSocketPool::RequestSocket(
    214     const std::string& group_name, const void* socket_params,
    215     RequestPriority priority, ClientSocketHandle* handle,
    216     const CompletionCallback& callback, const BoundNetLog& net_log) {
    217   const scoped_refptr<SOCKSSocketParams>* casted_socket_params =
    218       static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params);
    219 
    220   return base_.RequestSocket(group_name, *casted_socket_params, priority,
    221                              handle, callback, net_log);
    222 }
    223 
    224 void SOCKSClientSocketPool::RequestSockets(
    225     const std::string& group_name,
    226     const void* params,
    227     int num_sockets,
    228     const BoundNetLog& net_log) {
    229   const scoped_refptr<SOCKSSocketParams>* casted_params =
    230       static_cast<const scoped_refptr<SOCKSSocketParams>*>(params);
    231 
    232   base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
    233 }
    234 
    235 void SOCKSClientSocketPool::CancelRequest(const std::string& group_name,
    236                                           ClientSocketHandle* handle) {
    237   base_.CancelRequest(group_name, handle);
    238 }
    239 
    240 void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
    241                                           StreamSocket* socket, int id) {
    242   base_.ReleaseSocket(group_name, socket, id);
    243 }
    244 
    245 void SOCKSClientSocketPool::FlushWithError(int error) {
    246   base_.FlushWithError(error);
    247 }
    248 
    249 bool SOCKSClientSocketPool::IsStalled() const {
    250   return base_.IsStalled() || transport_pool_->IsStalled();
    251 }
    252 
    253 void SOCKSClientSocketPool::CloseIdleSockets() {
    254   base_.CloseIdleSockets();
    255 }
    256 
    257 int SOCKSClientSocketPool::IdleSocketCount() const {
    258   return base_.idle_socket_count();
    259 }
    260 
    261 int SOCKSClientSocketPool::IdleSocketCountInGroup(
    262     const std::string& group_name) const {
    263   return base_.IdleSocketCountInGroup(group_name);
    264 }
    265 
    266 LoadState SOCKSClientSocketPool::GetLoadState(
    267     const std::string& group_name, const ClientSocketHandle* handle) const {
    268   return base_.GetLoadState(group_name, handle);
    269 }
    270 
    271 void SOCKSClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) {
    272   base_.AddLayeredPool(layered_pool);
    273 }
    274 
    275 void SOCKSClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) {
    276   base_.RemoveLayeredPool(layered_pool);
    277 }
    278 
    279 base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue(
    280     const std::string& name,
    281     const std::string& type,
    282     bool include_nested_pools) const {
    283   base::DictionaryValue* dict = base_.GetInfoAsValue(name, type);
    284   if (include_nested_pools) {
    285     base::ListValue* list = new base::ListValue();
    286     list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool",
    287                                                  "transport_socket_pool",
    288                                                  false));
    289     dict->Set("nested_pools", list);
    290   }
    291   return dict;
    292 }
    293 
    294 base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const {
    295   return base_.ConnectionTimeout();
    296 }
    297 
    298 ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const {
    299   return base_.histograms();
    300 };
    301 
    302 bool SOCKSClientSocketPool::CloseOneIdleConnection() {
    303   if (base_.CloseOneIdleSocket())
    304     return true;
    305   return base_.CloseOneIdleConnectionInLayeredPool();
    306 }
    307 
    308 }  // namespace net
    309