1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socks_client_socket.h" 6 7 #include "base/basictypes.h" 8 #include "base/bind.h" 9 #include "base/compiler_specific.h" 10 #include "base/sys_byteorder.h" 11 #include "net/base/io_buffer.h" 12 #include "net/base/net_log.h" 13 #include "net/base/net_util.h" 14 #include "net/socket/client_socket_handle.h" 15 16 namespace net { 17 18 // Every SOCKS server requests a user-id from the client. It is optional 19 // and we send an empty string. 20 static const char kEmptyUserId[] = ""; 21 22 // For SOCKS4, the client sends 8 bytes plus the size of the user-id. 23 static const unsigned int kWriteHeaderSize = 8; 24 25 // For SOCKS4 the server sends 8 bytes for acknowledgement. 26 static const unsigned int kReadHeaderSize = 8; 27 28 // Server Response codes for SOCKS. 29 static const uint8 kServerResponseOk = 0x5A; 30 static const uint8 kServerResponseRejected = 0x5B; 31 static const uint8 kServerResponseNotReachable = 0x5C; 32 static const uint8 kServerResponseMismatchedUserId = 0x5D; 33 34 static const uint8 kSOCKSVersion4 = 0x04; 35 static const uint8 kSOCKSStreamRequest = 0x01; 36 37 // A struct holding the essential details of the SOCKS4 Server Request. 38 // The port in the header is stored in network byte order. 39 struct SOCKS4ServerRequest { 40 uint8 version; 41 uint8 command; 42 uint16 nw_port; 43 uint8 ip[4]; 44 }; 45 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize, 46 socks4_server_request_struct_wrong_size); 47 48 // A struct holding details of the SOCKS4 Server Response. 49 struct SOCKS4ServerResponse { 50 uint8 reserved_null; 51 uint8 code; 52 uint16 port; 53 uint8 ip[4]; 54 }; 55 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, 56 socks4_server_response_struct_wrong_size); 57 58 SOCKSClientSocket::SOCKSClientSocket( 59 scoped_ptr<ClientSocketHandle> transport_socket, 60 const HostResolver::RequestInfo& req_info, 61 RequestPriority priority, 62 HostResolver* host_resolver) 63 : transport_(transport_socket.Pass()), 64 next_state_(STATE_NONE), 65 completed_handshake_(false), 66 bytes_sent_(0), 67 bytes_received_(0), 68 host_resolver_(host_resolver), 69 host_request_info_(req_info), 70 priority_(priority), 71 net_log_(transport_->socket()->NetLog()) {} 72 73 SOCKSClientSocket::~SOCKSClientSocket() { 74 Disconnect(); 75 } 76 77 int SOCKSClientSocket::Connect(const CompletionCallback& callback) { 78 DCHECK(transport_.get()); 79 DCHECK(transport_->socket()); 80 DCHECK_EQ(STATE_NONE, next_state_); 81 DCHECK(user_callback_.is_null()); 82 83 // If already connected, then just return OK. 84 if (completed_handshake_) 85 return OK; 86 87 next_state_ = STATE_RESOLVE_HOST; 88 89 net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT); 90 91 int rv = DoLoop(OK); 92 if (rv == ERR_IO_PENDING) { 93 user_callback_ = callback; 94 } else { 95 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 96 } 97 return rv; 98 } 99 100 void SOCKSClientSocket::Disconnect() { 101 completed_handshake_ = false; 102 host_resolver_.Cancel(); 103 transport_->socket()->Disconnect(); 104 105 // Reset other states to make sure they aren't mistakenly used later. 106 // These are the states initialized by Connect(). 107 next_state_ = STATE_NONE; 108 user_callback_.Reset(); 109 } 110 111 bool SOCKSClientSocket::IsConnected() const { 112 return completed_handshake_ && transport_->socket()->IsConnected(); 113 } 114 115 bool SOCKSClientSocket::IsConnectedAndIdle() const { 116 return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); 117 } 118 119 const BoundNetLog& SOCKSClientSocket::NetLog() const { 120 return net_log_; 121 } 122 123 void SOCKSClientSocket::SetSubresourceSpeculation() { 124 if (transport_.get() && transport_->socket()) { 125 transport_->socket()->SetSubresourceSpeculation(); 126 } else { 127 NOTREACHED(); 128 } 129 } 130 131 void SOCKSClientSocket::SetOmniboxSpeculation() { 132 if (transport_.get() && transport_->socket()) { 133 transport_->socket()->SetOmniboxSpeculation(); 134 } else { 135 NOTREACHED(); 136 } 137 } 138 139 bool SOCKSClientSocket::WasEverUsed() const { 140 if (transport_.get() && transport_->socket()) { 141 return transport_->socket()->WasEverUsed(); 142 } 143 NOTREACHED(); 144 return false; 145 } 146 147 bool SOCKSClientSocket::UsingTCPFastOpen() const { 148 if (transport_.get() && transport_->socket()) { 149 return transport_->socket()->UsingTCPFastOpen(); 150 } 151 NOTREACHED(); 152 return false; 153 } 154 155 bool SOCKSClientSocket::WasNpnNegotiated() const { 156 if (transport_.get() && transport_->socket()) { 157 return transport_->socket()->WasNpnNegotiated(); 158 } 159 NOTREACHED(); 160 return false; 161 } 162 163 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const { 164 if (transport_.get() && transport_->socket()) { 165 return transport_->socket()->GetNegotiatedProtocol(); 166 } 167 NOTREACHED(); 168 return kProtoUnknown; 169 } 170 171 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 172 if (transport_.get() && transport_->socket()) { 173 return transport_->socket()->GetSSLInfo(ssl_info); 174 } 175 NOTREACHED(); 176 return false; 177 178 } 179 180 // Read is called by the transport layer above to read. This can only be done 181 // if the SOCKS handshake is complete. 182 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, 183 const CompletionCallback& callback) { 184 DCHECK(completed_handshake_); 185 DCHECK_EQ(STATE_NONE, next_state_); 186 DCHECK(user_callback_.is_null()); 187 188 return transport_->socket()->Read(buf, buf_len, callback); 189 } 190 191 // Write is called by the transport layer. This can only be done if the 192 // SOCKS handshake is complete. 193 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, 194 const CompletionCallback& callback) { 195 DCHECK(completed_handshake_); 196 DCHECK_EQ(STATE_NONE, next_state_); 197 DCHECK(user_callback_.is_null()); 198 199 return transport_->socket()->Write(buf, buf_len, callback); 200 } 201 202 bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) { 203 return transport_->socket()->SetReceiveBufferSize(size); 204 } 205 206 bool SOCKSClientSocket::SetSendBufferSize(int32 size) { 207 return transport_->socket()->SetSendBufferSize(size); 208 } 209 210 void SOCKSClientSocket::DoCallback(int result) { 211 DCHECK_NE(ERR_IO_PENDING, result); 212 DCHECK(!user_callback_.is_null()); 213 214 // Since Run() may result in Read being called, 215 // clear user_callback_ up front. 216 CompletionCallback c = user_callback_; 217 user_callback_.Reset(); 218 DVLOG(1) << "Finished setting up SOCKS handshake"; 219 c.Run(result); 220 } 221 222 void SOCKSClientSocket::OnIOComplete(int result) { 223 DCHECK_NE(STATE_NONE, next_state_); 224 int rv = DoLoop(result); 225 if (rv != ERR_IO_PENDING) { 226 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 227 DoCallback(rv); 228 } 229 } 230 231 int SOCKSClientSocket::DoLoop(int last_io_result) { 232 DCHECK_NE(next_state_, STATE_NONE); 233 int rv = last_io_result; 234 do { 235 State state = next_state_; 236 next_state_ = STATE_NONE; 237 switch (state) { 238 case STATE_RESOLVE_HOST: 239 DCHECK_EQ(OK, rv); 240 rv = DoResolveHost(); 241 break; 242 case STATE_RESOLVE_HOST_COMPLETE: 243 rv = DoResolveHostComplete(rv); 244 break; 245 case STATE_HANDSHAKE_WRITE: 246 DCHECK_EQ(OK, rv); 247 rv = DoHandshakeWrite(); 248 break; 249 case STATE_HANDSHAKE_WRITE_COMPLETE: 250 rv = DoHandshakeWriteComplete(rv); 251 break; 252 case STATE_HANDSHAKE_READ: 253 DCHECK_EQ(OK, rv); 254 rv = DoHandshakeRead(); 255 break; 256 case STATE_HANDSHAKE_READ_COMPLETE: 257 rv = DoHandshakeReadComplete(rv); 258 break; 259 default: 260 NOTREACHED() << "bad state"; 261 rv = ERR_UNEXPECTED; 262 break; 263 } 264 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); 265 return rv; 266 } 267 268 int SOCKSClientSocket::DoResolveHost() { 269 next_state_ = STATE_RESOLVE_HOST_COMPLETE; 270 // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4 271 // addresses for the target host. 272 host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4); 273 return host_resolver_.Resolve( 274 host_request_info_, 275 priority_, 276 &addresses_, 277 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)), 278 net_log_); 279 } 280 281 int SOCKSClientSocket::DoResolveHostComplete(int result) { 282 if (result != OK) { 283 // Resolving the hostname failed; fail the request rather than automatically 284 // falling back to SOCKS4a (since it can be confusing to see invalid IP 285 // addresses being sent to the SOCKS4 server when it doesn't support 4A.) 286 return result; 287 } 288 289 next_state_ = STATE_HANDSHAKE_WRITE; 290 return OK; 291 } 292 293 // Builds the buffer that is to be sent to the server. 294 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { 295 SOCKS4ServerRequest request; 296 request.version = kSOCKSVersion4; 297 request.command = kSOCKSStreamRequest; 298 request.nw_port = base::HostToNet16(host_request_info_.port()); 299 300 DCHECK(!addresses_.empty()); 301 const IPEndPoint& endpoint = addresses_.front(); 302 303 // We disabled IPv6 results when resolving the hostname, so none of the 304 // results in the list will be IPv6. 305 // TODO(eroman): we only ever use the first address in the list. It would be 306 // more robust to try all the IP addresses we have before 307 // failing the connect attempt. 308 CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily()); 309 CHECK_LE(endpoint.address().size(), sizeof(request.ip)); 310 memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size()); 311 312 DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort(); 313 314 std::string handshake_data(reinterpret_cast<char*>(&request), 315 sizeof(request)); 316 handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId)); 317 318 return handshake_data; 319 } 320 321 // Writes the SOCKS handshake data to the underlying socket connection. 322 int SOCKSClientSocket::DoHandshakeWrite() { 323 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; 324 325 if (buffer_.empty()) { 326 buffer_ = BuildHandshakeWriteBuffer(); 327 bytes_sent_ = 0; 328 } 329 330 int handshake_buf_len = buffer_.size() - bytes_sent_; 331 DCHECK_GT(handshake_buf_len, 0); 332 handshake_buf_ = new IOBuffer(handshake_buf_len); 333 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], 334 handshake_buf_len); 335 return transport_->socket()->Write( 336 handshake_buf_.get(), 337 handshake_buf_len, 338 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); 339 } 340 341 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { 342 if (result < 0) 343 return result; 344 345 // We ignore the case when result is 0, since the underlying Write 346 // may return spurious writes while waiting on the socket. 347 348 bytes_sent_ += result; 349 if (bytes_sent_ == buffer_.size()) { 350 next_state_ = STATE_HANDSHAKE_READ; 351 buffer_.clear(); 352 } else if (bytes_sent_ < buffer_.size()) { 353 next_state_ = STATE_HANDSHAKE_WRITE; 354 } else { 355 return ERR_UNEXPECTED; 356 } 357 358 return OK; 359 } 360 361 int SOCKSClientSocket::DoHandshakeRead() { 362 next_state_ = STATE_HANDSHAKE_READ_COMPLETE; 363 364 if (buffer_.empty()) { 365 bytes_received_ = 0; 366 } 367 368 int handshake_buf_len = kReadHeaderSize - bytes_received_; 369 handshake_buf_ = new IOBuffer(handshake_buf_len); 370 return transport_->socket()->Read( 371 handshake_buf_.get(), 372 handshake_buf_len, 373 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); 374 } 375 376 int SOCKSClientSocket::DoHandshakeReadComplete(int result) { 377 if (result < 0) 378 return result; 379 380 // The underlying socket closed unexpectedly. 381 if (result == 0) 382 return ERR_CONNECTION_CLOSED; 383 384 if (bytes_received_ + result > kReadHeaderSize) { 385 // TODO(eroman): Describe failure in NetLog. 386 return ERR_SOCKS_CONNECTION_FAILED; 387 } 388 389 buffer_.append(handshake_buf_->data(), result); 390 bytes_received_ += result; 391 if (bytes_received_ < kReadHeaderSize) { 392 next_state_ = STATE_HANDSHAKE_READ; 393 return OK; 394 } 395 396 const SOCKS4ServerResponse* response = 397 reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data()); 398 399 if (response->reserved_null != 0x00) { 400 LOG(ERROR) << "Unknown response from SOCKS server."; 401 return ERR_SOCKS_CONNECTION_FAILED; 402 } 403 404 switch (response->code) { 405 case kServerResponseOk: 406 completed_handshake_ = true; 407 return OK; 408 case kServerResponseRejected: 409 LOG(ERROR) << "SOCKS request rejected or failed"; 410 return ERR_SOCKS_CONNECTION_FAILED; 411 case kServerResponseNotReachable: 412 LOG(ERROR) << "SOCKS request failed because client is not running " 413 << "identd (or not reachable from the server)"; 414 return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE; 415 case kServerResponseMismatchedUserId: 416 LOG(ERROR) << "SOCKS request failed because client's identd could " 417 << "not confirm the user ID string in the request"; 418 return ERR_SOCKS_CONNECTION_FAILED; 419 default: 420 LOG(ERROR) << "SOCKS server sent unknown response"; 421 return ERR_SOCKS_CONNECTION_FAILED; 422 } 423 424 // Note: we ignore the last 6 bytes as specified by the SOCKS protocol 425 } 426 427 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const { 428 return transport_->socket()->GetPeerAddress(address); 429 } 430 431 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const { 432 return transport_->socket()->GetLocalAddress(address); 433 } 434 435 } // namespace net 436