Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks5_client_socket.h"
      6 
      7 #include "base/basictypes.h"
      8 #include "base/callback_helpers.h"
      9 #include "base/compiler_specific.h"
     10 #include "base/debug/trace_event.h"
     11 #include "base/format_macros.h"
     12 #include "base/strings/string_util.h"
     13 #include "base/sys_byteorder.h"
     14 #include "net/base/io_buffer.h"
     15 #include "net/base/net_log.h"
     16 #include "net/base/net_util.h"
     17 #include "net/socket/client_socket_handle.h"
     18 
     19 namespace net {
     20 
     21 const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
     22 const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
     23 const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
     24 const uint8 SOCKS5ClientSocket::kSOCKS5Version = 0x05;
     25 const uint8 SOCKS5ClientSocket::kTunnelCommand = 0x01;
     26 const uint8 SOCKS5ClientSocket::kNullByte = 0x00;
     27 
     28 COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4);
     29 COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6);
     30 
     31 SOCKS5ClientSocket::SOCKS5ClientSocket(
     32     scoped_ptr<ClientSocketHandle> transport_socket,
     33     const HostResolver::RequestInfo& req_info)
     34     : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete,
     35                               base::Unretained(this))),
     36       transport_(transport_socket.Pass()),
     37       next_state_(STATE_NONE),
     38       completed_handshake_(false),
     39       bytes_sent_(0),
     40       bytes_received_(0),
     41       read_header_size(kReadHeaderSize),
     42       was_ever_used_(false),
     43       host_request_info_(req_info),
     44       net_log_(transport_->socket()->NetLog()) {
     45 }
     46 
     47 SOCKS5ClientSocket::~SOCKS5ClientSocket() {
     48   Disconnect();
     49 }
     50 
     51 int SOCKS5ClientSocket::Connect(const CompletionCallback& callback) {
     52   DCHECK(transport_.get());
     53   DCHECK(transport_->socket());
     54   DCHECK_EQ(STATE_NONE, next_state_);
     55   DCHECK(user_callback_.is_null());
     56 
     57   // If already connected, then just return OK.
     58   if (completed_handshake_)
     59     return OK;
     60 
     61   net_log_.BeginEvent(NetLog::TYPE_SOCKS5_CONNECT);
     62 
     63   next_state_ = STATE_GREET_WRITE;
     64   buffer_.clear();
     65 
     66   int rv = DoLoop(OK);
     67   if (rv == ERR_IO_PENDING) {
     68     user_callback_ = callback;
     69   } else {
     70     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_CONNECT, rv);
     71   }
     72   return rv;
     73 }
     74 
     75 void SOCKS5ClientSocket::Disconnect() {
     76   completed_handshake_ = false;
     77   transport_->socket()->Disconnect();
     78 
     79   // Reset other states to make sure they aren't mistakenly used later.
     80   // These are the states initialized by Connect().
     81   next_state_ = STATE_NONE;
     82   user_callback_.Reset();
     83 }
     84 
     85 bool SOCKS5ClientSocket::IsConnected() const {
     86   return completed_handshake_ && transport_->socket()->IsConnected();
     87 }
     88 
     89 bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
     90   return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
     91 }
     92 
     93 const BoundNetLog& SOCKS5ClientSocket::NetLog() const {
     94   return net_log_;
     95 }
     96 
     97 void SOCKS5ClientSocket::SetSubresourceSpeculation() {
     98   if (transport_.get() && transport_->socket()) {
     99     transport_->socket()->SetSubresourceSpeculation();
    100   } else {
    101     NOTREACHED();
    102   }
    103 }
    104 
    105 void SOCKS5ClientSocket::SetOmniboxSpeculation() {
    106   if (transport_.get() && transport_->socket()) {
    107     transport_->socket()->SetOmniboxSpeculation();
    108   } else {
    109     NOTREACHED();
    110   }
    111 }
    112 
    113 bool SOCKS5ClientSocket::WasEverUsed() const {
    114   return was_ever_used_;
    115 }
    116 
    117 bool SOCKS5ClientSocket::UsingTCPFastOpen() const {
    118   if (transport_.get() && transport_->socket()) {
    119     return transport_->socket()->UsingTCPFastOpen();
    120   }
    121   NOTREACHED();
    122   return false;
    123 }
    124 
    125 bool SOCKS5ClientSocket::WasNpnNegotiated() const {
    126   if (transport_.get() && transport_->socket()) {
    127     return transport_->socket()->WasNpnNegotiated();
    128   }
    129   NOTREACHED();
    130   return false;
    131 }
    132 
    133 NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
    134   if (transport_.get() && transport_->socket()) {
    135     return transport_->socket()->GetNegotiatedProtocol();
    136   }
    137   NOTREACHED();
    138   return kProtoUnknown;
    139 }
    140 
    141 bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
    142   if (transport_.get() && transport_->socket()) {
    143     return transport_->socket()->GetSSLInfo(ssl_info);
    144   }
    145   NOTREACHED();
    146   return false;
    147 
    148 }
    149 
    150 // Read is called by the transport layer above to read. This can only be done
    151 // if the SOCKS handshake is complete.
    152 int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len,
    153                              const CompletionCallback& callback) {
    154   DCHECK(completed_handshake_);
    155   DCHECK_EQ(STATE_NONE, next_state_);
    156   DCHECK(user_callback_.is_null());
    157   DCHECK(!callback.is_null());
    158 
    159   int rv = transport_->socket()->Read(
    160       buf, buf_len,
    161       base::Bind(&SOCKS5ClientSocket::OnReadWriteComplete,
    162                  base::Unretained(this), callback));
    163   if (rv > 0)
    164     was_ever_used_ = true;
    165   return rv;
    166 }
    167 
    168 // Write is called by the transport layer. This can only be done if the
    169 // SOCKS handshake is complete.
    170 int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len,
    171                               const CompletionCallback& callback) {
    172   DCHECK(completed_handshake_);
    173   DCHECK_EQ(STATE_NONE, next_state_);
    174   DCHECK(user_callback_.is_null());
    175   DCHECK(!callback.is_null());
    176 
    177   int rv = transport_->socket()->Write(
    178       buf, buf_len,
    179       base::Bind(&SOCKS5ClientSocket::OnReadWriteComplete,
    180                  base::Unretained(this), callback));
    181   if (rv > 0)
    182     was_ever_used_ = true;
    183   return rv;
    184 }
    185 
    186 int SOCKS5ClientSocket::SetReceiveBufferSize(int32 size) {
    187   return transport_->socket()->SetReceiveBufferSize(size);
    188 }
    189 
    190 int SOCKS5ClientSocket::SetSendBufferSize(int32 size) {
    191   return transport_->socket()->SetSendBufferSize(size);
    192 }
    193 
    194 void SOCKS5ClientSocket::DoCallback(int result) {
    195   DCHECK_NE(ERR_IO_PENDING, result);
    196   DCHECK(!user_callback_.is_null());
    197 
    198   // Since Run() may result in Read being called,
    199   // clear user_callback_ up front.
    200   base::ResetAndReturn(&user_callback_).Run(result);
    201 }
    202 
    203 void SOCKS5ClientSocket::OnIOComplete(int result) {
    204   DCHECK_NE(STATE_NONE, next_state_);
    205   int rv = DoLoop(result);
    206   if (rv != ERR_IO_PENDING) {
    207     net_log_.EndEvent(NetLog::TYPE_SOCKS5_CONNECT);
    208     DoCallback(rv);
    209   }
    210 }
    211 
    212 void SOCKS5ClientSocket::OnReadWriteComplete(const CompletionCallback& callback,
    213                                              int result) {
    214   DCHECK_NE(ERR_IO_PENDING, result);
    215   DCHECK(!callback.is_null());
    216 
    217   if (result > 0)
    218     was_ever_used_ = true;
    219   callback.Run(result);
    220 }
    221 
    222 int SOCKS5ClientSocket::DoLoop(int last_io_result) {
    223   DCHECK_NE(next_state_, STATE_NONE);
    224   int rv = last_io_result;
    225   do {
    226     State state = next_state_;
    227     next_state_ = STATE_NONE;
    228     switch (state) {
    229       case STATE_GREET_WRITE:
    230         DCHECK_EQ(OK, rv);
    231         net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_WRITE);
    232         rv = DoGreetWrite();
    233         break;
    234       case STATE_GREET_WRITE_COMPLETE:
    235         rv = DoGreetWriteComplete(rv);
    236         net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_WRITE, rv);
    237         break;
    238       case STATE_GREET_READ:
    239         DCHECK_EQ(OK, rv);
    240         net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_READ);
    241         rv = DoGreetRead();
    242         break;
    243       case STATE_GREET_READ_COMPLETE:
    244         rv = DoGreetReadComplete(rv);
    245         net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_READ, rv);
    246         break;
    247       case STATE_HANDSHAKE_WRITE:
    248         DCHECK_EQ(OK, rv);
    249         net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE);
    250         rv = DoHandshakeWrite();
    251         break;
    252       case STATE_HANDSHAKE_WRITE_COMPLETE:
    253         rv = DoHandshakeWriteComplete(rv);
    254         net_log_.EndEventWithNetErrorCode(
    255             NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE, rv);
    256         break;
    257       case STATE_HANDSHAKE_READ:
    258         DCHECK_EQ(OK, rv);
    259         net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_READ);
    260         rv = DoHandshakeRead();
    261         break;
    262       case STATE_HANDSHAKE_READ_COMPLETE:
    263         rv = DoHandshakeReadComplete(rv);
    264         net_log_.EndEventWithNetErrorCode(
    265             NetLog::TYPE_SOCKS5_HANDSHAKE_READ, rv);
    266         break;
    267       default:
    268         NOTREACHED() << "bad state";
    269         rv = ERR_UNEXPECTED;
    270         break;
    271     }
    272   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    273   return rv;
    274 }
    275 
    276 const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 };  // no authentication
    277 
    278 int SOCKS5ClientSocket::DoGreetWrite() {
    279   // Since we only have 1 byte to send the hostname length in, if the
    280   // URL has a hostname longer than 255 characters we can't send it.
    281   if (0xFF < host_request_info_.hostname().size()) {
    282     net_log_.AddEvent(NetLog::TYPE_SOCKS_HOSTNAME_TOO_BIG);
    283     return ERR_SOCKS_CONNECTION_FAILED;
    284   }
    285 
    286   if (buffer_.empty()) {
    287     buffer_ = std::string(kSOCKS5GreetWriteData,
    288                           arraysize(kSOCKS5GreetWriteData));
    289     bytes_sent_ = 0;
    290   }
    291 
    292   next_state_ = STATE_GREET_WRITE_COMPLETE;
    293   size_t handshake_buf_len = buffer_.size() - bytes_sent_;
    294   handshake_buf_ = new IOBuffer(handshake_buf_len);
    295   memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
    296          handshake_buf_len);
    297   return transport_->socket()
    298       ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_);
    299 }
    300 
    301 int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
    302   if (result < 0)
    303     return result;
    304 
    305   bytes_sent_ += result;
    306   if (bytes_sent_ == buffer_.size()) {
    307     buffer_.clear();
    308     bytes_received_ = 0;
    309     next_state_ = STATE_GREET_READ;
    310   } else {
    311     next_state_ = STATE_GREET_WRITE;
    312   }
    313   return OK;
    314 }
    315 
    316 int SOCKS5ClientSocket::DoGreetRead() {
    317   next_state_ = STATE_GREET_READ_COMPLETE;
    318   size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
    319   handshake_buf_ = new IOBuffer(handshake_buf_len);
    320   return transport_->socket()
    321       ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_);
    322 }
    323 
    324 int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
    325   if (result < 0)
    326     return result;
    327 
    328   if (result == 0) {
    329     net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
    330     return ERR_SOCKS_CONNECTION_FAILED;
    331   }
    332 
    333   bytes_received_ += result;
    334   buffer_.append(handshake_buf_->data(), result);
    335   if (bytes_received_ < kGreetReadHeaderSize) {
    336     next_state_ = STATE_GREET_READ;
    337     return OK;
    338   }
    339 
    340   // Got the greet data.
    341   if (buffer_[0] != kSOCKS5Version) {
    342     net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION,
    343                       NetLog::IntegerCallback("version", buffer_[0]));
    344     return ERR_SOCKS_CONNECTION_FAILED;
    345   }
    346   if (buffer_[1] != 0x00) {
    347     net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_AUTH,
    348                       NetLog::IntegerCallback("method", buffer_[1]));
    349     return ERR_SOCKS_CONNECTION_FAILED;
    350   }
    351 
    352   buffer_.clear();
    353   next_state_ = STATE_HANDSHAKE_WRITE;
    354   return OK;
    355 }
    356 
    357 int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
    358     const {
    359   DCHECK(handshake->empty());
    360 
    361   handshake->push_back(kSOCKS5Version);
    362   handshake->push_back(kTunnelCommand);  // Connect command
    363   handshake->push_back(kNullByte);  // Reserved null
    364 
    365   handshake->push_back(kEndPointDomain);  // The type of the address.
    366 
    367   DCHECK_GE(static_cast<size_t>(0xFF), host_request_info_.hostname().size());
    368 
    369   // First add the size of the hostname, followed by the hostname.
    370   handshake->push_back(static_cast<unsigned char>(
    371       host_request_info_.hostname().size()));
    372   handshake->append(host_request_info_.hostname());
    373 
    374   uint16 nw_port = base::HostToNet16(host_request_info_.port());
    375   handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
    376   return OK;
    377 }
    378 
    379 // Writes the SOCKS handshake data to the underlying socket connection.
    380 int SOCKS5ClientSocket::DoHandshakeWrite() {
    381   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
    382 
    383   if (buffer_.empty()) {
    384     int rv = BuildHandshakeWriteBuffer(&buffer_);
    385     if (rv != OK)
    386       return rv;
    387     bytes_sent_ = 0;
    388   }
    389 
    390   int handshake_buf_len = buffer_.size() - bytes_sent_;
    391   DCHECK_LT(0, handshake_buf_len);
    392   handshake_buf_ = new IOBuffer(handshake_buf_len);
    393   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
    394          handshake_buf_len);
    395   return transport_->socket()
    396       ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_);
    397 }
    398 
    399 int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
    400   if (result < 0)
    401     return result;
    402 
    403   // We ignore the case when result is 0, since the underlying Write
    404   // may return spurious writes while waiting on the socket.
    405 
    406   bytes_sent_ += result;
    407   if (bytes_sent_ == buffer_.size()) {
    408     next_state_ = STATE_HANDSHAKE_READ;
    409     buffer_.clear();
    410   } else if (bytes_sent_ < buffer_.size()) {
    411     next_state_ = STATE_HANDSHAKE_WRITE;
    412   } else {
    413     NOTREACHED();
    414   }
    415 
    416   return OK;
    417 }
    418 
    419 int SOCKS5ClientSocket::DoHandshakeRead() {
    420   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
    421 
    422   if (buffer_.empty()) {
    423     bytes_received_ = 0;
    424     read_header_size = kReadHeaderSize;
    425   }
    426 
    427   int handshake_buf_len = read_header_size - bytes_received_;
    428   handshake_buf_ = new IOBuffer(handshake_buf_len);
    429   return transport_->socket()
    430       ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_);
    431 }
    432 
    433 int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
    434   if (result < 0)
    435     return result;
    436 
    437   // The underlying socket closed unexpectedly.
    438   if (result == 0) {
    439     net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
    440     return ERR_SOCKS_CONNECTION_FAILED;
    441   }
    442 
    443   buffer_.append(handshake_buf_->data(), result);
    444   bytes_received_ += result;
    445 
    446   // When the first few bytes are read, check how many more are required
    447   // and accordingly increase them
    448   if (bytes_received_ == kReadHeaderSize) {
    449     if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
    450       net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION,
    451                         NetLog::IntegerCallback("version", buffer_[0]));
    452       return ERR_SOCKS_CONNECTION_FAILED;
    453     }
    454     if (buffer_[1] != 0x00) {
    455       net_log_.AddEvent(NetLog::TYPE_SOCKS_SERVER_ERROR,
    456                         NetLog::IntegerCallback("error_code", buffer_[1]));
    457       return ERR_SOCKS_CONNECTION_FAILED;
    458     }
    459 
    460     // We check the type of IP/Domain the server returns and accordingly
    461     // increase the size of the response. For domains, we need to read the
    462     // size of the domain, so the initial request size is upto the domain
    463     // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
    464     // read, we substract 1 byte from the additional request size.
    465     SocksEndPointAddressType address_type =
    466         static_cast<SocksEndPointAddressType>(buffer_[3]);
    467     if (address_type == kEndPointDomain)
    468       read_header_size += static_cast<uint8>(buffer_[4]);
    469     else if (address_type == kEndPointResolvedIPv4)
    470       read_header_size += sizeof(struct in_addr) - 1;
    471     else if (address_type == kEndPointResolvedIPv6)
    472       read_header_size += sizeof(struct in6_addr) - 1;
    473     else {
    474       net_log_.AddEvent(NetLog::TYPE_SOCKS_UNKNOWN_ADDRESS_TYPE,
    475                         NetLog::IntegerCallback("address_type", buffer_[3]));
    476       return ERR_SOCKS_CONNECTION_FAILED;
    477     }
    478 
    479     read_header_size += 2;  // for the port.
    480     next_state_ = STATE_HANDSHAKE_READ;
    481     return OK;
    482   }
    483 
    484   // When the final bytes are read, setup handshake. We ignore the rest
    485   // of the response since they represent the SOCKSv5 endpoint and have
    486   // no use when doing a tunnel connection.
    487   if (bytes_received_ == read_header_size) {
    488     completed_handshake_ = true;
    489     buffer_.clear();
    490     next_state_ = STATE_NONE;
    491     return OK;
    492   }
    493 
    494   next_state_ = STATE_HANDSHAKE_READ;
    495   return OK;
    496 }
    497 
    498 int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
    499   return transport_->socket()->GetPeerAddress(address);
    500 }
    501 
    502 int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
    503   return transport_->socket()->GetLocalAddress(address);
    504 }
    505 
    506 }  // namespace net
    507