1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks_client_socket.h" 6 7 #include "base/basictypes.h" 8 #include "base/bind.h" 9 #include "base/compiler_specific.h" 10 #include "base/sys_byteorder.h" 11 #include "net/base/io_buffer.h" 12 #include "net/base/net_log.h" 13 #include "net/base/net_util.h" 14 #include "net/socket/client_socket_handle.h" 15 16 namespace net { 17 18 // Every SOCKS server requests a user-id from the client. It is optional 19 // and we send an empty string. 20 static const char kEmptyUserId[] = ""; 21 22 // For SOCKS4, the client sends 8 bytes plus the size of the user-id. 23 static const unsigned int kWriteHeaderSize = 8; 24 25 // For SOCKS4 the server sends 8 bytes for acknowledgement. 26 static const unsigned int kReadHeaderSize = 8; 27 28 // Server Response codes for SOCKS. 29 static const uint8 kServerResponseOk = 0x5A; 30 static const uint8 kServerResponseRejected = 0x5B; 31 static const uint8 kServerResponseNotReachable = 0x5C; 32 static const uint8 kServerResponseMismatchedUserId = 0x5D; 33 34 static const uint8 kSOCKSVersion4 = 0x04; 35 static const uint8 kSOCKSStreamRequest = 0x01; 36 37 // A struct holding the essential details of the SOCKS4 Server Request. 38 // The port in the header is stored in network byte order. 39 struct SOCKS4ServerRequest { 40 uint8 version; 41 uint8 command; 42 uint16 nw_port; 43 uint8 ip[4]; 44 }; 45 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize, 46 socks4_server_request_struct_wrong_size); 47 48 // A struct holding details of the SOCKS4 Server Response. 49 struct SOCKS4ServerResponse { 50 uint8 reserved_null; 51 uint8 code; 52 uint16 port; 53 uint8 ip[4]; 54 }; 55 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, 56 socks4_server_response_struct_wrong_size); 57 58 SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket, 59 const HostResolver::RequestInfo& req_info, 60 HostResolver* host_resolver) 61 : transport_(transport_socket), 62 next_state_(STATE_NONE), 63 completed_handshake_(false), 64 bytes_sent_(0), 65 bytes_received_(0), 66 host_resolver_(host_resolver), 67 host_request_info_(req_info), 68 net_log_(transport_socket->socket()->NetLog()) { 69 } 70 71 SOCKSClientSocket::SOCKSClientSocket(StreamSocket* transport_socket, 72 const HostResolver::RequestInfo& req_info, 73 HostResolver* host_resolver) 74 : transport_(new ClientSocketHandle()), 75 next_state_(STATE_NONE), 76 completed_handshake_(false), 77 bytes_sent_(0), 78 bytes_received_(0), 79 host_resolver_(host_resolver), 80 host_request_info_(req_info), 81 net_log_(transport_socket->NetLog()) { 82 transport_->set_socket(transport_socket); 83 } 84 85 SOCKSClientSocket::~SOCKSClientSocket() { 86 Disconnect(); 87 } 88 89 int SOCKSClientSocket::Connect(const CompletionCallback& callback) { 90 DCHECK(transport_.get()); 91 DCHECK(transport_->socket()); 92 DCHECK_EQ(STATE_NONE, next_state_); 93 DCHECK(user_callback_.is_null()); 94 95 // If already connected, then just return OK. 96 if (completed_handshake_) 97 return OK; 98 99 next_state_ = STATE_RESOLVE_HOST; 100 101 net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT); 102 103 int rv = DoLoop(OK); 104 if (rv == ERR_IO_PENDING) { 105 user_callback_ = callback; 106 } else { 107 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 108 } 109 return rv; 110 } 111 112 void SOCKSClientSocket::Disconnect() { 113 completed_handshake_ = false; 114 host_resolver_.Cancel(); 115 transport_->socket()->Disconnect(); 116 117 // Reset other states to make sure they aren't mistakenly used later. 118 // These are the states initialized by Connect(). 119 next_state_ = STATE_NONE; 120 user_callback_.Reset(); 121 } 122 123 bool SOCKSClientSocket::IsConnected() const { 124 return completed_handshake_ && transport_->socket()->IsConnected(); 125 } 126 127 bool SOCKSClientSocket::IsConnectedAndIdle() const { 128 return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); 129 } 130 131 const BoundNetLog& SOCKSClientSocket::NetLog() const { 132 return net_log_; 133 } 134 135 void SOCKSClientSocket::SetSubresourceSpeculation() { 136 if (transport_.get() && transport_->socket()) { 137 transport_->socket()->SetSubresourceSpeculation(); 138 } else { 139 NOTREACHED(); 140 } 141 } 142 143 void SOCKSClientSocket::SetOmniboxSpeculation() { 144 if (transport_.get() && transport_->socket()) { 145 transport_->socket()->SetOmniboxSpeculation(); 146 } else { 147 NOTREACHED(); 148 } 149 } 150 151 bool SOCKSClientSocket::WasEverUsed() const { 152 if (transport_.get() && transport_->socket()) { 153 return transport_->socket()->WasEverUsed(); 154 } 155 NOTREACHED(); 156 return false; 157 } 158 159 bool SOCKSClientSocket::UsingTCPFastOpen() const { 160 if (transport_.get() && transport_->socket()) { 161 return transport_->socket()->UsingTCPFastOpen(); 162 } 163 NOTREACHED(); 164 return false; 165 } 166 167 bool SOCKSClientSocket::WasNpnNegotiated() const { 168 if (transport_.get() && transport_->socket()) { 169 return transport_->socket()->WasNpnNegotiated(); 170 } 171 NOTREACHED(); 172 return false; 173 } 174 175 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const { 176 if (transport_.get() && transport_->socket()) { 177 return transport_->socket()->GetNegotiatedProtocol(); 178 } 179 NOTREACHED(); 180 return kProtoUnknown; 181 } 182 183 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 184 if (transport_.get() && transport_->socket()) { 185 return transport_->socket()->GetSSLInfo(ssl_info); 186 } 187 NOTREACHED(); 188 return false; 189 190 } 191 192 // Read is called by the transport layer above to read. This can only be done 193 // if the SOCKS handshake is complete. 194 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, 195 const CompletionCallback& callback) { 196 DCHECK(completed_handshake_); 197 DCHECK_EQ(STATE_NONE, next_state_); 198 DCHECK(user_callback_.is_null()); 199 200 return transport_->socket()->Read(buf, buf_len, callback); 201 } 202 203 // Write is called by the transport layer. This can only be done if the 204 // SOCKS handshake is complete. 205 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, 206 const CompletionCallback& callback) { 207 DCHECK(completed_handshake_); 208 DCHECK_EQ(STATE_NONE, next_state_); 209 DCHECK(user_callback_.is_null()); 210 211 return transport_->socket()->Write(buf, buf_len, callback); 212 } 213 214 bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) { 215 return transport_->socket()->SetReceiveBufferSize(size); 216 } 217 218 bool SOCKSClientSocket::SetSendBufferSize(int32 size) { 219 return transport_->socket()->SetSendBufferSize(size); 220 } 221 222 void SOCKSClientSocket::DoCallback(int result) { 223 DCHECK_NE(ERR_IO_PENDING, result); 224 DCHECK(!user_callback_.is_null()); 225 226 // Since Run() may result in Read being called, 227 // clear user_callback_ up front. 228 CompletionCallback c = user_callback_; 229 user_callback_.Reset(); 230 DVLOG(1) << "Finished setting up SOCKS handshake"; 231 c.Run(result); 232 } 233 234 void SOCKSClientSocket::OnIOComplete(int result) { 235 DCHECK_NE(STATE_NONE, next_state_); 236 int rv = DoLoop(result); 237 if (rv != ERR_IO_PENDING) { 238 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 239 DoCallback(rv); 240 } 241 } 242 243 int SOCKSClientSocket::DoLoop(int last_io_result) { 244 DCHECK_NE(next_state_, STATE_NONE); 245 int rv = last_io_result; 246 do { 247 State state = next_state_; 248 next_state_ = STATE_NONE; 249 switch (state) { 250 case STATE_RESOLVE_HOST: 251 DCHECK_EQ(OK, rv); 252 rv = DoResolveHost(); 253 break; 254 case STATE_RESOLVE_HOST_COMPLETE: 255 rv = DoResolveHostComplete(rv); 256 break; 257 case STATE_HANDSHAKE_WRITE: 258 DCHECK_EQ(OK, rv); 259 rv = DoHandshakeWrite(); 260 break; 261 case STATE_HANDSHAKE_WRITE_COMPLETE: 262 rv = DoHandshakeWriteComplete(rv); 263 break; 264 case STATE_HANDSHAKE_READ: 265 DCHECK_EQ(OK, rv); 266 rv = DoHandshakeRead(); 267 break; 268 case STATE_HANDSHAKE_READ_COMPLETE: 269 rv = DoHandshakeReadComplete(rv); 270 break; 271 default: 272 NOTREACHED() << "bad state"; 273 rv = ERR_UNEXPECTED; 274 break; 275 } 276 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); 277 return rv; 278 } 279 280 int SOCKSClientSocket::DoResolveHost() { 281 next_state_ = STATE_RESOLVE_HOST_COMPLETE; 282 // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4 283 // addresses for the target host. 284 host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4); 285 return host_resolver_.Resolve( 286 host_request_info_, &addresses_, 287 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)), 288 net_log_); 289 } 290 291 int SOCKSClientSocket::DoResolveHostComplete(int result) { 292 if (result != OK) { 293 // Resolving the hostname failed; fail the request rather than automatically 294 // falling back to SOCKS4a (since it can be confusing to see invalid IP 295 // addresses being sent to the SOCKS4 server when it doesn't support 4A.) 296 return result; 297 } 298 299 next_state_ = STATE_HANDSHAKE_WRITE; 300 return OK; 301 } 302 303 // Builds the buffer that is to be sent to the server. 304 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { 305 SOCKS4ServerRequest request; 306 request.version = kSOCKSVersion4; 307 request.command = kSOCKSStreamRequest; 308 request.nw_port = base::HostToNet16(host_request_info_.port()); 309 310 DCHECK(!addresses_.empty()); 311 const IPEndPoint& endpoint = addresses_.front(); 312 313 // We disabled IPv6 results when resolving the hostname, so none of the 314 // results in the list will be IPv6. 315 // TODO(eroman): we only ever use the first address in the list. It would be 316 // more robust to try all the IP addresses we have before 317 // failing the connect attempt. 318 CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily()); 319 CHECK_LE(endpoint.address().size(), sizeof(request.ip)); 320 memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size()); 321 322 DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort(); 323 324 std::string handshake_data(reinterpret_cast<char*>(&request), 325 sizeof(request)); 326 handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId)); 327 328 return handshake_data; 329 } 330 331 // Writes the SOCKS handshake data to the underlying socket connection. 332 int SOCKSClientSocket::DoHandshakeWrite() { 333 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; 334 335 if (buffer_.empty()) { 336 buffer_ = BuildHandshakeWriteBuffer(); 337 bytes_sent_ = 0; 338 } 339 340 int handshake_buf_len = buffer_.size() - bytes_sent_; 341 DCHECK_GT(handshake_buf_len, 0); 342 handshake_buf_ = new IOBuffer(handshake_buf_len); 343 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], 344 handshake_buf_len); 345 return transport_->socket()->Write( 346 handshake_buf_.get(), 347 handshake_buf_len, 348 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); 349 } 350 351 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { 352 if (result < 0) 353 return result; 354 355 // We ignore the case when result is 0, since the underlying Write 356 // may return spurious writes while waiting on the socket. 357 358 bytes_sent_ += result; 359 if (bytes_sent_ == buffer_.size()) { 360 next_state_ = STATE_HANDSHAKE_READ; 361 buffer_.clear(); 362 } else if (bytes_sent_ < buffer_.size()) { 363 next_state_ = STATE_HANDSHAKE_WRITE; 364 } else { 365 return ERR_UNEXPECTED; 366 } 367 368 return OK; 369 } 370 371 int SOCKSClientSocket::DoHandshakeRead() { 372 next_state_ = STATE_HANDSHAKE_READ_COMPLETE; 373 374 if (buffer_.empty()) { 375 bytes_received_ = 0; 376 } 377 378 int handshake_buf_len = kReadHeaderSize - bytes_received_; 379 handshake_buf_ = new IOBuffer(handshake_buf_len); 380 return transport_->socket()->Read( 381 handshake_buf_.get(), 382 handshake_buf_len, 383 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); 384 } 385 386 int SOCKSClientSocket::DoHandshakeReadComplete(int result) { 387 if (result < 0) 388 return result; 389 390 // The underlying socket closed unexpectedly. 391 if (result == 0) 392 return ERR_CONNECTION_CLOSED; 393 394 if (bytes_received_ + result > kReadHeaderSize) { 395 // TODO(eroman): Describe failure in NetLog. 396 return ERR_SOCKS_CONNECTION_FAILED; 397 } 398 399 buffer_.append(handshake_buf_->data(), result); 400 bytes_received_ += result; 401 if (bytes_received_ < kReadHeaderSize) { 402 next_state_ = STATE_HANDSHAKE_READ; 403 return OK; 404 } 405 406 const SOCKS4ServerResponse* response = 407 reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data()); 408 409 if (response->reserved_null != 0x00) { 410 LOG(ERROR) << "Unknown response from SOCKS server."; 411 return ERR_SOCKS_CONNECTION_FAILED; 412 } 413 414 switch (response->code) { 415 case kServerResponseOk: 416 completed_handshake_ = true; 417 return OK; 418 case kServerResponseRejected: 419 LOG(ERROR) << "SOCKS request rejected or failed"; 420 return ERR_SOCKS_CONNECTION_FAILED; 421 case kServerResponseNotReachable: 422 LOG(ERROR) << "SOCKS request failed because client is not running " 423 << "identd (or not reachable from the server)"; 424 return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE; 425 case kServerResponseMismatchedUserId: 426 LOG(ERROR) << "SOCKS request failed because client's identd could " 427 << "not confirm the user ID string in the request"; 428 return ERR_SOCKS_CONNECTION_FAILED; 429 default: 430 LOG(ERROR) << "SOCKS server sent unknown response"; 431 return ERR_SOCKS_CONNECTION_FAILED; 432 } 433 434 // Note: we ignore the last 6 bytes as specified by the SOCKS protocol 435 } 436 437 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const { 438 return transport_->socket()->GetPeerAddress(address); 439 } 440 441 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const { 442 return transport_->socket()->GetLocalAddress(address); 443 } 444 445 } // namespace net 446