Home | History | Annotate | Download | only in net
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/test/chromedriver/net/websocket.h"
      6 
      7 #include "base/base64.h"
      8 #include "base/bind.h"
      9 #include "base/bind_helpers.h"
     10 #include "base/memory/scoped_vector.h"
     11 #include "base/rand_util.h"
     12 #include "base/sha1.h"
     13 #include "base/strings/string_number_conversions.h"
     14 #include "base/strings/stringprintf.h"
     15 #include "net/base/address_list.h"
     16 #include "net/base/io_buffer.h"
     17 #include "net/base/ip_endpoint.h"
     18 #include "net/base/net_errors.h"
     19 #include "net/base/net_util.h"
     20 #include "net/http/http_response_headers.h"
     21 #include "net/http/http_util.h"
     22 #include "net/websockets/websocket_frame.h"
     23 
     24 WebSocket::WebSocket(const GURL& url, WebSocketListener* listener)
     25     : url_(url),
     26       listener_(listener),
     27       state_(INITIALIZED),
     28       write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)),
     29       read_buffer_(new net::IOBufferWithSize(4096)) {
     30   net::IPAddressNumber address;
     31   CHECK(net::ParseIPLiteralToNumber(url_.HostNoBrackets(), &address));
     32   int port = 80;
     33   base::StringToInt(url_.port(), &port);
     34   net::AddressList addresses(net::IPEndPoint(address, port));
     35   net::NetLog::Source source;
     36   socket_.reset(new net::TCPClientSocket(addresses, NULL, source));
     37 }
     38 
     39 WebSocket::~WebSocket() {
     40   CHECK(thread_checker_.CalledOnValidThread());
     41 }
     42 
     43 void WebSocket::Connect(const net::CompletionCallback& callback) {
     44   CHECK(thread_checker_.CalledOnValidThread());
     45   CHECK_EQ(INITIALIZED, state_);
     46   state_ = CONNECTING;
     47   connect_callback_ = callback;
     48   int code = socket_->Connect(base::Bind(
     49       &WebSocket::OnSocketConnect, base::Unretained(this)));
     50   if (code != net::ERR_IO_PENDING)
     51     OnSocketConnect(code);
     52 }
     53 
     54 bool WebSocket::Send(const std::string& message) {
     55   CHECK(thread_checker_.CalledOnValidThread());
     56   if (state_ != OPEN)
     57     return false;
     58 
     59   net::WebSocketFrameHeader header(net::WebSocketFrameHeader::kOpCodeText);
     60   header.final = true;
     61   header.masked = true;
     62   header.payload_length = message.length();
     63   int header_size = net::GetWebSocketFrameHeaderSize(header);
     64   net::WebSocketMaskingKey masking_key = net::GenerateWebSocketMaskingKey();
     65   std::string header_str;
     66   header_str.resize(header_size);
     67   CHECK_EQ(header_size, net::WriteWebSocketFrameHeader(
     68       header, &masking_key, &header_str[0], header_str.length()));
     69 
     70   std::string masked_message = message;
     71   net::MaskWebSocketFramePayload(
     72       masking_key, 0, &masked_message[0], masked_message.length());
     73   Write(header_str + masked_message);
     74   return true;
     75 }
     76 
     77 void WebSocket::OnSocketConnect(int code) {
     78   if (code != net::OK) {
     79     Close(code);
     80     return;
     81   }
     82 
     83   CHECK(base::Base64Encode(base::RandBytesAsString(16), &sec_key_));
     84   std::string handshake = base::StringPrintf(
     85       "GET %s HTTP/1.1\r\n"
     86       "Host: %s\r\n"
     87       "Upgrade: websocket\r\n"
     88       "Connection: Upgrade\r\n"
     89       "Sec-WebSocket-Key: %s\r\n"
     90       "Sec-WebSocket-Version: 13\r\n"
     91       "Pragma: no-cache\r\n"
     92       "Cache-Control: no-cache\r\n"
     93       "\r\n",
     94       url_.path().c_str(),
     95       url_.host().c_str(),
     96       sec_key_.c_str());
     97   Write(handshake);
     98   Read();
     99 }
    100 
    101 void WebSocket::Write(const std::string& data) {
    102   pending_write_ += data;
    103   if (!write_buffer_->BytesRemaining())
    104     ContinueWritingIfNecessary();
    105 }
    106 
    107 void WebSocket::OnWrite(int code) {
    108   if (!socket_->IsConnected()) {
    109     // Supposedly if |StreamSocket| is closed, the error code may be undefined.
    110     Close(net::ERR_FAILED);
    111     return;
    112   }
    113   if (code < 0) {
    114     Close(code);
    115     return;
    116   }
    117 
    118   write_buffer_->DidConsume(code);
    119   ContinueWritingIfNecessary();
    120 }
    121 
    122 void WebSocket::ContinueWritingIfNecessary() {
    123   if (!write_buffer_->BytesRemaining()) {
    124     if (pending_write_.empty())
    125       return;
    126     write_buffer_ = new net::DrainableIOBuffer(
    127         new net::StringIOBuffer(pending_write_),
    128         pending_write_.length());
    129     pending_write_.clear();
    130   }
    131   int code =
    132       socket_->Write(write_buffer_.get(),
    133                      write_buffer_->BytesRemaining(),
    134                      base::Bind(&WebSocket::OnWrite, base::Unretained(this)));
    135   if (code != net::ERR_IO_PENDING)
    136     OnWrite(code);
    137 }
    138 
    139 void WebSocket::Read() {
    140   int code =
    141       socket_->Read(read_buffer_.get(),
    142                     read_buffer_->size(),
    143                     base::Bind(&WebSocket::OnRead, base::Unretained(this)));
    144   if (code != net::ERR_IO_PENDING)
    145     OnRead(code);
    146 }
    147 
    148 void WebSocket::OnRead(int code) {
    149   if (code <= 0) {
    150     Close(code ? code : net::ERR_FAILED);
    151     return;
    152   }
    153 
    154   if (state_ == CONNECTING)
    155     OnReadDuringHandshake(read_buffer_->data(), code);
    156   else if (state_ == OPEN)
    157     OnReadDuringOpen(read_buffer_->data(), code);
    158 
    159   if (state_ != CLOSED)
    160     Read();
    161 }
    162 
    163 void WebSocket::OnReadDuringHandshake(const char* data, int len) {
    164   handshake_response_ += std::string(data, len);
    165   int headers_end = net::HttpUtil::LocateEndOfHeaders(
    166       handshake_response_.data(), handshake_response_.size(), 0);
    167   if (headers_end == -1)
    168     return;
    169 
    170   const char kMagicKey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    171   std::string websocket_accept;
    172   CHECK(base::Base64Encode(base::SHA1HashString(sec_key_ + kMagicKey),
    173                            &websocket_accept));
    174   scoped_refptr<net::HttpResponseHeaders> headers(
    175       new net::HttpResponseHeaders(
    176           net::HttpUtil::AssembleRawHeaders(
    177               handshake_response_.data(), headers_end)));
    178   if (headers->response_code() != 101 ||
    179       !headers->HasHeaderValue("Upgrade", "WebSocket") ||
    180       !headers->HasHeaderValue("Connection", "Upgrade") ||
    181       !headers->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept)) {
    182     Close(net::ERR_FAILED);
    183     return;
    184   }
    185   std::string leftover_message = handshake_response_.substr(headers_end);
    186   handshake_response_.clear();
    187   sec_key_.clear();
    188   state_ = OPEN;
    189   InvokeConnectCallback(net::OK);
    190   if (!leftover_message.empty())
    191     OnReadDuringOpen(leftover_message.c_str(), leftover_message.length());
    192 }
    193 
    194 void WebSocket::OnReadDuringOpen(const char* data, int len) {
    195   ScopedVector<net::WebSocketFrameChunk> frame_chunks;
    196   CHECK(parser_.Decode(data, len, &frame_chunks));
    197   for (size_t i = 0; i < frame_chunks.size(); ++i) {
    198     scoped_refptr<net::IOBufferWithSize> buffer = frame_chunks[i]->data;
    199     if (buffer.get())
    200       next_message_ += std::string(buffer->data(), buffer->size());
    201     if (frame_chunks[i]->final_chunk) {
    202       listener_->OnMessageReceived(next_message_);
    203       next_message_.clear();
    204     }
    205   }
    206 }
    207 
    208 void WebSocket::InvokeConnectCallback(int code) {
    209   net::CompletionCallback temp = connect_callback_;
    210   connect_callback_.Reset();
    211   CHECK(!temp.is_null());
    212   temp.Run(code);
    213 }
    214 
    215 void WebSocket::Close(int code) {
    216   socket_->Disconnect();
    217   if (!connect_callback_.is_null())
    218     InvokeConnectCallback(code);
    219   if (state_ == OPEN)
    220     listener_->OnClose();
    221 
    222   state_ = CLOSED;
    223 }
    224 
    225 
    226