1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/websockets/websocket_handshake_handler.h" 6 7 #include <limits> 8 9 #include "base/base64.h" 10 #include "base/sha1.h" 11 #include "base/strings/string_number_conversions.h" 12 #include "base/strings/string_piece.h" 13 #include "base/strings/string_tokenizer.h" 14 #include "base/strings/string_util.h" 15 #include "base/strings/stringprintf.h" 16 #include "net/http/http_request_headers.h" 17 #include "net/http/http_response_headers.h" 18 #include "net/http/http_util.h" 19 #include "net/websockets/websocket_handshake_constants.h" 20 #include "url/gurl.h" 21 22 namespace net { 23 namespace { 24 25 const int kVersionHeaderValueForRFC6455 = 13; 26 27 // Splits |handshake_message| into Status-Line or Request-Line (including CRLF) 28 // and headers (excluding 2nd CRLF of double CRLFs at the end of a handshake 29 // response). 30 void ParseHandshakeHeader( 31 const char* handshake_message, int len, 32 std::string* request_line, 33 std::string* headers) { 34 size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n"); 35 if (i == base::StringPiece::npos) { 36 *request_line = std::string(handshake_message, len); 37 *headers = ""; 38 return; 39 } 40 // |request_line| includes \r\n. 41 *request_line = std::string(handshake_message, i + 2); 42 43 int header_len = len - (i + 2) - 2; 44 if (header_len > 0) { 45 // |handshake_message| includes trailing \r\n\r\n. 46 // |headers| doesn't include 2nd \r\n. 47 *headers = std::string(handshake_message + i + 2, header_len); 48 } else { 49 *headers = ""; 50 } 51 } 52 53 void FetchHeaders(const std::string& headers, 54 const char* const headers_to_get[], 55 size_t headers_to_get_len, 56 std::vector<std::string>* values) { 57 net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n"); 58 while (iter.GetNext()) { 59 for (size_t i = 0; i < headers_to_get_len; i++) { 60 if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), 61 headers_to_get[i])) { 62 values->push_back(iter.values()); 63 } 64 } 65 } 66 } 67 68 bool GetHeaderName(std::string::const_iterator line_begin, 69 std::string::const_iterator line_end, 70 std::string::const_iterator* name_begin, 71 std::string::const_iterator* name_end) { 72 std::string::const_iterator colon = std::find(line_begin, line_end, ':'); 73 if (colon == line_end) { 74 return false; 75 } 76 *name_begin = line_begin; 77 *name_end = colon; 78 if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin)) 79 return false; 80 net::HttpUtil::TrimLWS(name_begin, name_end); 81 return true; 82 } 83 84 // Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that 85 // is, lines that are not formatted as "<name>: <value>\r\n". 86 std::string FilterHeaders( 87 const std::string& headers, 88 const char* const headers_to_remove[], 89 size_t headers_to_remove_len) { 90 std::string filtered_headers; 91 92 base::StringTokenizer lines(headers.begin(), headers.end(), "\r\n"); 93 while (lines.GetNext()) { 94 std::string::const_iterator line_begin = lines.token_begin(); 95 std::string::const_iterator line_end = lines.token_end(); 96 std::string::const_iterator name_begin; 97 std::string::const_iterator name_end; 98 bool should_remove = false; 99 if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) { 100 for (size_t i = 0; i < headers_to_remove_len; ++i) { 101 if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) { 102 should_remove = true; 103 break; 104 } 105 } 106 } 107 if (!should_remove) { 108 filtered_headers.append(line_begin, line_end); 109 filtered_headers.append("\r\n"); 110 } 111 } 112 return filtered_headers; 113 } 114 115 bool CheckVersionInRequest(const std::string& request_headers) { 116 std::vector<std::string> values; 117 const char* const headers_to_get[1] = { 118 websockets::kSecWebSocketVersionLowercase}; 119 FetchHeaders(request_headers, headers_to_get, 1, &values); 120 DCHECK_LE(values.size(), 1U); 121 if (values.empty()) 122 return false; 123 124 int version; 125 bool conversion_success = base::StringToInt(values[0], &version); 126 if (!conversion_success) 127 return false; 128 129 return version == kVersionHeaderValueForRFC6455; 130 } 131 132 // Append a header to a string. Equivalent to 133 // response_message += header + ": " + value + "\r\n" 134 // but avoids unnecessary allocations and copies. 135 void AppendHeader(const base::StringPiece& header, 136 const base::StringPiece& value, 137 std::string* response_message) { 138 static const char kColonSpace[] = ": "; 139 const size_t kColonSpaceSize = sizeof(kColonSpace) - 1; 140 static const char kCrNl[] = "\r\n"; 141 const size_t kCrNlSize = sizeof(kCrNl) - 1; 142 143 size_t extra_size = 144 header.size() + kColonSpaceSize + value.size() + kCrNlSize; 145 response_message->reserve(response_message->size() + extra_size); 146 response_message->append(header.begin(), header.end()); 147 response_message->append(kColonSpace, kColonSpace + kColonSpaceSize); 148 response_message->append(value.begin(), value.end()); 149 response_message->append(kCrNl, kCrNl + kCrNlSize); 150 } 151 152 } // namespace 153 154 WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler() 155 : original_length_(0), 156 raw_length_(0) {} 157 158 bool WebSocketHandshakeRequestHandler::ParseRequest( 159 const char* data, int length) { 160 DCHECK_GT(length, 0); 161 std::string input(data, length); 162 int input_header_length = 163 HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0); 164 if (input_header_length <= 0) 165 return false; 166 167 ParseHandshakeHeader(input.data(), 168 input_header_length, 169 &request_line_, 170 &headers_); 171 172 if (!CheckVersionInRequest(headers_)) { 173 NOTREACHED(); 174 return false; 175 } 176 177 original_length_ = input_header_length; 178 return true; 179 } 180 181 size_t WebSocketHandshakeRequestHandler::original_length() const { 182 return original_length_; 183 } 184 185 void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing( 186 const std::string& name, const std::string& value) { 187 DCHECK(!headers_.empty()); 188 HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_); 189 } 190 191 void WebSocketHandshakeRequestHandler::RemoveHeaders( 192 const char* const headers_to_remove[], 193 size_t headers_to_remove_len) { 194 DCHECK(!headers_.empty()); 195 headers_ = FilterHeaders( 196 headers_, headers_to_remove, headers_to_remove_len); 197 } 198 199 HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo( 200 const GURL& url, std::string* challenge) { 201 HttpRequestInfo request_info; 202 request_info.url = url; 203 size_t method_end = base::StringPiece(request_line_).find_first_of(" "); 204 if (method_end != base::StringPiece::npos) 205 request_info.method = std::string(request_line_.data(), method_end); 206 207 request_info.extra_headers.Clear(); 208 request_info.extra_headers.AddHeadersFromString(headers_); 209 210 request_info.extra_headers.RemoveHeader(websockets::kUpgrade); 211 request_info.extra_headers.RemoveHeader(HttpRequestHeaders::kConnection); 212 213 std::string key; 214 bool header_present = request_info.extra_headers.GetHeader( 215 websockets::kSecWebSocketKey, &key); 216 DCHECK(header_present); 217 request_info.extra_headers.RemoveHeader(websockets::kSecWebSocketKey); 218 *challenge = key; 219 return request_info; 220 } 221 222 bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock( 223 const GURL& url, 224 SpdyHeaderBlock* headers, 225 std::string* challenge, 226 int spdy_protocol_version) { 227 // Construct opening handshake request headers as a SPDY header block. 228 // For details, see WebSocket Layering over SPDY/3 Draft 8. 229 if (spdy_protocol_version <= 2) { 230 (*headers)["path"] = url.path(); 231 (*headers)["version"] = "WebSocket/13"; 232 (*headers)["scheme"] = url.scheme(); 233 } else { 234 (*headers)[":path"] = url.path(); 235 (*headers)[":version"] = "WebSocket/13"; 236 (*headers)[":scheme"] = url.scheme(); 237 } 238 239 HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n"); 240 while (iter.GetNext()) { 241 if (LowerCaseEqualsASCII(iter.name_begin(), 242 iter.name_end(), 243 websockets::kUpgradeLowercase) || 244 LowerCaseEqualsASCII( 245 iter.name_begin(), iter.name_end(), "connection") || 246 LowerCaseEqualsASCII(iter.name_begin(), 247 iter.name_end(), 248 websockets::kSecWebSocketVersionLowercase)) { 249 // These headers must be ignored. 250 continue; 251 } else if (LowerCaseEqualsASCII(iter.name_begin(), 252 iter.name_end(), 253 websockets::kSecWebSocketKeyLowercase)) { 254 *challenge = iter.values(); 255 // Sec-WebSocket-Key is not sent to the server. 256 continue; 257 } else if (LowerCaseEqualsASCII( 258 iter.name_begin(), iter.name_end(), "host") || 259 LowerCaseEqualsASCII( 260 iter.name_begin(), iter.name_end(), "origin") || 261 LowerCaseEqualsASCII( 262 iter.name_begin(), 263 iter.name_end(), 264 websockets::kSecWebSocketProtocolLowercase) || 265 LowerCaseEqualsASCII( 266 iter.name_begin(), 267 iter.name_end(), 268 websockets::kSecWebSocketExtensionsLowercase)) { 269 // TODO(toyoshim): Some WebSocket extensions may not be compatible with 270 // SPDY. We should omit them from a Sec-WebSocket-Extension header. 271 std::string name; 272 if (spdy_protocol_version <= 2) 273 name = base::StringToLowerASCII(iter.name()); 274 else 275 name = ":" + base::StringToLowerASCII(iter.name()); 276 (*headers)[name] = iter.values(); 277 continue; 278 } 279 // Others should be sent out to |headers|. 280 std::string name = base::StringToLowerASCII(iter.name()); 281 SpdyHeaderBlock::iterator found = headers->find(name); 282 if (found == headers->end()) { 283 (*headers)[name] = iter.values(); 284 } else { 285 // For now, websocket doesn't use multiple headers, but follows to http. 286 found->second.append(1, '\0'); // +=() doesn't append 0's 287 found->second.append(iter.values()); 288 } 289 } 290 291 return true; 292 } 293 294 std::string WebSocketHandshakeRequestHandler::GetRawRequest() { 295 DCHECK(!request_line_.empty()); 296 DCHECK(!headers_.empty()); 297 298 std::string raw_request = request_line_ + headers_ + "\r\n"; 299 raw_length_ = raw_request.size(); 300 return raw_request; 301 } 302 303 size_t WebSocketHandshakeRequestHandler::raw_length() const { 304 DCHECK_GT(raw_length_, 0); 305 return raw_length_; 306 } 307 308 WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler() 309 : original_header_length_(0) {} 310 311 WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {} 312 313 size_t WebSocketHandshakeResponseHandler::ParseRawResponse( 314 const char* data, int length) { 315 DCHECK_GT(length, 0); 316 if (HasResponse()) { 317 DCHECK(!status_line_.empty()); 318 // headers_ might be empty for wrong response from server. 319 320 return 0; 321 } 322 323 size_t old_original_length = original_.size(); 324 325 original_.append(data, length); 326 // TODO(ukai): fail fast when response gives wrong status code. 327 original_header_length_ = HttpUtil::LocateEndOfHeaders( 328 original_.data(), original_.size(), 0); 329 if (!HasResponse()) 330 return length; 331 332 ParseHandshakeHeader(original_.data(), 333 original_header_length_, 334 &status_line_, 335 &headers_); 336 int header_size = status_line_.size() + headers_.size(); 337 DCHECK_GE(original_header_length_, header_size); 338 header_separator_ = std::string(original_.data() + header_size, 339 original_header_length_ - header_size); 340 return original_header_length_ - old_original_length; 341 } 342 343 bool WebSocketHandshakeResponseHandler::HasResponse() const { 344 return original_header_length_ > 0 && 345 static_cast<size_t>(original_header_length_) <= original_.size(); 346 } 347 348 void ComputeSecWebSocketAccept(const std::string& key, 349 std::string* accept) { 350 DCHECK(accept); 351 352 std::string hash = 353 base::SHA1HashString(key + websockets::kWebSocketGuid); 354 base::Base64Encode(hash, accept); 355 } 356 357 bool WebSocketHandshakeResponseHandler::ParseResponseInfo( 358 const HttpResponseInfo& response_info, 359 const std::string& challenge) { 360 if (!response_info.headers.get()) 361 return false; 362 363 // TODO(ricea): Eliminate all the reallocations and string copies. 364 std::string response_message; 365 response_message = response_info.headers->GetStatusLine(); 366 response_message += "\r\n"; 367 368 AppendHeader(websockets::kUpgrade, 369 websockets::kWebSocketLowercase, 370 &response_message); 371 372 AppendHeader( 373 HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message); 374 375 std::string websocket_accept; 376 ComputeSecWebSocketAccept(challenge, &websocket_accept); 377 AppendHeader( 378 websockets::kSecWebSocketAccept, websocket_accept, &response_message); 379 380 void* iter = NULL; 381 std::string name; 382 std::string value; 383 while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) { 384 AppendHeader(name, value, &response_message); 385 } 386 response_message += "\r\n"; 387 388 return ParseRawResponse(response_message.data(), 389 response_message.size()) == response_message.size(); 390 } 391 392 bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock( 393 const SpdyHeaderBlock& headers, 394 const std::string& challenge, 395 int spdy_protocol_version) { 396 SpdyHeaderBlock::const_iterator status; 397 if (spdy_protocol_version <= 2) 398 status = headers.find("status"); 399 else 400 status = headers.find(":status"); 401 if (status == headers.end()) 402 return false; 403 404 std::string hash = 405 base::SHA1HashString(challenge + websockets::kWebSocketGuid); 406 std::string websocket_accept; 407 base::Base64Encode(hash, &websocket_accept); 408 409 std::string response_message = base::StringPrintf( 410 "%s %s\r\n", websockets::kHttpProtocolVersion, status->second.c_str()); 411 412 AppendHeader( 413 websockets::kUpgrade, websockets::kWebSocketLowercase, &response_message); 414 AppendHeader( 415 HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message); 416 AppendHeader( 417 websockets::kSecWebSocketAccept, websocket_accept, &response_message); 418 419 for (SpdyHeaderBlock::const_iterator iter = headers.begin(); 420 iter != headers.end(); 421 ++iter) { 422 // For each value, if the server sends a NUL-separated list of values, 423 // we separate that back out into individual headers for each value 424 // in the list. 425 if ((spdy_protocol_version <= 2 && 426 LowerCaseEqualsASCII(iter->first, "status")) || 427 (spdy_protocol_version >= 3 && 428 LowerCaseEqualsASCII(iter->first, ":status"))) { 429 // The status value is already handled as the first line of 430 // |response_message|. Just skip here. 431 continue; 432 } 433 const std::string& value = iter->second; 434 size_t start = 0; 435 size_t end = 0; 436 do { 437 end = value.find('\0', start); 438 std::string tval; 439 if (end != std::string::npos) 440 tval = value.substr(start, (end - start)); 441 else 442 tval = value.substr(start); 443 if (spdy_protocol_version >= 3 && 444 (LowerCaseEqualsASCII(iter->first, 445 websockets::kSecWebSocketProtocolSpdy3) || 446 LowerCaseEqualsASCII(iter->first, 447 websockets::kSecWebSocketExtensionsSpdy3))) 448 AppendHeader(iter->first.substr(1), tval, &response_message); 449 else 450 AppendHeader(iter->first, tval, &response_message); 451 start = end + 1; 452 } while (end != std::string::npos); 453 } 454 response_message += "\r\n"; 455 456 return ParseRawResponse(response_message.data(), 457 response_message.size()) == response_message.size(); 458 } 459 460 void WebSocketHandshakeResponseHandler::GetHeaders( 461 const char* const headers_to_get[], 462 size_t headers_to_get_len, 463 std::vector<std::string>* values) { 464 DCHECK(HasResponse()); 465 DCHECK(!status_line_.empty()); 466 // headers_ might be empty for wrong response from server. 467 if (headers_.empty()) 468 return; 469 470 FetchHeaders(headers_, headers_to_get, headers_to_get_len, values); 471 } 472 473 void WebSocketHandshakeResponseHandler::RemoveHeaders( 474 const char* const headers_to_remove[], 475 size_t headers_to_remove_len) { 476 DCHECK(HasResponse()); 477 DCHECK(!status_line_.empty()); 478 // headers_ might be empty for wrong response from server. 479 if (headers_.empty()) 480 return; 481 482 headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len); 483 } 484 485 std::string WebSocketHandshakeResponseHandler::GetRawResponse() const { 486 DCHECK(HasResponse()); 487 return original_.substr(0, original_header_length_); 488 } 489 490 std::string WebSocketHandshakeResponseHandler::GetResponse() { 491 DCHECK(HasResponse()); 492 DCHECK(!status_line_.empty()); 493 // headers_ might be empty for wrong response from server. 494 495 return status_line_ + headers_ + header_separator_; 496 } 497 498 } // namespace net 499