1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/websockets/websocket_basic_handshake_stream.h" 6 7 #include <algorithm> 8 #include <iterator> 9 #include <set> 10 #include <string> 11 #include <vector> 12 13 #include "base/base64.h" 14 #include "base/basictypes.h" 15 #include "base/bind.h" 16 #include "base/containers/hash_tables.h" 17 #include "base/logging.h" 18 #include "base/metrics/histogram.h" 19 #include "base/metrics/sparse_histogram.h" 20 #include "base/stl_util.h" 21 #include "base/strings/string_number_conversions.h" 22 #include "base/strings/string_piece.h" 23 #include "base/strings/string_util.h" 24 #include "base/strings/stringprintf.h" 25 #include "base/time/time.h" 26 #include "crypto/random.h" 27 #include "net/http/http_request_headers.h" 28 #include "net/http/http_request_info.h" 29 #include "net/http/http_response_body_drainer.h" 30 #include "net/http/http_response_headers.h" 31 #include "net/http/http_status_code.h" 32 #include "net/http/http_stream_parser.h" 33 #include "net/socket/client_socket_handle.h" 34 #include "net/socket/websocket_transport_client_socket_pool.h" 35 #include "net/websockets/websocket_basic_stream.h" 36 #include "net/websockets/websocket_deflate_predictor.h" 37 #include "net/websockets/websocket_deflate_predictor_impl.h" 38 #include "net/websockets/websocket_deflate_stream.h" 39 #include "net/websockets/websocket_deflater.h" 40 #include "net/websockets/websocket_extension_parser.h" 41 #include "net/websockets/websocket_handshake_constants.h" 42 #include "net/websockets/websocket_handshake_handler.h" 43 #include "net/websockets/websocket_handshake_request_info.h" 44 #include "net/websockets/websocket_handshake_response_info.h" 45 #include "net/websockets/websocket_stream.h" 46 47 namespace net { 48 49 // TODO(ricea): If more extensions are added, replace this with a more general 50 // mechanism. 51 struct WebSocketExtensionParams { 52 WebSocketExtensionParams() 53 : deflate_enabled(false), 54 client_window_bits(15), 55 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {} 56 57 bool deflate_enabled; 58 int client_window_bits; 59 WebSocketDeflater::ContextTakeOverMode deflate_mode; 60 }; 61 62 namespace { 63 64 enum GetHeaderResult { 65 GET_HEADER_OK, 66 GET_HEADER_MISSING, 67 GET_HEADER_MULTIPLE, 68 }; 69 70 std::string MissingHeaderMessage(const std::string& header_name) { 71 return std::string("'") + header_name + "' header is missing"; 72 } 73 74 std::string MultipleHeaderValuesMessage(const std::string& header_name) { 75 return 76 std::string("'") + 77 header_name + 78 "' header must not appear more than once in a response"; 79 } 80 81 std::string GenerateHandshakeChallenge() { 82 std::string raw_challenge(websockets::kRawChallengeLength, '\0'); 83 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); 84 std::string encoded_challenge; 85 base::Base64Encode(raw_challenge, &encoded_challenge); 86 return encoded_challenge; 87 } 88 89 void AddVectorHeaderIfNonEmpty(const char* name, 90 const std::vector<std::string>& value, 91 HttpRequestHeaders* headers) { 92 if (value.empty()) 93 return; 94 headers->SetHeader(name, JoinString(value, ", ")); 95 } 96 97 GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers, 98 const base::StringPiece& name, 99 std::string* value) { 100 void* state = NULL; 101 size_t num_values = 0; 102 std::string temp_value; 103 while (headers->EnumerateHeader(&state, name, &temp_value)) { 104 if (++num_values > 1) 105 return GET_HEADER_MULTIPLE; 106 *value = temp_value; 107 } 108 return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING; 109 } 110 111 bool ValidateHeaderHasSingleValue(GetHeaderResult result, 112 const std::string& header_name, 113 std::string* failure_message) { 114 if (result == GET_HEADER_MISSING) { 115 *failure_message = MissingHeaderMessage(header_name); 116 return false; 117 } 118 if (result == GET_HEADER_MULTIPLE) { 119 *failure_message = MultipleHeaderValuesMessage(header_name); 120 return false; 121 } 122 DCHECK_EQ(result, GET_HEADER_OK); 123 return true; 124 } 125 126 bool ValidateUpgrade(const HttpResponseHeaders* headers, 127 std::string* failure_message) { 128 std::string value; 129 GetHeaderResult result = 130 GetSingleHeaderValue(headers, websockets::kUpgrade, &value); 131 if (!ValidateHeaderHasSingleValue(result, 132 websockets::kUpgrade, 133 failure_message)) { 134 return false; 135 } 136 137 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) { 138 *failure_message = 139 "'Upgrade' header value is not 'WebSocket': " + value; 140 return false; 141 } 142 return true; 143 } 144 145 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers, 146 const std::string& expected, 147 std::string* failure_message) { 148 std::string actual; 149 GetHeaderResult result = 150 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual); 151 if (!ValidateHeaderHasSingleValue(result, 152 websockets::kSecWebSocketAccept, 153 failure_message)) { 154 return false; 155 } 156 157 if (expected != actual) { 158 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; 159 return false; 160 } 161 return true; 162 } 163 164 bool ValidateConnection(const HttpResponseHeaders* headers, 165 std::string* failure_message) { 166 // Connection header is permitted to contain other tokens. 167 if (!headers->HasHeader(HttpRequestHeaders::kConnection)) { 168 *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection); 169 return false; 170 } 171 if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection, 172 websockets::kUpgrade)) { 173 *failure_message = "'Connection' header value must contain 'Upgrade'"; 174 return false; 175 } 176 return true; 177 } 178 179 bool ValidateSubProtocol( 180 const HttpResponseHeaders* headers, 181 const std::vector<std::string>& requested_sub_protocols, 182 std::string* sub_protocol, 183 std::string* failure_message) { 184 void* state = NULL; 185 std::string value; 186 base::hash_set<std::string> requested_set(requested_sub_protocols.begin(), 187 requested_sub_protocols.end()); 188 int count = 0; 189 bool has_multiple_protocols = false; 190 bool has_invalid_protocol = false; 191 192 while (!has_invalid_protocol || !has_multiple_protocols) { 193 std::string temp_value; 194 if (!headers->EnumerateHeader( 195 &state, websockets::kSecWebSocketProtocol, &temp_value)) 196 break; 197 value = temp_value; 198 if (requested_set.count(value) == 0) 199 has_invalid_protocol = true; 200 if (++count > 1) 201 has_multiple_protocols = true; 202 } 203 204 if (has_multiple_protocols) { 205 *failure_message = 206 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol); 207 return false; 208 } else if (count > 0 && requested_sub_protocols.size() == 0) { 209 *failure_message = 210 std::string("Response must not include 'Sec-WebSocket-Protocol' " 211 "header if not present in request: ") 212 + value; 213 return false; 214 } else if (has_invalid_protocol) { 215 *failure_message = 216 "'Sec-WebSocket-Protocol' header value '" + 217 value + 218 "' in response does not match any of sent values"; 219 return false; 220 } else if (requested_sub_protocols.size() > 0 && count == 0) { 221 *failure_message = 222 "Sent non-empty 'Sec-WebSocket-Protocol' header " 223 "but no response was received"; 224 return false; 225 } 226 *sub_protocol = value; 227 return true; 228 } 229 230 bool DeflateError(std::string* message, const base::StringPiece& piece) { 231 *message = "Error in permessage-deflate: "; 232 piece.AppendToString(message); 233 return false; 234 } 235 236 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, 237 std::string* failure_message, 238 WebSocketExtensionParams* params) { 239 static const char kClientPrefix[] = "client_"; 240 static const char kServerPrefix[] = "server_"; 241 static const char kNoContextTakeover[] = "no_context_takeover"; 242 static const char kMaxWindowBits[] = "max_window_bits"; 243 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; 244 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, 245 the_strings_server_and_client_must_be_the_same_length); 246 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; 247 248 DCHECK_EQ("permessage-deflate", extension.name()); 249 const ParameterVector& parameters = extension.parameters(); 250 std::set<std::string> seen_names; 251 for (ParameterVector::const_iterator it = parameters.begin(); 252 it != parameters.end(); ++it) { 253 const std::string& name = it->name(); 254 if (seen_names.count(name) != 0) { 255 return DeflateError( 256 failure_message, 257 "Received duplicate permessage-deflate extension parameter " + name); 258 } 259 seen_names.insert(name); 260 const std::string client_or_server(name, 0, kPrefixLen); 261 const bool is_client = (client_or_server == kClientPrefix); 262 if (!is_client && client_or_server != kServerPrefix) { 263 return DeflateError( 264 failure_message, 265 "Received an unexpected permessage-deflate extension parameter"); 266 } 267 const std::string rest(name, kPrefixLen); 268 if (rest == kNoContextTakeover) { 269 if (it->HasValue()) { 270 return DeflateError(failure_message, 271 "Received invalid " + name + " parameter"); 272 } 273 if (is_client) 274 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; 275 } else if (rest == kMaxWindowBits) { 276 if (!it->HasValue()) 277 return DeflateError(failure_message, name + " must have value"); 278 int bits = 0; 279 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || 280 it->value()[0] == '0' || 281 it->value().find_first_not_of("0123456789") != std::string::npos) { 282 return DeflateError(failure_message, 283 "Received invalid " + name + " parameter"); 284 } 285 if (is_client) 286 params->client_window_bits = bits; 287 } else { 288 return DeflateError( 289 failure_message, 290 "Received an unexpected permessage-deflate extension parameter"); 291 } 292 } 293 params->deflate_enabled = true; 294 return true; 295 } 296 297 bool ValidateExtensions(const HttpResponseHeaders* headers, 298 const std::vector<std::string>& requested_extensions, 299 std::string* extensions, 300 std::string* failure_message, 301 WebSocketExtensionParams* params) { 302 void* state = NULL; 303 std::string value; 304 std::vector<std::string> accepted_extensions; 305 // TODO(ricea): If adding support for additional extensions, generalise this 306 // code. 307 bool seen_permessage_deflate = false; 308 while (headers->EnumerateHeader( 309 &state, websockets::kSecWebSocketExtensions, &value)) { 310 WebSocketExtensionParser parser; 311 parser.Parse(value); 312 if (parser.has_error()) { 313 // TODO(yhirano) Set appropriate failure message. 314 *failure_message = 315 "'Sec-WebSocket-Extensions' header value is " 316 "rejected by the parser: " + 317 value; 318 return false; 319 } 320 if (parser.extension().name() == "permessage-deflate") { 321 if (seen_permessage_deflate) { 322 *failure_message = "Received duplicate permessage-deflate response"; 323 return false; 324 } 325 seen_permessage_deflate = true; 326 if (!ValidatePerMessageDeflateExtension( 327 parser.extension(), failure_message, params)) 328 return false; 329 } else { 330 *failure_message = 331 "Found an unsupported extension '" + 332 parser.extension().name() + 333 "' in 'Sec-WebSocket-Extensions' header"; 334 return false; 335 } 336 accepted_extensions.push_back(value); 337 } 338 *extensions = JoinString(accepted_extensions, ", "); 339 return true; 340 } 341 342 } // namespace 343 344 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( 345 scoped_ptr<ClientSocketHandle> connection, 346 WebSocketStream::ConnectDelegate* connect_delegate, 347 bool using_proxy, 348 std::vector<std::string> requested_sub_protocols, 349 std::vector<std::string> requested_extensions, 350 std::string* failure_message) 351 : state_(connection.release(), using_proxy), 352 connect_delegate_(connect_delegate), 353 http_response_info_(NULL), 354 requested_sub_protocols_(requested_sub_protocols), 355 requested_extensions_(requested_extensions), 356 failure_message_(failure_message) { 357 DCHECK(connect_delegate); 358 DCHECK(failure_message); 359 } 360 361 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {} 362 363 int WebSocketBasicHandshakeStream::InitializeStream( 364 const HttpRequestInfo* request_info, 365 RequestPriority priority, 366 const BoundNetLog& net_log, 367 const CompletionCallback& callback) { 368 url_ = request_info->url; 369 state_.Initialize(request_info, priority, net_log, callback); 370 return OK; 371 } 372 373 int WebSocketBasicHandshakeStream::SendRequest( 374 const HttpRequestHeaders& headers, 375 HttpResponseInfo* response, 376 const CompletionCallback& callback) { 377 DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey)); 378 DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol)); 379 DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions)); 380 DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin)); 381 DCHECK(headers.HasHeader(websockets::kUpgrade)); 382 DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection)); 383 DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion)); 384 DCHECK(parser()); 385 386 http_response_info_ = response; 387 388 // Create a copy of the headers object, so that we can add the 389 // Sec-WebSockey-Key header. 390 HttpRequestHeaders enriched_headers; 391 enriched_headers.CopyFrom(headers); 392 std::string handshake_challenge; 393 if (handshake_challenge_for_testing_) { 394 handshake_challenge = *handshake_challenge_for_testing_; 395 handshake_challenge_for_testing_.reset(); 396 } else { 397 handshake_challenge = GenerateHandshakeChallenge(); 398 } 399 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge); 400 401 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, 402 requested_extensions_, 403 &enriched_headers); 404 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, 405 requested_sub_protocols_, 406 &enriched_headers); 407 408 ComputeSecWebSocketAccept(handshake_challenge, 409 &handshake_challenge_response_); 410 411 DCHECK(connect_delegate_); 412 scoped_ptr<WebSocketHandshakeRequestInfo> request( 413 new WebSocketHandshakeRequestInfo(url_, base::Time::Now())); 414 request->headers.CopyFrom(enriched_headers); 415 connect_delegate_->OnStartOpeningHandshake(request.Pass()); 416 417 return parser()->SendRequest( 418 state_.GenerateRequestLine(), enriched_headers, response, callback); 419 } 420 421 int WebSocketBasicHandshakeStream::ReadResponseHeaders( 422 const CompletionCallback& callback) { 423 // HttpStreamParser uses a weak pointer when reading from the 424 // socket, so it won't be called back after being destroyed. The 425 // HttpStreamParser is owned by HttpBasicState which is owned by this object, 426 // so this use of base::Unretained() is safe. 427 int rv = parser()->ReadResponseHeaders( 428 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback, 429 base::Unretained(this), 430 callback)); 431 if (rv == ERR_IO_PENDING) 432 return rv; 433 return ValidateResponse(rv); 434 } 435 436 int WebSocketBasicHandshakeStream::ReadResponseBody( 437 IOBuffer* buf, 438 int buf_len, 439 const CompletionCallback& callback) { 440 return parser()->ReadResponseBody(buf, buf_len, callback); 441 } 442 443 void WebSocketBasicHandshakeStream::Close(bool not_reusable) { 444 // This class ignores the value of |not_reusable| and never lets the socket be 445 // re-used. 446 if (parser()) 447 parser()->Close(true); 448 } 449 450 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const { 451 return parser()->IsResponseBodyComplete(); 452 } 453 454 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const { 455 return parser() && parser()->CanFindEndOfResponse(); 456 } 457 458 bool WebSocketBasicHandshakeStream::IsConnectionReused() const { 459 return parser()->IsConnectionReused(); 460 } 461 462 void WebSocketBasicHandshakeStream::SetConnectionReused() { 463 parser()->SetConnectionReused(); 464 } 465 466 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const { 467 return false; 468 } 469 470 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const { 471 return 0; 472 } 473 474 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo( 475 LoadTimingInfo* load_timing_info) const { 476 return state_.connection()->GetLoadTimingInfo(IsConnectionReused(), 477 load_timing_info); 478 } 479 480 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) { 481 parser()->GetSSLInfo(ssl_info); 482 } 483 484 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo( 485 SSLCertRequestInfo* cert_request_info) { 486 parser()->GetSSLCertRequestInfo(cert_request_info); 487 } 488 489 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; } 490 491 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) { 492 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this); 493 drainer->Start(session); 494 // |drainer| will delete itself. 495 } 496 497 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { 498 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is 499 // gone, then copy whatever has happened there over here. 500 } 501 502 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { 503 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make 504 // sure it does not touch it again before it is destroyed. 505 state_.DeleteParser(); 506 WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection()); 507 scoped_ptr<WebSocketStream> basic_stream( 508 new WebSocketBasicStream(state_.ReleaseConnection(), 509 state_.read_buf(), 510 sub_protocol_, 511 extensions_)); 512 DCHECK(extension_params_.get()); 513 if (extension_params_->deflate_enabled) { 514 UMA_HISTOGRAM_ENUMERATION( 515 "Net.WebSocket.DeflateMode", 516 extension_params_->deflate_mode, 517 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES); 518 519 return scoped_ptr<WebSocketStream>( 520 new WebSocketDeflateStream(basic_stream.Pass(), 521 extension_params_->deflate_mode, 522 extension_params_->client_window_bits, 523 scoped_ptr<WebSocketDeflatePredictor>( 524 new WebSocketDeflatePredictorImpl))); 525 } else { 526 return basic_stream.Pass(); 527 } 528 } 529 530 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( 531 const std::string& key) { 532 handshake_challenge_for_testing_.reset(new std::string(key)); 533 } 534 535 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback( 536 const CompletionCallback& callback, 537 int result) { 538 callback.Run(ValidateResponse(result)); 539 } 540 541 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() { 542 DCHECK(http_response_info_); 543 WebSocketDispatchOnFinishOpeningHandshake(connect_delegate_, 544 url_, 545 http_response_info_->headers, 546 http_response_info_->response_time); 547 } 548 549 int WebSocketBasicHandshakeStream::ValidateResponse(int rv) { 550 DCHECK(http_response_info_); 551 // Most net errors happen during connection, so they are not seen by this 552 // method. The histogram for error codes is created in 553 // Delegate::OnResponseStarted in websocket_stream.cc instead. 554 if (rv >= 0) { 555 const HttpResponseHeaders* headers = http_response_info_->headers.get(); 556 const int response_code = headers->response_code(); 557 UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ResponseCode", response_code); 558 switch (response_code) { 559 case HTTP_SWITCHING_PROTOCOLS: 560 OnFinishOpeningHandshake(); 561 return ValidateUpgradeResponse(headers); 562 563 // We need to pass these through for authentication to work. 564 case HTTP_UNAUTHORIZED: 565 case HTTP_PROXY_AUTHENTICATION_REQUIRED: 566 return OK; 567 568 // Other status codes are potentially risky (see the warnings in the 569 // WHATWG WebSocket API spec) and so are dropped by default. 570 default: 571 // A WebSocket server cannot be using HTTP/0.9, so if we see version 572 // 0.9, it means the response was garbage. 573 // Reporting "Unexpected response code: 200" in this case is not 574 // helpful, so use a different error message. 575 if (headers->GetHttpVersion() == HttpVersion(0, 9)) { 576 set_failure_message( 577 "Error during WebSocket handshake: Invalid status line"); 578 } else { 579 set_failure_message(base::StringPrintf( 580 "Error during WebSocket handshake: Unexpected response code: %d", 581 headers->response_code())); 582 } 583 OnFinishOpeningHandshake(); 584 return ERR_INVALID_RESPONSE; 585 } 586 } else { 587 if (rv == ERR_EMPTY_RESPONSE) { 588 set_failure_message( 589 "Connection closed before receiving a handshake response"); 590 return rv; 591 } 592 set_failure_message(std::string("Error during WebSocket handshake: ") + 593 ErrorToString(rv)); 594 OnFinishOpeningHandshake(); 595 return rv; 596 } 597 } 598 599 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( 600 const HttpResponseHeaders* headers) { 601 extension_params_.reset(new WebSocketExtensionParams); 602 std::string failure_message; 603 if (ValidateUpgrade(headers, &failure_message) && 604 ValidateSecWebSocketAccept( 605 headers, handshake_challenge_response_, &failure_message) && 606 ValidateConnection(headers, &failure_message) && 607 ValidateSubProtocol(headers, 608 requested_sub_protocols_, 609 &sub_protocol_, 610 &failure_message) && 611 ValidateExtensions(headers, 612 requested_extensions_, 613 &extensions_, 614 &failure_message, 615 extension_params_.get())) { 616 return OK; 617 } 618 set_failure_message("Error during WebSocket handshake: " + failure_message); 619 return ERR_INVALID_RESPONSE; 620 } 621 622 void WebSocketBasicHandshakeStream::set_failure_message( 623 const std::string& failure_message) { 624 *failure_message_ = failure_message; 625 } 626 627 } // namespace net 628