Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks5_client_socket.h"
      6 
      7 #include "base/basictypes.h"
      8 #include "base/compiler_specific.h"
      9 #include "base/debug/trace_event.h"
     10 #include "base/format_macros.h"
     11 #include "base/strings/string_util.h"
     12 #include "base/sys_byteorder.h"
     13 #include "net/base/io_buffer.h"
     14 #include "net/base/net_log.h"
     15 #include "net/base/net_util.h"
     16 #include "net/socket/client_socket_handle.h"
     17 
     18 namespace net {
     19 
     20 const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
     21 const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
     22 const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
     23 const uint8 SOCKS5ClientSocket::kSOCKS5Version = 0x05;
     24 const uint8 SOCKS5ClientSocket::kTunnelCommand = 0x01;
     25 const uint8 SOCKS5ClientSocket::kNullByte = 0x00;
     26 
     27 COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4);
     28 COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6);
     29 
     30 SOCKS5ClientSocket::SOCKS5ClientSocket(
     31     ClientSocketHandle* transport_socket,
     32     const HostResolver::RequestInfo& req_info)
     33     : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete,
     34                               base::Unretained(this))),
     35       transport_(transport_socket),
     36       next_state_(STATE_NONE),
     37       completed_handshake_(false),
     38       bytes_sent_(0),
     39       bytes_received_(0),
     40       read_header_size(kReadHeaderSize),
     41       host_request_info_(req_info),
     42       net_log_(transport_socket->socket()->NetLog()) {
     43 }
     44 
     45 SOCKS5ClientSocket::SOCKS5ClientSocket(
     46     StreamSocket* transport_socket,
     47     const HostResolver::RequestInfo& req_info)
     48     : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete,
     49                               base::Unretained(this))),
     50       transport_(new ClientSocketHandle()),
     51       next_state_(STATE_NONE),
     52       completed_handshake_(false),
     53       bytes_sent_(0),
     54       bytes_received_(0),
     55       read_header_size(kReadHeaderSize),
     56       host_request_info_(req_info),
     57       net_log_(transport_socket->NetLog()) {
     58   transport_->set_socket(transport_socket);
     59 }
     60 
     61 SOCKS5ClientSocket::~SOCKS5ClientSocket() {
     62   Disconnect();
     63 }
     64 
     65 int SOCKS5ClientSocket::Connect(const CompletionCallback& callback) {
     66   DCHECK(transport_.get());
     67   DCHECK(transport_->socket());
     68   DCHECK_EQ(STATE_NONE, next_state_);
     69   DCHECK(user_callback_.is_null());
     70 
     71   // If already connected, then just return OK.
     72   if (completed_handshake_)
     73     return OK;
     74 
     75   net_log_.BeginEvent(NetLog::TYPE_SOCKS5_CONNECT);
     76 
     77   next_state_ = STATE_GREET_WRITE;
     78   buffer_.clear();
     79 
     80   int rv = DoLoop(OK);
     81   if (rv == ERR_IO_PENDING) {
     82     user_callback_ = callback;
     83   } else {
     84     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_CONNECT, rv);
     85   }
     86   return rv;
     87 }
     88 
     89 void SOCKS5ClientSocket::Disconnect() {
     90   completed_handshake_ = false;
     91   transport_->socket()->Disconnect();
     92 
     93   // Reset other states to make sure they aren't mistakenly used later.
     94   // These are the states initialized by Connect().
     95   next_state_ = STATE_NONE;
     96   user_callback_.Reset();
     97 }
     98 
     99 bool SOCKS5ClientSocket::IsConnected() const {
    100   return completed_handshake_ && transport_->socket()->IsConnected();
    101 }
    102 
    103 bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
    104   return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
    105 }
    106 
    107 const BoundNetLog& SOCKS5ClientSocket::NetLog() const {
    108   return net_log_;
    109 }
    110 
    111 void SOCKS5ClientSocket::SetSubresourceSpeculation() {
    112   if (transport_.get() && transport_->socket()) {
    113     transport_->socket()->SetSubresourceSpeculation();
    114   } else {
    115     NOTREACHED();
    116   }
    117 }
    118 
    119 void SOCKS5ClientSocket::SetOmniboxSpeculation() {
    120   if (transport_.get() && transport_->socket()) {
    121     transport_->socket()->SetOmniboxSpeculation();
    122   } else {
    123     NOTREACHED();
    124   }
    125 }
    126 
    127 bool SOCKS5ClientSocket::WasEverUsed() const {
    128   if (transport_.get() && transport_->socket()) {
    129     return transport_->socket()->WasEverUsed();
    130   }
    131   NOTREACHED();
    132   return false;
    133 }
    134 
    135 bool SOCKS5ClientSocket::UsingTCPFastOpen() const {
    136   if (transport_.get() && transport_->socket()) {
    137     return transport_->socket()->UsingTCPFastOpen();
    138   }
    139   NOTREACHED();
    140   return false;
    141 }
    142 
    143 bool SOCKS5ClientSocket::WasNpnNegotiated() const {
    144   if (transport_.get() && transport_->socket()) {
    145     return transport_->socket()->WasNpnNegotiated();
    146   }
    147   NOTREACHED();
    148   return false;
    149 }
    150 
    151 NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
    152   if (transport_.get() && transport_->socket()) {
    153     return transport_->socket()->GetNegotiatedProtocol();
    154   }
    155   NOTREACHED();
    156   return kProtoUnknown;
    157 }
    158 
    159 bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
    160   if (transport_.get() && transport_->socket()) {
    161     return transport_->socket()->GetSSLInfo(ssl_info);
    162   }
    163   NOTREACHED();
    164   return false;
    165 
    166 }
    167 
    168 // Read is called by the transport layer above to read. This can only be done
    169 // if the SOCKS handshake is complete.
    170 int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len,
    171                              const CompletionCallback& callback) {
    172   DCHECK(completed_handshake_);
    173   DCHECK_EQ(STATE_NONE, next_state_);
    174   DCHECK(user_callback_.is_null());
    175 
    176   return transport_->socket()->Read(buf, buf_len, callback);
    177 }
    178 
    179 // Write is called by the transport layer. This can only be done if the
    180 // SOCKS handshake is complete.
    181 int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len,
    182                              const CompletionCallback& callback) {
    183   DCHECK(completed_handshake_);
    184   DCHECK_EQ(STATE_NONE, next_state_);
    185   DCHECK(user_callback_.is_null());
    186 
    187   return transport_->socket()->Write(buf, buf_len, callback);
    188 }
    189 
    190 bool SOCKS5ClientSocket::SetReceiveBufferSize(int32 size) {
    191   return transport_->socket()->SetReceiveBufferSize(size);
    192 }
    193 
    194 bool SOCKS5ClientSocket::SetSendBufferSize(int32 size) {
    195   return transport_->socket()->SetSendBufferSize(size);
    196 }
    197 
    198 void SOCKS5ClientSocket::DoCallback(int result) {
    199   DCHECK_NE(ERR_IO_PENDING, result);
    200   DCHECK(!user_callback_.is_null());
    201 
    202   // Since Run() may result in Read being called,
    203   // clear user_callback_ up front.
    204   CompletionCallback c = user_callback_;
    205   user_callback_.Reset();
    206   c.Run(result);
    207 }
    208 
    209 void SOCKS5ClientSocket::OnIOComplete(int result) {
    210   DCHECK_NE(STATE_NONE, next_state_);
    211   int rv = DoLoop(result);
    212   if (rv != ERR_IO_PENDING) {
    213     net_log_.EndEvent(NetLog::TYPE_SOCKS5_CONNECT);
    214     DoCallback(rv);
    215   }
    216 }
    217 
    218 int SOCKS5ClientSocket::DoLoop(int last_io_result) {
    219   DCHECK_NE(next_state_, STATE_NONE);
    220   int rv = last_io_result;
    221   do {
    222     State state = next_state_;
    223     next_state_ = STATE_NONE;
    224     switch (state) {
    225       case STATE_GREET_WRITE:
    226         DCHECK_EQ(OK, rv);
    227         net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_WRITE);
    228         rv = DoGreetWrite();
    229         break;
    230       case STATE_GREET_WRITE_COMPLETE:
    231         rv = DoGreetWriteComplete(rv);
    232         net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_WRITE, rv);
    233         break;
    234       case STATE_GREET_READ:
    235         DCHECK_EQ(OK, rv);
    236         net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_READ);
    237         rv = DoGreetRead();
    238         break;
    239       case STATE_GREET_READ_COMPLETE:
    240         rv = DoGreetReadComplete(rv);
    241         net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_READ, rv);
    242         break;
    243       case STATE_HANDSHAKE_WRITE:
    244         DCHECK_EQ(OK, rv);
    245         net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE);
    246         rv = DoHandshakeWrite();
    247         break;
    248       case STATE_HANDSHAKE_WRITE_COMPLETE:
    249         rv = DoHandshakeWriteComplete(rv);
    250         net_log_.EndEventWithNetErrorCode(
    251             NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE, rv);
    252         break;
    253       case STATE_HANDSHAKE_READ:
    254         DCHECK_EQ(OK, rv);
    255         net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_READ);
    256         rv = DoHandshakeRead();
    257         break;
    258       case STATE_HANDSHAKE_READ_COMPLETE:
    259         rv = DoHandshakeReadComplete(rv);
    260         net_log_.EndEventWithNetErrorCode(
    261             NetLog::TYPE_SOCKS5_HANDSHAKE_READ, rv);
    262         break;
    263       default:
    264         NOTREACHED() << "bad state";
    265         rv = ERR_UNEXPECTED;
    266         break;
    267     }
    268   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    269   return rv;
    270 }
    271 
    272 const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 };  // no authentication
    273 const char kSOCKS5GreetReadData[] = { 0x05, 0x00 };
    274 
    275 int SOCKS5ClientSocket::DoGreetWrite() {
    276   // Since we only have 1 byte to send the hostname length in, if the
    277   // URL has a hostname longer than 255 characters we can't send it.
    278   if (0xFF < host_request_info_.hostname().size()) {
    279     net_log_.AddEvent(NetLog::TYPE_SOCKS_HOSTNAME_TOO_BIG);
    280     return ERR_SOCKS_CONNECTION_FAILED;
    281   }
    282 
    283   if (buffer_.empty()) {
    284     buffer_ = std::string(kSOCKS5GreetWriteData,
    285                           arraysize(kSOCKS5GreetWriteData));
    286     bytes_sent_ = 0;
    287   }
    288 
    289   next_state_ = STATE_GREET_WRITE_COMPLETE;
    290   size_t handshake_buf_len = buffer_.size() - bytes_sent_;
    291   handshake_buf_ = new IOBuffer(handshake_buf_len);
    292   memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
    293          handshake_buf_len);
    294   return transport_->socket()
    295       ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_);
    296 }
    297 
    298 int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
    299   if (result < 0)
    300     return result;
    301 
    302   bytes_sent_ += result;
    303   if (bytes_sent_ == buffer_.size()) {
    304     buffer_.clear();
    305     bytes_received_ = 0;
    306     next_state_ = STATE_GREET_READ;
    307   } else {
    308     next_state_ = STATE_GREET_WRITE;
    309   }
    310   return OK;
    311 }
    312 
    313 int SOCKS5ClientSocket::DoGreetRead() {
    314   next_state_ = STATE_GREET_READ_COMPLETE;
    315   size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
    316   handshake_buf_ = new IOBuffer(handshake_buf_len);
    317   return transport_->socket()
    318       ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_);
    319 }
    320 
    321 int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
    322   if (result < 0)
    323     return result;
    324 
    325   if (result == 0) {
    326     net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
    327     return ERR_SOCKS_CONNECTION_FAILED;
    328   }
    329 
    330   bytes_received_ += result;
    331   buffer_.append(handshake_buf_->data(), result);
    332   if (bytes_received_ < kGreetReadHeaderSize) {
    333     next_state_ = STATE_GREET_READ;
    334     return OK;
    335   }
    336 
    337   // Got the greet data.
    338   if (buffer_[0] != kSOCKS5Version) {
    339     net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION,
    340                       NetLog::IntegerCallback("version", buffer_[0]));
    341     return ERR_SOCKS_CONNECTION_FAILED;
    342   }
    343   if (buffer_[1] != 0x00) {
    344     net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_AUTH,
    345                       NetLog::IntegerCallback("method", buffer_[1]));
    346     return ERR_SOCKS_CONNECTION_FAILED;
    347   }
    348 
    349   buffer_.clear();
    350   next_state_ = STATE_HANDSHAKE_WRITE;
    351   return OK;
    352 }
    353 
    354 int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
    355     const {
    356   DCHECK(handshake->empty());
    357 
    358   handshake->push_back(kSOCKS5Version);
    359   handshake->push_back(kTunnelCommand);  // Connect command
    360   handshake->push_back(kNullByte);  // Reserved null
    361 
    362   handshake->push_back(kEndPointDomain);  // The type of the address.
    363 
    364   DCHECK_GE(static_cast<size_t>(0xFF), host_request_info_.hostname().size());
    365 
    366   // First add the size of the hostname, followed by the hostname.
    367   handshake->push_back(static_cast<unsigned char>(
    368       host_request_info_.hostname().size()));
    369   handshake->append(host_request_info_.hostname());
    370 
    371   uint16 nw_port = base::HostToNet16(host_request_info_.port());
    372   handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
    373   return OK;
    374 }
    375 
    376 // Writes the SOCKS handshake data to the underlying socket connection.
    377 int SOCKS5ClientSocket::DoHandshakeWrite() {
    378   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
    379 
    380   if (buffer_.empty()) {
    381     int rv = BuildHandshakeWriteBuffer(&buffer_);
    382     if (rv != OK)
    383       return rv;
    384     bytes_sent_ = 0;
    385   }
    386 
    387   int handshake_buf_len = buffer_.size() - bytes_sent_;
    388   DCHECK_LT(0, handshake_buf_len);
    389   handshake_buf_ = new IOBuffer(handshake_buf_len);
    390   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
    391          handshake_buf_len);
    392   return transport_->socket()
    393       ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_);
    394 }
    395 
    396 int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
    397   if (result < 0)
    398     return result;
    399 
    400   // We ignore the case when result is 0, since the underlying Write
    401   // may return spurious writes while waiting on the socket.
    402 
    403   bytes_sent_ += result;
    404   if (bytes_sent_ == buffer_.size()) {
    405     next_state_ = STATE_HANDSHAKE_READ;
    406     buffer_.clear();
    407   } else if (bytes_sent_ < buffer_.size()) {
    408     next_state_ = STATE_HANDSHAKE_WRITE;
    409   } else {
    410     NOTREACHED();
    411   }
    412 
    413   return OK;
    414 }
    415 
    416 int SOCKS5ClientSocket::DoHandshakeRead() {
    417   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
    418 
    419   if (buffer_.empty()) {
    420     bytes_received_ = 0;
    421     read_header_size = kReadHeaderSize;
    422   }
    423 
    424   int handshake_buf_len = read_header_size - bytes_received_;
    425   handshake_buf_ = new IOBuffer(handshake_buf_len);
    426   return transport_->socket()
    427       ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_);
    428 }
    429 
    430 int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
    431   if (result < 0)
    432     return result;
    433 
    434   // The underlying socket closed unexpectedly.
    435   if (result == 0) {
    436     net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
    437     return ERR_SOCKS_CONNECTION_FAILED;
    438   }
    439 
    440   buffer_.append(handshake_buf_->data(), result);
    441   bytes_received_ += result;
    442 
    443   // When the first few bytes are read, check how many more are required
    444   // and accordingly increase them
    445   if (bytes_received_ == kReadHeaderSize) {
    446     if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
    447       net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION,
    448                         NetLog::IntegerCallback("version", buffer_[0]));
    449       return ERR_SOCKS_CONNECTION_FAILED;
    450     }
    451     if (buffer_[1] != 0x00) {
    452       net_log_.AddEvent(NetLog::TYPE_SOCKS_SERVER_ERROR,
    453                         NetLog::IntegerCallback("error_code", buffer_[1]));
    454       return ERR_SOCKS_CONNECTION_FAILED;
    455     }
    456 
    457     // We check the type of IP/Domain the server returns and accordingly
    458     // increase the size of the response. For domains, we need to read the
    459     // size of the domain, so the initial request size is upto the domain
    460     // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
    461     // read, we substract 1 byte from the additional request size.
    462     SocksEndPointAddressType address_type =
    463         static_cast<SocksEndPointAddressType>(buffer_[3]);
    464     if (address_type == kEndPointDomain)
    465       read_header_size += static_cast<uint8>(buffer_[4]);
    466     else if (address_type == kEndPointResolvedIPv4)
    467       read_header_size += sizeof(struct in_addr) - 1;
    468     else if (address_type == kEndPointResolvedIPv6)
    469       read_header_size += sizeof(struct in6_addr) - 1;
    470     else {
    471       net_log_.AddEvent(NetLog::TYPE_SOCKS_UNKNOWN_ADDRESS_TYPE,
    472                         NetLog::IntegerCallback("address_type", buffer_[3]));
    473       return ERR_SOCKS_CONNECTION_FAILED;
    474     }
    475 
    476     read_header_size += 2;  // for the port.
    477     next_state_ = STATE_HANDSHAKE_READ;
    478     return OK;
    479   }
    480 
    481   // When the final bytes are read, setup handshake. We ignore the rest
    482   // of the response since they represent the SOCKSv5 endpoint and have
    483   // no use when doing a tunnel connection.
    484   if (bytes_received_ == read_header_size) {
    485     completed_handshake_ = true;
    486     buffer_.clear();
    487     next_state_ = STATE_NONE;
    488     return OK;
    489   }
    490 
    491   next_state_ = STATE_HANDSHAKE_READ;
    492   return OK;
    493 }
    494 
    495 int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
    496   return transport_->socket()->GetPeerAddress(address);
    497 }
    498 
    499 int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
    500   return transport_->socket()->GetLocalAddress(address);
    501 }
    502 
    503 }  // namespace net
    504