1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/server/web_socket.h" 6 7 #include <limits> 8 9 #include "base/base64.h" 10 #include "base/rand_util.h" 11 #include "base/logging.h" 12 #include "base/md5.h" 13 #include "base/sha1.h" 14 #include "base/strings/string_number_conversions.h" 15 #include "base/strings/stringprintf.h" 16 #include "base/sys_byteorder.h" 17 #include "net/server/http_connection.h" 18 #include "net/server/http_server_request_info.h" 19 #include "net/server/http_server_response_info.h" 20 21 namespace net { 22 23 namespace { 24 25 static uint32 WebSocketKeyFingerprint(const std::string& str) { 26 std::string result; 27 const char* p_char = str.c_str(); 28 int length = str.length(); 29 int spaces = 0; 30 for (int i = 0; i < length; ++i) { 31 if (p_char[i] >= '0' && p_char[i] <= '9') 32 result.append(&p_char[i], 1); 33 else if (p_char[i] == ' ') 34 spaces++; 35 } 36 if (spaces == 0) 37 return 0; 38 int64 number = 0; 39 if (!base::StringToInt64(result, &number)) 40 return 0; 41 return base::HostToNet32(static_cast<uint32>(number / spaces)); 42 } 43 44 class WebSocketHixie76 : public net::WebSocket { 45 public: 46 static net::WebSocket* Create(HttpConnection* connection, 47 const HttpServerRequestInfo& request, 48 size_t* pos) { 49 if (connection->recv_data().length() < *pos + kWebSocketHandshakeBodyLen) 50 return NULL; 51 return new WebSocketHixie76(connection, request, pos); 52 } 53 54 virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { 55 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); 56 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); 57 58 uint32 fp1 = WebSocketKeyFingerprint(key1); 59 uint32 fp2 = WebSocketKeyFingerprint(key2); 60 61 char data[16]; 62 memcpy(data, &fp1, 4); 63 memcpy(data + 4, &fp2, 4); 64 memcpy(data + 8, &key3_[0], 8); 65 66 base::MD5Digest digest; 67 base::MD5Sum(data, 16, &digest); 68 69 std::string origin = request.GetHeaderValue("origin"); 70 std::string host = request.GetHeaderValue("host"); 71 std::string location = "ws://" + host + request.path; 72 connection_->Send(base::StringPrintf( 73 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" 74 "Upgrade: WebSocket\r\n" 75 "Connection: Upgrade\r\n" 76 "Sec-WebSocket-Origin: %s\r\n" 77 "Sec-WebSocket-Location: %s\r\n" 78 "\r\n", 79 origin.c_str(), 80 location.c_str())); 81 connection_->Send(reinterpret_cast<char*>(digest.a), 16); 82 } 83 84 virtual ParseResult Read(std::string* message) OVERRIDE { 85 DCHECK(message); 86 const std::string& data = connection_->recv_data(); 87 if (data[0]) 88 return FRAME_ERROR; 89 90 size_t pos = data.find('\377', 1); 91 if (pos == std::string::npos) 92 return FRAME_INCOMPLETE; 93 94 std::string buffer(data.begin() + 1, data.begin() + pos); 95 message->swap(buffer); 96 connection_->Shift(pos + 1); 97 98 return FRAME_OK; 99 } 100 101 virtual void Send(const std::string& message) OVERRIDE { 102 char message_start = 0; 103 char message_end = -1; 104 connection_->Send(&message_start, 1); 105 connection_->Send(message); 106 connection_->Send(&message_end, 1); 107 } 108 109 private: 110 static const int kWebSocketHandshakeBodyLen; 111 112 WebSocketHixie76(HttpConnection* connection, 113 const HttpServerRequestInfo& request, 114 size_t* pos) : WebSocket(connection) { 115 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); 116 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); 117 118 if (key1.empty()) { 119 connection->Send(HttpServerResponseInfo::CreateFor500( 120 "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " 121 "specified.")); 122 return; 123 } 124 125 if (key2.empty()) { 126 connection->Send(HttpServerResponseInfo::CreateFor500( 127 "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " 128 "specified.")); 129 return; 130 } 131 132 key3_ = connection->recv_data().substr( 133 *pos, 134 *pos + kWebSocketHandshakeBodyLen); 135 *pos += kWebSocketHandshakeBodyLen; 136 } 137 138 std::string key3_; 139 140 DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76); 141 }; 142 143 const int WebSocketHixie76::kWebSocketHandshakeBodyLen = 8; 144 145 146 // Constants for hybi-10 frame format. 147 148 typedef int OpCode; 149 150 const OpCode kOpCodeContinuation = 0x0; 151 const OpCode kOpCodeText = 0x1; 152 const OpCode kOpCodeBinary = 0x2; 153 const OpCode kOpCodeClose = 0x8; 154 const OpCode kOpCodePing = 0x9; 155 const OpCode kOpCodePong = 0xA; 156 157 const unsigned char kFinalBit = 0x80; 158 const unsigned char kReserved1Bit = 0x40; 159 const unsigned char kReserved2Bit = 0x20; 160 const unsigned char kReserved3Bit = 0x10; 161 const unsigned char kOpCodeMask = 0xF; 162 const unsigned char kMaskBit = 0x80; 163 const unsigned char kPayloadLengthMask = 0x7F; 164 165 const size_t kMaxSingleBytePayloadLength = 125; 166 const size_t kTwoBytePayloadLengthField = 126; 167 const size_t kEightBytePayloadLengthField = 127; 168 const size_t kMaskingKeyWidthInBytes = 4; 169 170 class WebSocketHybi17 : public WebSocket { 171 public: 172 static WebSocket* Create(HttpConnection* connection, 173 const HttpServerRequestInfo& request, 174 size_t* pos) { 175 std::string version = request.GetHeaderValue("sec-websocket-version"); 176 if (version != "8" && version != "13") 177 return NULL; 178 179 std::string key = request.GetHeaderValue("sec-websocket-key"); 180 if (key.empty()) { 181 connection->Send(HttpServerResponseInfo::CreateFor500( 182 "Invalid request format. Sec-WebSocket-Key is empty or isn't " 183 "specified.")); 184 return NULL; 185 } 186 return new WebSocketHybi17(connection, request, pos); 187 } 188 189 virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { 190 static const char* const kWebSocketGuid = 191 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 192 std::string key = request.GetHeaderValue("sec-websocket-key"); 193 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); 194 std::string encoded_hash; 195 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); 196 197 std::string response = base::StringPrintf( 198 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" 199 "Upgrade: WebSocket\r\n" 200 "Connection: Upgrade\r\n" 201 "Sec-WebSocket-Accept: %s\r\n" 202 "\r\n", 203 encoded_hash.c_str()); 204 connection_->Send(response); 205 } 206 207 virtual ParseResult Read(std::string* message) OVERRIDE { 208 const std::string& frame = connection_->recv_data(); 209 int bytes_consumed = 0; 210 211 ParseResult result = 212 WebSocket::DecodeFrameHybi17(frame, true, &bytes_consumed, message); 213 if (result == FRAME_OK) 214 connection_->Shift(bytes_consumed); 215 if (result == FRAME_CLOSE) 216 closed_ = true; 217 return result; 218 } 219 220 virtual void Send(const std::string& message) OVERRIDE { 221 if (closed_) 222 return; 223 std::string data = WebSocket::EncodeFrameHybi17(message, 0); 224 connection_->Send(data); 225 } 226 227 private: 228 WebSocketHybi17(HttpConnection* connection, 229 const HttpServerRequestInfo& request, 230 size_t* pos) 231 : WebSocket(connection), 232 op_code_(0), 233 final_(false), 234 reserved1_(false), 235 reserved2_(false), 236 reserved3_(false), 237 masked_(false), 238 payload_(0), 239 payload_length_(0), 240 frame_end_(0), 241 closed_(false) { 242 } 243 244 OpCode op_code_; 245 bool final_; 246 bool reserved1_; 247 bool reserved2_; 248 bool reserved3_; 249 bool masked_; 250 const char* payload_; 251 size_t payload_length_; 252 const char* frame_end_; 253 bool closed_; 254 255 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17); 256 }; 257 258 } // anonymous namespace 259 260 WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection, 261 const HttpServerRequestInfo& request, 262 size_t* pos) { 263 WebSocket* socket = WebSocketHybi17::Create(connection, request, pos); 264 if (socket) 265 return socket; 266 267 return WebSocketHixie76::Create(connection, request, pos); 268 } 269 270 // static 271 WebSocket::ParseResult WebSocket::DecodeFrameHybi17(const std::string& frame, 272 bool client_frame, 273 int* bytes_consumed, 274 std::string* output) { 275 size_t data_length = frame.length(); 276 if (data_length < 2) 277 return FRAME_INCOMPLETE; 278 279 const char* buffer_begin = const_cast<char*>(frame.data()); 280 const char* p = buffer_begin; 281 const char* buffer_end = p + data_length; 282 283 unsigned char first_byte = *p++; 284 unsigned char second_byte = *p++; 285 286 bool final = (first_byte & kFinalBit) != 0; 287 bool reserved1 = (first_byte & kReserved1Bit) != 0; 288 bool reserved2 = (first_byte & kReserved2Bit) != 0; 289 bool reserved3 = (first_byte & kReserved3Bit) != 0; 290 int op_code = first_byte & kOpCodeMask; 291 bool masked = (second_byte & kMaskBit) != 0; 292 if (!final || reserved1 || reserved2 || reserved3) 293 return FRAME_ERROR; // Extensions and not supported. 294 295 bool closed = false; 296 switch (op_code) { 297 case kOpCodeClose: 298 closed = true; 299 break; 300 case kOpCodeText: 301 break; 302 case kOpCodeBinary: // We don't support binary frames yet. 303 case kOpCodeContinuation: // We don't support binary frames yet. 304 case kOpCodePing: // We don't support binary frames yet. 305 case kOpCodePong: // We don't support binary frames yet. 306 default: 307 return FRAME_ERROR; 308 } 309 310 if (client_frame && !masked) // In Hybi-17 spec client MUST mask his frame. 311 return FRAME_ERROR; 312 313 uint64 payload_length64 = second_byte & kPayloadLengthMask; 314 if (payload_length64 > kMaxSingleBytePayloadLength) { 315 int extended_payload_length_size; 316 if (payload_length64 == kTwoBytePayloadLengthField) 317 extended_payload_length_size = 2; 318 else { 319 DCHECK(payload_length64 == kEightBytePayloadLengthField); 320 extended_payload_length_size = 8; 321 } 322 if (buffer_end - p < extended_payload_length_size) 323 return FRAME_INCOMPLETE; 324 payload_length64 = 0; 325 for (int i = 0; i < extended_payload_length_size; ++i) { 326 payload_length64 <<= 8; 327 payload_length64 |= static_cast<unsigned char>(*p++); 328 } 329 } 330 331 size_t actual_masking_key_length = masked ? kMaskingKeyWidthInBytes : 0; 332 static const uint64 max_payload_length = 0x7FFFFFFFFFFFFFFFull; 333 static size_t max_length = std::numeric_limits<size_t>::max(); 334 if (payload_length64 > max_payload_length || 335 payload_length64 + actual_masking_key_length > max_length) { 336 // WebSocket frame length too large. 337 return FRAME_ERROR; 338 } 339 size_t payload_length = static_cast<size_t>(payload_length64); 340 341 size_t total_length = actual_masking_key_length + payload_length; 342 if (static_cast<size_t>(buffer_end - p) < total_length) 343 return FRAME_INCOMPLETE; 344 345 if (masked) { 346 output->resize(payload_length); 347 const char* masking_key = p; 348 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes); 349 for (size_t i = 0; i < payload_length; ++i) // Unmask the payload. 350 (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; 351 } else { 352 std::string buffer(p, p + payload_length); 353 output->swap(buffer); 354 } 355 356 size_t pos = p + actual_masking_key_length + payload_length - buffer_begin; 357 *bytes_consumed = pos; 358 return closed ? FRAME_CLOSE : FRAME_OK; 359 } 360 361 // static 362 std::string WebSocket::EncodeFrameHybi17(const std::string& message, 363 int masking_key) { 364 std::vector<char> frame; 365 OpCode op_code = kOpCodeText; 366 size_t data_length = message.length(); 367 368 frame.push_back(kFinalBit | op_code); 369 char mask_key_bit = masking_key != 0 ? kMaskBit : 0; 370 if (data_length <= kMaxSingleBytePayloadLength) 371 frame.push_back(data_length | mask_key_bit); 372 else if (data_length <= 0xFFFF) { 373 frame.push_back(kTwoBytePayloadLengthField | mask_key_bit); 374 frame.push_back((data_length & 0xFF00) >> 8); 375 frame.push_back(data_length & 0xFF); 376 } else { 377 frame.push_back(kEightBytePayloadLengthField | mask_key_bit); 378 char extended_payload_length[8]; 379 size_t remaining = data_length; 380 // Fill the length into extended_payload_length in the network byte order. 381 for (int i = 0; i < 8; ++i) { 382 extended_payload_length[7 - i] = remaining & 0xFF; 383 remaining >>= 8; 384 } 385 frame.insert(frame.end(), 386 extended_payload_length, 387 extended_payload_length + 8); 388 DCHECK(!remaining); 389 } 390 391 const char* data = const_cast<char*>(message.data()); 392 if (masking_key != 0) { 393 const char* mask_bytes = reinterpret_cast<char*>(&masking_key); 394 frame.insert(frame.end(), mask_bytes, mask_bytes + 4); 395 for (size_t i = 0; i < data_length; ++i) // Mask the payload. 396 frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]); 397 } else { 398 frame.insert(frame.end(), data, data + data_length); 399 } 400 return std::string(&frame[0], frame.size()); 401 } 402 403 WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) { 404 } 405 406 } // namespace net 407