1 // Copyright (c) 2010 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/websockets/websocket_handshake_handler.h" 6 7 #include "base/md5.h" 8 #include "base/string_piece.h" 9 #include "base/string_util.h" 10 #include "googleurl/src/gurl.h" 11 #include "net/http/http_response_headers.h" 12 #include "net/http/http_util.h" 13 14 namespace { 15 16 const size_t kRequestKey3Size = 8U; 17 const size_t kResponseKeySize = 16U; 18 19 void ParseHandshakeHeader( 20 const char* handshake_message, int len, 21 std::string* status_line, 22 std::string* headers) { 23 size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n"); 24 if (i == base::StringPiece::npos) { 25 *status_line = std::string(handshake_message, len); 26 *headers = ""; 27 return; 28 } 29 // |status_line| includes \r\n. 30 *status_line = std::string(handshake_message, i + 2); 31 32 int header_len = len - (i + 2) - 2; 33 if (header_len > 0) { 34 // |handshake_message| includes tailing \r\n\r\n. 35 // |headers| doesn't include 2nd \r\n. 36 *headers = std::string(handshake_message + i + 2, header_len); 37 } else { 38 *headers = ""; 39 } 40 } 41 42 void FetchHeaders(const std::string& headers, 43 const char* const headers_to_get[], 44 size_t headers_to_get_len, 45 std::vector<std::string>* values) { 46 net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n"); 47 while (iter.GetNext()) { 48 for (size_t i = 0; i < headers_to_get_len; i++) { 49 if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), 50 headers_to_get[i])) { 51 values->push_back(iter.values()); 52 } 53 } 54 } 55 } 56 57 bool GetHeaderName(std::string::const_iterator line_begin, 58 std::string::const_iterator line_end, 59 std::string::const_iterator* name_begin, 60 std::string::const_iterator* name_end) { 61 std::string::const_iterator colon = std::find(line_begin, line_end, ':'); 62 if (colon == line_end) { 63 return false; 64 } 65 *name_begin = line_begin; 66 *name_end = colon; 67 if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin)) 68 return false; 69 net::HttpUtil::TrimLWS(name_begin, name_end); 70 return true; 71 } 72 73 // Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that 74 // is, lines that are not formatted as "<name>: <value>\r\n". 75 std::string FilterHeaders( 76 const std::string& headers, 77 const char* const headers_to_remove[], 78 size_t headers_to_remove_len) { 79 std::string filtered_headers; 80 81 StringTokenizer lines(headers.begin(), headers.end(), "\r\n"); 82 while (lines.GetNext()) { 83 std::string::const_iterator line_begin = lines.token_begin(); 84 std::string::const_iterator line_end = lines.token_end(); 85 std::string::const_iterator name_begin; 86 std::string::const_iterator name_end; 87 bool should_remove = false; 88 if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) { 89 for (size_t i = 0; i < headers_to_remove_len; ++i) { 90 if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) { 91 should_remove = true; 92 break; 93 } 94 } 95 } 96 if (!should_remove) { 97 filtered_headers.append(line_begin, line_end); 98 filtered_headers.append("\r\n"); 99 } 100 } 101 return filtered_headers; 102 } 103 104 // Gets a key number from |key| and appends the number to |challenge|. 105 // The key number (/part_N/) is extracted as step 4.-8. in 106 // 5.2. Sending the server's opening handshake of 107 // http://www.ietf.org/id/draft-ietf-hybi-thewebsocketprotocol-00.txt 108 void GetKeyNumber(const std::string& key, std::string* challenge) { 109 uint32 key_number = 0; 110 uint32 spaces = 0; 111 for (size_t i = 0; i < key.size(); ++i) { 112 if (isdigit(key[i])) { 113 // key_number should not overflow. (it comes from 114 // WebCore/websockets/WebSocketHandshake.cpp). 115 key_number = key_number * 10 + key[i] - '0'; 116 } else if (key[i] == ' ') { 117 ++spaces; 118 } 119 } 120 // spaces should not be zero in valid handshake request. 121 if (spaces == 0) 122 return; 123 key_number /= spaces; 124 125 char part[4]; 126 for (int i = 0; i < 4; i++) { 127 part[3 - i] = key_number & 0xFF; 128 key_number >>= 8; 129 } 130 challenge->append(part, 4); 131 } 132 133 } // anonymous namespace 134 135 namespace net { 136 137 WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler() 138 : original_length_(0), 139 raw_length_(0) {} 140 141 bool WebSocketHandshakeRequestHandler::ParseRequest( 142 const char* data, int length) { 143 DCHECK_GT(length, 0); 144 std::string input(data, length); 145 int input_header_length = 146 HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0); 147 if (input_header_length <= 0 || 148 input_header_length + kRequestKey3Size > input.size()) 149 return false; 150 151 ParseHandshakeHeader(input.data(), 152 input_header_length, 153 &status_line_, 154 &headers_); 155 156 // draft-hixie-thewebsocketprotocol-76 or later will send /key3/ 157 // after handshake request header. 158 // Assumes WebKit doesn't send any data after handshake request message 159 // until handshake is finished. 160 // Thus, |key3_| is part of handshake message, and not in part 161 // of WebSocket frame stream. 162 DCHECK_EQ(kRequestKey3Size, 163 input.size() - 164 input_header_length); 165 key3_ = std::string(input.data() + input_header_length, 166 input.size() - input_header_length); 167 original_length_ = input.size(); 168 return true; 169 } 170 171 size_t WebSocketHandshakeRequestHandler::original_length() const { 172 return original_length_; 173 } 174 175 void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing( 176 const std::string& name, const std::string& value) { 177 DCHECK(!headers_.empty()); 178 HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_); 179 } 180 181 void WebSocketHandshakeRequestHandler::RemoveHeaders( 182 const char* const headers_to_remove[], 183 size_t headers_to_remove_len) { 184 DCHECK(!headers_.empty()); 185 headers_ = FilterHeaders( 186 headers_, headers_to_remove, headers_to_remove_len); 187 } 188 189 HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo( 190 const GURL& url, std::string* challenge) { 191 HttpRequestInfo request_info; 192 request_info.url = url; 193 base::StringPiece method = status_line_.data(); 194 size_t method_end = base::StringPiece( 195 status_line_.data(), status_line_.size()).find_first_of(" "); 196 if (method_end != base::StringPiece::npos) 197 request_info.method = std::string(status_line_.data(), method_end); 198 199 request_info.extra_headers.Clear(); 200 request_info.extra_headers.AddHeadersFromString(headers_); 201 202 request_info.extra_headers.RemoveHeader("Upgrade"); 203 request_info.extra_headers.RemoveHeader("Connection"); 204 205 challenge->clear(); 206 std::string key; 207 request_info.extra_headers.GetHeader("Sec-WebSocket-Key1", &key); 208 request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key1"); 209 GetKeyNumber(key, challenge); 210 211 request_info.extra_headers.GetHeader("Sec-WebSocket-Key2", &key); 212 request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key2"); 213 GetKeyNumber(key, challenge); 214 215 challenge->append(key3_); 216 217 return request_info; 218 } 219 220 bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock( 221 const GURL& url, spdy::SpdyHeaderBlock* headers, std::string* challenge) { 222 // We don't set "method" and "version". These are fixed value in WebSocket 223 // protocol. 224 (*headers)["url"] = url.spec(); 225 226 std::string key1; 227 std::string key2; 228 HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n"); 229 while (iter.GetNext()) { 230 if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), 231 "connection")) { 232 // Ignore "Connection" header. 233 continue; 234 } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), 235 "upgrade")) { 236 // Ignore "Upgrade" header. 237 continue; 238 } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), 239 "sec-websocket-key1")) { 240 // Use only for generating challenge. 241 key1 = iter.values(); 242 continue; 243 } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), 244 "sec-websocket-key2")) { 245 // Use only for generating challenge. 246 key2 = iter.values(); 247 continue; 248 } 249 // Others should be sent out to |headers|. 250 std::string name = StringToLowerASCII(iter.name()); 251 spdy::SpdyHeaderBlock::iterator found = headers->find(name); 252 if (found == headers->end()) { 253 (*headers)[name] = iter.values(); 254 } else { 255 // For now, websocket doesn't use multiple headers, but follows to http. 256 found->second.append(1, '\0'); // +=() doesn't append 0's 257 found->second.append(iter.values()); 258 } 259 } 260 261 challenge->clear(); 262 GetKeyNumber(key1, challenge); 263 GetKeyNumber(key2, challenge); 264 challenge->append(key3_); 265 266 return true; 267 } 268 269 std::string WebSocketHandshakeRequestHandler::GetRawRequest() { 270 DCHECK(!status_line_.empty()); 271 DCHECK(!headers_.empty()); 272 DCHECK_EQ(kRequestKey3Size, key3_.size()); 273 std::string raw_request = status_line_ + headers_ + "\r\n" + key3_; 274 raw_length_ = raw_request.size(); 275 return raw_request; 276 } 277 278 size_t WebSocketHandshakeRequestHandler::raw_length() const { 279 DCHECK_GT(raw_length_, 0); 280 return raw_length_; 281 } 282 283 WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler() 284 : original_header_length_(0) { 285 } 286 287 WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {} 288 289 size_t WebSocketHandshakeResponseHandler::ParseRawResponse( 290 const char* data, int length) { 291 DCHECK_GT(length, 0); 292 if (HasResponse()) { 293 DCHECK(!status_line_.empty()); 294 DCHECK(!headers_.empty()); 295 DCHECK_EQ(kResponseKeySize, key_.size()); 296 return 0; 297 } 298 299 size_t old_original_length = original_.size(); 300 301 original_.append(data, length); 302 // TODO(ukai): fail fast when response gives wrong status code. 303 original_header_length_ = HttpUtil::LocateEndOfHeaders( 304 original_.data(), original_.size(), 0); 305 if (!HasResponse()) 306 return length; 307 308 ParseHandshakeHeader(original_.data(), 309 original_header_length_, 310 &status_line_, 311 &headers_); 312 int header_size = status_line_.size() + headers_.size(); 313 DCHECK_GE(original_header_length_, header_size); 314 header_separator_ = std::string(original_.data() + header_size, 315 original_header_length_ - header_size); 316 key_ = std::string(original_.data() + original_header_length_, 317 kResponseKeySize); 318 319 return original_header_length_ + kResponseKeySize - old_original_length; 320 } 321 322 bool WebSocketHandshakeResponseHandler::HasResponse() const { 323 return original_header_length_ > 0 && 324 original_header_length_ + kResponseKeySize <= original_.size(); 325 } 326 327 bool WebSocketHandshakeResponseHandler::ParseResponseInfo( 328 const HttpResponseInfo& response_info, 329 const std::string& challenge) { 330 if (!response_info.headers.get()) 331 return false; 332 333 std::string response_message; 334 response_message = response_info.headers->GetStatusLine(); 335 response_message += "\r\n"; 336 response_message += "Upgrade: WebSocket\r\n"; 337 response_message += "Connection: Upgrade\r\n"; 338 void* iter = NULL; 339 std::string name; 340 std::string value; 341 while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) { 342 response_message += name + ": " + value + "\r\n"; 343 } 344 response_message += "\r\n"; 345 346 MD5Digest digest; 347 MD5Sum(challenge.data(), challenge.size(), &digest); 348 349 const char* digest_data = reinterpret_cast<char*>(digest.a); 350 response_message.append(digest_data, sizeof(digest.a)); 351 352 return ParseRawResponse(response_message.data(), 353 response_message.size()) == response_message.size(); 354 } 355 356 bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock( 357 const spdy::SpdyHeaderBlock& headers, 358 const std::string& challenge) { 359 std::string response_message; 360 response_message = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"; 361 response_message += "Upgrade: WebSocket\r\n"; 362 response_message += "Connection: Upgrade\r\n"; 363 for (spdy::SpdyHeaderBlock::const_iterator iter = headers.begin(); 364 iter != headers.end(); 365 ++iter) { 366 // For each value, if the server sends a NUL-separated list of values, 367 // we separate that back out into individual headers for each value 368 // in the list. 369 const std::string& value = iter->second; 370 size_t start = 0; 371 size_t end = 0; 372 do { 373 end = value.find('\0', start); 374 std::string tval; 375 if (end != std::string::npos) 376 tval = value.substr(start, (end - start)); 377 else 378 tval = value.substr(start); 379 response_message += iter->first + ": " + tval + "\r\n"; 380 start = end + 1; 381 } while (end != std::string::npos); 382 } 383 response_message += "\r\n"; 384 385 MD5Digest digest; 386 MD5Sum(challenge.data(), challenge.size(), &digest); 387 388 const char* digest_data = reinterpret_cast<char*>(digest.a); 389 response_message.append(digest_data, sizeof(digest.a)); 390 391 return ParseRawResponse(response_message.data(), 392 response_message.size()) == response_message.size(); 393 } 394 395 void WebSocketHandshakeResponseHandler::GetHeaders( 396 const char* const headers_to_get[], 397 size_t headers_to_get_len, 398 std::vector<std::string>* values) { 399 DCHECK(HasResponse()); 400 DCHECK(!status_line_.empty()); 401 DCHECK(!headers_.empty()); 402 DCHECK_EQ(kResponseKeySize, key_.size()); 403 404 FetchHeaders(headers_, headers_to_get, headers_to_get_len, values); 405 } 406 407 void WebSocketHandshakeResponseHandler::RemoveHeaders( 408 const char* const headers_to_remove[], 409 size_t headers_to_remove_len) { 410 DCHECK(HasResponse()); 411 DCHECK(!status_line_.empty()); 412 DCHECK(!headers_.empty()); 413 DCHECK_EQ(kResponseKeySize, key_.size()); 414 415 headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len); 416 } 417 418 std::string WebSocketHandshakeResponseHandler::GetRawResponse() const { 419 DCHECK(HasResponse()); 420 return std::string(original_.data(), 421 original_header_length_ + kResponseKeySize); 422 } 423 424 std::string WebSocketHandshakeResponseHandler::GetResponse() { 425 DCHECK(HasResponse()); 426 DCHECK(!status_line_.empty()); 427 // headers_ might be empty for wrong response from server. 428 DCHECK_EQ(kResponseKeySize, key_.size()); 429 430 return status_line_ + headers_ + header_separator_ + key_; 431 } 432 433 } // namespace net 434