Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket.h"
      6 
      7 #include "base/basictypes.h"
      8 #include "base/compiler_specific.h"
      9 #include "base/trace_event.h"
     10 #include "net/base/io_buffer.h"
     11 #include "net/base/load_log.h"
     12 #include "net/base/net_util.h"
     13 #include "net/base/sys_addrinfo.h"
     14 
     15 namespace net {
     16 
     17 // Every SOCKS server requests a user-id from the client. It is optional
     18 // and we send an empty string.
     19 static const char kEmptyUserId[] = "";
     20 
     21 // The SOCKS4a implementation suggests to use an invalid IP in case the DNS
     22 // resolution at client fails.
     23 static const uint8 kInvalidIp[] = { 0, 0, 0, 127 };
     24 
     25 // For SOCKS4, the client sends 8 bytes  plus the size of the user-id.
     26 // For SOCKS4A, this increases to accomodate the unresolved hostname.
     27 static const unsigned int kWriteHeaderSize = 8;
     28 
     29 // For SOCKS4 and SOCKS4a, the server sends 8 bytes for acknowledgement.
     30 static const unsigned int kReadHeaderSize = 8;
     31 
     32 // Server Response codes for SOCKS.
     33 static const uint8 kServerResponseOk  = 0x5A;
     34 static const uint8 kServerResponseRejected = 0x5B;
     35 static const uint8 kServerResponseNotReachable = 0x5C;
     36 static const uint8 kServerResponseMismatchedUserId = 0x5D;
     37 
     38 static const uint8 kSOCKSVersion4 = 0x04;
     39 static const uint8 kSOCKSStreamRequest = 0x01;
     40 
     41 // A struct holding the essential details of the SOCKS4/4a Server Request.
     42 // The port in the header is stored in network byte order.
     43 struct SOCKS4ServerRequest {
     44   uint8 version;
     45   uint8 command;
     46   uint16 nw_port;
     47   uint8 ip[4];
     48 };
     49 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
     50                socks4_server_request_struct_wrong_size);
     51 
     52 // A struct holding details of the SOCKS4/4a Server Response.
     53 struct SOCKS4ServerResponse {
     54   uint8 reserved_null;
     55   uint8 code;
     56   uint16 port;
     57   uint8 ip[4];
     58 };
     59 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
     60                socks4_server_response_struct_wrong_size);
     61 
     62 SOCKSClientSocket::SOCKSClientSocket(ClientSocket* transport_socket,
     63                                      const HostResolver::RequestInfo& req_info,
     64                                      HostResolver* host_resolver)
     65     : ALLOW_THIS_IN_INITIALIZER_LIST(
     66           io_callback_(this, &SOCKSClientSocket::OnIOComplete)),
     67       transport_(transport_socket),
     68       next_state_(STATE_NONE),
     69       socks_version_(kSOCKS4Unresolved),
     70       user_callback_(NULL),
     71       completed_handshake_(false),
     72       bytes_sent_(0),
     73       bytes_received_(0),
     74       host_resolver_(host_resolver),
     75       host_request_info_(req_info) {
     76 }
     77 
     78 SOCKSClientSocket::~SOCKSClientSocket() {
     79   Disconnect();
     80 }
     81 
     82 int SOCKSClientSocket::Connect(CompletionCallback* callback,
     83                                LoadLog* load_log) {
     84   DCHECK(transport_.get());
     85   DCHECK(transport_->IsConnected());
     86   DCHECK_EQ(STATE_NONE, next_state_);
     87   DCHECK(!user_callback_);
     88 
     89   // If already connected, then just return OK.
     90   if (completed_handshake_)
     91     return OK;
     92 
     93   next_state_ = STATE_RESOLVE_HOST;
     94   load_log_ = load_log;
     95 
     96   LoadLog::BeginEvent(load_log, LoadLog::TYPE_SOCKS_CONNECT);
     97 
     98   int rv = DoLoop(OK);
     99   if (rv == ERR_IO_PENDING) {
    100     user_callback_ = callback;
    101   } else {
    102     LoadLog::EndEvent(load_log, LoadLog::TYPE_SOCKS_CONNECT);
    103     load_log_ = NULL;
    104   }
    105   return rv;
    106 }
    107 
    108 void SOCKSClientSocket::Disconnect() {
    109   completed_handshake_ = false;
    110   host_resolver_.Cancel();
    111   transport_->Disconnect();
    112 
    113   // Reset other states to make sure they aren't mistakenly used later.
    114   // These are the states initialized by Connect().
    115   next_state_ = STATE_NONE;
    116   user_callback_ = NULL;
    117   load_log_ = NULL;
    118 }
    119 
    120 bool SOCKSClientSocket::IsConnected() const {
    121   return completed_handshake_ && transport_->IsConnected();
    122 }
    123 
    124 bool SOCKSClientSocket::IsConnectedAndIdle() const {
    125   return completed_handshake_ && transport_->IsConnectedAndIdle();
    126 }
    127 
    128 // Read is called by the transport layer above to read. This can only be done
    129 // if the SOCKS handshake is complete.
    130 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
    131                             CompletionCallback* callback) {
    132   DCHECK(completed_handshake_);
    133   DCHECK_EQ(STATE_NONE, next_state_);
    134   DCHECK(!user_callback_);
    135 
    136   return transport_->Read(buf, buf_len, callback);
    137 }
    138 
    139 // Write is called by the transport layer. This can only be done if the
    140 // SOCKS handshake is complete.
    141 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
    142                              CompletionCallback* callback) {
    143   DCHECK(completed_handshake_);
    144   DCHECK_EQ(STATE_NONE, next_state_);
    145   DCHECK(!user_callback_);
    146 
    147   return transport_->Write(buf, buf_len, callback);
    148 }
    149 
    150 bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
    151   return transport_->SetReceiveBufferSize(size);
    152 }
    153 
    154 bool SOCKSClientSocket::SetSendBufferSize(int32 size) {
    155   return transport_->SetSendBufferSize(size);
    156 }
    157 
    158 void SOCKSClientSocket::DoCallback(int result) {
    159   DCHECK_NE(ERR_IO_PENDING, result);
    160   DCHECK(user_callback_);
    161 
    162   // Since Run() may result in Read being called,
    163   // clear user_callback_ up front.
    164   CompletionCallback* c = user_callback_;
    165   user_callback_ = NULL;
    166   DLOG(INFO) << "Finished setting up SOCKS handshake";
    167   c->Run(result);
    168 }
    169 
    170 void SOCKSClientSocket::OnIOComplete(int result) {
    171   DCHECK_NE(STATE_NONE, next_state_);
    172   int rv = DoLoop(result);
    173   if (rv != ERR_IO_PENDING) {
    174     LoadLog::EndEvent(load_log_, LoadLog::TYPE_SOCKS_CONNECT);
    175     load_log_ = NULL;
    176     DoCallback(rv);
    177   }
    178 }
    179 
    180 int SOCKSClientSocket::DoLoop(int last_io_result) {
    181   DCHECK_NE(next_state_, STATE_NONE);
    182   int rv = last_io_result;
    183   do {
    184     State state = next_state_;
    185     next_state_ = STATE_NONE;
    186     switch (state) {
    187       case STATE_RESOLVE_HOST:
    188         DCHECK_EQ(OK, rv);
    189         rv = DoResolveHost();
    190         break;
    191       case STATE_RESOLVE_HOST_COMPLETE:
    192         rv = DoResolveHostComplete(rv);
    193         break;
    194       case STATE_HANDSHAKE_WRITE:
    195         DCHECK_EQ(OK, rv);
    196         rv = DoHandshakeWrite();
    197         break;
    198       case STATE_HANDSHAKE_WRITE_COMPLETE:
    199         rv = DoHandshakeWriteComplete(rv);
    200         break;
    201       case STATE_HANDSHAKE_READ:
    202         DCHECK_EQ(OK, rv);
    203         rv = DoHandshakeRead();
    204         break;
    205       case STATE_HANDSHAKE_READ_COMPLETE:
    206         rv = DoHandshakeReadComplete(rv);
    207         break;
    208       default:
    209         NOTREACHED() << "bad state";
    210         rv = ERR_UNEXPECTED;
    211         break;
    212     }
    213   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    214   return rv;
    215 }
    216 
    217 int SOCKSClientSocket::DoResolveHost() {
    218   DCHECK_EQ(kSOCKS4Unresolved, socks_version_);
    219 
    220   next_state_ = STATE_RESOLVE_HOST_COMPLETE;
    221   return host_resolver_.Resolve(
    222       host_request_info_, &addresses_, &io_callback_, load_log_);
    223 }
    224 
    225 int SOCKSClientSocket::DoResolveHostComplete(int result) {
    226   DCHECK_EQ(kSOCKS4Unresolved, socks_version_);
    227 
    228   bool ok = (result == OK);
    229   next_state_ = STATE_HANDSHAKE_WRITE;
    230   if (ok) {
    231     DCHECK(addresses_.head());
    232 
    233     // If the host is resolved to an IPv6 address, we revert to SOCKS4a
    234     // since IPv6 is unsupported by SOCKS4/4a protocol.
    235     struct sockaddr *host_info = addresses_.head()->ai_addr;
    236     if (host_info->sa_family == AF_INET) {
    237       DLOG(INFO) << "Resolved host. Using SOCKS4 to communicate";
    238       socks_version_ = kSOCKS4;
    239     } else {
    240       DLOG(INFO) << "Resolved host but to IPv6. Using SOCKS4a to communicate";
    241       socks_version_ = kSOCKS4a;
    242     }
    243   } else {
    244     DLOG(INFO) << "Could not resolve host. Using SOCKS4a to communicate";
    245     socks_version_ = kSOCKS4a;
    246   }
    247 
    248   // Even if DNS resolution fails, we send OK since the server
    249   // resolves the domain.
    250   return OK;
    251 }
    252 
    253 // Builds the buffer that is to be sent to the server.
    254 // We check whether the SOCKS proxy is 4 or 4A.
    255 // In case it is 4A, the record size increases by size of the hostname.
    256 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
    257   DCHECK_NE(kSOCKS4Unresolved, socks_version_);
    258 
    259   SOCKS4ServerRequest request;
    260   request.version = kSOCKSVersion4;
    261   request.command = kSOCKSStreamRequest;
    262   request.nw_port = htons(host_request_info_.port());
    263 
    264   if (socks_version_ == kSOCKS4) {
    265     const struct addrinfo* ai = addresses_.head();
    266     DCHECK(ai);
    267     // If the sockaddr is IPv6, we have already marked the version to socks4a
    268     // and so this step does not get hit.
    269     struct sockaddr_in* ipv4_host =
    270         reinterpret_cast<struct sockaddr_in*>(ai->ai_addr);
    271     memcpy(&request.ip, &(ipv4_host->sin_addr), sizeof(ipv4_host->sin_addr));
    272 
    273     DLOG(INFO) << "Resolved Host is : " << NetAddressToString(ai);
    274   } else if (socks_version_ == kSOCKS4a) {
    275     // invalid IP of the form 0.0.0.127
    276     memcpy(&request.ip, kInvalidIp, arraysize(kInvalidIp));
    277   } else {
    278     NOTREACHED();
    279   }
    280 
    281   std::string handshake_data(reinterpret_cast<char*>(&request),
    282                              sizeof(request));
    283   handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
    284 
    285   // In case we are passing the domain also, pass the hostname
    286   // terminated with a null character.
    287   if (socks_version_ == kSOCKS4a) {
    288     handshake_data.append(host_request_info_.hostname());
    289     handshake_data.push_back('\0');
    290   }
    291 
    292   return handshake_data;
    293 }
    294 
    295 // Writes the SOCKS handshake data to the underlying socket connection.
    296 int SOCKSClientSocket::DoHandshakeWrite() {
    297   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
    298 
    299   if (buffer_.empty()) {
    300     buffer_ = BuildHandshakeWriteBuffer();
    301     bytes_sent_ = 0;
    302   }
    303 
    304   int handshake_buf_len = buffer_.size() - bytes_sent_;
    305   DCHECK_GT(handshake_buf_len, 0);
    306   handshake_buf_ = new IOBuffer(handshake_buf_len);
    307   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
    308          handshake_buf_len);
    309   return transport_->Write(handshake_buf_, handshake_buf_len, &io_callback_);
    310 }
    311 
    312 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
    313   DCHECK_NE(kSOCKS4Unresolved, socks_version_);
    314 
    315   if (result < 0)
    316     return result;
    317 
    318   // We ignore the case when result is 0, since the underlying Write
    319   // may return spurious writes while waiting on the socket.
    320 
    321   bytes_sent_ += result;
    322   if (bytes_sent_ == buffer_.size()) {
    323     next_state_ = STATE_HANDSHAKE_READ;
    324     buffer_.clear();
    325   } else if (bytes_sent_ < buffer_.size()) {
    326     next_state_ = STATE_HANDSHAKE_WRITE;
    327   } else {
    328     return ERR_UNEXPECTED;
    329   }
    330 
    331   return OK;
    332 }
    333 
    334 int SOCKSClientSocket::DoHandshakeRead() {
    335   DCHECK_NE(kSOCKS4Unresolved, socks_version_);
    336 
    337   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
    338 
    339   if (buffer_.empty()) {
    340     bytes_received_ = 0;
    341   }
    342 
    343   int handshake_buf_len = kReadHeaderSize - bytes_received_;
    344   handshake_buf_ = new IOBuffer(handshake_buf_len);
    345   return transport_->Read(handshake_buf_, handshake_buf_len, &io_callback_);
    346 }
    347 
    348 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
    349   DCHECK_NE(kSOCKS4Unresolved, socks_version_);
    350 
    351   if (result < 0)
    352     return result;
    353 
    354   // The underlying socket closed unexpectedly.
    355   if (result == 0)
    356     return ERR_CONNECTION_CLOSED;
    357 
    358   if (bytes_received_ + result > kReadHeaderSize)
    359     return ERR_INVALID_RESPONSE;
    360 
    361   buffer_.append(handshake_buf_->data(), result);
    362   bytes_received_ += result;
    363   if (bytes_received_ < kReadHeaderSize) {
    364     next_state_ = STATE_HANDSHAKE_READ;
    365     return OK;
    366   }
    367 
    368   const SOCKS4ServerResponse* response =
    369       reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
    370 
    371   if (response->reserved_null != 0x00) {
    372     LOG(ERROR) << "Unknown response from SOCKS server.";
    373     return ERR_INVALID_RESPONSE;
    374   }
    375 
    376   // TODO(arindam): Add SOCKS specific failure codes in net_error_list.h
    377   switch (response->code) {
    378     case kServerResponseOk:
    379       completed_handshake_ = true;
    380       return OK;
    381     case kServerResponseRejected:
    382       LOG(ERROR) << "SOCKS request rejected or failed";
    383       return ERR_FAILED;
    384     case kServerResponseNotReachable:
    385       LOG(ERROR) << "SOCKS request failed because client is not running "
    386                  << "identd (or not reachable from the server)";
    387       return ERR_NAME_NOT_RESOLVED;
    388     case kServerResponseMismatchedUserId:
    389       LOG(ERROR) << "SOCKS request failed because client's identd could "
    390                  << "not confirm the user ID string in the request";
    391       return ERR_FAILED;
    392     default:
    393       LOG(ERROR) << "SOCKS server sent unknown response";
    394       return ERR_INVALID_RESPONSE;
    395   }
    396 
    397   // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
    398 }
    399 
    400 int SOCKSClientSocket::GetPeerName(struct sockaddr* name,
    401                                    socklen_t* namelen) {
    402   return transport_->GetPeerName(name, namelen);
    403 }
    404 
    405 }  // namespace net
    406