1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks_client_socket.h" 6 7 #include "base/basictypes.h" 8 #include "base/bind.h" 9 #include "base/callback_helpers.h" 10 #include "base/compiler_specific.h" 11 #include "base/sys_byteorder.h" 12 #include "net/base/io_buffer.h" 13 #include "net/base/net_log.h" 14 #include "net/base/net_util.h" 15 #include "net/socket/client_socket_handle.h" 16 17 namespace net { 18 19 // Every SOCKS server requests a user-id from the client. It is optional 20 // and we send an empty string. 21 static const char kEmptyUserId[] = ""; 22 23 // For SOCKS4, the client sends 8 bytes plus the size of the user-id. 24 static const unsigned int kWriteHeaderSize = 8; 25 26 // For SOCKS4 the server sends 8 bytes for acknowledgement. 27 static const unsigned int kReadHeaderSize = 8; 28 29 // Server Response codes for SOCKS. 30 static const uint8 kServerResponseOk = 0x5A; 31 static const uint8 kServerResponseRejected = 0x5B; 32 static const uint8 kServerResponseNotReachable = 0x5C; 33 static const uint8 kServerResponseMismatchedUserId = 0x5D; 34 35 static const uint8 kSOCKSVersion4 = 0x04; 36 static const uint8 kSOCKSStreamRequest = 0x01; 37 38 // A struct holding the essential details of the SOCKS4 Server Request. 39 // The port in the header is stored in network byte order. 40 struct SOCKS4ServerRequest { 41 uint8 version; 42 uint8 command; 43 uint16 nw_port; 44 uint8 ip[4]; 45 }; 46 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize, 47 socks4_server_request_struct_wrong_size); 48 49 // A struct holding details of the SOCKS4 Server Response. 50 struct SOCKS4ServerResponse { 51 uint8 reserved_null; 52 uint8 code; 53 uint16 port; 54 uint8 ip[4]; 55 }; 56 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, 57 socks4_server_response_struct_wrong_size); 58 59 SOCKSClientSocket::SOCKSClientSocket( 60 scoped_ptr<ClientSocketHandle> transport_socket, 61 const HostResolver::RequestInfo& req_info, 62 RequestPriority priority, 63 HostResolver* host_resolver) 64 : transport_(transport_socket.Pass()), 65 next_state_(STATE_NONE), 66 completed_handshake_(false), 67 bytes_sent_(0), 68 bytes_received_(0), 69 was_ever_used_(false), 70 host_resolver_(host_resolver), 71 host_request_info_(req_info), 72 priority_(priority), 73 net_log_(transport_->socket()->NetLog()) {} 74 75 SOCKSClientSocket::~SOCKSClientSocket() { 76 Disconnect(); 77 } 78 79 int SOCKSClientSocket::Connect(const CompletionCallback& callback) { 80 DCHECK(transport_.get()); 81 DCHECK(transport_->socket()); 82 DCHECK_EQ(STATE_NONE, next_state_); 83 DCHECK(user_callback_.is_null()); 84 85 // If already connected, then just return OK. 86 if (completed_handshake_) 87 return OK; 88 89 next_state_ = STATE_RESOLVE_HOST; 90 91 net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT); 92 93 int rv = DoLoop(OK); 94 if (rv == ERR_IO_PENDING) { 95 user_callback_ = callback; 96 } else { 97 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 98 } 99 return rv; 100 } 101 102 void SOCKSClientSocket::Disconnect() { 103 completed_handshake_ = false; 104 host_resolver_.Cancel(); 105 transport_->socket()->Disconnect(); 106 107 // Reset other states to make sure they aren't mistakenly used later. 108 // These are the states initialized by Connect(). 109 next_state_ = STATE_NONE; 110 user_callback_.Reset(); 111 } 112 113 bool SOCKSClientSocket::IsConnected() const { 114 return completed_handshake_ && transport_->socket()->IsConnected(); 115 } 116 117 bool SOCKSClientSocket::IsConnectedAndIdle() const { 118 return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); 119 } 120 121 const BoundNetLog& SOCKSClientSocket::NetLog() const { 122 return net_log_; 123 } 124 125 void SOCKSClientSocket::SetSubresourceSpeculation() { 126 if (transport_.get() && transport_->socket()) { 127 transport_->socket()->SetSubresourceSpeculation(); 128 } else { 129 NOTREACHED(); 130 } 131 } 132 133 void SOCKSClientSocket::SetOmniboxSpeculation() { 134 if (transport_.get() && transport_->socket()) { 135 transport_->socket()->SetOmniboxSpeculation(); 136 } else { 137 NOTREACHED(); 138 } 139 } 140 141 bool SOCKSClientSocket::WasEverUsed() const { 142 return was_ever_used_; 143 } 144 145 bool SOCKSClientSocket::UsingTCPFastOpen() const { 146 if (transport_.get() && transport_->socket()) { 147 return transport_->socket()->UsingTCPFastOpen(); 148 } 149 NOTREACHED(); 150 return false; 151 } 152 153 bool SOCKSClientSocket::WasNpnNegotiated() const { 154 if (transport_.get() && transport_->socket()) { 155 return transport_->socket()->WasNpnNegotiated(); 156 } 157 NOTREACHED(); 158 return false; 159 } 160 161 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const { 162 if (transport_.get() && transport_->socket()) { 163 return transport_->socket()->GetNegotiatedProtocol(); 164 } 165 NOTREACHED(); 166 return kProtoUnknown; 167 } 168 169 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 170 if (transport_.get() && transport_->socket()) { 171 return transport_->socket()->GetSSLInfo(ssl_info); 172 } 173 NOTREACHED(); 174 return false; 175 176 } 177 178 // Read is called by the transport layer above to read. This can only be done 179 // if the SOCKS handshake is complete. 180 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, 181 const CompletionCallback& callback) { 182 DCHECK(completed_handshake_); 183 DCHECK_EQ(STATE_NONE, next_state_); 184 DCHECK(user_callback_.is_null()); 185 DCHECK(!callback.is_null()); 186 187 int rv = transport_->socket()->Read( 188 buf, buf_len, 189 base::Bind(&SOCKSClientSocket::OnReadWriteComplete, 190 base::Unretained(this), callback)); 191 if (rv > 0) 192 was_ever_used_ = true; 193 return rv; 194 } 195 196 // Write is called by the transport layer. This can only be done if the 197 // SOCKS handshake is complete. 198 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, 199 const CompletionCallback& callback) { 200 DCHECK(completed_handshake_); 201 DCHECK_EQ(STATE_NONE, next_state_); 202 DCHECK(user_callback_.is_null()); 203 DCHECK(!callback.is_null()); 204 205 int rv = transport_->socket()->Write( 206 buf, buf_len, 207 base::Bind(&SOCKSClientSocket::OnReadWriteComplete, 208 base::Unretained(this), callback)); 209 if (rv > 0) 210 was_ever_used_ = true; 211 return rv; 212 } 213 214 int SOCKSClientSocket::SetReceiveBufferSize(int32 size) { 215 return transport_->socket()->SetReceiveBufferSize(size); 216 } 217 218 int SOCKSClientSocket::SetSendBufferSize(int32 size) { 219 return transport_->socket()->SetSendBufferSize(size); 220 } 221 222 void SOCKSClientSocket::DoCallback(int result) { 223 DCHECK_NE(ERR_IO_PENDING, result); 224 DCHECK(!user_callback_.is_null()); 225 226 // Since Run() may result in Read being called, 227 // clear user_callback_ up front. 228 DVLOG(1) << "Finished setting up SOCKS handshake"; 229 base::ResetAndReturn(&user_callback_).Run(result); 230 } 231 232 void SOCKSClientSocket::OnIOComplete(int result) { 233 DCHECK_NE(STATE_NONE, next_state_); 234 int rv = DoLoop(result); 235 if (rv != ERR_IO_PENDING) { 236 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 237 DoCallback(rv); 238 } 239 } 240 241 void SOCKSClientSocket::OnReadWriteComplete(const CompletionCallback& callback, 242 int result) { 243 DCHECK_NE(ERR_IO_PENDING, result); 244 DCHECK(!callback.is_null()); 245 246 if (result > 0) 247 was_ever_used_ = true; 248 callback.Run(result); 249 } 250 251 int SOCKSClientSocket::DoLoop(int last_io_result) { 252 DCHECK_NE(next_state_, STATE_NONE); 253 int rv = last_io_result; 254 do { 255 State state = next_state_; 256 next_state_ = STATE_NONE; 257 switch (state) { 258 case STATE_RESOLVE_HOST: 259 DCHECK_EQ(OK, rv); 260 rv = DoResolveHost(); 261 break; 262 case STATE_RESOLVE_HOST_COMPLETE: 263 rv = DoResolveHostComplete(rv); 264 break; 265 case STATE_HANDSHAKE_WRITE: 266 DCHECK_EQ(OK, rv); 267 rv = DoHandshakeWrite(); 268 break; 269 case STATE_HANDSHAKE_WRITE_COMPLETE: 270 rv = DoHandshakeWriteComplete(rv); 271 break; 272 case STATE_HANDSHAKE_READ: 273 DCHECK_EQ(OK, rv); 274 rv = DoHandshakeRead(); 275 break; 276 case STATE_HANDSHAKE_READ_COMPLETE: 277 rv = DoHandshakeReadComplete(rv); 278 break; 279 default: 280 NOTREACHED() << "bad state"; 281 rv = ERR_UNEXPECTED; 282 break; 283 } 284 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); 285 return rv; 286 } 287 288 int SOCKSClientSocket::DoResolveHost() { 289 next_state_ = STATE_RESOLVE_HOST_COMPLETE; 290 // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4 291 // addresses for the target host. 292 host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4); 293 return host_resolver_.Resolve( 294 host_request_info_, 295 priority_, 296 &addresses_, 297 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)), 298 net_log_); 299 } 300 301 int SOCKSClientSocket::DoResolveHostComplete(int result) { 302 if (result != OK) { 303 // Resolving the hostname failed; fail the request rather than automatically 304 // falling back to SOCKS4a (since it can be confusing to see invalid IP 305 // addresses being sent to the SOCKS4 server when it doesn't support 4A.) 306 return result; 307 } 308 309 next_state_ = STATE_HANDSHAKE_WRITE; 310 return OK; 311 } 312 313 // Builds the buffer that is to be sent to the server. 314 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { 315 SOCKS4ServerRequest request; 316 request.version = kSOCKSVersion4; 317 request.command = kSOCKSStreamRequest; 318 request.nw_port = base::HostToNet16(host_request_info_.port()); 319 320 DCHECK(!addresses_.empty()); 321 const IPEndPoint& endpoint = addresses_.front(); 322 323 // We disabled IPv6 results when resolving the hostname, so none of the 324 // results in the list will be IPv6. 325 // TODO(eroman): we only ever use the first address in the list. It would be 326 // more robust to try all the IP addresses we have before 327 // failing the connect attempt. 328 CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily()); 329 CHECK_LE(endpoint.address().size(), sizeof(request.ip)); 330 memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size()); 331 332 DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort(); 333 334 std::string handshake_data(reinterpret_cast<char*>(&request), 335 sizeof(request)); 336 handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId)); 337 338 return handshake_data; 339 } 340 341 // Writes the SOCKS handshake data to the underlying socket connection. 342 int SOCKSClientSocket::DoHandshakeWrite() { 343 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; 344 345 if (buffer_.empty()) { 346 buffer_ = BuildHandshakeWriteBuffer(); 347 bytes_sent_ = 0; 348 } 349 350 int handshake_buf_len = buffer_.size() - bytes_sent_; 351 DCHECK_GT(handshake_buf_len, 0); 352 handshake_buf_ = new IOBuffer(handshake_buf_len); 353 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], 354 handshake_buf_len); 355 return transport_->socket()->Write( 356 handshake_buf_.get(), 357 handshake_buf_len, 358 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); 359 } 360 361 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { 362 if (result < 0) 363 return result; 364 365 // We ignore the case when result is 0, since the underlying Write 366 // may return spurious writes while waiting on the socket. 367 368 bytes_sent_ += result; 369 if (bytes_sent_ == buffer_.size()) { 370 next_state_ = STATE_HANDSHAKE_READ; 371 buffer_.clear(); 372 } else if (bytes_sent_ < buffer_.size()) { 373 next_state_ = STATE_HANDSHAKE_WRITE; 374 } else { 375 return ERR_UNEXPECTED; 376 } 377 378 return OK; 379 } 380 381 int SOCKSClientSocket::DoHandshakeRead() { 382 next_state_ = STATE_HANDSHAKE_READ_COMPLETE; 383 384 if (buffer_.empty()) { 385 bytes_received_ = 0; 386 } 387 388 int handshake_buf_len = kReadHeaderSize - bytes_received_; 389 handshake_buf_ = new IOBuffer(handshake_buf_len); 390 return transport_->socket()->Read( 391 handshake_buf_.get(), 392 handshake_buf_len, 393 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); 394 } 395 396 int SOCKSClientSocket::DoHandshakeReadComplete(int result) { 397 if (result < 0) 398 return result; 399 400 // The underlying socket closed unexpectedly. 401 if (result == 0) 402 return ERR_CONNECTION_CLOSED; 403 404 if (bytes_received_ + result > kReadHeaderSize) { 405 // TODO(eroman): Describe failure in NetLog. 406 return ERR_SOCKS_CONNECTION_FAILED; 407 } 408 409 buffer_.append(handshake_buf_->data(), result); 410 bytes_received_ += result; 411 if (bytes_received_ < kReadHeaderSize) { 412 next_state_ = STATE_HANDSHAKE_READ; 413 return OK; 414 } 415 416 const SOCKS4ServerResponse* response = 417 reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data()); 418 419 if (response->reserved_null != 0x00) { 420 LOG(ERROR) << "Unknown response from SOCKS server."; 421 return ERR_SOCKS_CONNECTION_FAILED; 422 } 423 424 switch (response->code) { 425 case kServerResponseOk: 426 completed_handshake_ = true; 427 return OK; 428 case kServerResponseRejected: 429 LOG(ERROR) << "SOCKS request rejected or failed"; 430 return ERR_SOCKS_CONNECTION_FAILED; 431 case kServerResponseNotReachable: 432 LOG(ERROR) << "SOCKS request failed because client is not running " 433 << "identd (or not reachable from the server)"; 434 return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE; 435 case kServerResponseMismatchedUserId: 436 LOG(ERROR) << "SOCKS request failed because client's identd could " 437 << "not confirm the user ID string in the request"; 438 return ERR_SOCKS_CONNECTION_FAILED; 439 default: 440 LOG(ERROR) << "SOCKS server sent unknown response"; 441 return ERR_SOCKS_CONNECTION_FAILED; 442 } 443 444 // Note: we ignore the last 6 bytes as specified by the SOCKS protocol 445 } 446 447 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const { 448 return transport_->socket()->GetPeerAddress(address); 449 } 450 451 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const { 452 return transport_->socket()->GetLocalAddress(address); 453 } 454 455 } // namespace net 456