1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/server/web_socket.h" 6 7 #include <limits> 8 9 #include "base/base64.h" 10 #include "base/rand_util.h" 11 #include "base/logging.h" 12 #include "base/md5.h" 13 #include "base/sha1.h" 14 #include "base/strings/string_number_conversions.h" 15 #include "base/strings/stringprintf.h" 16 #include "base/sys_byteorder.h" 17 #include "net/server/http_connection.h" 18 #include "net/server/http_server.h" 19 #include "net/server/http_server_request_info.h" 20 #include "net/server/http_server_response_info.h" 21 22 namespace net { 23 24 namespace { 25 26 static uint32 WebSocketKeyFingerprint(const std::string& str) { 27 std::string result; 28 const char* p_char = str.c_str(); 29 int length = str.length(); 30 int spaces = 0; 31 for (int i = 0; i < length; ++i) { 32 if (p_char[i] >= '0' && p_char[i] <= '9') 33 result.append(&p_char[i], 1); 34 else if (p_char[i] == ' ') 35 spaces++; 36 } 37 if (spaces == 0) 38 return 0; 39 int64 number = 0; 40 if (!base::StringToInt64(result, &number)) 41 return 0; 42 return base::HostToNet32(static_cast<uint32>(number / spaces)); 43 } 44 45 class WebSocketHixie76 : public net::WebSocket { 46 public: 47 static net::WebSocket* Create(HttpServer* server, 48 HttpConnection* connection, 49 const HttpServerRequestInfo& request, 50 size_t* pos) { 51 if (connection->read_buf()->GetSize() < 52 static_cast<int>(*pos + kWebSocketHandshakeBodyLen)) 53 return NULL; 54 return new WebSocketHixie76(server, connection, request, pos); 55 } 56 57 virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { 58 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); 59 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); 60 61 uint32 fp1 = WebSocketKeyFingerprint(key1); 62 uint32 fp2 = WebSocketKeyFingerprint(key2); 63 64 char data[16]; 65 memcpy(data, &fp1, 4); 66 memcpy(data + 4, &fp2, 4); 67 memcpy(data + 8, &key3_[0], 8); 68 69 base::MD5Digest digest; 70 base::MD5Sum(data, 16, &digest); 71 72 std::string origin = request.GetHeaderValue("origin"); 73 std::string host = request.GetHeaderValue("host"); 74 std::string location = "ws://" + host + request.path; 75 server_->SendRaw( 76 connection_->id(), 77 base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" 78 "Upgrade: WebSocket\r\n" 79 "Connection: Upgrade\r\n" 80 "Sec-WebSocket-Origin: %s\r\n" 81 "Sec-WebSocket-Location: %s\r\n" 82 "\r\n", 83 origin.c_str(), 84 location.c_str())); 85 server_->SendRaw(connection_->id(), 86 std::string(reinterpret_cast<char*>(digest.a), 16)); 87 } 88 89 virtual ParseResult Read(std::string* message) OVERRIDE { 90 DCHECK(message); 91 HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); 92 if (read_buf->StartOfBuffer()[0]) 93 return FRAME_ERROR; 94 95 base::StringPiece data(read_buf->StartOfBuffer(), read_buf->GetSize()); 96 size_t pos = data.find('\377', 1); 97 if (pos == base::StringPiece::npos) 98 return FRAME_INCOMPLETE; 99 100 message->assign(data.data() + 1, pos - 1); 101 read_buf->DidConsume(pos + 1); 102 103 return FRAME_OK; 104 } 105 106 virtual void Send(const std::string& message) OVERRIDE { 107 char message_start = 0; 108 char message_end = -1; 109 server_->SendRaw(connection_->id(), std::string(1, message_start)); 110 server_->SendRaw(connection_->id(), message); 111 server_->SendRaw(connection_->id(), std::string(1, message_end)); 112 } 113 114 private: 115 static const int kWebSocketHandshakeBodyLen; 116 117 WebSocketHixie76(HttpServer* server, 118 HttpConnection* connection, 119 const HttpServerRequestInfo& request, 120 size_t* pos) 121 : WebSocket(server, connection) { 122 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); 123 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); 124 125 if (key1.empty()) { 126 server->SendResponse( 127 connection->id(), 128 HttpServerResponseInfo::CreateFor500( 129 "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " 130 "specified.")); 131 return; 132 } 133 134 if (key2.empty()) { 135 server->SendResponse( 136 connection->id(), 137 HttpServerResponseInfo::CreateFor500( 138 "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " 139 "specified.")); 140 return; 141 } 142 143 key3_.assign(connection->read_buf()->StartOfBuffer() + *pos, 144 kWebSocketHandshakeBodyLen); 145 *pos += kWebSocketHandshakeBodyLen; 146 } 147 148 std::string key3_; 149 150 DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76); 151 }; 152 153 const int WebSocketHixie76::kWebSocketHandshakeBodyLen = 8; 154 155 156 // Constants for hybi-10 frame format. 157 158 typedef int OpCode; 159 160 const OpCode kOpCodeContinuation = 0x0; 161 const OpCode kOpCodeText = 0x1; 162 const OpCode kOpCodeBinary = 0x2; 163 const OpCode kOpCodeClose = 0x8; 164 const OpCode kOpCodePing = 0x9; 165 const OpCode kOpCodePong = 0xA; 166 167 const unsigned char kFinalBit = 0x80; 168 const unsigned char kReserved1Bit = 0x40; 169 const unsigned char kReserved2Bit = 0x20; 170 const unsigned char kReserved3Bit = 0x10; 171 const unsigned char kOpCodeMask = 0xF; 172 const unsigned char kMaskBit = 0x80; 173 const unsigned char kPayloadLengthMask = 0x7F; 174 175 const size_t kMaxSingleBytePayloadLength = 125; 176 const size_t kTwoBytePayloadLengthField = 126; 177 const size_t kEightBytePayloadLengthField = 127; 178 const size_t kMaskingKeyWidthInBytes = 4; 179 180 class WebSocketHybi17 : public WebSocket { 181 public: 182 static WebSocket* Create(HttpServer* server, 183 HttpConnection* connection, 184 const HttpServerRequestInfo& request, 185 size_t* pos) { 186 std::string version = request.GetHeaderValue("sec-websocket-version"); 187 if (version != "8" && version != "13") 188 return NULL; 189 190 std::string key = request.GetHeaderValue("sec-websocket-key"); 191 if (key.empty()) { 192 server->SendResponse( 193 connection->id(), 194 HttpServerResponseInfo::CreateFor500( 195 "Invalid request format. Sec-WebSocket-Key is empty or isn't " 196 "specified.")); 197 return NULL; 198 } 199 return new WebSocketHybi17(server, connection, request, pos); 200 } 201 202 virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { 203 static const char* const kWebSocketGuid = 204 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 205 std::string key = request.GetHeaderValue("sec-websocket-key"); 206 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); 207 std::string encoded_hash; 208 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); 209 210 server_->SendRaw( 211 connection_->id(), 212 base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" 213 "Upgrade: WebSocket\r\n" 214 "Connection: Upgrade\r\n" 215 "Sec-WebSocket-Accept: %s\r\n" 216 "\r\n", 217 encoded_hash.c_str())); 218 } 219 220 virtual ParseResult Read(std::string* message) OVERRIDE { 221 HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); 222 base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); 223 int bytes_consumed = 0; 224 ParseResult result = 225 WebSocket::DecodeFrameHybi17(frame, true, &bytes_consumed, message); 226 if (result == FRAME_OK) 227 read_buf->DidConsume(bytes_consumed); 228 if (result == FRAME_CLOSE) 229 closed_ = true; 230 return result; 231 } 232 233 virtual void Send(const std::string& message) OVERRIDE { 234 if (closed_) 235 return; 236 server_->SendRaw(connection_->id(), 237 WebSocket::EncodeFrameHybi17(message, 0)); 238 } 239 240 private: 241 WebSocketHybi17(HttpServer* server, 242 HttpConnection* connection, 243 const HttpServerRequestInfo& request, 244 size_t* pos) 245 : WebSocket(server, connection), 246 op_code_(0), 247 final_(false), 248 reserved1_(false), 249 reserved2_(false), 250 reserved3_(false), 251 masked_(false), 252 payload_(0), 253 payload_length_(0), 254 frame_end_(0), 255 closed_(false) { 256 } 257 258 OpCode op_code_; 259 bool final_; 260 bool reserved1_; 261 bool reserved2_; 262 bool reserved3_; 263 bool masked_; 264 const char* payload_; 265 size_t payload_length_; 266 const char* frame_end_; 267 bool closed_; 268 269 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17); 270 }; 271 272 } // anonymous namespace 273 274 WebSocket* WebSocket::CreateWebSocket(HttpServer* server, 275 HttpConnection* connection, 276 const HttpServerRequestInfo& request, 277 size_t* pos) { 278 WebSocket* socket = WebSocketHybi17::Create(server, connection, request, pos); 279 if (socket) 280 return socket; 281 282 return WebSocketHixie76::Create(server, connection, request, pos); 283 } 284 285 // static 286 WebSocket::ParseResult WebSocket::DecodeFrameHybi17( 287 const base::StringPiece& frame, 288 bool client_frame, 289 int* bytes_consumed, 290 std::string* output) { 291 size_t data_length = frame.length(); 292 if (data_length < 2) 293 return FRAME_INCOMPLETE; 294 295 const char* buffer_begin = const_cast<char*>(frame.data()); 296 const char* p = buffer_begin; 297 const char* buffer_end = p + data_length; 298 299 unsigned char first_byte = *p++; 300 unsigned char second_byte = *p++; 301 302 bool final = (first_byte & kFinalBit) != 0; 303 bool reserved1 = (first_byte & kReserved1Bit) != 0; 304 bool reserved2 = (first_byte & kReserved2Bit) != 0; 305 bool reserved3 = (first_byte & kReserved3Bit) != 0; 306 int op_code = first_byte & kOpCodeMask; 307 bool masked = (second_byte & kMaskBit) != 0; 308 if (!final || reserved1 || reserved2 || reserved3) 309 return FRAME_ERROR; // Extensions and not supported. 310 311 bool closed = false; 312 switch (op_code) { 313 case kOpCodeClose: 314 closed = true; 315 break; 316 case kOpCodeText: 317 break; 318 case kOpCodeBinary: // We don't support binary frames yet. 319 case kOpCodeContinuation: // We don't support binary frames yet. 320 case kOpCodePing: // We don't support binary frames yet. 321 case kOpCodePong: // We don't support binary frames yet. 322 default: 323 return FRAME_ERROR; 324 } 325 326 if (client_frame && !masked) // In Hybi-17 spec client MUST mask his frame. 327 return FRAME_ERROR; 328 329 uint64 payload_length64 = second_byte & kPayloadLengthMask; 330 if (payload_length64 > kMaxSingleBytePayloadLength) { 331 int extended_payload_length_size; 332 if (payload_length64 == kTwoBytePayloadLengthField) 333 extended_payload_length_size = 2; 334 else { 335 DCHECK(payload_length64 == kEightBytePayloadLengthField); 336 extended_payload_length_size = 8; 337 } 338 if (buffer_end - p < extended_payload_length_size) 339 return FRAME_INCOMPLETE; 340 payload_length64 = 0; 341 for (int i = 0; i < extended_payload_length_size; ++i) { 342 payload_length64 <<= 8; 343 payload_length64 |= static_cast<unsigned char>(*p++); 344 } 345 } 346 347 size_t actual_masking_key_length = masked ? kMaskingKeyWidthInBytes : 0; 348 static const uint64 max_payload_length = 0x7FFFFFFFFFFFFFFFull; 349 static size_t max_length = std::numeric_limits<size_t>::max(); 350 if (payload_length64 > max_payload_length || 351 payload_length64 + actual_masking_key_length > max_length) { 352 // WebSocket frame length too large. 353 return FRAME_ERROR; 354 } 355 size_t payload_length = static_cast<size_t>(payload_length64); 356 357 size_t total_length = actual_masking_key_length + payload_length; 358 if (static_cast<size_t>(buffer_end - p) < total_length) 359 return FRAME_INCOMPLETE; 360 361 if (masked) { 362 output->resize(payload_length); 363 const char* masking_key = p; 364 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes); 365 for (size_t i = 0; i < payload_length; ++i) // Unmask the payload. 366 (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; 367 } else { 368 output->assign(p, p + payload_length); 369 } 370 371 size_t pos = p + actual_masking_key_length + payload_length - buffer_begin; 372 *bytes_consumed = pos; 373 return closed ? FRAME_CLOSE : FRAME_OK; 374 } 375 376 // static 377 std::string WebSocket::EncodeFrameHybi17(const std::string& message, 378 int masking_key) { 379 std::vector<char> frame; 380 OpCode op_code = kOpCodeText; 381 size_t data_length = message.length(); 382 383 frame.push_back(kFinalBit | op_code); 384 char mask_key_bit = masking_key != 0 ? kMaskBit : 0; 385 if (data_length <= kMaxSingleBytePayloadLength) 386 frame.push_back(data_length | mask_key_bit); 387 else if (data_length <= 0xFFFF) { 388 frame.push_back(kTwoBytePayloadLengthField | mask_key_bit); 389 frame.push_back((data_length & 0xFF00) >> 8); 390 frame.push_back(data_length & 0xFF); 391 } else { 392 frame.push_back(kEightBytePayloadLengthField | mask_key_bit); 393 char extended_payload_length[8]; 394 size_t remaining = data_length; 395 // Fill the length into extended_payload_length in the network byte order. 396 for (int i = 0; i < 8; ++i) { 397 extended_payload_length[7 - i] = remaining & 0xFF; 398 remaining >>= 8; 399 } 400 frame.insert(frame.end(), 401 extended_payload_length, 402 extended_payload_length + 8); 403 DCHECK(!remaining); 404 } 405 406 const char* data = const_cast<char*>(message.data()); 407 if (masking_key != 0) { 408 const char* mask_bytes = reinterpret_cast<char*>(&masking_key); 409 frame.insert(frame.end(), mask_bytes, mask_bytes + 4); 410 for (size_t i = 0; i < data_length; ++i) // Mask the payload. 411 frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]); 412 } else { 413 frame.insert(frame.end(), data, data + data_length); 414 } 415 return std::string(&frame[0], frame.size()); 416 } 417 418 WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) 419 : server_(server), 420 connection_(connection) { 421 } 422 423 } // namespace net 424