Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket.h"
      6 
      7 #include "base/basictypes.h"
      8 #include "base/bind.h"
      9 #include "base/compiler_specific.h"
     10 #include "base/sys_byteorder.h"
     11 #include "net/base/io_buffer.h"
     12 #include "net/base/net_log.h"
     13 #include "net/base/net_util.h"
     14 #include "net/socket/client_socket_handle.h"
     15 
     16 namespace net {
     17 
     18 // Every SOCKS server requests a user-id from the client. It is optional
     19 // and we send an empty string.
     20 static const char kEmptyUserId[] = "";
     21 
     22 // For SOCKS4, the client sends 8 bytes  plus the size of the user-id.
     23 static const unsigned int kWriteHeaderSize = 8;
     24 
     25 // For SOCKS4 the server sends 8 bytes for acknowledgement.
     26 static const unsigned int kReadHeaderSize = 8;
     27 
     28 // Server Response codes for SOCKS.
     29 static const uint8 kServerResponseOk  = 0x5A;
     30 static const uint8 kServerResponseRejected = 0x5B;
     31 static const uint8 kServerResponseNotReachable = 0x5C;
     32 static const uint8 kServerResponseMismatchedUserId = 0x5D;
     33 
     34 static const uint8 kSOCKSVersion4 = 0x04;
     35 static const uint8 kSOCKSStreamRequest = 0x01;
     36 
     37 // A struct holding the essential details of the SOCKS4 Server Request.
     38 // The port in the header is stored in network byte order.
     39 struct SOCKS4ServerRequest {
     40   uint8 version;
     41   uint8 command;
     42   uint16 nw_port;
     43   uint8 ip[4];
     44 };
     45 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
     46                socks4_server_request_struct_wrong_size);
     47 
     48 // A struct holding details of the SOCKS4 Server Response.
     49 struct SOCKS4ServerResponse {
     50   uint8 reserved_null;
     51   uint8 code;
     52   uint16 port;
     53   uint8 ip[4];
     54 };
     55 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
     56                socks4_server_response_struct_wrong_size);
     57 
     58 SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket,
     59                                      const HostResolver::RequestInfo& req_info,
     60                                      HostResolver* host_resolver)
     61     : transport_(transport_socket),
     62       next_state_(STATE_NONE),
     63       completed_handshake_(false),
     64       bytes_sent_(0),
     65       bytes_received_(0),
     66       host_resolver_(host_resolver),
     67       host_request_info_(req_info),
     68       net_log_(transport_socket->socket()->NetLog()) {
     69 }
     70 
     71 SOCKSClientSocket::SOCKSClientSocket(StreamSocket* transport_socket,
     72                                      const HostResolver::RequestInfo& req_info,
     73                                      HostResolver* host_resolver)
     74     : transport_(new ClientSocketHandle()),
     75       next_state_(STATE_NONE),
     76       completed_handshake_(false),
     77       bytes_sent_(0),
     78       bytes_received_(0),
     79       host_resolver_(host_resolver),
     80       host_request_info_(req_info),
     81       net_log_(transport_socket->NetLog()) {
     82   transport_->set_socket(transport_socket);
     83 }
     84 
     85 SOCKSClientSocket::~SOCKSClientSocket() {
     86   Disconnect();
     87 }
     88 
     89 int SOCKSClientSocket::Connect(const CompletionCallback& callback) {
     90   DCHECK(transport_.get());
     91   DCHECK(transport_->socket());
     92   DCHECK_EQ(STATE_NONE, next_state_);
     93   DCHECK(user_callback_.is_null());
     94 
     95   // If already connected, then just return OK.
     96   if (completed_handshake_)
     97     return OK;
     98 
     99   next_state_ = STATE_RESOLVE_HOST;
    100 
    101   net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT);
    102 
    103   int rv = DoLoop(OK);
    104   if (rv == ERR_IO_PENDING) {
    105     user_callback_ = callback;
    106   } else {
    107     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
    108   }
    109   return rv;
    110 }
    111 
    112 void SOCKSClientSocket::Disconnect() {
    113   completed_handshake_ = false;
    114   host_resolver_.Cancel();
    115   transport_->socket()->Disconnect();
    116 
    117   // Reset other states to make sure they aren't mistakenly used later.
    118   // These are the states initialized by Connect().
    119   next_state_ = STATE_NONE;
    120   user_callback_.Reset();
    121 }
    122 
    123 bool SOCKSClientSocket::IsConnected() const {
    124   return completed_handshake_ && transport_->socket()->IsConnected();
    125 }
    126 
    127 bool SOCKSClientSocket::IsConnectedAndIdle() const {
    128   return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
    129 }
    130 
    131 const BoundNetLog& SOCKSClientSocket::NetLog() const {
    132   return net_log_;
    133 }
    134 
    135 void SOCKSClientSocket::SetSubresourceSpeculation() {
    136   if (transport_.get() && transport_->socket()) {
    137     transport_->socket()->SetSubresourceSpeculation();
    138   } else {
    139     NOTREACHED();
    140   }
    141 }
    142 
    143 void SOCKSClientSocket::SetOmniboxSpeculation() {
    144   if (transport_.get() && transport_->socket()) {
    145     transport_->socket()->SetOmniboxSpeculation();
    146   } else {
    147     NOTREACHED();
    148   }
    149 }
    150 
    151 bool SOCKSClientSocket::WasEverUsed() const {
    152   if (transport_.get() && transport_->socket()) {
    153     return transport_->socket()->WasEverUsed();
    154   }
    155   NOTREACHED();
    156   return false;
    157 }
    158 
    159 bool SOCKSClientSocket::UsingTCPFastOpen() const {
    160   if (transport_.get() && transport_->socket()) {
    161     return transport_->socket()->UsingTCPFastOpen();
    162   }
    163   NOTREACHED();
    164   return false;
    165 }
    166 
    167 bool SOCKSClientSocket::WasNpnNegotiated() const {
    168   if (transport_.get() && transport_->socket()) {
    169     return transport_->socket()->WasNpnNegotiated();
    170   }
    171   NOTREACHED();
    172   return false;
    173 }
    174 
    175 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
    176   if (transport_.get() && transport_->socket()) {
    177     return transport_->socket()->GetNegotiatedProtocol();
    178   }
    179   NOTREACHED();
    180   return kProtoUnknown;
    181 }
    182 
    183 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
    184   if (transport_.get() && transport_->socket()) {
    185     return transport_->socket()->GetSSLInfo(ssl_info);
    186   }
    187   NOTREACHED();
    188   return false;
    189 
    190 }
    191 
    192 // Read is called by the transport layer above to read. This can only be done
    193 // if the SOCKS handshake is complete.
    194 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
    195                             const CompletionCallback& callback) {
    196   DCHECK(completed_handshake_);
    197   DCHECK_EQ(STATE_NONE, next_state_);
    198   DCHECK(user_callback_.is_null());
    199 
    200   return transport_->socket()->Read(buf, buf_len, callback);
    201 }
    202 
    203 // Write is called by the transport layer. This can only be done if the
    204 // SOCKS handshake is complete.
    205 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
    206                              const CompletionCallback& callback) {
    207   DCHECK(completed_handshake_);
    208   DCHECK_EQ(STATE_NONE, next_state_);
    209   DCHECK(user_callback_.is_null());
    210 
    211   return transport_->socket()->Write(buf, buf_len, callback);
    212 }
    213 
    214 bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
    215   return transport_->socket()->SetReceiveBufferSize(size);
    216 }
    217 
    218 bool SOCKSClientSocket::SetSendBufferSize(int32 size) {
    219   return transport_->socket()->SetSendBufferSize(size);
    220 }
    221 
    222 void SOCKSClientSocket::DoCallback(int result) {
    223   DCHECK_NE(ERR_IO_PENDING, result);
    224   DCHECK(!user_callback_.is_null());
    225 
    226   // Since Run() may result in Read being called,
    227   // clear user_callback_ up front.
    228   CompletionCallback c = user_callback_;
    229   user_callback_.Reset();
    230   DVLOG(1) << "Finished setting up SOCKS handshake";
    231   c.Run(result);
    232 }
    233 
    234 void SOCKSClientSocket::OnIOComplete(int result) {
    235   DCHECK_NE(STATE_NONE, next_state_);
    236   int rv = DoLoop(result);
    237   if (rv != ERR_IO_PENDING) {
    238     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
    239     DoCallback(rv);
    240   }
    241 }
    242 
    243 int SOCKSClientSocket::DoLoop(int last_io_result) {
    244   DCHECK_NE(next_state_, STATE_NONE);
    245   int rv = last_io_result;
    246   do {
    247     State state = next_state_;
    248     next_state_ = STATE_NONE;
    249     switch (state) {
    250       case STATE_RESOLVE_HOST:
    251         DCHECK_EQ(OK, rv);
    252         rv = DoResolveHost();
    253         break;
    254       case STATE_RESOLVE_HOST_COMPLETE:
    255         rv = DoResolveHostComplete(rv);
    256         break;
    257       case STATE_HANDSHAKE_WRITE:
    258         DCHECK_EQ(OK, rv);
    259         rv = DoHandshakeWrite();
    260         break;
    261       case STATE_HANDSHAKE_WRITE_COMPLETE:
    262         rv = DoHandshakeWriteComplete(rv);
    263         break;
    264       case STATE_HANDSHAKE_READ:
    265         DCHECK_EQ(OK, rv);
    266         rv = DoHandshakeRead();
    267         break;
    268       case STATE_HANDSHAKE_READ_COMPLETE:
    269         rv = DoHandshakeReadComplete(rv);
    270         break;
    271       default:
    272         NOTREACHED() << "bad state";
    273         rv = ERR_UNEXPECTED;
    274         break;
    275     }
    276   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    277   return rv;
    278 }
    279 
    280 int SOCKSClientSocket::DoResolveHost() {
    281   next_state_ = STATE_RESOLVE_HOST_COMPLETE;
    282   // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
    283   // addresses for the target host.
    284   host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
    285   return host_resolver_.Resolve(
    286       host_request_info_, &addresses_,
    287       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
    288       net_log_);
    289 }
    290 
    291 int SOCKSClientSocket::DoResolveHostComplete(int result) {
    292   if (result != OK) {
    293     // Resolving the hostname failed; fail the request rather than automatically
    294     // falling back to SOCKS4a (since it can be confusing to see invalid IP
    295     // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
    296     return result;
    297   }
    298 
    299   next_state_ = STATE_HANDSHAKE_WRITE;
    300   return OK;
    301 }
    302 
    303 // Builds the buffer that is to be sent to the server.
    304 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
    305   SOCKS4ServerRequest request;
    306   request.version = kSOCKSVersion4;
    307   request.command = kSOCKSStreamRequest;
    308   request.nw_port = base::HostToNet16(host_request_info_.port());
    309 
    310   DCHECK(!addresses_.empty());
    311   const IPEndPoint& endpoint = addresses_.front();
    312 
    313   // We disabled IPv6 results when resolving the hostname, so none of the
    314   // results in the list will be IPv6.
    315   // TODO(eroman): we only ever use the first address in the list. It would be
    316   //               more robust to try all the IP addresses we have before
    317   //               failing the connect attempt.
    318   CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
    319   CHECK_LE(endpoint.address().size(), sizeof(request.ip));
    320   memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size());
    321 
    322   DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
    323 
    324   std::string handshake_data(reinterpret_cast<char*>(&request),
    325                              sizeof(request));
    326   handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
    327 
    328   return handshake_data;
    329 }
    330 
    331 // Writes the SOCKS handshake data to the underlying socket connection.
    332 int SOCKSClientSocket::DoHandshakeWrite() {
    333   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
    334 
    335   if (buffer_.empty()) {
    336     buffer_ = BuildHandshakeWriteBuffer();
    337     bytes_sent_ = 0;
    338   }
    339 
    340   int handshake_buf_len = buffer_.size() - bytes_sent_;
    341   DCHECK_GT(handshake_buf_len, 0);
    342   handshake_buf_ = new IOBuffer(handshake_buf_len);
    343   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
    344          handshake_buf_len);
    345   return transport_->socket()->Write(
    346       handshake_buf_.get(),
    347       handshake_buf_len,
    348       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
    349 }
    350 
    351 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
    352   if (result < 0)
    353     return result;
    354 
    355   // We ignore the case when result is 0, since the underlying Write
    356   // may return spurious writes while waiting on the socket.
    357 
    358   bytes_sent_ += result;
    359   if (bytes_sent_ == buffer_.size()) {
    360     next_state_ = STATE_HANDSHAKE_READ;
    361     buffer_.clear();
    362   } else if (bytes_sent_ < buffer_.size()) {
    363     next_state_ = STATE_HANDSHAKE_WRITE;
    364   } else {
    365     return ERR_UNEXPECTED;
    366   }
    367 
    368   return OK;
    369 }
    370 
    371 int SOCKSClientSocket::DoHandshakeRead() {
    372   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
    373 
    374   if (buffer_.empty()) {
    375     bytes_received_ = 0;
    376   }
    377 
    378   int handshake_buf_len = kReadHeaderSize - bytes_received_;
    379   handshake_buf_ = new IOBuffer(handshake_buf_len);
    380   return transport_->socket()->Read(
    381       handshake_buf_.get(),
    382       handshake_buf_len,
    383       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
    384 }
    385 
    386 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
    387   if (result < 0)
    388     return result;
    389 
    390   // The underlying socket closed unexpectedly.
    391   if (result == 0)
    392     return ERR_CONNECTION_CLOSED;
    393 
    394   if (bytes_received_ + result > kReadHeaderSize) {
    395     // TODO(eroman): Describe failure in NetLog.
    396     return ERR_SOCKS_CONNECTION_FAILED;
    397   }
    398 
    399   buffer_.append(handshake_buf_->data(), result);
    400   bytes_received_ += result;
    401   if (bytes_received_ < kReadHeaderSize) {
    402     next_state_ = STATE_HANDSHAKE_READ;
    403     return OK;
    404   }
    405 
    406   const SOCKS4ServerResponse* response =
    407       reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
    408 
    409   if (response->reserved_null != 0x00) {
    410     LOG(ERROR) << "Unknown response from SOCKS server.";
    411     return ERR_SOCKS_CONNECTION_FAILED;
    412   }
    413 
    414   switch (response->code) {
    415     case kServerResponseOk:
    416       completed_handshake_ = true;
    417       return OK;
    418     case kServerResponseRejected:
    419       LOG(ERROR) << "SOCKS request rejected or failed";
    420       return ERR_SOCKS_CONNECTION_FAILED;
    421     case kServerResponseNotReachable:
    422       LOG(ERROR) << "SOCKS request failed because client is not running "
    423                  << "identd (or not reachable from the server)";
    424       return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
    425     case kServerResponseMismatchedUserId:
    426       LOG(ERROR) << "SOCKS request failed because client's identd could "
    427                  << "not confirm the user ID string in the request";
    428       return ERR_SOCKS_CONNECTION_FAILED;
    429     default:
    430       LOG(ERROR) << "SOCKS server sent unknown response";
    431       return ERR_SOCKS_CONNECTION_FAILED;
    432   }
    433 
    434   // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
    435 }
    436 
    437 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
    438   return transport_->socket()->GetPeerAddress(address);
    439 }
    440 
    441 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
    442   return transport_->socket()->GetLocalAddress(address);
    443 }
    444 
    445 }  // namespace net
    446