Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket.h"
      6 
      7 #include "base/basictypes.h"
      8 #include "base/bind.h"
      9 #include "base/callback_helpers.h"
     10 #include "base/compiler_specific.h"
     11 #include "base/sys_byteorder.h"
     12 #include "net/base/io_buffer.h"
     13 #include "net/base/net_log.h"
     14 #include "net/base/net_util.h"
     15 #include "net/socket/client_socket_handle.h"
     16 
     17 namespace net {
     18 
     19 // Every SOCKS server requests a user-id from the client. It is optional
     20 // and we send an empty string.
     21 static const char kEmptyUserId[] = "";
     22 
     23 // For SOCKS4, the client sends 8 bytes  plus the size of the user-id.
     24 static const unsigned int kWriteHeaderSize = 8;
     25 
     26 // For SOCKS4 the server sends 8 bytes for acknowledgement.
     27 static const unsigned int kReadHeaderSize = 8;
     28 
     29 // Server Response codes for SOCKS.
     30 static const uint8 kServerResponseOk  = 0x5A;
     31 static const uint8 kServerResponseRejected = 0x5B;
     32 static const uint8 kServerResponseNotReachable = 0x5C;
     33 static const uint8 kServerResponseMismatchedUserId = 0x5D;
     34 
     35 static const uint8 kSOCKSVersion4 = 0x04;
     36 static const uint8 kSOCKSStreamRequest = 0x01;
     37 
     38 // A struct holding the essential details of the SOCKS4 Server Request.
     39 // The port in the header is stored in network byte order.
     40 struct SOCKS4ServerRequest {
     41   uint8 version;
     42   uint8 command;
     43   uint16 nw_port;
     44   uint8 ip[4];
     45 };
     46 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
     47                socks4_server_request_struct_wrong_size);
     48 
     49 // A struct holding details of the SOCKS4 Server Response.
     50 struct SOCKS4ServerResponse {
     51   uint8 reserved_null;
     52   uint8 code;
     53   uint16 port;
     54   uint8 ip[4];
     55 };
     56 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
     57                socks4_server_response_struct_wrong_size);
     58 
     59 SOCKSClientSocket::SOCKSClientSocket(
     60     scoped_ptr<ClientSocketHandle> transport_socket,
     61     const HostResolver::RequestInfo& req_info,
     62     RequestPriority priority,
     63     HostResolver* host_resolver)
     64     : transport_(transport_socket.Pass()),
     65       next_state_(STATE_NONE),
     66       completed_handshake_(false),
     67       bytes_sent_(0),
     68       bytes_received_(0),
     69       was_ever_used_(false),
     70       host_resolver_(host_resolver),
     71       host_request_info_(req_info),
     72       priority_(priority),
     73       net_log_(transport_->socket()->NetLog()) {}
     74 
     75 SOCKSClientSocket::~SOCKSClientSocket() {
     76   Disconnect();
     77 }
     78 
     79 int SOCKSClientSocket::Connect(const CompletionCallback& callback) {
     80   DCHECK(transport_.get());
     81   DCHECK(transport_->socket());
     82   DCHECK_EQ(STATE_NONE, next_state_);
     83   DCHECK(user_callback_.is_null());
     84 
     85   // If already connected, then just return OK.
     86   if (completed_handshake_)
     87     return OK;
     88 
     89   next_state_ = STATE_RESOLVE_HOST;
     90 
     91   net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT);
     92 
     93   int rv = DoLoop(OK);
     94   if (rv == ERR_IO_PENDING) {
     95     user_callback_ = callback;
     96   } else {
     97     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
     98   }
     99   return rv;
    100 }
    101 
    102 void SOCKSClientSocket::Disconnect() {
    103   completed_handshake_ = false;
    104   host_resolver_.Cancel();
    105   transport_->socket()->Disconnect();
    106 
    107   // Reset other states to make sure they aren't mistakenly used later.
    108   // These are the states initialized by Connect().
    109   next_state_ = STATE_NONE;
    110   user_callback_.Reset();
    111 }
    112 
    113 bool SOCKSClientSocket::IsConnected() const {
    114   return completed_handshake_ && transport_->socket()->IsConnected();
    115 }
    116 
    117 bool SOCKSClientSocket::IsConnectedAndIdle() const {
    118   return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
    119 }
    120 
    121 const BoundNetLog& SOCKSClientSocket::NetLog() const {
    122   return net_log_;
    123 }
    124 
    125 void SOCKSClientSocket::SetSubresourceSpeculation() {
    126   if (transport_.get() && transport_->socket()) {
    127     transport_->socket()->SetSubresourceSpeculation();
    128   } else {
    129     NOTREACHED();
    130   }
    131 }
    132 
    133 void SOCKSClientSocket::SetOmniboxSpeculation() {
    134   if (transport_.get() && transport_->socket()) {
    135     transport_->socket()->SetOmniboxSpeculation();
    136   } else {
    137     NOTREACHED();
    138   }
    139 }
    140 
    141 bool SOCKSClientSocket::WasEverUsed() const {
    142   return was_ever_used_;
    143 }
    144 
    145 bool SOCKSClientSocket::UsingTCPFastOpen() const {
    146   if (transport_.get() && transport_->socket()) {
    147     return transport_->socket()->UsingTCPFastOpen();
    148   }
    149   NOTREACHED();
    150   return false;
    151 }
    152 
    153 bool SOCKSClientSocket::WasNpnNegotiated() const {
    154   if (transport_.get() && transport_->socket()) {
    155     return transport_->socket()->WasNpnNegotiated();
    156   }
    157   NOTREACHED();
    158   return false;
    159 }
    160 
    161 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
    162   if (transport_.get() && transport_->socket()) {
    163     return transport_->socket()->GetNegotiatedProtocol();
    164   }
    165   NOTREACHED();
    166   return kProtoUnknown;
    167 }
    168 
    169 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
    170   if (transport_.get() && transport_->socket()) {
    171     return transport_->socket()->GetSSLInfo(ssl_info);
    172   }
    173   NOTREACHED();
    174   return false;
    175 
    176 }
    177 
    178 // Read is called by the transport layer above to read. This can only be done
    179 // if the SOCKS handshake is complete.
    180 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
    181                             const CompletionCallback& callback) {
    182   DCHECK(completed_handshake_);
    183   DCHECK_EQ(STATE_NONE, next_state_);
    184   DCHECK(user_callback_.is_null());
    185   DCHECK(!callback.is_null());
    186 
    187   int rv = transport_->socket()->Read(
    188       buf, buf_len,
    189       base::Bind(&SOCKSClientSocket::OnReadWriteComplete,
    190                  base::Unretained(this), callback));
    191   if (rv > 0)
    192     was_ever_used_ = true;
    193   return rv;
    194 }
    195 
    196 // Write is called by the transport layer. This can only be done if the
    197 // SOCKS handshake is complete.
    198 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
    199                              const CompletionCallback& callback) {
    200   DCHECK(completed_handshake_);
    201   DCHECK_EQ(STATE_NONE, next_state_);
    202   DCHECK(user_callback_.is_null());
    203   DCHECK(!callback.is_null());
    204 
    205   int rv = transport_->socket()->Write(
    206       buf, buf_len,
    207       base::Bind(&SOCKSClientSocket::OnReadWriteComplete,
    208                  base::Unretained(this), callback));
    209   if (rv > 0)
    210     was_ever_used_ = true;
    211   return rv;
    212 }
    213 
    214 int SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
    215   return transport_->socket()->SetReceiveBufferSize(size);
    216 }
    217 
    218 int SOCKSClientSocket::SetSendBufferSize(int32 size) {
    219   return transport_->socket()->SetSendBufferSize(size);
    220 }
    221 
    222 void SOCKSClientSocket::DoCallback(int result) {
    223   DCHECK_NE(ERR_IO_PENDING, result);
    224   DCHECK(!user_callback_.is_null());
    225 
    226   // Since Run() may result in Read being called,
    227   // clear user_callback_ up front.
    228   DVLOG(1) << "Finished setting up SOCKS handshake";
    229   base::ResetAndReturn(&user_callback_).Run(result);
    230 }
    231 
    232 void SOCKSClientSocket::OnIOComplete(int result) {
    233   DCHECK_NE(STATE_NONE, next_state_);
    234   int rv = DoLoop(result);
    235   if (rv != ERR_IO_PENDING) {
    236     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
    237     DoCallback(rv);
    238   }
    239 }
    240 
    241 void SOCKSClientSocket::OnReadWriteComplete(const CompletionCallback& callback,
    242                                             int result) {
    243   DCHECK_NE(ERR_IO_PENDING, result);
    244   DCHECK(!callback.is_null());
    245 
    246   if (result > 0)
    247     was_ever_used_ = true;
    248   callback.Run(result);
    249 }
    250 
    251 int SOCKSClientSocket::DoLoop(int last_io_result) {
    252   DCHECK_NE(next_state_, STATE_NONE);
    253   int rv = last_io_result;
    254   do {
    255     State state = next_state_;
    256     next_state_ = STATE_NONE;
    257     switch (state) {
    258       case STATE_RESOLVE_HOST:
    259         DCHECK_EQ(OK, rv);
    260         rv = DoResolveHost();
    261         break;
    262       case STATE_RESOLVE_HOST_COMPLETE:
    263         rv = DoResolveHostComplete(rv);
    264         break;
    265       case STATE_HANDSHAKE_WRITE:
    266         DCHECK_EQ(OK, rv);
    267         rv = DoHandshakeWrite();
    268         break;
    269       case STATE_HANDSHAKE_WRITE_COMPLETE:
    270         rv = DoHandshakeWriteComplete(rv);
    271         break;
    272       case STATE_HANDSHAKE_READ:
    273         DCHECK_EQ(OK, rv);
    274         rv = DoHandshakeRead();
    275         break;
    276       case STATE_HANDSHAKE_READ_COMPLETE:
    277         rv = DoHandshakeReadComplete(rv);
    278         break;
    279       default:
    280         NOTREACHED() << "bad state";
    281         rv = ERR_UNEXPECTED;
    282         break;
    283     }
    284   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    285   return rv;
    286 }
    287 
    288 int SOCKSClientSocket::DoResolveHost() {
    289   next_state_ = STATE_RESOLVE_HOST_COMPLETE;
    290   // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
    291   // addresses for the target host.
    292   host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
    293   return host_resolver_.Resolve(
    294       host_request_info_,
    295       priority_,
    296       &addresses_,
    297       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
    298       net_log_);
    299 }
    300 
    301 int SOCKSClientSocket::DoResolveHostComplete(int result) {
    302   if (result != OK) {
    303     // Resolving the hostname failed; fail the request rather than automatically
    304     // falling back to SOCKS4a (since it can be confusing to see invalid IP
    305     // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
    306     return result;
    307   }
    308 
    309   next_state_ = STATE_HANDSHAKE_WRITE;
    310   return OK;
    311 }
    312 
    313 // Builds the buffer that is to be sent to the server.
    314 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
    315   SOCKS4ServerRequest request;
    316   request.version = kSOCKSVersion4;
    317   request.command = kSOCKSStreamRequest;
    318   request.nw_port = base::HostToNet16(host_request_info_.port());
    319 
    320   DCHECK(!addresses_.empty());
    321   const IPEndPoint& endpoint = addresses_.front();
    322 
    323   // We disabled IPv6 results when resolving the hostname, so none of the
    324   // results in the list will be IPv6.
    325   // TODO(eroman): we only ever use the first address in the list. It would be
    326   //               more robust to try all the IP addresses we have before
    327   //               failing the connect attempt.
    328   CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
    329   CHECK_LE(endpoint.address().size(), sizeof(request.ip));
    330   memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size());
    331 
    332   DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
    333 
    334   std::string handshake_data(reinterpret_cast<char*>(&request),
    335                              sizeof(request));
    336   handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
    337 
    338   return handshake_data;
    339 }
    340 
    341 // Writes the SOCKS handshake data to the underlying socket connection.
    342 int SOCKSClientSocket::DoHandshakeWrite() {
    343   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
    344 
    345   if (buffer_.empty()) {
    346     buffer_ = BuildHandshakeWriteBuffer();
    347     bytes_sent_ = 0;
    348   }
    349 
    350   int handshake_buf_len = buffer_.size() - bytes_sent_;
    351   DCHECK_GT(handshake_buf_len, 0);
    352   handshake_buf_ = new IOBuffer(handshake_buf_len);
    353   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
    354          handshake_buf_len);
    355   return transport_->socket()->Write(
    356       handshake_buf_.get(),
    357       handshake_buf_len,
    358       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
    359 }
    360 
    361 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
    362   if (result < 0)
    363     return result;
    364 
    365   // We ignore the case when result is 0, since the underlying Write
    366   // may return spurious writes while waiting on the socket.
    367 
    368   bytes_sent_ += result;
    369   if (bytes_sent_ == buffer_.size()) {
    370     next_state_ = STATE_HANDSHAKE_READ;
    371     buffer_.clear();
    372   } else if (bytes_sent_ < buffer_.size()) {
    373     next_state_ = STATE_HANDSHAKE_WRITE;
    374   } else {
    375     return ERR_UNEXPECTED;
    376   }
    377 
    378   return OK;
    379 }
    380 
    381 int SOCKSClientSocket::DoHandshakeRead() {
    382   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
    383 
    384   if (buffer_.empty()) {
    385     bytes_received_ = 0;
    386   }
    387 
    388   int handshake_buf_len = kReadHeaderSize - bytes_received_;
    389   handshake_buf_ = new IOBuffer(handshake_buf_len);
    390   return transport_->socket()->Read(
    391       handshake_buf_.get(),
    392       handshake_buf_len,
    393       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
    394 }
    395 
    396 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
    397   if (result < 0)
    398     return result;
    399 
    400   // The underlying socket closed unexpectedly.
    401   if (result == 0)
    402     return ERR_CONNECTION_CLOSED;
    403 
    404   if (bytes_received_ + result > kReadHeaderSize) {
    405     // TODO(eroman): Describe failure in NetLog.
    406     return ERR_SOCKS_CONNECTION_FAILED;
    407   }
    408 
    409   buffer_.append(handshake_buf_->data(), result);
    410   bytes_received_ += result;
    411   if (bytes_received_ < kReadHeaderSize) {
    412     next_state_ = STATE_HANDSHAKE_READ;
    413     return OK;
    414   }
    415 
    416   const SOCKS4ServerResponse* response =
    417       reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
    418 
    419   if (response->reserved_null != 0x00) {
    420     LOG(ERROR) << "Unknown response from SOCKS server.";
    421     return ERR_SOCKS_CONNECTION_FAILED;
    422   }
    423 
    424   switch (response->code) {
    425     case kServerResponseOk:
    426       completed_handshake_ = true;
    427       return OK;
    428     case kServerResponseRejected:
    429       LOG(ERROR) << "SOCKS request rejected or failed";
    430       return ERR_SOCKS_CONNECTION_FAILED;
    431     case kServerResponseNotReachable:
    432       LOG(ERROR) << "SOCKS request failed because client is not running "
    433                  << "identd (or not reachable from the server)";
    434       return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
    435     case kServerResponseMismatchedUserId:
    436       LOG(ERROR) << "SOCKS request failed because client's identd could "
    437                  << "not confirm the user ID string in the request";
    438       return ERR_SOCKS_CONNECTION_FAILED;
    439     default:
    440       LOG(ERROR) << "SOCKS server sent unknown response";
    441       return ERR_SOCKS_CONNECTION_FAILED;
    442   }
    443 
    444   // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
    445 }
    446 
    447 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
    448   return transport_->socket()->GetPeerAddress(address);
    449 }
    450 
    451 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
    452   return transport_->socket()->GetLocalAddress(address);
    453 }
    454 
    455 }  // namespace net
    456