Home | History | Annotate | Download | only in net
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/test/chromedriver/net/websocket.h"
      6 
      7 #include <string.h>
      8 
      9 #include "base/base64.h"
     10 #include "base/bind.h"
     11 #include "base/bind_helpers.h"
     12 #include "base/memory/scoped_vector.h"
     13 #include "base/rand_util.h"
     14 #include "base/sha1.h"
     15 #include "base/strings/string_number_conversions.h"
     16 #include "base/strings/stringprintf.h"
     17 #include "net/base/address_list.h"
     18 #include "net/base/io_buffer.h"
     19 #include "net/base/ip_endpoint.h"
     20 #include "net/base/net_errors.h"
     21 #include "net/base/net_util.h"
     22 #include "net/base/sys_addrinfo.h"
     23 #include "net/http/http_response_headers.h"
     24 #include "net/http/http_util.h"
     25 #include "net/websockets/websocket_frame.h"
     26 
     27 #if defined(OS_WIN)
     28 #include <Winsock2.h>
     29 #endif
     30 
     31 namespace {
     32 
     33 bool ResolveHost(const std::string& host, net::IPAddressNumber* address) {
     34   struct addrinfo hints;
     35   memset(&hints, 0, sizeof(hints));
     36   hints.ai_family = AF_UNSPEC;
     37   hints.ai_socktype = SOCK_STREAM;
     38 
     39   struct addrinfo* result;
     40   if (getaddrinfo(host.c_str(), NULL, &hints, &result))
     41     return false;
     42 
     43   for (struct addrinfo* addr = result; addr; addr = addr->ai_next) {
     44     if (addr->ai_family == AF_INET || addr->ai_family == AF_INET6) {
     45       net::IPEndPoint end_point;
     46       if (!end_point.FromSockAddr(addr->ai_addr, addr->ai_addrlen)) {
     47         freeaddrinfo(result);
     48         return false;
     49       }
     50       *address = end_point.address();
     51     }
     52   }
     53   freeaddrinfo(result);
     54   return true;
     55 }
     56 
     57 }  // namespace
     58 
     59 WebSocket::WebSocket(const GURL& url, WebSocketListener* listener)
     60     : url_(url),
     61       listener_(listener),
     62       state_(INITIALIZED),
     63       write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)),
     64       read_buffer_(new net::IOBufferWithSize(4096)) {}
     65 
     66 WebSocket::~WebSocket() {
     67   CHECK(thread_checker_.CalledOnValidThread());
     68 }
     69 
     70 void WebSocket::Connect(const net::CompletionCallback& callback) {
     71   CHECK(thread_checker_.CalledOnValidThread());
     72   CHECK_EQ(INITIALIZED, state_);
     73 
     74   net::IPAddressNumber address;
     75   if (!net::ParseIPLiteralToNumber(url_.HostNoBrackets(), &address)) {
     76     if (!ResolveHost(url_.HostNoBrackets(), &address)) {
     77       callback.Run(net::ERR_ADDRESS_UNREACHABLE);
     78       return;
     79     }
     80   }
     81   int port = 80;
     82   base::StringToInt(url_.port(), &port);
     83   net::AddressList addresses(net::IPEndPoint(address, port));
     84   net::NetLog::Source source;
     85   socket_.reset(new net::TCPClientSocket(addresses, NULL, source));
     86 
     87   state_ = CONNECTING;
     88   connect_callback_ = callback;
     89   int code = socket_->Connect(base::Bind(
     90       &WebSocket::OnSocketConnect, base::Unretained(this)));
     91   if (code != net::ERR_IO_PENDING)
     92     OnSocketConnect(code);
     93 }
     94 
     95 bool WebSocket::Send(const std::string& message) {
     96   CHECK(thread_checker_.CalledOnValidThread());
     97   if (state_ != OPEN)
     98     return false;
     99 
    100   net::WebSocketFrameHeader header(net::WebSocketFrameHeader::kOpCodeText);
    101   header.final = true;
    102   header.masked = true;
    103   header.payload_length = message.length();
    104   int header_size = net::GetWebSocketFrameHeaderSize(header);
    105   net::WebSocketMaskingKey masking_key = net::GenerateWebSocketMaskingKey();
    106   std::string header_str;
    107   header_str.resize(header_size);
    108   CHECK_EQ(header_size, net::WriteWebSocketFrameHeader(
    109       header, &masking_key, &header_str[0], header_str.length()));
    110 
    111   std::string masked_message = message;
    112   net::MaskWebSocketFramePayload(
    113       masking_key, 0, &masked_message[0], masked_message.length());
    114   Write(header_str + masked_message);
    115   return true;
    116 }
    117 
    118 void WebSocket::OnSocketConnect(int code) {
    119   if (code != net::OK) {
    120     Close(code);
    121     return;
    122   }
    123 
    124   base::Base64Encode(base::RandBytesAsString(16), &sec_key_);
    125   std::string handshake = base::StringPrintf(
    126       "GET %s HTTP/1.1\r\n"
    127       "Host: %s\r\n"
    128       "Upgrade: websocket\r\n"
    129       "Connection: Upgrade\r\n"
    130       "Sec-WebSocket-Key: %s\r\n"
    131       "Sec-WebSocket-Version: 13\r\n"
    132       "Pragma: no-cache\r\n"
    133       "Cache-Control: no-cache\r\n"
    134       "\r\n",
    135       url_.path().c_str(),
    136       url_.host().c_str(),
    137       sec_key_.c_str());
    138   Write(handshake);
    139   Read();
    140 }
    141 
    142 void WebSocket::Write(const std::string& data) {
    143   pending_write_ += data;
    144   if (!write_buffer_->BytesRemaining())
    145     ContinueWritingIfNecessary();
    146 }
    147 
    148 void WebSocket::OnWrite(int code) {
    149   if (!socket_->IsConnected()) {
    150     // Supposedly if |StreamSocket| is closed, the error code may be undefined.
    151     Close(net::ERR_FAILED);
    152     return;
    153   }
    154   if (code < 0) {
    155     Close(code);
    156     return;
    157   }
    158 
    159   write_buffer_->DidConsume(code);
    160   ContinueWritingIfNecessary();
    161 }
    162 
    163 void WebSocket::ContinueWritingIfNecessary() {
    164   if (!write_buffer_->BytesRemaining()) {
    165     if (pending_write_.empty())
    166       return;
    167     write_buffer_ = new net::DrainableIOBuffer(
    168         new net::StringIOBuffer(pending_write_),
    169         pending_write_.length());
    170     pending_write_.clear();
    171   }
    172   int code =
    173       socket_->Write(write_buffer_.get(),
    174                      write_buffer_->BytesRemaining(),
    175                      base::Bind(&WebSocket::OnWrite, base::Unretained(this)));
    176   if (code != net::ERR_IO_PENDING)
    177     OnWrite(code);
    178 }
    179 
    180 void WebSocket::Read() {
    181   int code =
    182       socket_->Read(read_buffer_.get(),
    183                     read_buffer_->size(),
    184                     base::Bind(&WebSocket::OnRead, base::Unretained(this)));
    185   if (code != net::ERR_IO_PENDING)
    186     OnRead(code);
    187 }
    188 
    189 void WebSocket::OnRead(int code) {
    190   if (code <= 0) {
    191     Close(code ? code : net::ERR_FAILED);
    192     return;
    193   }
    194 
    195   if (state_ == CONNECTING)
    196     OnReadDuringHandshake(read_buffer_->data(), code);
    197   else if (state_ == OPEN)
    198     OnReadDuringOpen(read_buffer_->data(), code);
    199 
    200   if (state_ != CLOSED)
    201     Read();
    202 }
    203 
    204 void WebSocket::OnReadDuringHandshake(const char* data, int len) {
    205   handshake_response_ += std::string(data, len);
    206   int headers_end = net::HttpUtil::LocateEndOfHeaders(
    207       handshake_response_.data(), handshake_response_.size(), 0);
    208   if (headers_end == -1)
    209     return;
    210 
    211   const char kMagicKey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    212   std::string websocket_accept;
    213   base::Base64Encode(base::SHA1HashString(sec_key_ + kMagicKey),
    214                      &websocket_accept);
    215   scoped_refptr<net::HttpResponseHeaders> headers(
    216       new net::HttpResponseHeaders(
    217           net::HttpUtil::AssembleRawHeaders(
    218               handshake_response_.data(), headers_end)));
    219   if (headers->response_code() != 101 ||
    220       !headers->HasHeaderValue("Upgrade", "WebSocket") ||
    221       !headers->HasHeaderValue("Connection", "Upgrade") ||
    222       !headers->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept)) {
    223     Close(net::ERR_FAILED);
    224     return;
    225   }
    226   std::string leftover_message = handshake_response_.substr(headers_end);
    227   handshake_response_.clear();
    228   sec_key_.clear();
    229   state_ = OPEN;
    230   InvokeConnectCallback(net::OK);
    231   if (!leftover_message.empty())
    232     OnReadDuringOpen(leftover_message.c_str(), leftover_message.length());
    233 }
    234 
    235 void WebSocket::OnReadDuringOpen(const char* data, int len) {
    236   ScopedVector<net::WebSocketFrameChunk> frame_chunks;
    237   CHECK(parser_.Decode(data, len, &frame_chunks));
    238   for (size_t i = 0; i < frame_chunks.size(); ++i) {
    239     scoped_refptr<net::IOBufferWithSize> buffer = frame_chunks[i]->data;
    240     if (buffer.get())
    241       next_message_ += std::string(buffer->data(), buffer->size());
    242     if (frame_chunks[i]->final_chunk) {
    243       listener_->OnMessageReceived(next_message_);
    244       next_message_.clear();
    245     }
    246   }
    247 }
    248 
    249 void WebSocket::InvokeConnectCallback(int code) {
    250   net::CompletionCallback temp = connect_callback_;
    251   connect_callback_.Reset();
    252   CHECK(!temp.is_null());
    253   temp.Run(code);
    254 }
    255 
    256 void WebSocket::Close(int code) {
    257   socket_->Disconnect();
    258   if (!connect_callback_.is_null())
    259     InvokeConnectCallback(code);
    260   if (state_ == OPEN)
    261     listener_->OnClose();
    262 
    263   state_ = CLOSED;
    264 }
    265