Home | History | Annotate | Download | only in websockets
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/websockets/websocket_basic_handshake_stream.h"
      6 
      7 #include <algorithm>
      8 #include <iterator>
      9 
     10 #include "base/base64.h"
     11 #include "base/basictypes.h"
     12 #include "base/bind.h"
     13 #include "base/containers/hash_tables.h"
     14 #include "base/stl_util.h"
     15 #include "base/strings/string_util.h"
     16 #include "crypto/random.h"
     17 #include "net/http/http_request_headers.h"
     18 #include "net/http/http_request_info.h"
     19 #include "net/http/http_response_body_drainer.h"
     20 #include "net/http/http_response_headers.h"
     21 #include "net/http/http_status_code.h"
     22 #include "net/http/http_stream_parser.h"
     23 #include "net/socket/client_socket_handle.h"
     24 #include "net/websockets/websocket_basic_stream.h"
     25 #include "net/websockets/websocket_handshake_constants.h"
     26 #include "net/websockets/websocket_handshake_handler.h"
     27 #include "net/websockets/websocket_stream.h"
     28 
     29 namespace net {
     30 namespace {
     31 
     32 std::string GenerateHandshakeChallenge() {
     33   std::string raw_challenge(websockets::kRawChallengeLength, '\0');
     34   crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
     35   std::string encoded_challenge;
     36   base::Base64Encode(raw_challenge, &encoded_challenge);
     37   return encoded_challenge;
     38 }
     39 
     40 void AddVectorHeaderIfNonEmpty(const char* name,
     41                                const std::vector<std::string>& value,
     42                                HttpRequestHeaders* headers) {
     43   if (value.empty())
     44     return;
     45   headers->SetHeader(name, JoinString(value, ", "));
     46 }
     47 
     48 // If |case_sensitive| is false, then |value| must be in lower-case.
     49 bool ValidateSingleTokenHeader(
     50     const scoped_refptr<HttpResponseHeaders>& headers,
     51     const base::StringPiece& name,
     52     const std::string& value,
     53     bool case_sensitive) {
     54   void* state = NULL;
     55   std::string token;
     56   int tokens = 0;
     57   bool has_value = false;
     58   while (headers->EnumerateHeader(&state, name, &token)) {
     59     if (++tokens > 1)
     60       return false;
     61     has_value = case_sensitive ? value == token
     62                                : LowerCaseEqualsASCII(token, value.c_str());
     63   }
     64   return has_value;
     65 }
     66 
     67 bool ValidateSubProtocol(
     68     const scoped_refptr<HttpResponseHeaders>& headers,
     69     const std::vector<std::string>& requested_sub_protocols,
     70     std::string* sub_protocol) {
     71   void* state = NULL;
     72   std::string token;
     73   base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
     74                                             requested_sub_protocols.end());
     75   int accepted = 0;
     76   while (headers->EnumerateHeader(
     77       &state, websockets::kSecWebSocketProtocol, &token)) {
     78     if (requested_set.count(token) == 0)
     79       return false;
     80 
     81     *sub_protocol = token;
     82     // The server is only allowed to accept one protocol.
     83     if (++accepted > 1)
     84       return false;
     85   }
     86   // If the browser requested > 0 protocols, the server is required to accept
     87   // one.
     88   return requested_set.empty() || accepted == 1;
     89 }
     90 
     91 bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers,
     92                         const std::vector<std::string>& requested_extensions,
     93                         std::string* extensions) {
     94   void* state = NULL;
     95   std::string token;
     96   while (headers->EnumerateHeader(
     97       &state, websockets::kSecWebSocketExtensions, &token)) {
     98     // TODO(ricea): Accept permessage-deflate with valid parameters.
     99     return false;
    100   }
    101   return true;
    102 }
    103 
    104 }  // namespace
    105 
    106 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
    107     scoped_ptr<ClientSocketHandle> connection,
    108     bool using_proxy,
    109     std::vector<std::string> requested_sub_protocols,
    110     std::vector<std::string> requested_extensions)
    111     : state_(connection.release(), using_proxy),
    112       http_response_info_(NULL),
    113       requested_sub_protocols_(requested_sub_protocols),
    114       requested_extensions_(requested_extensions) {}
    115 
    116 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
    117 
    118 int WebSocketBasicHandshakeStream::InitializeStream(
    119     const HttpRequestInfo* request_info,
    120     RequestPriority priority,
    121     const BoundNetLog& net_log,
    122     const CompletionCallback& callback) {
    123   state_.Initialize(request_info, priority, net_log, callback);
    124   return OK;
    125 }
    126 
    127 int WebSocketBasicHandshakeStream::SendRequest(
    128     const HttpRequestHeaders& headers,
    129     HttpResponseInfo* response,
    130     const CompletionCallback& callback) {
    131   DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
    132   DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
    133   DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
    134   DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
    135   DCHECK(headers.HasHeader(websockets::kUpgrade));
    136   DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
    137   DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
    138   DCHECK(parser());
    139 
    140   http_response_info_ = response;
    141 
    142   // Create a copy of the headers object, so that we can add the
    143   // Sec-WebSockey-Key header.
    144   HttpRequestHeaders enriched_headers;
    145   enriched_headers.CopyFrom(headers);
    146   std::string handshake_challenge;
    147   if (handshake_challenge_for_testing_) {
    148     handshake_challenge = *handshake_challenge_for_testing_;
    149     handshake_challenge_for_testing_.reset();
    150   } else {
    151     handshake_challenge = GenerateHandshakeChallenge();
    152   }
    153   enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
    154 
    155   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
    156                             requested_sub_protocols_,
    157                             &enriched_headers);
    158   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
    159                             requested_extensions_,
    160                             &enriched_headers);
    161 
    162   ComputeSecWebSocketAccept(handshake_challenge,
    163                             &handshake_challenge_response_);
    164 
    165   return parser()->SendRequest(
    166       state_.GenerateRequestLine(), enriched_headers, response, callback);
    167 }
    168 
    169 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
    170     const CompletionCallback& callback) {
    171   // HttpStreamParser uses a weak pointer when reading from the
    172   // socket, so it won't be called back after being destroyed. The
    173   // HttpStreamParser is owned by HttpBasicState which is owned by this object,
    174   // so this use of base::Unretained() is safe.
    175   int rv = parser()->ReadResponseHeaders(
    176       base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback,
    177                  base::Unretained(this),
    178                  callback));
    179   return rv == OK ? ValidateResponse() : rv;
    180 }
    181 
    182 const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const {
    183   return parser()->GetResponseInfo();
    184 }
    185 
    186 int WebSocketBasicHandshakeStream::ReadResponseBody(
    187     IOBuffer* buf,
    188     int buf_len,
    189     const CompletionCallback& callback) {
    190   return parser()->ReadResponseBody(buf, buf_len, callback);
    191 }
    192 
    193 void WebSocketBasicHandshakeStream::Close(bool not_reusable) {
    194   // This class ignores the value of |not_reusable| and never lets the socket be
    195   // re-used.
    196   if (parser())
    197     parser()->Close(true);
    198 }
    199 
    200 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
    201   return parser()->IsResponseBodyComplete();
    202 }
    203 
    204 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
    205   return parser() && parser()->CanFindEndOfResponse();
    206 }
    207 
    208 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
    209   return parser()->IsConnectionReused();
    210 }
    211 
    212 void WebSocketBasicHandshakeStream::SetConnectionReused() {
    213   parser()->SetConnectionReused();
    214 }
    215 
    216 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
    217   return false;
    218 }
    219 
    220 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
    221   return 0;
    222 }
    223 
    224 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
    225     LoadTimingInfo* load_timing_info) const {
    226   return state_.connection()->GetLoadTimingInfo(IsConnectionReused(),
    227                                                 load_timing_info);
    228 }
    229 
    230 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
    231   parser()->GetSSLInfo(ssl_info);
    232 }
    233 
    234 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
    235     SSLCertRequestInfo* cert_request_info) {
    236   parser()->GetSSLCertRequestInfo(cert_request_info);
    237 }
    238 
    239 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
    240 
    241 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) {
    242   HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this);
    243   drainer->Start(session);
    244   // |drainer| will delete itself.
    245 }
    246 
    247 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
    248   // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
    249   // gone, then copy whatever has happened there over here.
    250 }
    251 
    252 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
    253   // TODO(ricea): Add deflate support.
    254 
    255   // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
    256   // sure it does not touch it again before it is destroyed.
    257   state_.DeleteParser();
    258   return scoped_ptr<WebSocketStream>(
    259       new WebSocketBasicStream(state_.ReleaseConnection(),
    260                                state_.read_buf(),
    261                                sub_protocol_,
    262                                extensions_));
    263 }
    264 
    265 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
    266     const std::string& key) {
    267   handshake_challenge_for_testing_.reset(new std::string(key));
    268 }
    269 
    270 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
    271     const CompletionCallback& callback,
    272     int result) {
    273   if (result == OK)
    274     result = ValidateResponse();
    275   callback.Run(result);
    276 }
    277 
    278 int WebSocketBasicHandshakeStream::ValidateResponse() {
    279   DCHECK(http_response_info_);
    280   const scoped_refptr<HttpResponseHeaders>& headers =
    281       http_response_info_->headers;
    282 
    283   switch (headers->response_code()) {
    284     case HTTP_SWITCHING_PROTOCOLS:
    285       return ValidateUpgradeResponse(headers);
    286 
    287     // We need to pass these through for authentication to work.
    288     case HTTP_UNAUTHORIZED:
    289     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
    290       return OK;
    291 
    292     // Other status codes are potentially risky (see the warnings in the
    293     // WHATWG WebSocket API spec) and so are dropped by default.
    294     default:
    295       return ERR_INVALID_RESPONSE;
    296   }
    297 }
    298 
    299 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
    300     const scoped_refptr<HttpResponseHeaders>& headers) {
    301   if (ValidateSingleTokenHeader(headers,
    302                                 websockets::kUpgrade,
    303                                 websockets::kWebSocketLowercase,
    304                                 false) &&
    305       ValidateSingleTokenHeader(headers,
    306                                 websockets::kSecWebSocketAccept,
    307                                 handshake_challenge_response_,
    308                                 true) &&
    309       headers->HasHeaderValue(HttpRequestHeaders::kConnection,
    310                               websockets::kUpgrade) &&
    311       ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_) &&
    312       ValidateExtensions(headers, requested_extensions_, &extensions_)) {
    313     return OK;
    314   }
    315   return ERR_INVALID_RESPONSE;
    316 }
    317 
    318 }  // namespace net
    319