Home | History | Annotate | Download | only in websockets
      1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <algorithm>
      6 #include <limits>
      7 
      8 #include "net/websockets/websocket.h"
      9 
     10 #include "base/message_loop.h"
     11 #include "net/http/http_response_headers.h"
     12 #include "net/http/http_util.h"
     13 
     14 namespace net {
     15 
     16 static const int kWebSocketPort = 80;
     17 static const int kSecureWebSocketPort = 443;
     18 
     19 static const char kServerHandshakeHeader[] =
     20     "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
     21 static const size_t kServerHandshakeHeaderLength =
     22     sizeof(kServerHandshakeHeader) - 1;
     23 
     24 static const char kUpgradeHeader[] = "Upgrade: WebSocket\r\n";
     25 static const size_t kUpgradeHeaderLength = sizeof(kUpgradeHeader) - 1;
     26 
     27 static const char kConnectionHeader[] = "Connection: Upgrade\r\n";
     28 static const size_t kConnectionHeaderLength = sizeof(kConnectionHeader) - 1;
     29 
     30 bool WebSocket::Request::is_secure() const {
     31   return url_.SchemeIs("wss");
     32 }
     33 
     34 WebSocket::WebSocket(Request* request, WebSocketDelegate* delegate)
     35     : ready_state_(INITIALIZED),
     36       mode_(MODE_INCOMPLETE),
     37       request_(request),
     38       delegate_(delegate),
     39       origin_loop_(MessageLoop::current()),
     40       socket_stream_(NULL),
     41       max_pending_send_allowed_(0),
     42       current_read_buf_(NULL),
     43       read_consumed_len_(0),
     44       current_write_buf_(NULL) {
     45   DCHECK(request_.get());
     46   DCHECK(delegate_);
     47   DCHECK(origin_loop_);
     48 }
     49 
     50 WebSocket::~WebSocket() {
     51   DCHECK(ready_state_ == INITIALIZED || !delegate_);
     52   DCHECK(!socket_stream_);
     53   DCHECK(!delegate_);
     54 }
     55 
     56 void WebSocket::Connect() {
     57   DCHECK(ready_state_ == INITIALIZED);
     58   DCHECK(request_.get());
     59   DCHECK(delegate_);
     60   DCHECK(!socket_stream_);
     61   DCHECK(MessageLoop::current() == origin_loop_);
     62 
     63   socket_stream_ = new SocketStream(request_->url(), this);
     64   socket_stream_->set_context(request_->context());
     65 
     66   if (request_->host_resolver())
     67     socket_stream_->SetHostResolver(request_->host_resolver());
     68   if (request_->client_socket_factory())
     69     socket_stream_->SetClientSocketFactory(request_->client_socket_factory());
     70 
     71   AddRef();  // Release in DoClose().
     72   ready_state_ = CONNECTING;
     73   socket_stream_->Connect();
     74 }
     75 
     76 void WebSocket::Send(const std::string& msg) {
     77   DCHECK(ready_state_ == OPEN);
     78   DCHECK(MessageLoop::current() == origin_loop_);
     79 
     80   IOBufferWithSize* buf = new IOBufferWithSize(msg.size() + 2);
     81   char* p = buf->data();
     82   *p = '\0';
     83   memcpy(p + 1, msg.data(), msg.size());
     84   *(p + 1 + msg.size()) = '\xff';
     85   pending_write_bufs_.push_back(buf);
     86   SendPending();
     87 }
     88 
     89 void WebSocket::Close() {
     90   DCHECK(MessageLoop::current() == origin_loop_);
     91 
     92   if (ready_state_ == INITIALIZED) {
     93     DCHECK(!socket_stream_);
     94     ready_state_ = CLOSED;
     95     return;
     96   }
     97   if (ready_state_ != CLOSED) {
     98     DCHECK(socket_stream_);
     99     socket_stream_->Close();
    100     return;
    101   }
    102 }
    103 
    104 void WebSocket::DetachDelegate() {
    105   if (!delegate_)
    106     return;
    107   delegate_ = NULL;
    108   Close();
    109 }
    110 
    111 void WebSocket::OnConnected(SocketStream* socket_stream,
    112                             int max_pending_send_allowed) {
    113   DCHECK(socket_stream == socket_stream_);
    114   max_pending_send_allowed_ = max_pending_send_allowed;
    115 
    116   // Use |max_pending_send_allowed| as hint for initial size of read buffer.
    117   current_read_buf_ = new GrowableIOBuffer();
    118   current_read_buf_->SetCapacity(max_pending_send_allowed_);
    119   read_consumed_len_ = 0;
    120 
    121   DCHECK(!current_write_buf_);
    122   const std::string msg = request_->CreateClientHandshakeMessage();
    123   IOBufferWithSize* buf = new IOBufferWithSize(msg.size());
    124   memcpy(buf->data(), msg.data(), msg.size());
    125   pending_write_bufs_.push_back(buf);
    126   origin_loop_->PostTask(FROM_HERE,
    127                          NewRunnableMethod(this, &WebSocket::SendPending));
    128 }
    129 
    130 void WebSocket::OnSentData(SocketStream* socket_stream, int amount_sent) {
    131   DCHECK(socket_stream == socket_stream_);
    132   DCHECK(current_write_buf_);
    133   current_write_buf_->DidConsume(amount_sent);
    134   DCHECK_GE(current_write_buf_->BytesRemaining(), 0);
    135   if (current_write_buf_->BytesRemaining() == 0) {
    136     current_write_buf_ = NULL;
    137     pending_write_bufs_.pop_front();
    138   }
    139   origin_loop_->PostTask(FROM_HERE,
    140                          NewRunnableMethod(this, &WebSocket::SendPending));
    141 }
    142 
    143 void WebSocket::OnReceivedData(SocketStream* socket_stream,
    144                                const char* data, int len) {
    145   DCHECK(socket_stream == socket_stream_);
    146   AddToReadBuffer(data, len);
    147   origin_loop_->PostTask(FROM_HERE,
    148                          NewRunnableMethod(this, &WebSocket::DoReceivedData));
    149 }
    150 
    151 void WebSocket::OnClose(SocketStream* socket_stream) {
    152   origin_loop_->PostTask(FROM_HERE,
    153                          NewRunnableMethod(this, &WebSocket::DoClose));
    154 }
    155 
    156 void WebSocket::OnError(const SocketStream* socket_stream, int error) {
    157   origin_loop_->PostTask(FROM_HERE,
    158                          NewRunnableMethod(this, &WebSocket::DoError, error));
    159 }
    160 
    161 std::string WebSocket::Request::CreateClientHandshakeMessage() const {
    162   std::string msg;
    163   msg = "GET ";
    164   msg += url_.path();
    165   if (url_.has_query()) {
    166     msg += "?";
    167     msg += url_.query();
    168   }
    169   msg += " HTTP/1.1\r\n";
    170   msg += kUpgradeHeader;
    171   msg += kConnectionHeader;
    172   msg += "Host: ";
    173   msg += StringToLowerASCII(url_.host());
    174   if (url_.has_port()) {
    175     bool secure = is_secure();
    176     int port = url_.EffectiveIntPort();
    177     if ((!secure &&
    178          port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
    179         (secure &&
    180          port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
    181       msg += ":";
    182       msg += IntToString(port);
    183     }
    184   }
    185   msg += "\r\n";
    186   msg += "Origin: ";
    187   // It's OK to lowercase the origin as the Origin header does not contain
    188   // the path or query portions, as per
    189   // http://tools.ietf.org/html/draft-abarth-origin-00.
    190   //
    191   // TODO(satorux): Should we trim the port portion here if it's 80 for
    192   // http:// or 443 for https:// ? Or can we assume it's done by the
    193   // client of the library?
    194   msg += StringToLowerASCII(origin_);
    195   msg += "\r\n";
    196   if (!protocol_.empty()) {
    197     msg += "WebSocket-Protocol: ";
    198     msg += protocol_;
    199     msg += "\r\n";
    200   }
    201   // TODO(ukai): Add cookie if necessary.
    202   msg += "\r\n";
    203   return msg;
    204 }
    205 
    206 int WebSocket::CheckHandshake() {
    207   DCHECK(current_read_buf_);
    208   DCHECK(ready_state_ == CONNECTING);
    209   mode_ = MODE_INCOMPLETE;
    210   const char *start = current_read_buf_->StartOfBuffer() + read_consumed_len_;
    211   const char *p = start;
    212   size_t len = current_read_buf_->offset() - read_consumed_len_;
    213   if (len < kServerHandshakeHeaderLength) {
    214     return -1;
    215   }
    216   if (!memcmp(p, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
    217     mode_ = MODE_NORMAL;
    218   } else {
    219     int eoh = HttpUtil::LocateEndOfHeaders(p, len);
    220     if (eoh < 0)
    221       return -1;
    222     scoped_refptr<HttpResponseHeaders> headers(
    223         new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(p, eoh)));
    224     if (headers->response_code() == 407) {
    225       mode_ = MODE_AUTHENTICATE;
    226       // TODO(ukai): Implement authentication handlers.
    227     }
    228     DLOG(INFO) << "non-normal websocket connection. "
    229                << "response_code=" << headers->response_code()
    230                << " mode=" << mode_;
    231     // Invalid response code.
    232     ready_state_ = CLOSED;
    233     return eoh;
    234   }
    235   const char* end = p + len + 1;
    236   p += kServerHandshakeHeaderLength;
    237 
    238   if (mode_ == MODE_NORMAL) {
    239     size_t header_size = end - p;
    240     if (header_size < kUpgradeHeaderLength)
    241       return -1;
    242     if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
    243       DLOG(INFO) << "Bad Upgrade Header "
    244                  << std::string(p, kUpgradeHeaderLength);
    245       ready_state_ = CLOSED;
    246       return p - start;
    247     }
    248     p += kUpgradeHeaderLength;
    249 
    250     header_size = end - p;
    251     if (header_size < kConnectionHeaderLength)
    252       return -1;
    253     if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
    254       DLOG(INFO) << "Bad Connection Header "
    255                  << std::string(p, kConnectionHeaderLength);
    256       ready_state_ = CLOSED;
    257       return p - start;
    258     }
    259     p += kConnectionHeaderLength;
    260   }
    261   int eoh = HttpUtil::LocateEndOfHeaders(start, len);
    262   if (eoh == -1)
    263     return eoh;
    264   scoped_refptr<HttpResponseHeaders> headers(
    265       new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(start, eoh)));
    266   if (!ProcessHeaders(*headers)) {
    267     DLOG(INFO) << "Process Headers failed: "
    268                << std::string(start, eoh);
    269     ready_state_ = CLOSED;
    270     return eoh;
    271   }
    272   switch (mode_) {
    273     case MODE_NORMAL:
    274       if (CheckResponseHeaders()) {
    275         ready_state_ = OPEN;
    276       } else {
    277         ready_state_ = CLOSED;
    278       }
    279       break;
    280     default:
    281       ready_state_ = CLOSED;
    282       break;
    283   }
    284   if (ready_state_ == CLOSED)
    285     DLOG(INFO) << "CheckHandshake mode=" << mode_
    286                << " " << std::string(start, eoh);
    287   return eoh;
    288 }
    289 
    290 // Gets the value of the specified header.
    291 // It assures only one header of |name| in |headers|.
    292 // Returns true iff single header of |name| is found in |headers|
    293 // and |value| is filled with the value.
    294 // Returns false otherwise.
    295 static bool GetSingleHeader(const HttpResponseHeaders& headers,
    296                             const std::string& name,
    297                             std::string* value) {
    298   std::string first_value;
    299   void* iter = NULL;
    300   if (!headers.EnumerateHeader(&iter, name, &first_value))
    301     return false;
    302 
    303   // Checks no more |name| found in |headers|.
    304   // Second call of EnumerateHeader() must return false.
    305   std::string second_value;
    306   if (headers.EnumerateHeader(&iter, name, &second_value))
    307     return false;
    308   *value = first_value;
    309   return true;
    310 }
    311 
    312 bool WebSocket::ProcessHeaders(const HttpResponseHeaders& headers) {
    313   if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
    314     return false;
    315 
    316   if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
    317     return false;
    318 
    319   if (!request_->protocol().empty()
    320       && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
    321     return false;
    322   return true;
    323 }
    324 
    325 bool WebSocket::CheckResponseHeaders() const {
    326   DCHECK(mode_ == MODE_NORMAL);
    327   if (!LowerCaseEqualsASCII(request_->origin(), ws_origin_.c_str()))
    328     return false;
    329   if (request_->location() != ws_location_)
    330     return false;
    331   if (request_->protocol() != ws_protocol_)
    332     return false;
    333   return true;
    334 }
    335 
    336 void WebSocket::SendPending() {
    337   DCHECK(MessageLoop::current() == origin_loop_);
    338   DCHECK(socket_stream_);
    339   if (!current_write_buf_) {
    340     if (pending_write_bufs_.empty())
    341       return;
    342     current_write_buf_ = new DrainableIOBuffer(
    343         pending_write_bufs_.front(), pending_write_bufs_.front()->size());
    344   }
    345   DCHECK_GT(current_write_buf_->BytesRemaining(), 0);
    346   bool sent = socket_stream_->SendData(
    347       current_write_buf_->data(),
    348       std::min(current_write_buf_->BytesRemaining(),
    349                max_pending_send_allowed_));
    350   DCHECK(sent);
    351 }
    352 
    353 void WebSocket::DoReceivedData() {
    354   DCHECK(MessageLoop::current() == origin_loop_);
    355   switch (ready_state_) {
    356     case CONNECTING:
    357       {
    358         int eoh = CheckHandshake();
    359         if (eoh < 0) {
    360           // Not enough data,  Retry when more data is available.
    361           return;
    362         }
    363         SkipReadBuffer(eoh);
    364       }
    365       if (ready_state_ != OPEN) {
    366         // Handshake failed.
    367         socket_stream_->Close();
    368         return;
    369       }
    370       if (delegate_)
    371         delegate_->OnOpen(this);
    372       if (current_read_buf_->offset() == read_consumed_len_) {
    373         // No remaining data after handshake message.
    374         break;
    375       }
    376       // FALL THROUGH
    377     case OPEN:
    378       ProcessFrameData();
    379       break;
    380 
    381     case CLOSED:
    382       // Closed just after DoReceivedData is queued on |origin_loop_|.
    383       break;
    384     default:
    385       NOTREACHED();
    386       break;
    387   }
    388 }
    389 
    390 void WebSocket::ProcessFrameData() {
    391   DCHECK(current_read_buf_);
    392   const char* start_frame =
    393       current_read_buf_->StartOfBuffer() + read_consumed_len_;
    394   const char* next_frame = start_frame;
    395   const char* p = next_frame;
    396   const char* end =
    397       current_read_buf_->StartOfBuffer() + current_read_buf_->offset();
    398   while (p < end) {
    399     unsigned char frame_byte = static_cast<unsigned char>(*p++);
    400     if ((frame_byte & 0x80) == 0x80) {
    401       int length = 0;
    402       while (p < end) {
    403         if (length > std::numeric_limits<int>::max() / 128) {
    404           // frame length overflow.
    405           socket_stream_->Close();
    406           return;
    407         }
    408         unsigned char c = static_cast<unsigned char>(*p);
    409         length = length * 128 + (c & 0x7f);
    410         ++p;
    411         if ((c & 0x80) != 0x80)
    412           break;
    413       }
    414       // Checks if the frame body hasn't been completely received yet.
    415       // It also checks the case the frame length bytes haven't been completely
    416       // received yet, because p == end and length > 0 in such case.
    417       if (p + length < end) {
    418         p += length;
    419         next_frame = p;
    420       } else {
    421         break;
    422       }
    423     } else {
    424       const char* msg_start = p;
    425       while (p < end && *p != '\xff')
    426         ++p;
    427       if (p < end && *p == '\xff') {
    428         if (frame_byte == 0x00 && delegate_)
    429           delegate_->OnMessage(this, std::string(msg_start, p - msg_start));
    430         ++p;
    431         next_frame = p;
    432       }
    433     }
    434   }
    435   SkipReadBuffer(next_frame - start_frame);
    436 }
    437 
    438 void WebSocket::AddToReadBuffer(const char* data, int len) {
    439   DCHECK(current_read_buf_);
    440   // Check if |current_read_buf_| has enough space to store |len| of |data|.
    441   if (len >= current_read_buf_->RemainingCapacity()) {
    442     current_read_buf_->SetCapacity(
    443         current_read_buf_->offset() + len);
    444   }
    445 
    446   DCHECK(current_read_buf_->RemainingCapacity() >= len);
    447   memcpy(current_read_buf_->data(), data, len);
    448   current_read_buf_->set_offset(current_read_buf_->offset() + len);
    449 }
    450 
    451 void WebSocket::SkipReadBuffer(int len) {
    452   if (len == 0)
    453     return;
    454   DCHECK_GT(len, 0);
    455   read_consumed_len_ += len;
    456   int remaining = current_read_buf_->offset() - read_consumed_len_;
    457   DCHECK_GE(remaining, 0);
    458   if (remaining < read_consumed_len_ &&
    459       current_read_buf_->RemainingCapacity() < read_consumed_len_) {
    460     // Pre compaction:
    461     // 0             v-read_consumed_len_  v-offset               v- capacity
    462     // |..processed..| .. remaining ..     | .. RemainingCapacity |
    463     //
    464     memmove(current_read_buf_->StartOfBuffer(),
    465             current_read_buf_->StartOfBuffer() + read_consumed_len_,
    466             remaining);
    467     read_consumed_len_ = 0;
    468     current_read_buf_->set_offset(remaining);
    469     // Post compaction:
    470     // 0read_consumed_len_  v- offset                             v- capacity
    471     // |.. remaining ..     | ..  RemainingCapacity  ...          |
    472     //
    473   }
    474 }
    475 
    476 void WebSocket::DoClose() {
    477   DCHECK(MessageLoop::current() == origin_loop_);
    478   WebSocketDelegate* delegate = delegate_;
    479   delegate_ = NULL;
    480   ready_state_ = CLOSED;
    481   if (!socket_stream_)
    482     return;
    483   socket_stream_ = NULL;
    484   if (delegate)
    485     delegate->OnClose(this);
    486   Release();
    487 }
    488 
    489 void WebSocket::DoError(int error) {
    490   DCHECK(MessageLoop::current() == origin_loop_);
    491   if (delegate_)
    492     delegate_->OnError(this, error);
    493 }
    494 
    495 }  // namespace net
    496