Home | History | Annotate | Download | only in websockets
      1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/websockets/websocket_handshake_handler.h"
      6 
      7 #include "base/md5.h"
      8 #include "base/string_piece.h"
      9 #include "base/string_util.h"
     10 #include "googleurl/src/gurl.h"
     11 #include "net/http/http_response_headers.h"
     12 #include "net/http/http_util.h"
     13 
     14 namespace {
     15 
     16 const size_t kRequestKey3Size = 8U;
     17 const size_t kResponseKeySize = 16U;
     18 
     19 void ParseHandshakeHeader(
     20     const char* handshake_message, int len,
     21     std::string* status_line,
     22     std::string* headers) {
     23   size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
     24   if (i == base::StringPiece::npos) {
     25     *status_line = std::string(handshake_message, len);
     26     *headers = "";
     27     return;
     28   }
     29   // |status_line| includes \r\n.
     30   *status_line = std::string(handshake_message, i + 2);
     31 
     32   int header_len = len - (i + 2) - 2;
     33   if (header_len > 0) {
     34     // |handshake_message| includes tailing \r\n\r\n.
     35     // |headers| doesn't include 2nd \r\n.
     36     *headers = std::string(handshake_message + i + 2, header_len);
     37   } else {
     38     *headers = "";
     39   }
     40 }
     41 
     42 void FetchHeaders(const std::string& headers,
     43                   const char* const headers_to_get[],
     44                   size_t headers_to_get_len,
     45                   std::vector<std::string>* values) {
     46   net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
     47   while (iter.GetNext()) {
     48     for (size_t i = 0; i < headers_to_get_len; i++) {
     49       if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
     50                                headers_to_get[i])) {
     51         values->push_back(iter.values());
     52       }
     53     }
     54   }
     55 }
     56 
     57 bool GetHeaderName(std::string::const_iterator line_begin,
     58                    std::string::const_iterator line_end,
     59                    std::string::const_iterator* name_begin,
     60                    std::string::const_iterator* name_end) {
     61   std::string::const_iterator colon = std::find(line_begin, line_end, ':');
     62   if (colon == line_end) {
     63     return false;
     64   }
     65   *name_begin = line_begin;
     66   *name_end = colon;
     67   if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
     68     return false;
     69   net::HttpUtil::TrimLWS(name_begin, name_end);
     70   return true;
     71 }
     72 
     73 // Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
     74 // is, lines that are not formatted as "<name>: <value>\r\n".
     75 std::string FilterHeaders(
     76     const std::string& headers,
     77     const char* const headers_to_remove[],
     78     size_t headers_to_remove_len) {
     79   std::string filtered_headers;
     80 
     81   StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
     82   while (lines.GetNext()) {
     83     std::string::const_iterator line_begin = lines.token_begin();
     84     std::string::const_iterator line_end = lines.token_end();
     85     std::string::const_iterator name_begin;
     86     std::string::const_iterator name_end;
     87     bool should_remove = false;
     88     if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
     89       for (size_t i = 0; i < headers_to_remove_len; ++i) {
     90         if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
     91           should_remove = true;
     92           break;
     93         }
     94       }
     95     }
     96     if (!should_remove) {
     97       filtered_headers.append(line_begin, line_end);
     98       filtered_headers.append("\r\n");
     99     }
    100   }
    101   return filtered_headers;
    102 }
    103 
    104 // Gets a key number from |key| and appends the number to |challenge|.
    105 // The key number (/part_N/) is extracted as step 4.-8. in
    106 // 5.2. Sending the server's opening handshake of
    107 // http://www.ietf.org/id/draft-ietf-hybi-thewebsocketprotocol-00.txt
    108 void GetKeyNumber(const std::string& key, std::string* challenge) {
    109   uint32 key_number = 0;
    110   uint32 spaces = 0;
    111   for (size_t i = 0; i < key.size(); ++i) {
    112     if (isdigit(key[i])) {
    113       // key_number should not overflow. (it comes from
    114       // WebCore/websockets/WebSocketHandshake.cpp).
    115       key_number = key_number * 10 + key[i] - '0';
    116     } else if (key[i] == ' ') {
    117       ++spaces;
    118     }
    119   }
    120   // spaces should not be zero in valid handshake request.
    121   if (spaces == 0)
    122     return;
    123   key_number /= spaces;
    124 
    125   char part[4];
    126   for (int i = 0; i < 4; i++) {
    127     part[3 - i] = key_number & 0xFF;
    128     key_number >>= 8;
    129   }
    130   challenge->append(part, 4);
    131 }
    132 
    133 }  // anonymous namespace
    134 
    135 namespace net {
    136 
    137 WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
    138     : original_length_(0),
    139       raw_length_(0) {}
    140 
    141 bool WebSocketHandshakeRequestHandler::ParseRequest(
    142     const char* data, int length) {
    143   DCHECK_GT(length, 0);
    144   std::string input(data, length);
    145   int input_header_length =
    146       HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
    147   if (input_header_length <= 0 ||
    148       input_header_length + kRequestKey3Size > input.size())
    149     return false;
    150 
    151   ParseHandshakeHeader(input.data(),
    152                        input_header_length,
    153                        &status_line_,
    154                        &headers_);
    155 
    156   // draft-hixie-thewebsocketprotocol-76 or later will send /key3/
    157   // after handshake request header.
    158   // Assumes WebKit doesn't send any data after handshake request message
    159   // until handshake is finished.
    160   // Thus, |key3_| is part of handshake message, and not in part
    161   // of WebSocket frame stream.
    162   DCHECK_EQ(kRequestKey3Size,
    163             input.size() -
    164             input_header_length);
    165   key3_ = std::string(input.data() + input_header_length,
    166                       input.size() - input_header_length);
    167   original_length_ = input.size();
    168   return true;
    169 }
    170 
    171 size_t WebSocketHandshakeRequestHandler::original_length() const {
    172   return original_length_;
    173 }
    174 
    175 void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
    176     const std::string& name, const std::string& value) {
    177   DCHECK(!headers_.empty());
    178   HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
    179 }
    180 
    181 void WebSocketHandshakeRequestHandler::RemoveHeaders(
    182     const char* const headers_to_remove[],
    183     size_t headers_to_remove_len) {
    184   DCHECK(!headers_.empty());
    185   headers_ = FilterHeaders(
    186       headers_, headers_to_remove, headers_to_remove_len);
    187 }
    188 
    189 HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
    190     const GURL& url, std::string* challenge) {
    191   HttpRequestInfo request_info;
    192   request_info.url = url;
    193   base::StringPiece method = status_line_.data();
    194   size_t method_end = base::StringPiece(
    195       status_line_.data(), status_line_.size()).find_first_of(" ");
    196   if (method_end != base::StringPiece::npos)
    197     request_info.method = std::string(status_line_.data(), method_end);
    198 
    199   request_info.extra_headers.Clear();
    200   request_info.extra_headers.AddHeadersFromString(headers_);
    201 
    202   request_info.extra_headers.RemoveHeader("Upgrade");
    203   request_info.extra_headers.RemoveHeader("Connection");
    204 
    205   challenge->clear();
    206   std::string key;
    207   request_info.extra_headers.GetHeader("Sec-WebSocket-Key1", &key);
    208   request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key1");
    209   GetKeyNumber(key, challenge);
    210 
    211   request_info.extra_headers.GetHeader("Sec-WebSocket-Key2", &key);
    212   request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key2");
    213   GetKeyNumber(key, challenge);
    214 
    215   challenge->append(key3_);
    216 
    217   return request_info;
    218 }
    219 
    220 bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
    221     const GURL& url, spdy::SpdyHeaderBlock* headers, std::string* challenge) {
    222   // We don't set "method" and "version".  These are fixed value in WebSocket
    223   // protocol.
    224   (*headers)["url"] = url.spec();
    225 
    226   std::string key1;
    227   std::string key2;
    228   HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
    229   while (iter.GetNext()) {
    230     if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
    231                              "connection")) {
    232       // Ignore "Connection" header.
    233       continue;
    234     } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
    235                                     "upgrade")) {
    236       // Ignore "Upgrade" header.
    237       continue;
    238     } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
    239                                     "sec-websocket-key1")) {
    240       // Use only for generating challenge.
    241       key1 = iter.values();
    242       continue;
    243     } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
    244                                     "sec-websocket-key2")) {
    245       // Use only for generating challenge.
    246       key2 = iter.values();
    247       continue;
    248     }
    249     // Others should be sent out to |headers|.
    250     std::string name = StringToLowerASCII(iter.name());
    251     spdy::SpdyHeaderBlock::iterator found = headers->find(name);
    252     if (found == headers->end()) {
    253       (*headers)[name] = iter.values();
    254     } else {
    255       // For now, websocket doesn't use multiple headers, but follows to http.
    256       found->second.append(1, '\0');  // +=() doesn't append 0's
    257       found->second.append(iter.values());
    258     }
    259   }
    260 
    261   challenge->clear();
    262   GetKeyNumber(key1, challenge);
    263   GetKeyNumber(key2, challenge);
    264   challenge->append(key3_);
    265 
    266   return true;
    267 }
    268 
    269 std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
    270   DCHECK(!status_line_.empty());
    271   DCHECK(!headers_.empty());
    272   DCHECK_EQ(kRequestKey3Size, key3_.size());
    273   std::string raw_request = status_line_ + headers_ + "\r\n" + key3_;
    274   raw_length_ = raw_request.size();
    275   return raw_request;
    276 }
    277 
    278 size_t WebSocketHandshakeRequestHandler::raw_length() const {
    279   DCHECK_GT(raw_length_, 0);
    280   return raw_length_;
    281 }
    282 
    283 WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
    284     : original_header_length_(0) {
    285 }
    286 
    287 WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
    288 
    289 size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
    290     const char* data, int length) {
    291   DCHECK_GT(length, 0);
    292   if (HasResponse()) {
    293     DCHECK(!status_line_.empty());
    294     DCHECK(!headers_.empty());
    295     DCHECK_EQ(kResponseKeySize, key_.size());
    296     return 0;
    297   }
    298 
    299   size_t old_original_length = original_.size();
    300 
    301   original_.append(data, length);
    302   // TODO(ukai): fail fast when response gives wrong status code.
    303   original_header_length_ = HttpUtil::LocateEndOfHeaders(
    304       original_.data(), original_.size(), 0);
    305   if (!HasResponse())
    306     return length;
    307 
    308   ParseHandshakeHeader(original_.data(),
    309                        original_header_length_,
    310                        &status_line_,
    311                        &headers_);
    312   int header_size = status_line_.size() + headers_.size();
    313   DCHECK_GE(original_header_length_, header_size);
    314   header_separator_ = std::string(original_.data() + header_size,
    315                                   original_header_length_ - header_size);
    316   key_ = std::string(original_.data() + original_header_length_,
    317                      kResponseKeySize);
    318 
    319   return original_header_length_ + kResponseKeySize - old_original_length;
    320 }
    321 
    322 bool WebSocketHandshakeResponseHandler::HasResponse() const {
    323   return original_header_length_ > 0 &&
    324       original_header_length_ + kResponseKeySize <= original_.size();
    325 }
    326 
    327 bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
    328     const HttpResponseInfo& response_info,
    329     const std::string& challenge) {
    330   if (!response_info.headers.get())
    331     return false;
    332 
    333   std::string response_message;
    334   response_message = response_info.headers->GetStatusLine();
    335   response_message += "\r\n";
    336   response_message += "Upgrade: WebSocket\r\n";
    337   response_message += "Connection: Upgrade\r\n";
    338   void* iter = NULL;
    339   std::string name;
    340   std::string value;
    341   while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
    342     response_message += name + ": " + value + "\r\n";
    343   }
    344   response_message += "\r\n";
    345 
    346   MD5Digest digest;
    347   MD5Sum(challenge.data(), challenge.size(), &digest);
    348 
    349   const char* digest_data = reinterpret_cast<char*>(digest.a);
    350   response_message.append(digest_data, sizeof(digest.a));
    351 
    352   return ParseRawResponse(response_message.data(),
    353                           response_message.size()) == response_message.size();
    354 }
    355 
    356 bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
    357     const spdy::SpdyHeaderBlock& headers,
    358     const std::string& challenge) {
    359   std::string response_message;
    360   response_message = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n";
    361   response_message += "Upgrade: WebSocket\r\n";
    362   response_message += "Connection: Upgrade\r\n";
    363   for (spdy::SpdyHeaderBlock::const_iterator iter = headers.begin();
    364        iter != headers.end();
    365        ++iter) {
    366     // For each value, if the server sends a NUL-separated list of values,
    367     // we separate that back out into individual headers for each value
    368     // in the list.
    369     const std::string& value = iter->second;
    370     size_t start = 0;
    371     size_t end = 0;
    372     do {
    373       end = value.find('\0', start);
    374       std::string tval;
    375       if (end != std::string::npos)
    376         tval = value.substr(start, (end - start));
    377       else
    378         tval = value.substr(start);
    379       response_message += iter->first + ": " + tval + "\r\n";
    380       start = end + 1;
    381     } while (end != std::string::npos);
    382   }
    383   response_message += "\r\n";
    384 
    385   MD5Digest digest;
    386   MD5Sum(challenge.data(), challenge.size(), &digest);
    387 
    388   const char* digest_data = reinterpret_cast<char*>(digest.a);
    389   response_message.append(digest_data, sizeof(digest.a));
    390 
    391   return ParseRawResponse(response_message.data(),
    392                           response_message.size()) == response_message.size();
    393 }
    394 
    395 void WebSocketHandshakeResponseHandler::GetHeaders(
    396     const char* const headers_to_get[],
    397     size_t headers_to_get_len,
    398     std::vector<std::string>* values) {
    399   DCHECK(HasResponse());
    400   DCHECK(!status_line_.empty());
    401   DCHECK(!headers_.empty());
    402   DCHECK_EQ(kResponseKeySize, key_.size());
    403 
    404   FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
    405 }
    406 
    407 void WebSocketHandshakeResponseHandler::RemoveHeaders(
    408     const char* const headers_to_remove[],
    409     size_t headers_to_remove_len) {
    410   DCHECK(HasResponse());
    411   DCHECK(!status_line_.empty());
    412   DCHECK(!headers_.empty());
    413   DCHECK_EQ(kResponseKeySize, key_.size());
    414 
    415   headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
    416 }
    417 
    418 std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
    419   DCHECK(HasResponse());
    420   return std::string(original_.data(),
    421                      original_header_length_ + kResponseKeySize);
    422 }
    423 
    424 std::string WebSocketHandshakeResponseHandler::GetResponse() {
    425   DCHECK(HasResponse());
    426   DCHECK(!status_line_.empty());
    427   // headers_ might be empty for wrong response from server.
    428   DCHECK_EQ(kResponseKeySize, key_.size());
    429 
    430   return status_line_ + headers_ + header_separator_ + key_;
    431 }
    432 
    433 }  // namespace net
    434