Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket_pool.h"
      6 
      7 #include "base/time.h"
      8 #include "base/values.h"
      9 #include "googleurl/src/gurl.h"
     10 #include "net/base/net_errors.h"
     11 #include "net/socket/client_socket_factory.h"
     12 #include "net/socket/client_socket_handle.h"
     13 #include "net/socket/client_socket_pool_base.h"
     14 #include "net/socket/socks5_client_socket.h"
     15 #include "net/socket/socks_client_socket.h"
     16 #include "net/socket/transport_client_socket_pool.h"
     17 
     18 namespace net {
     19 
     20 SOCKSSocketParams::SOCKSSocketParams(
     21     const scoped_refptr<TransportSocketParams>& proxy_server,
     22     bool socks_v5,
     23     const HostPortPair& host_port_pair,
     24     RequestPriority priority,
     25     const GURL& referrer)
     26     : transport_params_(proxy_server),
     27       destination_(host_port_pair),
     28       socks_v5_(socks_v5) {
     29   if (transport_params_)
     30     ignore_limits_ = transport_params_->ignore_limits();
     31   else
     32     ignore_limits_ = false;
     33   // The referrer is used by the DNS prefetch system to correlate resolutions
     34   // with the page that triggered them. It doesn't impact the actual addresses
     35   // that we resolve to.
     36   destination_.set_referrer(referrer);
     37   destination_.set_priority(priority);
     38 }
     39 
     40 #ifdef ANDROID
     41 bool SOCKSSocketParams::getUID(uid_t *uid) const {
     42   if (transport_params_)
     43     return transport_params_->getUID(uid);
     44   else
     45     return false;
     46 }
     47 
     48 void SOCKSSocketParams::setUID(uid_t uid) {
     49   if (transport_params_)
     50     return transport_params_->setUID(uid);
     51 }
     52 #endif
     53 
     54 SOCKSSocketParams::~SOCKSSocketParams() {}
     55 
     56 // SOCKSConnectJobs will time out after this many seconds.  Note this is on
     57 // top of the timeout for the transport socket.
     58 static const int kSOCKSConnectJobTimeoutInSeconds = 30;
     59 
     60 SOCKSConnectJob::SOCKSConnectJob(
     61     const std::string& group_name,
     62     const scoped_refptr<SOCKSSocketParams>& socks_params,
     63     const base::TimeDelta& timeout_duration,
     64     TransportClientSocketPool* transport_pool,
     65     HostResolver* host_resolver,
     66     Delegate* delegate,
     67     NetLog* net_log)
     68     : ConnectJob(group_name, timeout_duration, delegate,
     69                  BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
     70       socks_params_(socks_params),
     71       transport_pool_(transport_pool),
     72       resolver_(host_resolver),
     73       ALLOW_THIS_IN_INITIALIZER_LIST(
     74           callback_(this, &SOCKSConnectJob::OnIOComplete)) {
     75 }
     76 
     77 SOCKSConnectJob::~SOCKSConnectJob() {
     78   // We don't worry about cancelling the tcp socket since the destructor in
     79   // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of
     80   // it.
     81 }
     82 
     83 LoadState SOCKSConnectJob::GetLoadState() const {
     84   switch (next_state_) {
     85     case STATE_TRANSPORT_CONNECT:
     86     case STATE_TRANSPORT_CONNECT_COMPLETE:
     87       return transport_socket_handle_->GetLoadState();
     88     case STATE_SOCKS_CONNECT:
     89     case STATE_SOCKS_CONNECT_COMPLETE:
     90       return LOAD_STATE_CONNECTING;
     91     default:
     92       NOTREACHED();
     93       return LOAD_STATE_IDLE;
     94   }
     95 }
     96 
     97 void SOCKSConnectJob::OnIOComplete(int result) {
     98   int rv = DoLoop(result);
     99   if (rv != ERR_IO_PENDING)
    100     NotifyDelegateOfCompletion(rv);  // Deletes |this|
    101 }
    102 
    103 int SOCKSConnectJob::DoLoop(int result) {
    104   DCHECK_NE(next_state_, STATE_NONE);
    105 
    106   int rv = result;
    107   do {
    108     State state = next_state_;
    109     next_state_ = STATE_NONE;
    110     switch (state) {
    111       case STATE_TRANSPORT_CONNECT:
    112         DCHECK_EQ(OK, rv);
    113         rv = DoTransportConnect();
    114         break;
    115       case STATE_TRANSPORT_CONNECT_COMPLETE:
    116         rv = DoTransportConnectComplete(rv);
    117         break;
    118       case STATE_SOCKS_CONNECT:
    119         DCHECK_EQ(OK, rv);
    120         rv = DoSOCKSConnect();
    121         break;
    122       case STATE_SOCKS_CONNECT_COMPLETE:
    123         rv = DoSOCKSConnectComplete(rv);
    124         break;
    125       default:
    126         NOTREACHED() << "bad state";
    127         rv = ERR_FAILED;
    128         break;
    129     }
    130   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    131 
    132   return rv;
    133 }
    134 
    135 int SOCKSConnectJob::DoTransportConnect() {
    136   next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
    137   transport_socket_handle_.reset(new ClientSocketHandle());
    138   return transport_socket_handle_->Init(group_name(),
    139                                         socks_params_->transport_params(),
    140                                         socks_params_->destination().priority(),
    141                                         &callback_,
    142                                         transport_pool_,
    143                                         net_log());
    144 }
    145 
    146 int SOCKSConnectJob::DoTransportConnectComplete(int result) {
    147   if (result != OK)
    148     return ERR_PROXY_CONNECTION_FAILED;
    149 
    150   // Reset the timer to just the length of time allowed for SOCKS handshake
    151   // so that a fast TCP connection plus a slow SOCKS failure doesn't take
    152   // longer to timeout than it should.
    153   ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds));
    154   next_state_ = STATE_SOCKS_CONNECT;
    155   return result;
    156 }
    157 
    158 int SOCKSConnectJob::DoSOCKSConnect() {
    159   next_state_ = STATE_SOCKS_CONNECT_COMPLETE;
    160 
    161   // Add a SOCKS connection on top of the tcp socket.
    162   if (socks_params_->is_socks_v5()) {
    163     socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(),
    164                                          socks_params_->destination()));
    165   } else {
    166     socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(),
    167                                         socks_params_->destination(),
    168                                         resolver_));
    169   }
    170 
    171 #ifdef ANDROID
    172   uid_t calling_uid = 0;
    173   bool valid_uid = socks_params_->transport_params()->getUID(&calling_uid);
    174 #endif
    175 
    176   return socket_->Connect(&callback_
    177 #ifdef ANDROID
    178                           , socks_params_->ignore_limits()
    179                           , valid_uid
    180                           , calling_uid
    181 #endif
    182                          );
    183 }
    184 
    185 int SOCKSConnectJob::DoSOCKSConnectComplete(int result) {
    186   if (result != OK) {
    187     socket_->Disconnect();
    188     return result;
    189   }
    190 
    191   set_socket(socket_.release());
    192   return result;
    193 }
    194 
    195 int SOCKSConnectJob::ConnectInternal() {
    196   next_state_ = STATE_TRANSPORT_CONNECT;
    197   return DoLoop(OK);
    198 }
    199 
    200 ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
    201     const std::string& group_name,
    202     const PoolBase::Request& request,
    203     ConnectJob::Delegate* delegate) const {
    204   return new SOCKSConnectJob(group_name,
    205                              request.params(),
    206                              ConnectionTimeout(),
    207                              transport_pool_,
    208                              host_resolver_,
    209                              delegate,
    210                              net_log_);
    211 }
    212 
    213 base::TimeDelta
    214 SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const {
    215   return transport_pool_->ConnectionTimeout() +
    216       base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds);
    217 }
    218 
    219 SOCKSClientSocketPool::SOCKSClientSocketPool(
    220     int max_sockets,
    221     int max_sockets_per_group,
    222     ClientSocketPoolHistograms* histograms,
    223     HostResolver* host_resolver,
    224     TransportClientSocketPool* transport_pool,
    225     NetLog* net_log)
    226     : transport_pool_(transport_pool),
    227       base_(max_sockets, max_sockets_per_group, histograms,
    228             base::TimeDelta::FromSeconds(
    229                 ClientSocketPool::unused_idle_socket_timeout()),
    230             base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout),
    231             new SOCKSConnectJobFactory(transport_pool,
    232                                        host_resolver,
    233                                        net_log)) {
    234 }
    235 
    236 SOCKSClientSocketPool::~SOCKSClientSocketPool() {}
    237 
    238 int SOCKSClientSocketPool::RequestSocket(const std::string& group_name,
    239                                          const void* socket_params,
    240                                          RequestPriority priority,
    241                                          ClientSocketHandle* handle,
    242                                          CompletionCallback* callback,
    243                                          const BoundNetLog& net_log) {
    244   const scoped_refptr<SOCKSSocketParams>* casted_socket_params =
    245       static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params);
    246 
    247   return base_.RequestSocket(group_name, *casted_socket_params, priority,
    248                              handle, callback, net_log);
    249 }
    250 
    251 void SOCKSClientSocketPool::RequestSockets(
    252     const std::string& group_name,
    253     const void* params,
    254     int num_sockets,
    255     const BoundNetLog& net_log) {
    256   const scoped_refptr<SOCKSSocketParams>* casted_params =
    257       static_cast<const scoped_refptr<SOCKSSocketParams>*>(params);
    258 
    259   base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
    260 }
    261 
    262 void SOCKSClientSocketPool::CancelRequest(const std::string& group_name,
    263                                           ClientSocketHandle* handle) {
    264   base_.CancelRequest(group_name, handle);
    265 }
    266 
    267 void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
    268                                           ClientSocket* socket, int id) {
    269   base_.ReleaseSocket(group_name, socket, id);
    270 }
    271 
    272 void SOCKSClientSocketPool::Flush() {
    273   base_.Flush();
    274 }
    275 
    276 void SOCKSClientSocketPool::CloseIdleSockets() {
    277   base_.CloseIdleSockets();
    278 }
    279 
    280 int SOCKSClientSocketPool::IdleSocketCount() const {
    281   return base_.idle_socket_count();
    282 }
    283 
    284 int SOCKSClientSocketPool::IdleSocketCountInGroup(
    285     const std::string& group_name) const {
    286   return base_.IdleSocketCountInGroup(group_name);
    287 }
    288 
    289 LoadState SOCKSClientSocketPool::GetLoadState(
    290     const std::string& group_name, const ClientSocketHandle* handle) const {
    291   return base_.GetLoadState(group_name, handle);
    292 }
    293 
    294 DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue(
    295     const std::string& name,
    296     const std::string& type,
    297     bool include_nested_pools) const {
    298   DictionaryValue* dict = base_.GetInfoAsValue(name, type);
    299   if (include_nested_pools) {
    300     ListValue* list = new ListValue();
    301     list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool",
    302                                                  "transport_socket_pool",
    303                                                  false));
    304     dict->Set("nested_pools", list);
    305   }
    306   return dict;
    307 }
    308 
    309 base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const {
    310   return base_.ConnectionTimeout();
    311 }
    312 
    313 ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const {
    314   return base_.histograms();
    315 };
    316 
    317 }  // namespace net
    318