Home | History | Annotate | Download | only in websockets
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/websockets/websocket_handshake.h"
      6 
      7 #include <algorithm>
      8 #include <vector>
      9 
     10 #include "base/logging.h"
     11 #include "base/md5.h"
     12 #include "base/memory/ref_counted.h"
     13 #include "base/rand_util.h"
     14 #include "base/string_number_conversions.h"
     15 #include "base/string_util.h"
     16 #include "base/stringprintf.h"
     17 #include "net/http/http_response_headers.h"
     18 #include "net/http/http_util.h"
     19 
     20 namespace net {
     21 
     22 const int WebSocketHandshake::kWebSocketPort = 80;
     23 const int WebSocketHandshake::kSecureWebSocketPort = 443;
     24 
     25 WebSocketHandshake::WebSocketHandshake(
     26     const GURL& url,
     27     const std::string& origin,
     28     const std::string& location,
     29     const std::string& protocol)
     30     : url_(url),
     31       origin_(origin),
     32       location_(location),
     33       protocol_(protocol),
     34       mode_(MODE_INCOMPLETE) {
     35 }
     36 
     37 WebSocketHandshake::~WebSocketHandshake() {
     38 }
     39 
     40 bool WebSocketHandshake::is_secure() const {
     41   return url_.SchemeIs("wss");
     42 }
     43 
     44 std::string WebSocketHandshake::CreateClientHandshakeMessage() {
     45   if (!parameter_.get()) {
     46     parameter_.reset(new Parameter);
     47     parameter_->GenerateKeys();
     48   }
     49   std::string msg;
     50 
     51   // WebSocket protocol 4.1 Opening handshake.
     52 
     53   msg = "GET ";
     54   msg += GetResourceName();
     55   msg += " HTTP/1.1\r\n";
     56 
     57   std::vector<std::string> fields;
     58 
     59   fields.push_back("Upgrade: WebSocket");
     60   fields.push_back("Connection: Upgrade");
     61 
     62   fields.push_back("Host: " + GetHostFieldValue());
     63 
     64   fields.push_back("Origin: " + GetOriginFieldValue());
     65 
     66   if (!protocol_.empty())
     67     fields.push_back("Sec-WebSocket-Protocol: " + protocol_);
     68 
     69   // TODO(ukai): Add cookie if necessary.
     70 
     71   fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1());
     72   fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2());
     73 
     74   std::random_shuffle(fields.begin(), fields.end(), base::RandGenerator);
     75 
     76   for (size_t i = 0; i < fields.size(); i++) {
     77     msg += fields[i] + "\r\n";
     78   }
     79   msg += "\r\n";
     80 
     81   msg.append(parameter_->GetKey3());
     82   return msg;
     83 }
     84 
     85 int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) {
     86   mode_ = MODE_INCOMPLETE;
     87   int eoh = HttpUtil::LocateEndOfHeaders(data, len);
     88   if (eoh < 0)
     89     return -1;
     90 
     91   scoped_refptr<HttpResponseHeaders> headers(
     92       new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
     93 
     94   if (headers->response_code() != 101) {
     95     mode_ = MODE_FAILED;
     96     DVLOG(1) << "Bad response code: " << headers->response_code();
     97     return eoh;
     98   }
     99   mode_ = MODE_NORMAL;
    100   if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) {
    101     DVLOG(1) << "Process Headers failed: " << std::string(data, eoh);
    102     mode_ = MODE_FAILED;
    103     return eoh;
    104   }
    105   if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) {
    106     mode_ = MODE_INCOMPLETE;
    107     return -1;
    108   }
    109   uint8 expected[Parameter::kExpectedResponseSize];
    110   parameter_->GetExpectedResponse(expected);
    111   if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) {
    112     mode_ = MODE_FAILED;
    113     return eoh + Parameter::kExpectedResponseSize;
    114   }
    115   mode_ = MODE_CONNECTED;
    116   return eoh + Parameter::kExpectedResponseSize;
    117 }
    118 
    119 std::string WebSocketHandshake::GetResourceName() const {
    120   std::string resource_name = url_.path();
    121   if (url_.has_query()) {
    122     resource_name += "?";
    123     resource_name += url_.query();
    124   }
    125   return resource_name;
    126 }
    127 
    128 std::string WebSocketHandshake::GetHostFieldValue() const {
    129   // url_.host() is expected to be encoded in punnycode here.
    130   std::string host = StringToLowerASCII(url_.host());
    131   if (url_.has_port()) {
    132     bool secure = is_secure();
    133     int port = url_.EffectiveIntPort();
    134     if ((!secure &&
    135          port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
    136         (secure &&
    137          port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
    138       host += ":";
    139       host += base::IntToString(port);
    140     }
    141   }
    142   return host;
    143 }
    144 
    145 std::string WebSocketHandshake::GetOriginFieldValue() const {
    146   // It's OK to lowercase the origin as the Origin header does not contain
    147   // the path or query portions, as per
    148   // http://tools.ietf.org/html/draft-abarth-origin-00.
    149   //
    150   // TODO(satorux): Should we trim the port portion here if it's 80 for
    151   // http:// or 443 for https:// ? Or can we assume it's done by the
    152   // client of the library?
    153   return StringToLowerASCII(origin_);
    154 }
    155 
    156 /* static */
    157 bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers,
    158                                          const std::string& name,
    159                                          std::string* value) {
    160   std::string first_value;
    161   void* iter = NULL;
    162   if (!headers.EnumerateHeader(&iter, name, &first_value))
    163     return false;
    164 
    165   // Checks no more |name| found in |headers|.
    166   // Second call of EnumerateHeader() must return false.
    167   std::string second_value;
    168   if (headers.EnumerateHeader(&iter, name, &second_value))
    169     return false;
    170   *value = first_value;
    171   return true;
    172 }
    173 
    174 bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) {
    175   std::string value;
    176   if (!GetSingleHeader(headers, "upgrade", &value) ||
    177       value != "WebSocket")
    178     return false;
    179 
    180   if (!GetSingleHeader(headers, "connection", &value) ||
    181       !LowerCaseEqualsASCII(value, "upgrade"))
    182     return false;
    183 
    184   if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_))
    185     return false;
    186 
    187   if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_))
    188     return false;
    189 
    190   // If |protocol_| is not specified by client, we don't care if there's
    191   // protocol field or not as specified in the spec.
    192   if (!protocol_.empty()
    193       && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_))
    194     return false;
    195   return true;
    196 }
    197 
    198 bool WebSocketHandshake::CheckResponseHeaders() const {
    199   DCHECK(mode_ == MODE_NORMAL);
    200   if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str()))
    201     return false;
    202   if (location_ != ws_location_)
    203     return false;
    204   if (!protocol_.empty() && protocol_ != ws_protocol_)
    205     return false;
    206   return true;
    207 }
    208 
    209 namespace {
    210 
    211 // unsigned int version of base::RandInt().
    212 // we can't use base::RandInt(), because max would be negative if it is
    213 // represented as int, so DCHECK(min <= max) fails.
    214 uint32 RandUint32(uint32 min, uint32 max) {
    215   DCHECK(min <= max);
    216 
    217   uint64 range = static_cast<int64>(max) - min + 1;
    218   uint64 number = base::RandUint64();
    219   // TODO(ukai): fix to be uniform.
    220   // the distribution of the result of modulo will be biased.
    221   uint32 result = min + static_cast<uint32>(number % range);
    222   DCHECK(result >= min && result <= max);
    223   return result;
    224 }
    225 
    226 }
    227 
    228 uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) =
    229     RandUint32;
    230 uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39];
    231 
    232 WebSocketHandshake::Parameter::Parameter()
    233     : number_1_(0), number_2_(0) {
    234   if (randomCharacterInSecWebSocketKey[0] == '\0') {
    235     int i = 0;
    236     for (int ch = 0x21; ch <= 0x2F; ch++, i++)
    237       randomCharacterInSecWebSocketKey[i] = ch;
    238     for (int ch = 0x3A; ch <= 0x7E; ch++, i++)
    239       randomCharacterInSecWebSocketKey[i] = ch;
    240   }
    241 }
    242 
    243 WebSocketHandshake::Parameter::~Parameter() {}
    244 
    245 void WebSocketHandshake::Parameter::GenerateKeys() {
    246   GenerateSecWebSocketKey(&number_1_, &key_1_);
    247   GenerateSecWebSocketKey(&number_2_, &key_2_);
    248   GenerateKey3();
    249 }
    250 
    251 static void SetChallengeNumber(uint8* buf, uint32 number) {
    252   uint8* p = buf + 3;
    253   for (int i = 0; i < 4; i++) {
    254     *p = (uint8)(number & 0xFF);
    255     --p;
    256     number >>= 8;
    257   }
    258 }
    259 
    260 void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const {
    261   uint8 challenge[kExpectedResponseSize];
    262   SetChallengeNumber(&challenge[0], number_1_);
    263   SetChallengeNumber(&challenge[4], number_2_);
    264   memcpy(&challenge[8], key_3_.data(), kKey3Size);
    265   MD5Digest digest;
    266   MD5Sum(challenge, kExpectedResponseSize, &digest);
    267   memcpy(expected, digest.a, kExpectedResponseSize);
    268 }
    269 
    270 /* static */
    271 void WebSocketHandshake::Parameter::SetRandomNumberGenerator(
    272     uint32 (*rand)(uint32 min, uint32 max)) {
    273   rand_ = rand;
    274 }
    275 
    276 void WebSocketHandshake::Parameter::GenerateSecWebSocketKey(
    277     uint32* number, std::string* key) {
    278   uint32 space = rand_(1, 12);
    279   uint32 max = 4294967295U / space;
    280   *number = rand_(0, max);
    281   uint32 product = *number * space;
    282 
    283   std::string s = base::StringPrintf("%u", product);
    284   int n = rand_(1, 12);
    285   for (int i = 0; i < n; i++) {
    286     int pos = rand_(0, s.length());
    287     int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1);
    288     s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) +
    289         s.substr(pos);
    290   }
    291   for (uint32 i = 0; i < space; i++) {
    292     int pos = rand_(1, s.length() - 1);
    293     s = s.substr(0, pos) + " " + s.substr(pos);
    294   }
    295   *key = s;
    296 }
    297 
    298 void WebSocketHandshake::Parameter::GenerateKey3() {
    299   key_3_.clear();
    300   for (int i = 0; i < 8; i++) {
    301     key_3_.append(1, rand_(0, 255));
    302   }
    303 }
    304 
    305 }  // namespace net
    306