Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket.h"
      6 
      7 #include "base/basictypes.h"
      8 #include "base/compiler_specific.h"
      9 #include "net/base/io_buffer.h"
     10 #include "net/base/net_log.h"
     11 #include "net/base/net_util.h"
     12 #include "net/base/sys_addrinfo.h"
     13 #include "net/socket/client_socket_handle.h"
     14 
     15 namespace net {
     16 
     17 // Every SOCKS server requests a user-id from the client. It is optional
     18 // and we send an empty string.
     19 static const char kEmptyUserId[] = "";
     20 
     21 // For SOCKS4, the client sends 8 bytes  plus the size of the user-id.
     22 static const unsigned int kWriteHeaderSize = 8;
     23 
     24 // For SOCKS4 the server sends 8 bytes for acknowledgement.
     25 static const unsigned int kReadHeaderSize = 8;
     26 
     27 // Server Response codes for SOCKS.
     28 static const uint8 kServerResponseOk  = 0x5A;
     29 static const uint8 kServerResponseRejected = 0x5B;
     30 static const uint8 kServerResponseNotReachable = 0x5C;
     31 static const uint8 kServerResponseMismatchedUserId = 0x5D;
     32 
     33 static const uint8 kSOCKSVersion4 = 0x04;
     34 static const uint8 kSOCKSStreamRequest = 0x01;
     35 
     36 // A struct holding the essential details of the SOCKS4 Server Request.
     37 // The port in the header is stored in network byte order.
     38 struct SOCKS4ServerRequest {
     39   uint8 version;
     40   uint8 command;
     41   uint16 nw_port;
     42   uint8 ip[4];
     43 };
     44 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
     45                socks4_server_request_struct_wrong_size);
     46 
     47 // A struct holding details of the SOCKS4 Server Response.
     48 struct SOCKS4ServerResponse {
     49   uint8 reserved_null;
     50   uint8 code;
     51   uint16 port;
     52   uint8 ip[4];
     53 };
     54 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
     55                socks4_server_response_struct_wrong_size);
     56 
     57 SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket,
     58                                      const HostResolver::RequestInfo& req_info,
     59                                      HostResolver* host_resolver)
     60     : ALLOW_THIS_IN_INITIALIZER_LIST(
     61           io_callback_(this, &SOCKSClientSocket::OnIOComplete)),
     62       transport_(transport_socket),
     63       next_state_(STATE_NONE),
     64       user_callback_(NULL),
     65       completed_handshake_(false),
     66       bytes_sent_(0),
     67       bytes_received_(0),
     68       host_resolver_(host_resolver),
     69       host_request_info_(req_info),
     70       net_log_(transport_socket->socket()->NetLog()) {
     71 }
     72 
     73 SOCKSClientSocket::SOCKSClientSocket(ClientSocket* transport_socket,
     74                                      const HostResolver::RequestInfo& req_info,
     75                                      HostResolver* host_resolver)
     76     : ALLOW_THIS_IN_INITIALIZER_LIST(
     77           io_callback_(this, &SOCKSClientSocket::OnIOComplete)),
     78       transport_(new ClientSocketHandle()),
     79       next_state_(STATE_NONE),
     80       user_callback_(NULL),
     81       completed_handshake_(false),
     82       bytes_sent_(0),
     83       bytes_received_(0),
     84       host_resolver_(host_resolver),
     85       host_request_info_(req_info),
     86       net_log_(transport_socket->NetLog()) {
     87   transport_->set_socket(transport_socket);
     88 }
     89 
     90 SOCKSClientSocket::~SOCKSClientSocket() {
     91   Disconnect();
     92 }
     93 
     94 #ifdef ANDROID
     95 // TODO(kristianm): find out if Connect should block
     96 #endif
     97 int SOCKSClientSocket::Connect(CompletionCallback* callback
     98 #ifdef ANDROID
     99                                , bool wait_for_connect
    100                                , bool valid_uid
    101                                , uid_t calling_uid
    102 #endif
    103                               ) {
    104   DCHECK(transport_.get());
    105   DCHECK(transport_->socket());
    106   DCHECK_EQ(STATE_NONE, next_state_);
    107   DCHECK(!user_callback_);
    108 
    109   // If already connected, then just return OK.
    110   if (completed_handshake_)
    111     return OK;
    112 
    113   next_state_ = STATE_RESOLVE_HOST;
    114 
    115   net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT, NULL);
    116 
    117   int rv = DoLoop(OK);
    118   if (rv == ERR_IO_PENDING) {
    119     user_callback_ = callback;
    120   } else {
    121     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
    122   }
    123   return rv;
    124 }
    125 
    126 void SOCKSClientSocket::Disconnect() {
    127   completed_handshake_ = false;
    128   host_resolver_.Cancel();
    129   transport_->socket()->Disconnect();
    130 
    131   // Reset other states to make sure they aren't mistakenly used later.
    132   // These are the states initialized by Connect().
    133   next_state_ = STATE_NONE;
    134   user_callback_ = NULL;
    135 }
    136 
    137 bool SOCKSClientSocket::IsConnected() const {
    138   return completed_handshake_ && transport_->socket()->IsConnected();
    139 }
    140 
    141 bool SOCKSClientSocket::IsConnectedAndIdle() const {
    142   return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
    143 }
    144 
    145 const BoundNetLog& SOCKSClientSocket::NetLog() const {
    146   return net_log_;
    147 }
    148 
    149 void SOCKSClientSocket::SetSubresourceSpeculation() {
    150   if (transport_.get() && transport_->socket()) {
    151     transport_->socket()->SetSubresourceSpeculation();
    152   } else {
    153     NOTREACHED();
    154   }
    155 }
    156 
    157 void SOCKSClientSocket::SetOmniboxSpeculation() {
    158   if (transport_.get() && transport_->socket()) {
    159     transport_->socket()->SetOmniboxSpeculation();
    160   } else {
    161     NOTREACHED();
    162   }
    163 }
    164 
    165 bool SOCKSClientSocket::WasEverUsed() const {
    166   if (transport_.get() && transport_->socket()) {
    167     return transport_->socket()->WasEverUsed();
    168   }
    169   NOTREACHED();
    170   return false;
    171 }
    172 
    173 bool SOCKSClientSocket::UsingTCPFastOpen() const {
    174   if (transport_.get() && transport_->socket()) {
    175     return transport_->socket()->UsingTCPFastOpen();
    176   }
    177   NOTREACHED();
    178   return false;
    179 }
    180 
    181 
    182 // Read is called by the transport layer above to read. This can only be done
    183 // if the SOCKS handshake is complete.
    184 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
    185                             CompletionCallback* callback) {
    186   DCHECK(completed_handshake_);
    187   DCHECK_EQ(STATE_NONE, next_state_);
    188   DCHECK(!user_callback_);
    189 
    190   return transport_->socket()->Read(buf, buf_len, callback);
    191 }
    192 
    193 // Write is called by the transport layer. This can only be done if the
    194 // SOCKS handshake is complete.
    195 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
    196                              CompletionCallback* callback) {
    197   DCHECK(completed_handshake_);
    198   DCHECK_EQ(STATE_NONE, next_state_);
    199   DCHECK(!user_callback_);
    200 
    201   return transport_->socket()->Write(buf, buf_len, callback);
    202 }
    203 
    204 bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
    205   return transport_->socket()->SetReceiveBufferSize(size);
    206 }
    207 
    208 bool SOCKSClientSocket::SetSendBufferSize(int32 size) {
    209   return transport_->socket()->SetSendBufferSize(size);
    210 }
    211 
    212 void SOCKSClientSocket::DoCallback(int result) {
    213   DCHECK_NE(ERR_IO_PENDING, result);
    214   DCHECK(user_callback_);
    215 
    216   // Since Run() may result in Read being called,
    217   // clear user_callback_ up front.
    218   CompletionCallback* c = user_callback_;
    219   user_callback_ = NULL;
    220   DVLOG(1) << "Finished setting up SOCKS handshake";
    221   c->Run(result);
    222 }
    223 
    224 void SOCKSClientSocket::OnIOComplete(int result) {
    225   DCHECK_NE(STATE_NONE, next_state_);
    226   int rv = DoLoop(result);
    227   if (rv != ERR_IO_PENDING) {
    228     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
    229     DoCallback(rv);
    230   }
    231 }
    232 
    233 int SOCKSClientSocket::DoLoop(int last_io_result) {
    234   DCHECK_NE(next_state_, STATE_NONE);
    235   int rv = last_io_result;
    236   do {
    237     State state = next_state_;
    238     next_state_ = STATE_NONE;
    239     switch (state) {
    240       case STATE_RESOLVE_HOST:
    241         DCHECK_EQ(OK, rv);
    242         rv = DoResolveHost();
    243         break;
    244       case STATE_RESOLVE_HOST_COMPLETE:
    245         rv = DoResolveHostComplete(rv);
    246         break;
    247       case STATE_HANDSHAKE_WRITE:
    248         DCHECK_EQ(OK, rv);
    249         rv = DoHandshakeWrite();
    250         break;
    251       case STATE_HANDSHAKE_WRITE_COMPLETE:
    252         rv = DoHandshakeWriteComplete(rv);
    253         break;
    254       case STATE_HANDSHAKE_READ:
    255         DCHECK_EQ(OK, rv);
    256         rv = DoHandshakeRead();
    257         break;
    258       case STATE_HANDSHAKE_READ_COMPLETE:
    259         rv = DoHandshakeReadComplete(rv);
    260         break;
    261       default:
    262         NOTREACHED() << "bad state";
    263         rv = ERR_UNEXPECTED;
    264         break;
    265     }
    266   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    267   return rv;
    268 }
    269 
    270 int SOCKSClientSocket::DoResolveHost() {
    271   next_state_ = STATE_RESOLVE_HOST_COMPLETE;
    272   // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
    273   // addresses for the target host.
    274   host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
    275   return host_resolver_.Resolve(
    276       host_request_info_, &addresses_, &io_callback_, net_log_);
    277 }
    278 
    279 int SOCKSClientSocket::DoResolveHostComplete(int result) {
    280   if (result != OK) {
    281     // Resolving the hostname failed; fail the request rather than automatically
    282     // falling back to SOCKS4a (since it can be confusing to see invalid IP
    283     // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
    284     return result;
    285   }
    286 
    287   next_state_ = STATE_HANDSHAKE_WRITE;
    288   return OK;
    289 }
    290 
    291 // Builds the buffer that is to be sent to the server.
    292 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
    293   SOCKS4ServerRequest request;
    294   request.version = kSOCKSVersion4;
    295   request.command = kSOCKSStreamRequest;
    296   request.nw_port = htons(host_request_info_.port());
    297 
    298   const struct addrinfo* ai = addresses_.head();
    299   DCHECK(ai);
    300 
    301   // We disabled IPv6 results when resolving the hostname, so none of the
    302   // results in the list will be IPv6.
    303   // TODO(eroman): we only ever use the first address in the list. It would be
    304   //               more robust to try all the IP addresses we have before
    305   //               failing the connect attempt.
    306   CHECK_EQ(AF_INET, ai->ai_addr->sa_family);
    307   struct sockaddr_in* ipv4_host =
    308       reinterpret_cast<struct sockaddr_in*>(ai->ai_addr);
    309   memcpy(&request.ip, &ipv4_host->sin_addr, sizeof(ipv4_host->sin_addr));
    310 
    311   DVLOG(1) << "Resolved Host is : " << NetAddressToString(ai);
    312 
    313   std::string handshake_data(reinterpret_cast<char*>(&request),
    314                              sizeof(request));
    315   handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
    316 
    317   return handshake_data;
    318 }
    319 
    320 // Writes the SOCKS handshake data to the underlying socket connection.
    321 int SOCKSClientSocket::DoHandshakeWrite() {
    322   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
    323 
    324   if (buffer_.empty()) {
    325     buffer_ = BuildHandshakeWriteBuffer();
    326     bytes_sent_ = 0;
    327   }
    328 
    329   int handshake_buf_len = buffer_.size() - bytes_sent_;
    330   DCHECK_GT(handshake_buf_len, 0);
    331   handshake_buf_ = new IOBuffer(handshake_buf_len);
    332   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
    333          handshake_buf_len);
    334   return transport_->socket()->Write(handshake_buf_, handshake_buf_len,
    335                                      &io_callback_);
    336 }
    337 
    338 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
    339   if (result < 0)
    340     return result;
    341 
    342   // We ignore the case when result is 0, since the underlying Write
    343   // may return spurious writes while waiting on the socket.
    344 
    345   bytes_sent_ += result;
    346   if (bytes_sent_ == buffer_.size()) {
    347     next_state_ = STATE_HANDSHAKE_READ;
    348     buffer_.clear();
    349   } else if (bytes_sent_ < buffer_.size()) {
    350     next_state_ = STATE_HANDSHAKE_WRITE;
    351   } else {
    352     return ERR_UNEXPECTED;
    353   }
    354 
    355   return OK;
    356 }
    357 
    358 int SOCKSClientSocket::DoHandshakeRead() {
    359   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
    360 
    361   if (buffer_.empty()) {
    362     bytes_received_ = 0;
    363   }
    364 
    365   int handshake_buf_len = kReadHeaderSize - bytes_received_;
    366   handshake_buf_ = new IOBuffer(handshake_buf_len);
    367   return transport_->socket()->Read(handshake_buf_, handshake_buf_len,
    368                                     &io_callback_);
    369 }
    370 
    371 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
    372   if (result < 0)
    373     return result;
    374 
    375   // The underlying socket closed unexpectedly.
    376   if (result == 0)
    377     return ERR_CONNECTION_CLOSED;
    378 
    379   if (bytes_received_ + result > kReadHeaderSize) {
    380     // TODO(eroman): Describe failure in NetLog.
    381     return ERR_SOCKS_CONNECTION_FAILED;
    382   }
    383 
    384   buffer_.append(handshake_buf_->data(), result);
    385   bytes_received_ += result;
    386   if (bytes_received_ < kReadHeaderSize) {
    387     next_state_ = STATE_HANDSHAKE_READ;
    388     return OK;
    389   }
    390 
    391   const SOCKS4ServerResponse* response =
    392       reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
    393 
    394   if (response->reserved_null != 0x00) {
    395     LOG(ERROR) << "Unknown response from SOCKS server.";
    396     return ERR_SOCKS_CONNECTION_FAILED;
    397   }
    398 
    399   switch (response->code) {
    400     case kServerResponseOk:
    401       completed_handshake_ = true;
    402       return OK;
    403     case kServerResponseRejected:
    404       LOG(ERROR) << "SOCKS request rejected or failed";
    405       return ERR_SOCKS_CONNECTION_FAILED;
    406     case kServerResponseNotReachable:
    407       LOG(ERROR) << "SOCKS request failed because client is not running "
    408                  << "identd (or not reachable from the server)";
    409       return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
    410     case kServerResponseMismatchedUserId:
    411       LOG(ERROR) << "SOCKS request failed because client's identd could "
    412                  << "not confirm the user ID string in the request";
    413       return ERR_SOCKS_CONNECTION_FAILED;
    414     default:
    415       LOG(ERROR) << "SOCKS server sent unknown response";
    416       return ERR_SOCKS_CONNECTION_FAILED;
    417   }
    418 
    419   // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
    420 }
    421 
    422 int SOCKSClientSocket::GetPeerAddress(AddressList* address) const {
    423   return transport_->socket()->GetPeerAddress(address);
    424 }
    425 
    426 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
    427   return transport_->socket()->GetLocalAddress(address);
    428 }
    429 
    430 }  // namespace net
    431