1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks_client_socket_pool.h" 6 7 #include "base/time.h" 8 #include "base/values.h" 9 #include "googleurl/src/gurl.h" 10 #include "net/base/net_errors.h" 11 #include "net/socket/client_socket_factory.h" 12 #include "net/socket/client_socket_handle.h" 13 #include "net/socket/client_socket_pool_base.h" 14 #include "net/socket/socks5_client_socket.h" 15 #include "net/socket/socks_client_socket.h" 16 #include "net/socket/transport_client_socket_pool.h" 17 18 namespace net { 19 20 SOCKSSocketParams::SOCKSSocketParams( 21 const scoped_refptr<TransportSocketParams>& proxy_server, 22 bool socks_v5, 23 const HostPortPair& host_port_pair, 24 RequestPriority priority, 25 const GURL& referrer) 26 : transport_params_(proxy_server), 27 destination_(host_port_pair), 28 socks_v5_(socks_v5) { 29 if (transport_params_) 30 ignore_limits_ = transport_params_->ignore_limits(); 31 else 32 ignore_limits_ = false; 33 // The referrer is used by the DNS prefetch system to correlate resolutions 34 // with the page that triggered them. It doesn't impact the actual addresses 35 // that we resolve to. 36 destination_.set_referrer(referrer); 37 destination_.set_priority(priority); 38 } 39 40 #ifdef ANDROID 41 bool SOCKSSocketParams::getUID(uid_t *uid) const { 42 if (transport_params_) 43 return transport_params_->getUID(uid); 44 else 45 return false; 46 } 47 48 void SOCKSSocketParams::setUID(uid_t uid) { 49 if (transport_params_) 50 return transport_params_->setUID(uid); 51 } 52 #endif 53 54 SOCKSSocketParams::~SOCKSSocketParams() {} 55 56 // SOCKSConnectJobs will time out after this many seconds. Note this is on 57 // top of the timeout for the transport socket. 58 static const int kSOCKSConnectJobTimeoutInSeconds = 30; 59 60 SOCKSConnectJob::SOCKSConnectJob( 61 const std::string& group_name, 62 const scoped_refptr<SOCKSSocketParams>& socks_params, 63 const base::TimeDelta& timeout_duration, 64 TransportClientSocketPool* transport_pool, 65 HostResolver* host_resolver, 66 Delegate* delegate, 67 NetLog* net_log) 68 : ConnectJob(group_name, timeout_duration, delegate, 69 BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), 70 socks_params_(socks_params), 71 transport_pool_(transport_pool), 72 resolver_(host_resolver), 73 ALLOW_THIS_IN_INITIALIZER_LIST( 74 callback_(this, &SOCKSConnectJob::OnIOComplete)) { 75 } 76 77 SOCKSConnectJob::~SOCKSConnectJob() { 78 // We don't worry about cancelling the tcp socket since the destructor in 79 // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of 80 // it. 81 } 82 83 LoadState SOCKSConnectJob::GetLoadState() const { 84 switch (next_state_) { 85 case STATE_TRANSPORT_CONNECT: 86 case STATE_TRANSPORT_CONNECT_COMPLETE: 87 return transport_socket_handle_->GetLoadState(); 88 case STATE_SOCKS_CONNECT: 89 case STATE_SOCKS_CONNECT_COMPLETE: 90 return LOAD_STATE_CONNECTING; 91 default: 92 NOTREACHED(); 93 return LOAD_STATE_IDLE; 94 } 95 } 96 97 void SOCKSConnectJob::OnIOComplete(int result) { 98 int rv = DoLoop(result); 99 if (rv != ERR_IO_PENDING) 100 NotifyDelegateOfCompletion(rv); // Deletes |this| 101 } 102 103 int SOCKSConnectJob::DoLoop(int result) { 104 DCHECK_NE(next_state_, STATE_NONE); 105 106 int rv = result; 107 do { 108 State state = next_state_; 109 next_state_ = STATE_NONE; 110 switch (state) { 111 case STATE_TRANSPORT_CONNECT: 112 DCHECK_EQ(OK, rv); 113 rv = DoTransportConnect(); 114 break; 115 case STATE_TRANSPORT_CONNECT_COMPLETE: 116 rv = DoTransportConnectComplete(rv); 117 break; 118 case STATE_SOCKS_CONNECT: 119 DCHECK_EQ(OK, rv); 120 rv = DoSOCKSConnect(); 121 break; 122 case STATE_SOCKS_CONNECT_COMPLETE: 123 rv = DoSOCKSConnectComplete(rv); 124 break; 125 default: 126 NOTREACHED() << "bad state"; 127 rv = ERR_FAILED; 128 break; 129 } 130 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); 131 132 return rv; 133 } 134 135 int SOCKSConnectJob::DoTransportConnect() { 136 next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; 137 transport_socket_handle_.reset(new ClientSocketHandle()); 138 return transport_socket_handle_->Init(group_name(), 139 socks_params_->transport_params(), 140 socks_params_->destination().priority(), 141 &callback_, 142 transport_pool_, 143 net_log()); 144 } 145 146 int SOCKSConnectJob::DoTransportConnectComplete(int result) { 147 if (result != OK) 148 return ERR_PROXY_CONNECTION_FAILED; 149 150 // Reset the timer to just the length of time allowed for SOCKS handshake 151 // so that a fast TCP connection plus a slow SOCKS failure doesn't take 152 // longer to timeout than it should. 153 ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds)); 154 next_state_ = STATE_SOCKS_CONNECT; 155 return result; 156 } 157 158 int SOCKSConnectJob::DoSOCKSConnect() { 159 next_state_ = STATE_SOCKS_CONNECT_COMPLETE; 160 161 // Add a SOCKS connection on top of the tcp socket. 162 if (socks_params_->is_socks_v5()) { 163 socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(), 164 socks_params_->destination())); 165 } else { 166 socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(), 167 socks_params_->destination(), 168 resolver_)); 169 } 170 171 #ifdef ANDROID 172 uid_t calling_uid = 0; 173 bool valid_uid = socks_params_->transport_params()->getUID(&calling_uid); 174 #endif 175 176 return socket_->Connect(&callback_ 177 #ifdef ANDROID 178 , socks_params_->ignore_limits() 179 , valid_uid 180 , calling_uid 181 #endif 182 ); 183 } 184 185 int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { 186 if (result != OK) { 187 socket_->Disconnect(); 188 return result; 189 } 190 191 set_socket(socket_.release()); 192 return result; 193 } 194 195 int SOCKSConnectJob::ConnectInternal() { 196 next_state_ = STATE_TRANSPORT_CONNECT; 197 return DoLoop(OK); 198 } 199 200 ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( 201 const std::string& group_name, 202 const PoolBase::Request& request, 203 ConnectJob::Delegate* delegate) const { 204 return new SOCKSConnectJob(group_name, 205 request.params(), 206 ConnectionTimeout(), 207 transport_pool_, 208 host_resolver_, 209 delegate, 210 net_log_); 211 } 212 213 base::TimeDelta 214 SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const { 215 return transport_pool_->ConnectionTimeout() + 216 base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds); 217 } 218 219 SOCKSClientSocketPool::SOCKSClientSocketPool( 220 int max_sockets, 221 int max_sockets_per_group, 222 ClientSocketPoolHistograms* histograms, 223 HostResolver* host_resolver, 224 TransportClientSocketPool* transport_pool, 225 NetLog* net_log) 226 : transport_pool_(transport_pool), 227 base_(max_sockets, max_sockets_per_group, histograms, 228 base::TimeDelta::FromSeconds( 229 ClientSocketPool::unused_idle_socket_timeout()), 230 base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout), 231 new SOCKSConnectJobFactory(transport_pool, 232 host_resolver, 233 net_log)) { 234 } 235 236 SOCKSClientSocketPool::~SOCKSClientSocketPool() {} 237 238 int SOCKSClientSocketPool::RequestSocket(const std::string& group_name, 239 const void* socket_params, 240 RequestPriority priority, 241 ClientSocketHandle* handle, 242 CompletionCallback* callback, 243 const BoundNetLog& net_log) { 244 const scoped_refptr<SOCKSSocketParams>* casted_socket_params = 245 static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params); 246 247 return base_.RequestSocket(group_name, *casted_socket_params, priority, 248 handle, callback, net_log); 249 } 250 251 void SOCKSClientSocketPool::RequestSockets( 252 const std::string& group_name, 253 const void* params, 254 int num_sockets, 255 const BoundNetLog& net_log) { 256 const scoped_refptr<SOCKSSocketParams>* casted_params = 257 static_cast<const scoped_refptr<SOCKSSocketParams>*>(params); 258 259 base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); 260 } 261 262 void SOCKSClientSocketPool::CancelRequest(const std::string& group_name, 263 ClientSocketHandle* handle) { 264 base_.CancelRequest(group_name, handle); 265 } 266 267 void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, 268 ClientSocket* socket, int id) { 269 base_.ReleaseSocket(group_name, socket, id); 270 } 271 272 void SOCKSClientSocketPool::Flush() { 273 base_.Flush(); 274 } 275 276 void SOCKSClientSocketPool::CloseIdleSockets() { 277 base_.CloseIdleSockets(); 278 } 279 280 int SOCKSClientSocketPool::IdleSocketCount() const { 281 return base_.idle_socket_count(); 282 } 283 284 int SOCKSClientSocketPool::IdleSocketCountInGroup( 285 const std::string& group_name) const { 286 return base_.IdleSocketCountInGroup(group_name); 287 } 288 289 LoadState SOCKSClientSocketPool::GetLoadState( 290 const std::string& group_name, const ClientSocketHandle* handle) const { 291 return base_.GetLoadState(group_name, handle); 292 } 293 294 DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue( 295 const std::string& name, 296 const std::string& type, 297 bool include_nested_pools) const { 298 DictionaryValue* dict = base_.GetInfoAsValue(name, type); 299 if (include_nested_pools) { 300 ListValue* list = new ListValue(); 301 list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool", 302 "transport_socket_pool", 303 false)); 304 dict->Set("nested_pools", list); 305 } 306 return dict; 307 } 308 309 base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const { 310 return base_.ConnectionTimeout(); 311 } 312 313 ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const { 314 return base_.histograms(); 315 }; 316 317 } // namespace net 318