Home | History | Annotate | Download | only in websockets
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/websockets/websocket_handshake_draft75.h"
      6 
      7 #include "base/memory/ref_counted.h"
      8 #include "base/string_util.h"
      9 #include "net/http/http_response_headers.h"
     10 #include "net/http/http_util.h"
     11 
     12 namespace net {
     13 
     14 const char WebSocketHandshakeDraft75::kServerHandshakeHeader[] =
     15     "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
     16 const size_t WebSocketHandshakeDraft75::kServerHandshakeHeaderLength =
     17     sizeof(kServerHandshakeHeader) - 1;
     18 
     19 const char WebSocketHandshakeDraft75::kUpgradeHeader[] =
     20     "Upgrade: WebSocket\r\n";
     21 const size_t WebSocketHandshakeDraft75::kUpgradeHeaderLength =
     22     sizeof(kUpgradeHeader) - 1;
     23 
     24 const char WebSocketHandshakeDraft75::kConnectionHeader[] =
     25     "Connection: Upgrade\r\n";
     26 const size_t WebSocketHandshakeDraft75::kConnectionHeaderLength =
     27     sizeof(kConnectionHeader) - 1;
     28 
     29 WebSocketHandshakeDraft75::WebSocketHandshakeDraft75(
     30     const GURL& url,
     31     const std::string& origin,
     32     const std::string& location,
     33     const std::string& protocol)
     34     : WebSocketHandshake(url, origin, location, protocol) {
     35 }
     36 
     37 WebSocketHandshakeDraft75::~WebSocketHandshakeDraft75() {
     38 }
     39 
     40 std::string WebSocketHandshakeDraft75::CreateClientHandshakeMessage() {
     41   std::string msg;
     42   msg = "GET ";
     43   msg += GetResourceName();
     44   msg += " HTTP/1.1\r\n";
     45   msg += kUpgradeHeader;
     46   msg += kConnectionHeader;
     47   msg += "Host: ";
     48   msg += GetHostFieldValue();
     49   msg += "\r\n";
     50   msg += "Origin: ";
     51   msg += GetOriginFieldValue();
     52   msg += "\r\n";
     53   if (!protocol_.empty()) {
     54     msg += "WebSocket-Protocol: ";
     55     msg += protocol_;
     56     msg += "\r\n";
     57   }
     58   // TODO(ukai): Add cookie if necessary.
     59   msg += "\r\n";
     60   return msg;
     61 }
     62 
     63 int WebSocketHandshakeDraft75::ReadServerHandshake(
     64     const char* data, size_t len) {
     65   mode_ = MODE_INCOMPLETE;
     66   if (len < kServerHandshakeHeaderLength) {
     67     return -1;
     68   }
     69   if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
     70     mode_ = MODE_NORMAL;
     71   } else {
     72     int eoh = HttpUtil::LocateEndOfHeaders(data, len);
     73     if (eoh < 0)
     74       return -1;
     75     return eoh;
     76   }
     77   const char* p = data + kServerHandshakeHeaderLength;
     78   const char* end = data + len;
     79 
     80   if (mode_ == MODE_NORMAL) {
     81     size_t header_size = end - p;
     82     if (header_size < kUpgradeHeaderLength)
     83       return -1;
     84     if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
     85       mode_ = MODE_FAILED;
     86       DVLOG(1) << "Bad Upgrade Header " << std::string(p, kUpgradeHeaderLength);
     87       return p - data;
     88     }
     89     p += kUpgradeHeaderLength;
     90     header_size = end - p;
     91     if (header_size < kConnectionHeaderLength)
     92       return -1;
     93     if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
     94       mode_ = MODE_FAILED;
     95       DVLOG(1) << "Bad Connection Header "
     96                << std::string(p, kConnectionHeaderLength);
     97       return p - data;
     98     }
     99     p += kConnectionHeaderLength;
    100   }
    101 
    102   int eoh = HttpUtil::LocateEndOfHeaders(data, len);
    103   if (eoh == -1)
    104     return eoh;
    105 
    106   scoped_refptr<HttpResponseHeaders> headers(
    107       new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
    108   if (!ProcessHeaders(*headers)) {
    109     DVLOG(1) << "Process Headers failed: " << std::string(data, eoh);
    110     mode_ = MODE_FAILED;
    111   }
    112   switch (mode_) {
    113     case MODE_NORMAL:
    114       if (CheckResponseHeaders()) {
    115         mode_ = MODE_CONNECTED;
    116       } else {
    117         mode_ = MODE_FAILED;
    118       }
    119       break;
    120     default:
    121       mode_ = MODE_FAILED;
    122       break;
    123   }
    124   return eoh;
    125 }
    126 
    127 bool WebSocketHandshakeDraft75::ProcessHeaders(
    128     const HttpResponseHeaders& headers) {
    129   if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
    130     return false;
    131 
    132   if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
    133     return false;
    134 
    135   // If |protocol_| is not specified by client, we don't care if there's
    136   // protocol field or not as specified in the spec.
    137   if (!protocol_.empty()
    138       && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
    139     return false;
    140   return true;
    141 }
    142 
    143 bool WebSocketHandshakeDraft75::CheckResponseHeaders() const {
    144   DCHECK(mode_ == MODE_NORMAL);
    145   if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str()))
    146     return false;
    147   if (location_ != ws_location_)
    148     return false;
    149   if (!protocol_.empty() && protocol_ != ws_protocol_)
    150     return false;
    151   return true;
    152 }
    153 
    154 }  // namespace net
    155