Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket_pool.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/bind_helpers.h"
      9 #include "base/time/time.h"
     10 #include "base/values.h"
     11 #include "net/base/net_errors.h"
     12 #include "net/socket/client_socket_factory.h"
     13 #include "net/socket/client_socket_handle.h"
     14 #include "net/socket/client_socket_pool_base.h"
     15 #include "net/socket/socks5_client_socket.h"
     16 #include "net/socket/socks_client_socket.h"
     17 #include "net/socket/transport_client_socket_pool.h"
     18 
     19 namespace net {
     20 
     21 SOCKSSocketParams::SOCKSSocketParams(
     22     const scoped_refptr<TransportSocketParams>& proxy_server,
     23     bool socks_v5,
     24     const HostPortPair& host_port_pair)
     25     : transport_params_(proxy_server),
     26       destination_(host_port_pair),
     27       socks_v5_(socks_v5) {
     28   if (transport_params_.get())
     29     ignore_limits_ = transport_params_->ignore_limits();
     30   else
     31     ignore_limits_ = false;
     32 }
     33 
     34 SOCKSSocketParams::~SOCKSSocketParams() {}
     35 
     36 // SOCKSConnectJobs will time out after this many seconds.  Note this is on
     37 // top of the timeout for the transport socket.
     38 static const int kSOCKSConnectJobTimeoutInSeconds = 30;
     39 
     40 SOCKSConnectJob::SOCKSConnectJob(
     41     const std::string& group_name,
     42     RequestPriority priority,
     43     const scoped_refptr<SOCKSSocketParams>& socks_params,
     44     const base::TimeDelta& timeout_duration,
     45     TransportClientSocketPool* transport_pool,
     46     HostResolver* host_resolver,
     47     Delegate* delegate,
     48     NetLog* net_log)
     49     : ConnectJob(group_name, timeout_duration, priority, delegate,
     50                  BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
     51       socks_params_(socks_params),
     52       transport_pool_(transport_pool),
     53       resolver_(host_resolver),
     54       callback_(base::Bind(&SOCKSConnectJob::OnIOComplete,
     55                            base::Unretained(this))) {
     56 }
     57 
     58 SOCKSConnectJob::~SOCKSConnectJob() {
     59   // We don't worry about cancelling the tcp socket since the destructor in
     60   // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of
     61   // it.
     62 }
     63 
     64 LoadState SOCKSConnectJob::GetLoadState() const {
     65   switch (next_state_) {
     66     case STATE_TRANSPORT_CONNECT:
     67     case STATE_TRANSPORT_CONNECT_COMPLETE:
     68       return transport_socket_handle_->GetLoadState();
     69     case STATE_SOCKS_CONNECT:
     70     case STATE_SOCKS_CONNECT_COMPLETE:
     71       return LOAD_STATE_CONNECTING;
     72     default:
     73       NOTREACHED();
     74       return LOAD_STATE_IDLE;
     75   }
     76 }
     77 
     78 void SOCKSConnectJob::OnIOComplete(int result) {
     79   int rv = DoLoop(result);
     80   if (rv != ERR_IO_PENDING)
     81     NotifyDelegateOfCompletion(rv);  // Deletes |this|
     82 }
     83 
     84 int SOCKSConnectJob::DoLoop(int result) {
     85   DCHECK_NE(next_state_, STATE_NONE);
     86 
     87   int rv = result;
     88   do {
     89     State state = next_state_;
     90     next_state_ = STATE_NONE;
     91     switch (state) {
     92       case STATE_TRANSPORT_CONNECT:
     93         DCHECK_EQ(OK, rv);
     94         rv = DoTransportConnect();
     95         break;
     96       case STATE_TRANSPORT_CONNECT_COMPLETE:
     97         rv = DoTransportConnectComplete(rv);
     98         break;
     99       case STATE_SOCKS_CONNECT:
    100         DCHECK_EQ(OK, rv);
    101         rv = DoSOCKSConnect();
    102         break;
    103       case STATE_SOCKS_CONNECT_COMPLETE:
    104         rv = DoSOCKSConnectComplete(rv);
    105         break;
    106       default:
    107         NOTREACHED() << "bad state";
    108         rv = ERR_FAILED;
    109         break;
    110     }
    111   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    112 
    113   return rv;
    114 }
    115 
    116 int SOCKSConnectJob::DoTransportConnect() {
    117   next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
    118   transport_socket_handle_.reset(new ClientSocketHandle());
    119   return transport_socket_handle_->Init(group_name(),
    120                                         socks_params_->transport_params(),
    121                                         priority(),
    122                                         callback_,
    123                                         transport_pool_,
    124                                         net_log());
    125 }
    126 
    127 int SOCKSConnectJob::DoTransportConnectComplete(int result) {
    128   if (result != OK)
    129     return ERR_PROXY_CONNECTION_FAILED;
    130 
    131   // Reset the timer to just the length of time allowed for SOCKS handshake
    132   // so that a fast TCP connection plus a slow SOCKS failure doesn't take
    133   // longer to timeout than it should.
    134   ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds));
    135   next_state_ = STATE_SOCKS_CONNECT;
    136   return result;
    137 }
    138 
    139 int SOCKSConnectJob::DoSOCKSConnect() {
    140   next_state_ = STATE_SOCKS_CONNECT_COMPLETE;
    141 
    142   // Add a SOCKS connection on top of the tcp socket.
    143   if (socks_params_->is_socks_v5()) {
    144     socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(),
    145                                          socks_params_->destination()));
    146   } else {
    147     socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(),
    148                                         socks_params_->destination(),
    149                                         priority(),
    150                                         resolver_));
    151   }
    152   return socket_->Connect(
    153       base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this)));
    154 }
    155 
    156 int SOCKSConnectJob::DoSOCKSConnectComplete(int result) {
    157   if (result != OK) {
    158     socket_->Disconnect();
    159     return result;
    160   }
    161 
    162   SetSocket(socket_.Pass());
    163   return result;
    164 }
    165 
    166 int SOCKSConnectJob::ConnectInternal() {
    167   next_state_ = STATE_TRANSPORT_CONNECT;
    168   return DoLoop(OK);
    169 }
    170 
    171 scoped_ptr<ConnectJob>
    172 SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
    173     const std::string& group_name,
    174     const PoolBase::Request& request,
    175     ConnectJob::Delegate* delegate) const {
    176   return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name,
    177                                                     request.priority(),
    178                                                     request.params(),
    179                                                     ConnectionTimeout(),
    180                                                     transport_pool_,
    181                                                     host_resolver_,
    182                                                     delegate,
    183                                                     net_log_));
    184 }
    185 
    186 base::TimeDelta
    187 SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const {
    188   return transport_pool_->ConnectionTimeout() +
    189       base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds);
    190 }
    191 
    192 SOCKSClientSocketPool::SOCKSClientSocketPool(
    193     int max_sockets,
    194     int max_sockets_per_group,
    195     ClientSocketPoolHistograms* histograms,
    196     HostResolver* host_resolver,
    197     TransportClientSocketPool* transport_pool,
    198     NetLog* net_log)
    199     : transport_pool_(transport_pool),
    200       base_(this, max_sockets, max_sockets_per_group, histograms,
    201             ClientSocketPool::unused_idle_socket_timeout(),
    202             ClientSocketPool::used_idle_socket_timeout(),
    203             new SOCKSConnectJobFactory(transport_pool,
    204                                        host_resolver,
    205                                        net_log)) {
    206   // We should always have a |transport_pool_| except in unit tests.
    207   if (transport_pool_)
    208     base_.AddLowerLayeredPool(transport_pool_);
    209 }
    210 
    211 SOCKSClientSocketPool::~SOCKSClientSocketPool() {
    212 }
    213 
    214 int SOCKSClientSocketPool::RequestSocket(
    215     const std::string& group_name, const void* socket_params,
    216     RequestPriority priority, ClientSocketHandle* handle,
    217     const CompletionCallback& callback, const BoundNetLog& net_log) {
    218   const scoped_refptr<SOCKSSocketParams>* casted_socket_params =
    219       static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params);
    220 
    221   return base_.RequestSocket(group_name, *casted_socket_params, priority,
    222                              handle, callback, net_log);
    223 }
    224 
    225 void SOCKSClientSocketPool::RequestSockets(
    226     const std::string& group_name,
    227     const void* params,
    228     int num_sockets,
    229     const BoundNetLog& net_log) {
    230   const scoped_refptr<SOCKSSocketParams>* casted_params =
    231       static_cast<const scoped_refptr<SOCKSSocketParams>*>(params);
    232 
    233   base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
    234 }
    235 
    236 void SOCKSClientSocketPool::CancelRequest(const std::string& group_name,
    237                                           ClientSocketHandle* handle) {
    238   base_.CancelRequest(group_name, handle);
    239 }
    240 
    241 void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
    242                                           scoped_ptr<StreamSocket> socket,
    243                                           int id) {
    244   base_.ReleaseSocket(group_name, socket.Pass(), id);
    245 }
    246 
    247 void SOCKSClientSocketPool::FlushWithError(int error) {
    248   base_.FlushWithError(error);
    249 }
    250 
    251 void SOCKSClientSocketPool::CloseIdleSockets() {
    252   base_.CloseIdleSockets();
    253 }
    254 
    255 int SOCKSClientSocketPool::IdleSocketCount() const {
    256   return base_.idle_socket_count();
    257 }
    258 
    259 int SOCKSClientSocketPool::IdleSocketCountInGroup(
    260     const std::string& group_name) const {
    261   return base_.IdleSocketCountInGroup(group_name);
    262 }
    263 
    264 LoadState SOCKSClientSocketPool::GetLoadState(
    265     const std::string& group_name, const ClientSocketHandle* handle) const {
    266   return base_.GetLoadState(group_name, handle);
    267 }
    268 
    269 base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue(
    270     const std::string& name,
    271     const std::string& type,
    272     bool include_nested_pools) const {
    273   base::DictionaryValue* dict = base_.GetInfoAsValue(name, type);
    274   if (include_nested_pools) {
    275     base::ListValue* list = new base::ListValue();
    276     list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool",
    277                                                  "transport_socket_pool",
    278                                                  false));
    279     dict->Set("nested_pools", list);
    280   }
    281   return dict;
    282 }
    283 
    284 base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const {
    285   return base_.ConnectionTimeout();
    286 }
    287 
    288 ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const {
    289   return base_.histograms();
    290 };
    291 
    292 bool SOCKSClientSocketPool::IsStalled() const {
    293   return base_.IsStalled();
    294 }
    295 
    296 void SOCKSClientSocketPool::AddHigherLayeredPool(
    297     HigherLayeredPool* higher_pool) {
    298   base_.AddHigherLayeredPool(higher_pool);
    299 }
    300 
    301 void SOCKSClientSocketPool::RemoveHigherLayeredPool(
    302     HigherLayeredPool* higher_pool) {
    303   base_.RemoveHigherLayeredPool(higher_pool);
    304 }
    305 
    306 bool SOCKSClientSocketPool::CloseOneIdleConnection() {
    307   if (base_.CloseOneIdleSocket())
    308     return true;
    309   return base_.CloseOneIdleConnectionInHigherLayeredPool();
    310 }
    311 
    312 }  // namespace net
    313