Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socks_client_socket.h"
      6 
      7 #include "base/basictypes.h"
      8 #include "base/bind.h"
      9 #include "base/compiler_specific.h"
     10 #include "base/sys_byteorder.h"
     11 #include "net/base/io_buffer.h"
     12 #include "net/base/net_log.h"
     13 #include "net/base/net_util.h"
     14 #include "net/socket/client_socket_handle.h"
     15 
     16 namespace net {
     17 
     18 // Every SOCKS server requests a user-id from the client. It is optional
     19 // and we send an empty string.
     20 static const char kEmptyUserId[] = "";
     21 
     22 // For SOCKS4, the client sends 8 bytes  plus the size of the user-id.
     23 static const unsigned int kWriteHeaderSize = 8;
     24 
     25 // For SOCKS4 the server sends 8 bytes for acknowledgement.
     26 static const unsigned int kReadHeaderSize = 8;
     27 
     28 // Server Response codes for SOCKS.
     29 static const uint8 kServerResponseOk  = 0x5A;
     30 static const uint8 kServerResponseRejected = 0x5B;
     31 static const uint8 kServerResponseNotReachable = 0x5C;
     32 static const uint8 kServerResponseMismatchedUserId = 0x5D;
     33 
     34 static const uint8 kSOCKSVersion4 = 0x04;
     35 static const uint8 kSOCKSStreamRequest = 0x01;
     36 
     37 // A struct holding the essential details of the SOCKS4 Server Request.
     38 // The port in the header is stored in network byte order.
     39 struct SOCKS4ServerRequest {
     40   uint8 version;
     41   uint8 command;
     42   uint16 nw_port;
     43   uint8 ip[4];
     44 };
     45 COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
     46                socks4_server_request_struct_wrong_size);
     47 
     48 // A struct holding details of the SOCKS4 Server Response.
     49 struct SOCKS4ServerResponse {
     50   uint8 reserved_null;
     51   uint8 code;
     52   uint16 port;
     53   uint8 ip[4];
     54 };
     55 COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
     56                socks4_server_response_struct_wrong_size);
     57 
     58 SOCKSClientSocket::SOCKSClientSocket(
     59     scoped_ptr<ClientSocketHandle> transport_socket,
     60     const HostResolver::RequestInfo& req_info,
     61     RequestPriority priority,
     62     HostResolver* host_resolver)
     63     : transport_(transport_socket.Pass()),
     64       next_state_(STATE_NONE),
     65       completed_handshake_(false),
     66       bytes_sent_(0),
     67       bytes_received_(0),
     68       host_resolver_(host_resolver),
     69       host_request_info_(req_info),
     70       priority_(priority),
     71       net_log_(transport_->socket()->NetLog()) {}
     72 
     73 SOCKSClientSocket::~SOCKSClientSocket() {
     74   Disconnect();
     75 }
     76 
     77 int SOCKSClientSocket::Connect(const CompletionCallback& callback) {
     78   DCHECK(transport_.get());
     79   DCHECK(transport_->socket());
     80   DCHECK_EQ(STATE_NONE, next_state_);
     81   DCHECK(user_callback_.is_null());
     82 
     83   // If already connected, then just return OK.
     84   if (completed_handshake_)
     85     return OK;
     86 
     87   next_state_ = STATE_RESOLVE_HOST;
     88 
     89   net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT);
     90 
     91   int rv = DoLoop(OK);
     92   if (rv == ERR_IO_PENDING) {
     93     user_callback_ = callback;
     94   } else {
     95     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
     96   }
     97   return rv;
     98 }
     99 
    100 void SOCKSClientSocket::Disconnect() {
    101   completed_handshake_ = false;
    102   host_resolver_.Cancel();
    103   transport_->socket()->Disconnect();
    104 
    105   // Reset other states to make sure they aren't mistakenly used later.
    106   // These are the states initialized by Connect().
    107   next_state_ = STATE_NONE;
    108   user_callback_.Reset();
    109 }
    110 
    111 bool SOCKSClientSocket::IsConnected() const {
    112   return completed_handshake_ && transport_->socket()->IsConnected();
    113 }
    114 
    115 bool SOCKSClientSocket::IsConnectedAndIdle() const {
    116   return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
    117 }
    118 
    119 const BoundNetLog& SOCKSClientSocket::NetLog() const {
    120   return net_log_;
    121 }
    122 
    123 void SOCKSClientSocket::SetSubresourceSpeculation() {
    124   if (transport_.get() && transport_->socket()) {
    125     transport_->socket()->SetSubresourceSpeculation();
    126   } else {
    127     NOTREACHED();
    128   }
    129 }
    130 
    131 void SOCKSClientSocket::SetOmniboxSpeculation() {
    132   if (transport_.get() && transport_->socket()) {
    133     transport_->socket()->SetOmniboxSpeculation();
    134   } else {
    135     NOTREACHED();
    136   }
    137 }
    138 
    139 bool SOCKSClientSocket::WasEverUsed() const {
    140   if (transport_.get() && transport_->socket()) {
    141     return transport_->socket()->WasEverUsed();
    142   }
    143   NOTREACHED();
    144   return false;
    145 }
    146 
    147 bool SOCKSClientSocket::UsingTCPFastOpen() const {
    148   if (transport_.get() && transport_->socket()) {
    149     return transport_->socket()->UsingTCPFastOpen();
    150   }
    151   NOTREACHED();
    152   return false;
    153 }
    154 
    155 bool SOCKSClientSocket::WasNpnNegotiated() const {
    156   if (transport_.get() && transport_->socket()) {
    157     return transport_->socket()->WasNpnNegotiated();
    158   }
    159   NOTREACHED();
    160   return false;
    161 }
    162 
    163 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
    164   if (transport_.get() && transport_->socket()) {
    165     return transport_->socket()->GetNegotiatedProtocol();
    166   }
    167   NOTREACHED();
    168   return kProtoUnknown;
    169 }
    170 
    171 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
    172   if (transport_.get() && transport_->socket()) {
    173     return transport_->socket()->GetSSLInfo(ssl_info);
    174   }
    175   NOTREACHED();
    176   return false;
    177 
    178 }
    179 
    180 // Read is called by the transport layer above to read. This can only be done
    181 // if the SOCKS handshake is complete.
    182 int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
    183                             const CompletionCallback& callback) {
    184   DCHECK(completed_handshake_);
    185   DCHECK_EQ(STATE_NONE, next_state_);
    186   DCHECK(user_callback_.is_null());
    187 
    188   return transport_->socket()->Read(buf, buf_len, callback);
    189 }
    190 
    191 // Write is called by the transport layer. This can only be done if the
    192 // SOCKS handshake is complete.
    193 int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
    194                              const CompletionCallback& callback) {
    195   DCHECK(completed_handshake_);
    196   DCHECK_EQ(STATE_NONE, next_state_);
    197   DCHECK(user_callback_.is_null());
    198 
    199   return transport_->socket()->Write(buf, buf_len, callback);
    200 }
    201 
    202 bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
    203   return transport_->socket()->SetReceiveBufferSize(size);
    204 }
    205 
    206 bool SOCKSClientSocket::SetSendBufferSize(int32 size) {
    207   return transport_->socket()->SetSendBufferSize(size);
    208 }
    209 
    210 void SOCKSClientSocket::DoCallback(int result) {
    211   DCHECK_NE(ERR_IO_PENDING, result);
    212   DCHECK(!user_callback_.is_null());
    213 
    214   // Since Run() may result in Read being called,
    215   // clear user_callback_ up front.
    216   CompletionCallback c = user_callback_;
    217   user_callback_.Reset();
    218   DVLOG(1) << "Finished setting up SOCKS handshake";
    219   c.Run(result);
    220 }
    221 
    222 void SOCKSClientSocket::OnIOComplete(int result) {
    223   DCHECK_NE(STATE_NONE, next_state_);
    224   int rv = DoLoop(result);
    225   if (rv != ERR_IO_PENDING) {
    226     net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
    227     DoCallback(rv);
    228   }
    229 }
    230 
    231 int SOCKSClientSocket::DoLoop(int last_io_result) {
    232   DCHECK_NE(next_state_, STATE_NONE);
    233   int rv = last_io_result;
    234   do {
    235     State state = next_state_;
    236     next_state_ = STATE_NONE;
    237     switch (state) {
    238       case STATE_RESOLVE_HOST:
    239         DCHECK_EQ(OK, rv);
    240         rv = DoResolveHost();
    241         break;
    242       case STATE_RESOLVE_HOST_COMPLETE:
    243         rv = DoResolveHostComplete(rv);
    244         break;
    245       case STATE_HANDSHAKE_WRITE:
    246         DCHECK_EQ(OK, rv);
    247         rv = DoHandshakeWrite();
    248         break;
    249       case STATE_HANDSHAKE_WRITE_COMPLETE:
    250         rv = DoHandshakeWriteComplete(rv);
    251         break;
    252       case STATE_HANDSHAKE_READ:
    253         DCHECK_EQ(OK, rv);
    254         rv = DoHandshakeRead();
    255         break;
    256       case STATE_HANDSHAKE_READ_COMPLETE:
    257         rv = DoHandshakeReadComplete(rv);
    258         break;
    259       default:
    260         NOTREACHED() << "bad state";
    261         rv = ERR_UNEXPECTED;
    262         break;
    263     }
    264   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
    265   return rv;
    266 }
    267 
    268 int SOCKSClientSocket::DoResolveHost() {
    269   next_state_ = STATE_RESOLVE_HOST_COMPLETE;
    270   // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
    271   // addresses for the target host.
    272   host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
    273   return host_resolver_.Resolve(
    274       host_request_info_,
    275       priority_,
    276       &addresses_,
    277       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
    278       net_log_);
    279 }
    280 
    281 int SOCKSClientSocket::DoResolveHostComplete(int result) {
    282   if (result != OK) {
    283     // Resolving the hostname failed; fail the request rather than automatically
    284     // falling back to SOCKS4a (since it can be confusing to see invalid IP
    285     // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
    286     return result;
    287   }
    288 
    289   next_state_ = STATE_HANDSHAKE_WRITE;
    290   return OK;
    291 }
    292 
    293 // Builds the buffer that is to be sent to the server.
    294 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
    295   SOCKS4ServerRequest request;
    296   request.version = kSOCKSVersion4;
    297   request.command = kSOCKSStreamRequest;
    298   request.nw_port = base::HostToNet16(host_request_info_.port());
    299 
    300   DCHECK(!addresses_.empty());
    301   const IPEndPoint& endpoint = addresses_.front();
    302 
    303   // We disabled IPv6 results when resolving the hostname, so none of the
    304   // results in the list will be IPv6.
    305   // TODO(eroman): we only ever use the first address in the list. It would be
    306   //               more robust to try all the IP addresses we have before
    307   //               failing the connect attempt.
    308   CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
    309   CHECK_LE(endpoint.address().size(), sizeof(request.ip));
    310   memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size());
    311 
    312   DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
    313 
    314   std::string handshake_data(reinterpret_cast<char*>(&request),
    315                              sizeof(request));
    316   handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
    317 
    318   return handshake_data;
    319 }
    320 
    321 // Writes the SOCKS handshake data to the underlying socket connection.
    322 int SOCKSClientSocket::DoHandshakeWrite() {
    323   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
    324 
    325   if (buffer_.empty()) {
    326     buffer_ = BuildHandshakeWriteBuffer();
    327     bytes_sent_ = 0;
    328   }
    329 
    330   int handshake_buf_len = buffer_.size() - bytes_sent_;
    331   DCHECK_GT(handshake_buf_len, 0);
    332   handshake_buf_ = new IOBuffer(handshake_buf_len);
    333   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
    334          handshake_buf_len);
    335   return transport_->socket()->Write(
    336       handshake_buf_.get(),
    337       handshake_buf_len,
    338       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
    339 }
    340 
    341 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
    342   if (result < 0)
    343     return result;
    344 
    345   // We ignore the case when result is 0, since the underlying Write
    346   // may return spurious writes while waiting on the socket.
    347 
    348   bytes_sent_ += result;
    349   if (bytes_sent_ == buffer_.size()) {
    350     next_state_ = STATE_HANDSHAKE_READ;
    351     buffer_.clear();
    352   } else if (bytes_sent_ < buffer_.size()) {
    353     next_state_ = STATE_HANDSHAKE_WRITE;
    354   } else {
    355     return ERR_UNEXPECTED;
    356   }
    357 
    358   return OK;
    359 }
    360 
    361 int SOCKSClientSocket::DoHandshakeRead() {
    362   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
    363 
    364   if (buffer_.empty()) {
    365     bytes_received_ = 0;
    366   }
    367 
    368   int handshake_buf_len = kReadHeaderSize - bytes_received_;
    369   handshake_buf_ = new IOBuffer(handshake_buf_len);
    370   return transport_->socket()->Read(
    371       handshake_buf_.get(),
    372       handshake_buf_len,
    373       base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
    374 }
    375 
    376 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
    377   if (result < 0)
    378     return result;
    379 
    380   // The underlying socket closed unexpectedly.
    381   if (result == 0)
    382     return ERR_CONNECTION_CLOSED;
    383 
    384   if (bytes_received_ + result > kReadHeaderSize) {
    385     // TODO(eroman): Describe failure in NetLog.
    386     return ERR_SOCKS_CONNECTION_FAILED;
    387   }
    388 
    389   buffer_.append(handshake_buf_->data(), result);
    390   bytes_received_ += result;
    391   if (bytes_received_ < kReadHeaderSize) {
    392     next_state_ = STATE_HANDSHAKE_READ;
    393     return OK;
    394   }
    395 
    396   const SOCKS4ServerResponse* response =
    397       reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
    398 
    399   if (response->reserved_null != 0x00) {
    400     LOG(ERROR) << "Unknown response from SOCKS server.";
    401     return ERR_SOCKS_CONNECTION_FAILED;
    402   }
    403 
    404   switch (response->code) {
    405     case kServerResponseOk:
    406       completed_handshake_ = true;
    407       return OK;
    408     case kServerResponseRejected:
    409       LOG(ERROR) << "SOCKS request rejected or failed";
    410       return ERR_SOCKS_CONNECTION_FAILED;
    411     case kServerResponseNotReachable:
    412       LOG(ERROR) << "SOCKS request failed because client is not running "
    413                  << "identd (or not reachable from the server)";
    414       return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
    415     case kServerResponseMismatchedUserId:
    416       LOG(ERROR) << "SOCKS request failed because client's identd could "
    417                  << "not confirm the user ID string in the request";
    418       return ERR_SOCKS_CONNECTION_FAILED;
    419     default:
    420       LOG(ERROR) << "SOCKS server sent unknown response";
    421       return ERR_SOCKS_CONNECTION_FAILED;
    422   }
    423 
    424   // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
    425 }
    426 
    427 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
    428   return transport_->socket()->GetPeerAddress(address);
    429 }
    430 
    431 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
    432   return transport_->socket()->GetLocalAddress(address);
    433 }
    434 
    435 }  // namespace net
    436