1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks_client_socket_pool.h" 6 7 #include "base/bind.h" 8 #include "base/bind_helpers.h" 9 #include "base/time/time.h" 10 #include "base/values.h" 11 #include "net/base/net_errors.h" 12 #include "net/socket/client_socket_factory.h" 13 #include "net/socket/client_socket_handle.h" 14 #include "net/socket/client_socket_pool_base.h" 15 #include "net/socket/socks5_client_socket.h" 16 #include "net/socket/socks_client_socket.h" 17 #include "net/socket/transport_client_socket_pool.h" 18 19 namespace net { 20 21 SOCKSSocketParams::SOCKSSocketParams( 22 const scoped_refptr<TransportSocketParams>& proxy_server, 23 bool socks_v5, 24 const HostPortPair& host_port_pair, 25 RequestPriority priority) 26 : transport_params_(proxy_server), 27 destination_(host_port_pair), 28 socks_v5_(socks_v5) { 29 if (transport_params_.get()) 30 ignore_limits_ = transport_params_->ignore_limits(); 31 else 32 ignore_limits_ = false; 33 destination_.set_priority(priority); 34 } 35 36 SOCKSSocketParams::~SOCKSSocketParams() {} 37 38 // SOCKSConnectJobs will time out after this many seconds. Note this is on 39 // top of the timeout for the transport socket. 40 static const int kSOCKSConnectJobTimeoutInSeconds = 30; 41 42 SOCKSConnectJob::SOCKSConnectJob( 43 const std::string& group_name, 44 const scoped_refptr<SOCKSSocketParams>& socks_params, 45 const base::TimeDelta& timeout_duration, 46 TransportClientSocketPool* transport_pool, 47 HostResolver* host_resolver, 48 Delegate* delegate, 49 NetLog* net_log) 50 : ConnectJob(group_name, timeout_duration, delegate, 51 BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), 52 socks_params_(socks_params), 53 transport_pool_(transport_pool), 54 resolver_(host_resolver), 55 callback_(base::Bind(&SOCKSConnectJob::OnIOComplete, 56 base::Unretained(this))) { 57 } 58 59 SOCKSConnectJob::~SOCKSConnectJob() { 60 // We don't worry about cancelling the tcp socket since the destructor in 61 // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of 62 // it. 63 } 64 65 LoadState SOCKSConnectJob::GetLoadState() const { 66 switch (next_state_) { 67 case STATE_TRANSPORT_CONNECT: 68 case STATE_TRANSPORT_CONNECT_COMPLETE: 69 return transport_socket_handle_->GetLoadState(); 70 case STATE_SOCKS_CONNECT: 71 case STATE_SOCKS_CONNECT_COMPLETE: 72 return LOAD_STATE_CONNECTING; 73 default: 74 NOTREACHED(); 75 return LOAD_STATE_IDLE; 76 } 77 } 78 79 void SOCKSConnectJob::OnIOComplete(int result) { 80 int rv = DoLoop(result); 81 if (rv != ERR_IO_PENDING) 82 NotifyDelegateOfCompletion(rv); // Deletes |this| 83 } 84 85 int SOCKSConnectJob::DoLoop(int result) { 86 DCHECK_NE(next_state_, STATE_NONE); 87 88 int rv = result; 89 do { 90 State state = next_state_; 91 next_state_ = STATE_NONE; 92 switch (state) { 93 case STATE_TRANSPORT_CONNECT: 94 DCHECK_EQ(OK, rv); 95 rv = DoTransportConnect(); 96 break; 97 case STATE_TRANSPORT_CONNECT_COMPLETE: 98 rv = DoTransportConnectComplete(rv); 99 break; 100 case STATE_SOCKS_CONNECT: 101 DCHECK_EQ(OK, rv); 102 rv = DoSOCKSConnect(); 103 break; 104 case STATE_SOCKS_CONNECT_COMPLETE: 105 rv = DoSOCKSConnectComplete(rv); 106 break; 107 default: 108 NOTREACHED() << "bad state"; 109 rv = ERR_FAILED; 110 break; 111 } 112 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); 113 114 return rv; 115 } 116 117 int SOCKSConnectJob::DoTransportConnect() { 118 next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; 119 transport_socket_handle_.reset(new ClientSocketHandle()); 120 return transport_socket_handle_->Init( 121 group_name(), socks_params_->transport_params(), 122 socks_params_->destination().priority(), callback_, transport_pool_, 123 net_log()); 124 } 125 126 int SOCKSConnectJob::DoTransportConnectComplete(int result) { 127 if (result != OK) 128 return ERR_PROXY_CONNECTION_FAILED; 129 130 // Reset the timer to just the length of time allowed for SOCKS handshake 131 // so that a fast TCP connection plus a slow SOCKS failure doesn't take 132 // longer to timeout than it should. 133 ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds)); 134 next_state_ = STATE_SOCKS_CONNECT; 135 return result; 136 } 137 138 int SOCKSConnectJob::DoSOCKSConnect() { 139 next_state_ = STATE_SOCKS_CONNECT_COMPLETE; 140 141 // Add a SOCKS connection on top of the tcp socket. 142 if (socks_params_->is_socks_v5()) { 143 socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(), 144 socks_params_->destination())); 145 } else { 146 socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(), 147 socks_params_->destination(), 148 resolver_)); 149 } 150 return socket_->Connect( 151 base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this))); 152 } 153 154 int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { 155 if (result != OK) { 156 socket_->Disconnect(); 157 return result; 158 } 159 160 set_socket(socket_.release()); 161 return result; 162 } 163 164 int SOCKSConnectJob::ConnectInternal() { 165 next_state_ = STATE_TRANSPORT_CONNECT; 166 return DoLoop(OK); 167 } 168 169 ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( 170 const std::string& group_name, 171 const PoolBase::Request& request, 172 ConnectJob::Delegate* delegate) const { 173 return new SOCKSConnectJob(group_name, 174 request.params(), 175 ConnectionTimeout(), 176 transport_pool_, 177 host_resolver_, 178 delegate, 179 net_log_); 180 } 181 182 base::TimeDelta 183 SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const { 184 return transport_pool_->ConnectionTimeout() + 185 base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds); 186 } 187 188 SOCKSClientSocketPool::SOCKSClientSocketPool( 189 int max_sockets, 190 int max_sockets_per_group, 191 ClientSocketPoolHistograms* histograms, 192 HostResolver* host_resolver, 193 TransportClientSocketPool* transport_pool, 194 NetLog* net_log) 195 : transport_pool_(transport_pool), 196 base_(max_sockets, max_sockets_per_group, histograms, 197 ClientSocketPool::unused_idle_socket_timeout(), 198 ClientSocketPool::used_idle_socket_timeout(), 199 new SOCKSConnectJobFactory(transport_pool, 200 host_resolver, 201 net_log)) { 202 // We should always have a |transport_pool_| except in unit tests. 203 if (transport_pool_) 204 transport_pool_->AddLayeredPool(this); 205 } 206 207 SOCKSClientSocketPool::~SOCKSClientSocketPool() { 208 // We should always have a |transport_pool_| except in unit tests. 209 if (transport_pool_) 210 transport_pool_->RemoveLayeredPool(this); 211 } 212 213 int SOCKSClientSocketPool::RequestSocket( 214 const std::string& group_name, const void* socket_params, 215 RequestPriority priority, ClientSocketHandle* handle, 216 const CompletionCallback& callback, const BoundNetLog& net_log) { 217 const scoped_refptr<SOCKSSocketParams>* casted_socket_params = 218 static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params); 219 220 return base_.RequestSocket(group_name, *casted_socket_params, priority, 221 handle, callback, net_log); 222 } 223 224 void SOCKSClientSocketPool::RequestSockets( 225 const std::string& group_name, 226 const void* params, 227 int num_sockets, 228 const BoundNetLog& net_log) { 229 const scoped_refptr<SOCKSSocketParams>* casted_params = 230 static_cast<const scoped_refptr<SOCKSSocketParams>*>(params); 231 232 base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); 233 } 234 235 void SOCKSClientSocketPool::CancelRequest(const std::string& group_name, 236 ClientSocketHandle* handle) { 237 base_.CancelRequest(group_name, handle); 238 } 239 240 void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, 241 StreamSocket* socket, int id) { 242 base_.ReleaseSocket(group_name, socket, id); 243 } 244 245 void SOCKSClientSocketPool::FlushWithError(int error) { 246 base_.FlushWithError(error); 247 } 248 249 bool SOCKSClientSocketPool::IsStalled() const { 250 return base_.IsStalled() || transport_pool_->IsStalled(); 251 } 252 253 void SOCKSClientSocketPool::CloseIdleSockets() { 254 base_.CloseIdleSockets(); 255 } 256 257 int SOCKSClientSocketPool::IdleSocketCount() const { 258 return base_.idle_socket_count(); 259 } 260 261 int SOCKSClientSocketPool::IdleSocketCountInGroup( 262 const std::string& group_name) const { 263 return base_.IdleSocketCountInGroup(group_name); 264 } 265 266 LoadState SOCKSClientSocketPool::GetLoadState( 267 const std::string& group_name, const ClientSocketHandle* handle) const { 268 return base_.GetLoadState(group_name, handle); 269 } 270 271 void SOCKSClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { 272 base_.AddLayeredPool(layered_pool); 273 } 274 275 void SOCKSClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { 276 base_.RemoveLayeredPool(layered_pool); 277 } 278 279 base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue( 280 const std::string& name, 281 const std::string& type, 282 bool include_nested_pools) const { 283 base::DictionaryValue* dict = base_.GetInfoAsValue(name, type); 284 if (include_nested_pools) { 285 base::ListValue* list = new base::ListValue(); 286 list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool", 287 "transport_socket_pool", 288 false)); 289 dict->Set("nested_pools", list); 290 } 291 return dict; 292 } 293 294 base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const { 295 return base_.ConnectionTimeout(); 296 } 297 298 ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const { 299 return base_.histograms(); 300 }; 301 302 bool SOCKSClientSocketPool::CloseOneIdleConnection() { 303 if (base_.CloseOneIdleSocket()) 304 return true; 305 return base_.CloseOneIdleConnectionInLayeredPool(); 306 } 307 308 } // namespace net 309