1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/websockets/websocket_basic_handshake_stream.h" 6 7 #include <algorithm> 8 #include <iterator> 9 10 #include "base/base64.h" 11 #include "base/basictypes.h" 12 #include "base/bind.h" 13 #include "base/containers/hash_tables.h" 14 #include "base/stl_util.h" 15 #include "base/strings/string_util.h" 16 #include "crypto/random.h" 17 #include "net/http/http_request_headers.h" 18 #include "net/http/http_request_info.h" 19 #include "net/http/http_response_body_drainer.h" 20 #include "net/http/http_response_headers.h" 21 #include "net/http/http_status_code.h" 22 #include "net/http/http_stream_parser.h" 23 #include "net/socket/client_socket_handle.h" 24 #include "net/websockets/websocket_basic_stream.h" 25 #include "net/websockets/websocket_handshake_constants.h" 26 #include "net/websockets/websocket_handshake_handler.h" 27 #include "net/websockets/websocket_stream.h" 28 29 namespace net { 30 namespace { 31 32 std::string GenerateHandshakeChallenge() { 33 std::string raw_challenge(websockets::kRawChallengeLength, '\0'); 34 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); 35 std::string encoded_challenge; 36 base::Base64Encode(raw_challenge, &encoded_challenge); 37 return encoded_challenge; 38 } 39 40 void AddVectorHeaderIfNonEmpty(const char* name, 41 const std::vector<std::string>& value, 42 HttpRequestHeaders* headers) { 43 if (value.empty()) 44 return; 45 headers->SetHeader(name, JoinString(value, ", ")); 46 } 47 48 // If |case_sensitive| is false, then |value| must be in lower-case. 49 bool ValidateSingleTokenHeader( 50 const scoped_refptr<HttpResponseHeaders>& headers, 51 const base::StringPiece& name, 52 const std::string& value, 53 bool case_sensitive) { 54 void* state = NULL; 55 std::string token; 56 int tokens = 0; 57 bool has_value = false; 58 while (headers->EnumerateHeader(&state, name, &token)) { 59 if (++tokens > 1) 60 return false; 61 has_value = case_sensitive ? value == token 62 : LowerCaseEqualsASCII(token, value.c_str()); 63 } 64 return has_value; 65 } 66 67 bool ValidateSubProtocol( 68 const scoped_refptr<HttpResponseHeaders>& headers, 69 const std::vector<std::string>& requested_sub_protocols, 70 std::string* sub_protocol) { 71 void* state = NULL; 72 std::string token; 73 base::hash_set<std::string> requested_set(requested_sub_protocols.begin(), 74 requested_sub_protocols.end()); 75 int accepted = 0; 76 while (headers->EnumerateHeader( 77 &state, websockets::kSecWebSocketProtocol, &token)) { 78 if (requested_set.count(token) == 0) 79 return false; 80 81 *sub_protocol = token; 82 // The server is only allowed to accept one protocol. 83 if (++accepted > 1) 84 return false; 85 } 86 // If the browser requested > 0 protocols, the server is required to accept 87 // one. 88 return requested_set.empty() || accepted == 1; 89 } 90 91 bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers, 92 const std::vector<std::string>& requested_extensions, 93 std::string* extensions) { 94 void* state = NULL; 95 std::string token; 96 while (headers->EnumerateHeader( 97 &state, websockets::kSecWebSocketExtensions, &token)) { 98 // TODO(ricea): Accept permessage-deflate with valid parameters. 99 return false; 100 } 101 return true; 102 } 103 104 } // namespace 105 106 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( 107 scoped_ptr<ClientSocketHandle> connection, 108 bool using_proxy, 109 std::vector<std::string> requested_sub_protocols, 110 std::vector<std::string> requested_extensions) 111 : state_(connection.release(), using_proxy), 112 http_response_info_(NULL), 113 requested_sub_protocols_(requested_sub_protocols), 114 requested_extensions_(requested_extensions) {} 115 116 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {} 117 118 int WebSocketBasicHandshakeStream::InitializeStream( 119 const HttpRequestInfo* request_info, 120 RequestPriority priority, 121 const BoundNetLog& net_log, 122 const CompletionCallback& callback) { 123 state_.Initialize(request_info, priority, net_log, callback); 124 return OK; 125 } 126 127 int WebSocketBasicHandshakeStream::SendRequest( 128 const HttpRequestHeaders& headers, 129 HttpResponseInfo* response, 130 const CompletionCallback& callback) { 131 DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey)); 132 DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol)); 133 DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions)); 134 DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin)); 135 DCHECK(headers.HasHeader(websockets::kUpgrade)); 136 DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection)); 137 DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion)); 138 DCHECK(parser()); 139 140 http_response_info_ = response; 141 142 // Create a copy of the headers object, so that we can add the 143 // Sec-WebSockey-Key header. 144 HttpRequestHeaders enriched_headers; 145 enriched_headers.CopyFrom(headers); 146 std::string handshake_challenge; 147 if (handshake_challenge_for_testing_) { 148 handshake_challenge = *handshake_challenge_for_testing_; 149 handshake_challenge_for_testing_.reset(); 150 } else { 151 handshake_challenge = GenerateHandshakeChallenge(); 152 } 153 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge); 154 155 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, 156 requested_sub_protocols_, 157 &enriched_headers); 158 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, 159 requested_extensions_, 160 &enriched_headers); 161 162 ComputeSecWebSocketAccept(handshake_challenge, 163 &handshake_challenge_response_); 164 165 return parser()->SendRequest( 166 state_.GenerateRequestLine(), enriched_headers, response, callback); 167 } 168 169 int WebSocketBasicHandshakeStream::ReadResponseHeaders( 170 const CompletionCallback& callback) { 171 // HttpStreamParser uses a weak pointer when reading from the 172 // socket, so it won't be called back after being destroyed. The 173 // HttpStreamParser is owned by HttpBasicState which is owned by this object, 174 // so this use of base::Unretained() is safe. 175 int rv = parser()->ReadResponseHeaders( 176 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback, 177 base::Unretained(this), 178 callback)); 179 return rv == OK ? ValidateResponse() : rv; 180 } 181 182 const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const { 183 return parser()->GetResponseInfo(); 184 } 185 186 int WebSocketBasicHandshakeStream::ReadResponseBody( 187 IOBuffer* buf, 188 int buf_len, 189 const CompletionCallback& callback) { 190 return parser()->ReadResponseBody(buf, buf_len, callback); 191 } 192 193 void WebSocketBasicHandshakeStream::Close(bool not_reusable) { 194 // This class ignores the value of |not_reusable| and never lets the socket be 195 // re-used. 196 if (parser()) 197 parser()->Close(true); 198 } 199 200 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const { 201 return parser()->IsResponseBodyComplete(); 202 } 203 204 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const { 205 return parser() && parser()->CanFindEndOfResponse(); 206 } 207 208 bool WebSocketBasicHandshakeStream::IsConnectionReused() const { 209 return parser()->IsConnectionReused(); 210 } 211 212 void WebSocketBasicHandshakeStream::SetConnectionReused() { 213 parser()->SetConnectionReused(); 214 } 215 216 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const { 217 return false; 218 } 219 220 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const { 221 return 0; 222 } 223 224 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo( 225 LoadTimingInfo* load_timing_info) const { 226 return state_.connection()->GetLoadTimingInfo(IsConnectionReused(), 227 load_timing_info); 228 } 229 230 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) { 231 parser()->GetSSLInfo(ssl_info); 232 } 233 234 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo( 235 SSLCertRequestInfo* cert_request_info) { 236 parser()->GetSSLCertRequestInfo(cert_request_info); 237 } 238 239 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; } 240 241 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) { 242 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this); 243 drainer->Start(session); 244 // |drainer| will delete itself. 245 } 246 247 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { 248 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is 249 // gone, then copy whatever has happened there over here. 250 } 251 252 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { 253 // TODO(ricea): Add deflate support. 254 255 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make 256 // sure it does not touch it again before it is destroyed. 257 state_.DeleteParser(); 258 return scoped_ptr<WebSocketStream>( 259 new WebSocketBasicStream(state_.ReleaseConnection(), 260 state_.read_buf(), 261 sub_protocol_, 262 extensions_)); 263 } 264 265 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( 266 const std::string& key) { 267 handshake_challenge_for_testing_.reset(new std::string(key)); 268 } 269 270 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback( 271 const CompletionCallback& callback, 272 int result) { 273 if (result == OK) 274 result = ValidateResponse(); 275 callback.Run(result); 276 } 277 278 int WebSocketBasicHandshakeStream::ValidateResponse() { 279 DCHECK(http_response_info_); 280 const scoped_refptr<HttpResponseHeaders>& headers = 281 http_response_info_->headers; 282 283 switch (headers->response_code()) { 284 case HTTP_SWITCHING_PROTOCOLS: 285 return ValidateUpgradeResponse(headers); 286 287 // We need to pass these through for authentication to work. 288 case HTTP_UNAUTHORIZED: 289 case HTTP_PROXY_AUTHENTICATION_REQUIRED: 290 return OK; 291 292 // Other status codes are potentially risky (see the warnings in the 293 // WHATWG WebSocket API spec) and so are dropped by default. 294 default: 295 return ERR_INVALID_RESPONSE; 296 } 297 } 298 299 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( 300 const scoped_refptr<HttpResponseHeaders>& headers) { 301 if (ValidateSingleTokenHeader(headers, 302 websockets::kUpgrade, 303 websockets::kWebSocketLowercase, 304 false) && 305 ValidateSingleTokenHeader(headers, 306 websockets::kSecWebSocketAccept, 307 handshake_challenge_response_, 308 true) && 309 headers->HasHeaderValue(HttpRequestHeaders::kConnection, 310 websockets::kUpgrade) && 311 ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_) && 312 ValidateExtensions(headers, requested_extensions_, &extensions_)) { 313 return OK; 314 } 315 return ERR_INVALID_RESPONSE; 316 } 317 318 } // namespace net 319