1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks_client_socket_pool.h" 6 7 #include "base/bind.h" 8 #include "base/bind_helpers.h" 9 #include "base/time/time.h" 10 #include "base/values.h" 11 #include "net/base/net_errors.h" 12 #include "net/socket/client_socket_factory.h" 13 #include "net/socket/client_socket_handle.h" 14 #include "net/socket/client_socket_pool_base.h" 15 #include "net/socket/socks5_client_socket.h" 16 #include "net/socket/socks_client_socket.h" 17 #include "net/socket/transport_client_socket_pool.h" 18 19 namespace net { 20 21 SOCKSSocketParams::SOCKSSocketParams( 22 const scoped_refptr<TransportSocketParams>& proxy_server, 23 bool socks_v5, 24 const HostPortPair& host_port_pair) 25 : transport_params_(proxy_server), 26 destination_(host_port_pair), 27 socks_v5_(socks_v5) { 28 if (transport_params_.get()) 29 ignore_limits_ = transport_params_->ignore_limits(); 30 else 31 ignore_limits_ = false; 32 } 33 34 SOCKSSocketParams::~SOCKSSocketParams() {} 35 36 // SOCKSConnectJobs will time out after this many seconds. Note this is on 37 // top of the timeout for the transport socket. 38 static const int kSOCKSConnectJobTimeoutInSeconds = 30; 39 40 SOCKSConnectJob::SOCKSConnectJob( 41 const std::string& group_name, 42 RequestPriority priority, 43 const scoped_refptr<SOCKSSocketParams>& socks_params, 44 const base::TimeDelta& timeout_duration, 45 TransportClientSocketPool* transport_pool, 46 HostResolver* host_resolver, 47 Delegate* delegate, 48 NetLog* net_log) 49 : ConnectJob(group_name, timeout_duration, priority, delegate, 50 BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), 51 socks_params_(socks_params), 52 transport_pool_(transport_pool), 53 resolver_(host_resolver), 54 callback_(base::Bind(&SOCKSConnectJob::OnIOComplete, 55 base::Unretained(this))) { 56 } 57 58 SOCKSConnectJob::~SOCKSConnectJob() { 59 // We don't worry about cancelling the tcp socket since the destructor in 60 // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of 61 // it. 62 } 63 64 LoadState SOCKSConnectJob::GetLoadState() const { 65 switch (next_state_) { 66 case STATE_TRANSPORT_CONNECT: 67 case STATE_TRANSPORT_CONNECT_COMPLETE: 68 return transport_socket_handle_->GetLoadState(); 69 case STATE_SOCKS_CONNECT: 70 case STATE_SOCKS_CONNECT_COMPLETE: 71 return LOAD_STATE_CONNECTING; 72 default: 73 NOTREACHED(); 74 return LOAD_STATE_IDLE; 75 } 76 } 77 78 void SOCKSConnectJob::OnIOComplete(int result) { 79 int rv = DoLoop(result); 80 if (rv != ERR_IO_PENDING) 81 NotifyDelegateOfCompletion(rv); // Deletes |this| 82 } 83 84 int SOCKSConnectJob::DoLoop(int result) { 85 DCHECK_NE(next_state_, STATE_NONE); 86 87 int rv = result; 88 do { 89 State state = next_state_; 90 next_state_ = STATE_NONE; 91 switch (state) { 92 case STATE_TRANSPORT_CONNECT: 93 DCHECK_EQ(OK, rv); 94 rv = DoTransportConnect(); 95 break; 96 case STATE_TRANSPORT_CONNECT_COMPLETE: 97 rv = DoTransportConnectComplete(rv); 98 break; 99 case STATE_SOCKS_CONNECT: 100 DCHECK_EQ(OK, rv); 101 rv = DoSOCKSConnect(); 102 break; 103 case STATE_SOCKS_CONNECT_COMPLETE: 104 rv = DoSOCKSConnectComplete(rv); 105 break; 106 default: 107 NOTREACHED() << "bad state"; 108 rv = ERR_FAILED; 109 break; 110 } 111 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); 112 113 return rv; 114 } 115 116 int SOCKSConnectJob::DoTransportConnect() { 117 next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; 118 transport_socket_handle_.reset(new ClientSocketHandle()); 119 return transport_socket_handle_->Init(group_name(), 120 socks_params_->transport_params(), 121 priority(), 122 callback_, 123 transport_pool_, 124 net_log()); 125 } 126 127 int SOCKSConnectJob::DoTransportConnectComplete(int result) { 128 if (result != OK) 129 return ERR_PROXY_CONNECTION_FAILED; 130 131 // Reset the timer to just the length of time allowed for SOCKS handshake 132 // so that a fast TCP connection plus a slow SOCKS failure doesn't take 133 // longer to timeout than it should. 134 ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds)); 135 next_state_ = STATE_SOCKS_CONNECT; 136 return result; 137 } 138 139 int SOCKSConnectJob::DoSOCKSConnect() { 140 next_state_ = STATE_SOCKS_CONNECT_COMPLETE; 141 142 // Add a SOCKS connection on top of the tcp socket. 143 if (socks_params_->is_socks_v5()) { 144 socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(), 145 socks_params_->destination())); 146 } else { 147 socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(), 148 socks_params_->destination(), 149 priority(), 150 resolver_)); 151 } 152 return socket_->Connect( 153 base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this))); 154 } 155 156 int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { 157 if (result != OK) { 158 socket_->Disconnect(); 159 return result; 160 } 161 162 SetSocket(socket_.Pass()); 163 return result; 164 } 165 166 int SOCKSConnectJob::ConnectInternal() { 167 next_state_ = STATE_TRANSPORT_CONNECT; 168 return DoLoop(OK); 169 } 170 171 scoped_ptr<ConnectJob> 172 SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( 173 const std::string& group_name, 174 const PoolBase::Request& request, 175 ConnectJob::Delegate* delegate) const { 176 return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name, 177 request.priority(), 178 request.params(), 179 ConnectionTimeout(), 180 transport_pool_, 181 host_resolver_, 182 delegate, 183 net_log_)); 184 } 185 186 base::TimeDelta 187 SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const { 188 return transport_pool_->ConnectionTimeout() + 189 base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds); 190 } 191 192 SOCKSClientSocketPool::SOCKSClientSocketPool( 193 int max_sockets, 194 int max_sockets_per_group, 195 ClientSocketPoolHistograms* histograms, 196 HostResolver* host_resolver, 197 TransportClientSocketPool* transport_pool, 198 NetLog* net_log) 199 : transport_pool_(transport_pool), 200 base_(this, max_sockets, max_sockets_per_group, histograms, 201 ClientSocketPool::unused_idle_socket_timeout(), 202 ClientSocketPool::used_idle_socket_timeout(), 203 new SOCKSConnectJobFactory(transport_pool, 204 host_resolver, 205 net_log)) { 206 // We should always have a |transport_pool_| except in unit tests. 207 if (transport_pool_) 208 base_.AddLowerLayeredPool(transport_pool_); 209 } 210 211 SOCKSClientSocketPool::~SOCKSClientSocketPool() { 212 } 213 214 int SOCKSClientSocketPool::RequestSocket( 215 const std::string& group_name, const void* socket_params, 216 RequestPriority priority, ClientSocketHandle* handle, 217 const CompletionCallback& callback, const BoundNetLog& net_log) { 218 const scoped_refptr<SOCKSSocketParams>* casted_socket_params = 219 static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params); 220 221 return base_.RequestSocket(group_name, *casted_socket_params, priority, 222 handle, callback, net_log); 223 } 224 225 void SOCKSClientSocketPool::RequestSockets( 226 const std::string& group_name, 227 const void* params, 228 int num_sockets, 229 const BoundNetLog& net_log) { 230 const scoped_refptr<SOCKSSocketParams>* casted_params = 231 static_cast<const scoped_refptr<SOCKSSocketParams>*>(params); 232 233 base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); 234 } 235 236 void SOCKSClientSocketPool::CancelRequest(const std::string& group_name, 237 ClientSocketHandle* handle) { 238 base_.CancelRequest(group_name, handle); 239 } 240 241 void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, 242 scoped_ptr<StreamSocket> socket, 243 int id) { 244 base_.ReleaseSocket(group_name, socket.Pass(), id); 245 } 246 247 void SOCKSClientSocketPool::FlushWithError(int error) { 248 base_.FlushWithError(error); 249 } 250 251 void SOCKSClientSocketPool::CloseIdleSockets() { 252 base_.CloseIdleSockets(); 253 } 254 255 int SOCKSClientSocketPool::IdleSocketCount() const { 256 return base_.idle_socket_count(); 257 } 258 259 int SOCKSClientSocketPool::IdleSocketCountInGroup( 260 const std::string& group_name) const { 261 return base_.IdleSocketCountInGroup(group_name); 262 } 263 264 LoadState SOCKSClientSocketPool::GetLoadState( 265 const std::string& group_name, const ClientSocketHandle* handle) const { 266 return base_.GetLoadState(group_name, handle); 267 } 268 269 base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue( 270 const std::string& name, 271 const std::string& type, 272 bool include_nested_pools) const { 273 base::DictionaryValue* dict = base_.GetInfoAsValue(name, type); 274 if (include_nested_pools) { 275 base::ListValue* list = new base::ListValue(); 276 list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool", 277 "transport_socket_pool", 278 false)); 279 dict->Set("nested_pools", list); 280 } 281 return dict; 282 } 283 284 base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const { 285 return base_.ConnectionTimeout(); 286 } 287 288 ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const { 289 return base_.histograms(); 290 }; 291 292 bool SOCKSClientSocketPool::IsStalled() const { 293 return base_.IsStalled(); 294 } 295 296 void SOCKSClientSocketPool::AddHigherLayeredPool( 297 HigherLayeredPool* higher_pool) { 298 base_.AddHigherLayeredPool(higher_pool); 299 } 300 301 void SOCKSClientSocketPool::RemoveHigherLayeredPool( 302 HigherLayeredPool* higher_pool) { 303 base_.RemoveHigherLayeredPool(higher_pool); 304 } 305 306 bool SOCKSClientSocketPool::CloseOneIdleConnection() { 307 if (base_.CloseOneIdleSocket()) 308 return true; 309 return base_.CloseOneIdleConnectionInHigherLayeredPool(); 310 } 311 312 } // namespace net 313