Home | History | Annotate | Download | only in websockets
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/websockets/websocket_handshake_handler.h"
      6 
      7 #include <limits>
      8 
      9 #include "base/base64.h"
     10 #include "base/sha1.h"
     11 #include "base/strings/string_number_conversions.h"
     12 #include "base/strings/string_piece.h"
     13 #include "base/strings/string_tokenizer.h"
     14 #include "base/strings/string_util.h"
     15 #include "base/strings/stringprintf.h"
     16 #include "net/http/http_request_headers.h"
     17 #include "net/http/http_response_headers.h"
     18 #include "net/http/http_util.h"
     19 #include "net/websockets/websocket_handshake_constants.h"
     20 #include "url/gurl.h"
     21 
     22 namespace net {
     23 namespace {
     24 
     25 const int kVersionHeaderValueForRFC6455 = 13;
     26 
     27 // Splits |handshake_message| into Status-Line or Request-Line (including CRLF)
     28 // and headers (excluding 2nd CRLF of double CRLFs at the end of a handshake
     29 // response).
     30 void ParseHandshakeHeader(
     31     const char* handshake_message, int len,
     32     std::string* request_line,
     33     std::string* headers) {
     34   size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
     35   if (i == base::StringPiece::npos) {
     36     *request_line = std::string(handshake_message, len);
     37     *headers = "";
     38     return;
     39   }
     40   // |request_line| includes \r\n.
     41   *request_line = std::string(handshake_message, i + 2);
     42 
     43   int header_len = len - (i + 2) - 2;
     44   if (header_len > 0) {
     45     // |handshake_message| includes trailing \r\n\r\n.
     46     // |headers| doesn't include 2nd \r\n.
     47     *headers = std::string(handshake_message + i + 2, header_len);
     48   } else {
     49     *headers = "";
     50   }
     51 }
     52 
     53 void FetchHeaders(const std::string& headers,
     54                   const char* const headers_to_get[],
     55                   size_t headers_to_get_len,
     56                   std::vector<std::string>* values) {
     57   net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
     58   while (iter.GetNext()) {
     59     for (size_t i = 0; i < headers_to_get_len; i++) {
     60       if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
     61                                headers_to_get[i])) {
     62         values->push_back(iter.values());
     63       }
     64     }
     65   }
     66 }
     67 
     68 bool GetHeaderName(std::string::const_iterator line_begin,
     69                    std::string::const_iterator line_end,
     70                    std::string::const_iterator* name_begin,
     71                    std::string::const_iterator* name_end) {
     72   std::string::const_iterator colon = std::find(line_begin, line_end, ':');
     73   if (colon == line_end) {
     74     return false;
     75   }
     76   *name_begin = line_begin;
     77   *name_end = colon;
     78   if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
     79     return false;
     80   net::HttpUtil::TrimLWS(name_begin, name_end);
     81   return true;
     82 }
     83 
     84 // Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
     85 // is, lines that are not formatted as "<name>: <value>\r\n".
     86 std::string FilterHeaders(
     87     const std::string& headers,
     88     const char* const headers_to_remove[],
     89     size_t headers_to_remove_len) {
     90   std::string filtered_headers;
     91 
     92   base::StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
     93   while (lines.GetNext()) {
     94     std::string::const_iterator line_begin = lines.token_begin();
     95     std::string::const_iterator line_end = lines.token_end();
     96     std::string::const_iterator name_begin;
     97     std::string::const_iterator name_end;
     98     bool should_remove = false;
     99     if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
    100       for (size_t i = 0; i < headers_to_remove_len; ++i) {
    101         if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
    102           should_remove = true;
    103           break;
    104         }
    105       }
    106     }
    107     if (!should_remove) {
    108       filtered_headers.append(line_begin, line_end);
    109       filtered_headers.append("\r\n");
    110     }
    111   }
    112   return filtered_headers;
    113 }
    114 
    115 bool CheckVersionInRequest(const std::string& request_headers) {
    116   std::vector<std::string> values;
    117   const char* const headers_to_get[1] = {
    118     websockets::kSecWebSocketVersionLowercase};
    119   FetchHeaders(request_headers, headers_to_get, 1, &values);
    120   DCHECK_LE(values.size(), 1U);
    121   if (values.empty())
    122     return false;
    123 
    124   int version;
    125   bool conversion_success = base::StringToInt(values[0], &version);
    126   if (!conversion_success)
    127     return false;
    128 
    129   return version == kVersionHeaderValueForRFC6455;
    130 }
    131 
    132 // Append a header to a string. Equivalent to
    133 //   response_message += header + ": " + value + "\r\n"
    134 // but avoids unnecessary allocations and copies.
    135 void AppendHeader(const base::StringPiece& header,
    136                   const base::StringPiece& value,
    137                   std::string* response_message) {
    138   static const char kColonSpace[] = ": ";
    139   const size_t kColonSpaceSize = sizeof(kColonSpace) - 1;
    140   static const char kCrNl[] = "\r\n";
    141   const size_t kCrNlSize = sizeof(kCrNl) - 1;
    142 
    143   size_t extra_size =
    144       header.size() + kColonSpaceSize + value.size() + kCrNlSize;
    145   response_message->reserve(response_message->size() + extra_size);
    146   response_message->append(header.begin(), header.end());
    147   response_message->append(kColonSpace, kColonSpace + kColonSpaceSize);
    148   response_message->append(value.begin(), value.end());
    149   response_message->append(kCrNl, kCrNl + kCrNlSize);
    150 }
    151 
    152 }  // namespace
    153 
    154 WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
    155     : original_length_(0),
    156       raw_length_(0) {}
    157 
    158 bool WebSocketHandshakeRequestHandler::ParseRequest(
    159     const char* data, int length) {
    160   DCHECK_GT(length, 0);
    161   std::string input(data, length);
    162   int input_header_length =
    163       HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
    164   if (input_header_length <= 0)
    165     return false;
    166 
    167   ParseHandshakeHeader(input.data(),
    168                        input_header_length,
    169                        &request_line_,
    170                        &headers_);
    171 
    172   if (!CheckVersionInRequest(headers_)) {
    173     NOTREACHED();
    174     return false;
    175   }
    176 
    177   original_length_ = input_header_length;
    178   return true;
    179 }
    180 
    181 size_t WebSocketHandshakeRequestHandler::original_length() const {
    182   return original_length_;
    183 }
    184 
    185 void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
    186     const std::string& name, const std::string& value) {
    187   DCHECK(!headers_.empty());
    188   HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
    189 }
    190 
    191 void WebSocketHandshakeRequestHandler::RemoveHeaders(
    192     const char* const headers_to_remove[],
    193     size_t headers_to_remove_len) {
    194   DCHECK(!headers_.empty());
    195   headers_ = FilterHeaders(
    196       headers_, headers_to_remove, headers_to_remove_len);
    197 }
    198 
    199 HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
    200     const GURL& url, std::string* challenge) {
    201   HttpRequestInfo request_info;
    202   request_info.url = url;
    203   size_t method_end = base::StringPiece(request_line_).find_first_of(" ");
    204   if (method_end != base::StringPiece::npos)
    205     request_info.method = std::string(request_line_.data(), method_end);
    206 
    207   request_info.extra_headers.Clear();
    208   request_info.extra_headers.AddHeadersFromString(headers_);
    209 
    210   request_info.extra_headers.RemoveHeader(websockets::kUpgrade);
    211   request_info.extra_headers.RemoveHeader(HttpRequestHeaders::kConnection);
    212 
    213   std::string key;
    214   bool header_present = request_info.extra_headers.GetHeader(
    215       websockets::kSecWebSocketKey, &key);
    216   DCHECK(header_present);
    217   request_info.extra_headers.RemoveHeader(websockets::kSecWebSocketKey);
    218   *challenge = key;
    219   return request_info;
    220 }
    221 
    222 bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
    223     const GURL& url,
    224     SpdyHeaderBlock* headers,
    225     std::string* challenge,
    226     int spdy_protocol_version) {
    227   // Construct opening handshake request headers as a SPDY header block.
    228   // For details, see WebSocket Layering over SPDY/3 Draft 8.
    229   if (spdy_protocol_version <= 2) {
    230     (*headers)["path"] = url.path();
    231     (*headers)["version"] = "WebSocket/13";
    232     (*headers)["scheme"] = url.scheme();
    233   } else {
    234     (*headers)[":path"] = url.path();
    235     (*headers)[":version"] = "WebSocket/13";
    236     (*headers)[":scheme"] = url.scheme();
    237   }
    238 
    239   HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
    240   while (iter.GetNext()) {
    241     if (LowerCaseEqualsASCII(iter.name_begin(),
    242                              iter.name_end(),
    243                              websockets::kUpgradeLowercase) ||
    244         LowerCaseEqualsASCII(
    245             iter.name_begin(), iter.name_end(), "connection") ||
    246         LowerCaseEqualsASCII(iter.name_begin(),
    247                              iter.name_end(),
    248                              websockets::kSecWebSocketVersionLowercase)) {
    249       // These headers must be ignored.
    250       continue;
    251     } else if (LowerCaseEqualsASCII(iter.name_begin(),
    252                                     iter.name_end(),
    253                                     websockets::kSecWebSocketKeyLowercase)) {
    254       *challenge = iter.values();
    255       // Sec-WebSocket-Key is not sent to the server.
    256       continue;
    257     } else if (LowerCaseEqualsASCII(
    258                    iter.name_begin(), iter.name_end(), "host") ||
    259                LowerCaseEqualsASCII(
    260                    iter.name_begin(), iter.name_end(), "origin") ||
    261                LowerCaseEqualsASCII(
    262                    iter.name_begin(),
    263                    iter.name_end(),
    264                    websockets::kSecWebSocketProtocolLowercase) ||
    265                LowerCaseEqualsASCII(
    266                    iter.name_begin(),
    267                    iter.name_end(),
    268                    websockets::kSecWebSocketExtensionsLowercase)) {
    269       // TODO(toyoshim): Some WebSocket extensions may not be compatible with
    270       // SPDY. We should omit them from a Sec-WebSocket-Extension header.
    271       std::string name;
    272       if (spdy_protocol_version <= 2)
    273         name = StringToLowerASCII(iter.name());
    274       else
    275         name = ":" + StringToLowerASCII(iter.name());
    276       (*headers)[name] = iter.values();
    277       continue;
    278     }
    279     // Others should be sent out to |headers|.
    280     std::string name = StringToLowerASCII(iter.name());
    281     SpdyHeaderBlock::iterator found = headers->find(name);
    282     if (found == headers->end()) {
    283       (*headers)[name] = iter.values();
    284     } else {
    285       // For now, websocket doesn't use multiple headers, but follows to http.
    286       found->second.append(1, '\0');  // +=() doesn't append 0's
    287       found->second.append(iter.values());
    288     }
    289   }
    290 
    291   return true;
    292 }
    293 
    294 std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
    295   DCHECK(!request_line_.empty());
    296   DCHECK(!headers_.empty());
    297 
    298   std::string raw_request = request_line_ + headers_ + "\r\n";
    299   raw_length_ = raw_request.size();
    300   return raw_request;
    301 }
    302 
    303 size_t WebSocketHandshakeRequestHandler::raw_length() const {
    304   DCHECK_GT(raw_length_, 0);
    305   return raw_length_;
    306 }
    307 
    308 WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
    309     : original_header_length_(0) {}
    310 
    311 WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
    312 
    313 size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
    314     const char* data, int length) {
    315   DCHECK_GT(length, 0);
    316   if (HasResponse()) {
    317     DCHECK(!status_line_.empty());
    318     // headers_ might be empty for wrong response from server.
    319 
    320     return 0;
    321   }
    322 
    323   size_t old_original_length = original_.size();
    324 
    325   original_.append(data, length);
    326   // TODO(ukai): fail fast when response gives wrong status code.
    327   original_header_length_ = HttpUtil::LocateEndOfHeaders(
    328       original_.data(), original_.size(), 0);
    329   if (!HasResponse())
    330     return length;
    331 
    332   ParseHandshakeHeader(original_.data(),
    333                        original_header_length_,
    334                        &status_line_,
    335                        &headers_);
    336   int header_size = status_line_.size() + headers_.size();
    337   DCHECK_GE(original_header_length_, header_size);
    338   header_separator_ = std::string(original_.data() + header_size,
    339                                   original_header_length_ - header_size);
    340   return original_header_length_ - old_original_length;
    341 }
    342 
    343 bool WebSocketHandshakeResponseHandler::HasResponse() const {
    344   return original_header_length_ > 0 &&
    345       static_cast<size_t>(original_header_length_) <= original_.size();
    346 }
    347 
    348 void ComputeSecWebSocketAccept(const std::string& key,
    349                                std::string* accept) {
    350   DCHECK(accept);
    351 
    352   std::string hash =
    353       base::SHA1HashString(key + websockets::kWebSocketGuid);
    354   base::Base64Encode(hash, accept);
    355 }
    356 
    357 bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
    358     const HttpResponseInfo& response_info,
    359     const std::string& challenge) {
    360   if (!response_info.headers.get())
    361     return false;
    362 
    363   // TODO(ricea): Eliminate all the reallocations and string copies.
    364   std::string response_message;
    365   response_message = response_info.headers->GetStatusLine();
    366   response_message += "\r\n";
    367 
    368   AppendHeader(websockets::kUpgrade,
    369                websockets::kWebSocketLowercase,
    370                &response_message);
    371 
    372   AppendHeader(
    373       HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message);
    374 
    375   std::string websocket_accept;
    376   ComputeSecWebSocketAccept(challenge, &websocket_accept);
    377   AppendHeader(
    378       websockets::kSecWebSocketAccept, websocket_accept, &response_message);
    379 
    380   void* iter = NULL;
    381   std::string name;
    382   std::string value;
    383   while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
    384     AppendHeader(name, value, &response_message);
    385   }
    386   response_message += "\r\n";
    387 
    388   return ParseRawResponse(response_message.data(),
    389                           response_message.size()) == response_message.size();
    390 }
    391 
    392 bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
    393     const SpdyHeaderBlock& headers,
    394     const std::string& challenge,
    395     int spdy_protocol_version) {
    396   SpdyHeaderBlock::const_iterator status;
    397   if (spdy_protocol_version <= 2)
    398     status = headers.find("status");
    399   else
    400     status = headers.find(":status");
    401   if (status == headers.end())
    402     return false;
    403 
    404   std::string hash =
    405       base::SHA1HashString(challenge + websockets::kWebSocketGuid);
    406   std::string websocket_accept;
    407   base::Base64Encode(hash, &websocket_accept);
    408 
    409   std::string response_message = base::StringPrintf(
    410       "%s %s\r\n", websockets::kHttpProtocolVersion, status->second.c_str());
    411 
    412   AppendHeader(
    413       websockets::kUpgrade, websockets::kWebSocketLowercase, &response_message);
    414   AppendHeader(
    415       HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message);
    416   AppendHeader(
    417       websockets::kSecWebSocketAccept, websocket_accept, &response_message);
    418 
    419   for (SpdyHeaderBlock::const_iterator iter = headers.begin();
    420        iter != headers.end();
    421        ++iter) {
    422     // For each value, if the server sends a NUL-separated list of values,
    423     // we separate that back out into individual headers for each value
    424     // in the list.
    425     if ((spdy_protocol_version <= 2 &&
    426          LowerCaseEqualsASCII(iter->first, "status")) ||
    427         (spdy_protocol_version >= 3 &&
    428          LowerCaseEqualsASCII(iter->first, ":status"))) {
    429       // The status value is already handled as the first line of
    430       // |response_message|. Just skip here.
    431       continue;
    432     }
    433     const std::string& value = iter->second;
    434     size_t start = 0;
    435     size_t end = 0;
    436     do {
    437       end = value.find('\0', start);
    438       std::string tval;
    439       if (end != std::string::npos)
    440         tval = value.substr(start, (end - start));
    441       else
    442         tval = value.substr(start);
    443       if (spdy_protocol_version >= 3 &&
    444           (LowerCaseEqualsASCII(iter->first,
    445                                 websockets::kSecWebSocketProtocolSpdy3) ||
    446            LowerCaseEqualsASCII(iter->first,
    447                                 websockets::kSecWebSocketExtensionsSpdy3)))
    448         AppendHeader(iter->first.substr(1), tval, &response_message);
    449       else
    450         AppendHeader(iter->first, tval, &response_message);
    451       start = end + 1;
    452     } while (end != std::string::npos);
    453   }
    454   response_message += "\r\n";
    455 
    456   return ParseRawResponse(response_message.data(),
    457                           response_message.size()) == response_message.size();
    458 }
    459 
    460 void WebSocketHandshakeResponseHandler::GetHeaders(
    461     const char* const headers_to_get[],
    462     size_t headers_to_get_len,
    463     std::vector<std::string>* values) {
    464   DCHECK(HasResponse());
    465   DCHECK(!status_line_.empty());
    466   // headers_ might be empty for wrong response from server.
    467   if (headers_.empty())
    468     return;
    469 
    470   FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
    471 }
    472 
    473 void WebSocketHandshakeResponseHandler::RemoveHeaders(
    474     const char* const headers_to_remove[],
    475     size_t headers_to_remove_len) {
    476   DCHECK(HasResponse());
    477   DCHECK(!status_line_.empty());
    478   // headers_ might be empty for wrong response from server.
    479   if (headers_.empty())
    480     return;
    481 
    482   headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
    483 }
    484 
    485 std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
    486   DCHECK(HasResponse());
    487   return original_.substr(0, original_header_length_);
    488 }
    489 
    490 std::string WebSocketHandshakeResponseHandler::GetResponse() {
    491   DCHECK(HasResponse());
    492   DCHECK(!status_line_.empty());
    493   // headers_ might be empty for wrong response from server.
    494 
    495   return status_line_ + headers_ + header_separator_;
    496 }
    497 
    498 }  // namespace net
    499