Home | History | Annotate | Download | only in server
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/server/web_socket.h"
      6 
      7 #include <limits>
      8 
      9 #include "base/base64.h"
     10 #include "base/rand_util.h"
     11 #include "base/logging.h"
     12 #include "base/md5.h"
     13 #include "base/sha1.h"
     14 #include "base/strings/string_number_conversions.h"
     15 #include "base/strings/stringprintf.h"
     16 #include "base/sys_byteorder.h"
     17 #include "net/server/http_connection.h"
     18 #include "net/server/http_server_request_info.h"
     19 #include "net/server/http_server_response_info.h"
     20 
     21 namespace net {
     22 
     23 namespace {
     24 
     25 static uint32 WebSocketKeyFingerprint(const std::string& str) {
     26   std::string result;
     27   const char* p_char = str.c_str();
     28   int length = str.length();
     29   int spaces = 0;
     30   for (int i = 0; i < length; ++i) {
     31     if (p_char[i] >= '0' && p_char[i] <= '9')
     32       result.append(&p_char[i], 1);
     33     else if (p_char[i] == ' ')
     34       spaces++;
     35   }
     36   if (spaces == 0)
     37     return 0;
     38   int64 number = 0;
     39   if (!base::StringToInt64(result, &number))
     40     return 0;
     41   return base::HostToNet32(static_cast<uint32>(number / spaces));
     42 }
     43 
     44 class WebSocketHixie76 : public net::WebSocket {
     45  public:
     46   static net::WebSocket* Create(HttpConnection* connection,
     47                                 const HttpServerRequestInfo& request,
     48                                 size_t* pos) {
     49     if (connection->recv_data().length() < *pos + kWebSocketHandshakeBodyLen)
     50       return NULL;
     51     return new WebSocketHixie76(connection, request, pos);
     52   }
     53 
     54   virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE {
     55     std::string key1 = request.GetHeaderValue("sec-websocket-key1");
     56     std::string key2 = request.GetHeaderValue("sec-websocket-key2");
     57 
     58     uint32 fp1 = WebSocketKeyFingerprint(key1);
     59     uint32 fp2 = WebSocketKeyFingerprint(key2);
     60 
     61     char data[16];
     62     memcpy(data, &fp1, 4);
     63     memcpy(data + 4, &fp2, 4);
     64     memcpy(data + 8, &key3_[0], 8);
     65 
     66     base::MD5Digest digest;
     67     base::MD5Sum(data, 16, &digest);
     68 
     69     std::string origin = request.GetHeaderValue("origin");
     70     std::string host = request.GetHeaderValue("host");
     71     std::string location = "ws://" + host + request.path;
     72     connection_->Send(base::StringPrintf(
     73         "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
     74         "Upgrade: WebSocket\r\n"
     75         "Connection: Upgrade\r\n"
     76         "Sec-WebSocket-Origin: %s\r\n"
     77         "Sec-WebSocket-Location: %s\r\n"
     78         "\r\n",
     79         origin.c_str(),
     80         location.c_str()));
     81     connection_->Send(reinterpret_cast<char*>(digest.a), 16);
     82   }
     83 
     84   virtual ParseResult Read(std::string* message) OVERRIDE {
     85     DCHECK(message);
     86     const std::string& data = connection_->recv_data();
     87     if (data[0])
     88       return FRAME_ERROR;
     89 
     90     size_t pos = data.find('\377', 1);
     91     if (pos == std::string::npos)
     92       return FRAME_INCOMPLETE;
     93 
     94     std::string buffer(data.begin() + 1, data.begin() + pos);
     95     message->swap(buffer);
     96     connection_->Shift(pos + 1);
     97 
     98     return FRAME_OK;
     99   }
    100 
    101   virtual void Send(const std::string& message) OVERRIDE {
    102     char message_start = 0;
    103     char message_end = -1;
    104     connection_->Send(&message_start, 1);
    105     connection_->Send(message);
    106     connection_->Send(&message_end, 1);
    107   }
    108 
    109  private:
    110   static const int kWebSocketHandshakeBodyLen;
    111 
    112   WebSocketHixie76(HttpConnection* connection,
    113                    const HttpServerRequestInfo& request,
    114                    size_t* pos) : WebSocket(connection) {
    115     std::string key1 = request.GetHeaderValue("sec-websocket-key1");
    116     std::string key2 = request.GetHeaderValue("sec-websocket-key2");
    117 
    118     if (key1.empty()) {
    119       connection->Send(HttpServerResponseInfo::CreateFor500(
    120           "Invalid request format. Sec-WebSocket-Key1 is empty or isn't "
    121           "specified."));
    122       return;
    123     }
    124 
    125     if (key2.empty()) {
    126       connection->Send(HttpServerResponseInfo::CreateFor500(
    127           "Invalid request format. Sec-WebSocket-Key2 is empty or isn't "
    128           "specified."));
    129       return;
    130     }
    131 
    132     key3_ = connection->recv_data().substr(
    133         *pos,
    134         *pos + kWebSocketHandshakeBodyLen);
    135     *pos += kWebSocketHandshakeBodyLen;
    136   }
    137 
    138   std::string key3_;
    139 
    140   DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76);
    141 };
    142 
    143 const int WebSocketHixie76::kWebSocketHandshakeBodyLen = 8;
    144 
    145 
    146 // Constants for hybi-10 frame format.
    147 
    148 typedef int OpCode;
    149 
    150 const OpCode kOpCodeContinuation = 0x0;
    151 const OpCode kOpCodeText = 0x1;
    152 const OpCode kOpCodeBinary = 0x2;
    153 const OpCode kOpCodeClose = 0x8;
    154 const OpCode kOpCodePing = 0x9;
    155 const OpCode kOpCodePong = 0xA;
    156 
    157 const unsigned char kFinalBit = 0x80;
    158 const unsigned char kReserved1Bit = 0x40;
    159 const unsigned char kReserved2Bit = 0x20;
    160 const unsigned char kReserved3Bit = 0x10;
    161 const unsigned char kOpCodeMask = 0xF;
    162 const unsigned char kMaskBit = 0x80;
    163 const unsigned char kPayloadLengthMask = 0x7F;
    164 
    165 const size_t kMaxSingleBytePayloadLength = 125;
    166 const size_t kTwoBytePayloadLengthField = 126;
    167 const size_t kEightBytePayloadLengthField = 127;
    168 const size_t kMaskingKeyWidthInBytes = 4;
    169 
    170 class WebSocketHybi17 : public WebSocket {
    171  public:
    172   static WebSocket* Create(HttpConnection* connection,
    173                            const HttpServerRequestInfo& request,
    174                            size_t* pos) {
    175     std::string version = request.GetHeaderValue("sec-websocket-version");
    176     if (version != "8" && version != "13")
    177       return NULL;
    178 
    179     std::string key = request.GetHeaderValue("sec-websocket-key");
    180     if (key.empty()) {
    181       connection->Send(HttpServerResponseInfo::CreateFor500(
    182           "Invalid request format. Sec-WebSocket-Key is empty or isn't "
    183           "specified."));
    184       return NULL;
    185     }
    186     return new WebSocketHybi17(connection, request, pos);
    187   }
    188 
    189   virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE {
    190     static const char* const kWebSocketGuid =
    191         "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    192     std::string key = request.GetHeaderValue("sec-websocket-key");
    193     std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid);
    194     std::string encoded_hash;
    195     base::Base64Encode(base::SHA1HashString(data), &encoded_hash);
    196 
    197     std::string response = base::StringPrintf(
    198         "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    199         "Upgrade: WebSocket\r\n"
    200         "Connection: Upgrade\r\n"
    201         "Sec-WebSocket-Accept: %s\r\n"
    202         "\r\n",
    203         encoded_hash.c_str());
    204     connection_->Send(response);
    205   }
    206 
    207   virtual ParseResult Read(std::string* message) OVERRIDE {
    208     const std::string& frame = connection_->recv_data();
    209     int bytes_consumed = 0;
    210 
    211     ParseResult result =
    212         WebSocket::DecodeFrameHybi17(frame, true, &bytes_consumed, message);
    213     if (result == FRAME_OK)
    214       connection_->Shift(bytes_consumed);
    215     if (result == FRAME_CLOSE)
    216       closed_ = true;
    217     return result;
    218   }
    219 
    220   virtual void Send(const std::string& message) OVERRIDE {
    221     if (closed_)
    222       return;
    223     std::string data = WebSocket::EncodeFrameHybi17(message, 0);
    224     connection_->Send(data);
    225   }
    226 
    227  private:
    228   WebSocketHybi17(HttpConnection* connection,
    229                   const HttpServerRequestInfo& request,
    230                   size_t* pos)
    231     : WebSocket(connection),
    232       op_code_(0),
    233       final_(false),
    234       reserved1_(false),
    235       reserved2_(false),
    236       reserved3_(false),
    237       masked_(false),
    238       payload_(0),
    239       payload_length_(0),
    240       frame_end_(0),
    241       closed_(false) {
    242   }
    243 
    244   OpCode op_code_;
    245   bool final_;
    246   bool reserved1_;
    247   bool reserved2_;
    248   bool reserved3_;
    249   bool masked_;
    250   const char* payload_;
    251   size_t payload_length_;
    252   const char* frame_end_;
    253   bool closed_;
    254 
    255   DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17);
    256 };
    257 
    258 }  // anonymous namespace
    259 
    260 WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection,
    261                                       const HttpServerRequestInfo& request,
    262                                       size_t* pos) {
    263   WebSocket* socket = WebSocketHybi17::Create(connection, request, pos);
    264   if (socket)
    265     return socket;
    266 
    267   return WebSocketHixie76::Create(connection, request, pos);
    268 }
    269 
    270 // static
    271 WebSocket::ParseResult WebSocket::DecodeFrameHybi17(const std::string& frame,
    272                                                     bool client_frame,
    273                                                     int* bytes_consumed,
    274                                                     std::string* output) {
    275   size_t data_length = frame.length();
    276   if (data_length < 2)
    277     return FRAME_INCOMPLETE;
    278 
    279   const char* buffer_begin = const_cast<char*>(frame.data());
    280   const char* p = buffer_begin;
    281   const char* buffer_end = p + data_length;
    282 
    283   unsigned char first_byte = *p++;
    284   unsigned char second_byte = *p++;
    285 
    286   bool final = (first_byte & kFinalBit) != 0;
    287   bool reserved1 = (first_byte & kReserved1Bit) != 0;
    288   bool reserved2 = (first_byte & kReserved2Bit) != 0;
    289   bool reserved3 = (first_byte & kReserved3Bit) != 0;
    290   int op_code = first_byte & kOpCodeMask;
    291   bool masked = (second_byte & kMaskBit) != 0;
    292   if (!final || reserved1 || reserved2 || reserved3)
    293     return FRAME_ERROR;  // Extensions and not supported.
    294 
    295   bool closed = false;
    296   switch (op_code) {
    297   case kOpCodeClose:
    298     closed = true;
    299     break;
    300   case kOpCodeText:
    301     break;
    302   case kOpCodeBinary: // We don't support binary frames yet.
    303   case kOpCodeContinuation: // We don't support binary frames yet.
    304   case kOpCodePing: // We don't support binary frames yet.
    305   case kOpCodePong: // We don't support binary frames yet.
    306   default:
    307     return FRAME_ERROR;
    308   }
    309 
    310   if (client_frame && !masked) // In Hybi-17 spec client MUST mask his frame.
    311     return FRAME_ERROR;
    312 
    313   uint64 payload_length64 = second_byte & kPayloadLengthMask;
    314   if (payload_length64 > kMaxSingleBytePayloadLength) {
    315     int extended_payload_length_size;
    316     if (payload_length64 == kTwoBytePayloadLengthField)
    317       extended_payload_length_size = 2;
    318     else {
    319       DCHECK(payload_length64 == kEightBytePayloadLengthField);
    320       extended_payload_length_size = 8;
    321     }
    322     if (buffer_end - p < extended_payload_length_size)
    323       return FRAME_INCOMPLETE;
    324     payload_length64 = 0;
    325     for (int i = 0; i < extended_payload_length_size; ++i) {
    326       payload_length64 <<= 8;
    327       payload_length64 |= static_cast<unsigned char>(*p++);
    328     }
    329   }
    330 
    331   size_t actual_masking_key_length = masked ? kMaskingKeyWidthInBytes : 0;
    332   static const uint64 max_payload_length = 0x7FFFFFFFFFFFFFFFull;
    333   static size_t max_length = std::numeric_limits<size_t>::max();
    334   if (payload_length64 > max_payload_length ||
    335       payload_length64 + actual_masking_key_length > max_length) {
    336     // WebSocket frame length too large.
    337     return FRAME_ERROR;
    338   }
    339   size_t payload_length = static_cast<size_t>(payload_length64);
    340 
    341   size_t total_length = actual_masking_key_length + payload_length;
    342   if (static_cast<size_t>(buffer_end - p) < total_length)
    343     return FRAME_INCOMPLETE;
    344 
    345   if (masked) {
    346     output->resize(payload_length);
    347     const char* masking_key = p;
    348     char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes);
    349     for (size_t i = 0; i < payload_length; ++i)  // Unmask the payload.
    350       (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes];
    351   } else {
    352     std::string buffer(p, p + payload_length);
    353     output->swap(buffer);
    354   }
    355 
    356   size_t pos = p + actual_masking_key_length + payload_length - buffer_begin;
    357   *bytes_consumed = pos;
    358   return closed ? FRAME_CLOSE : FRAME_OK;
    359 }
    360 
    361 // static
    362 std::string WebSocket::EncodeFrameHybi17(const std::string& message,
    363                                          int masking_key) {
    364   std::vector<char> frame;
    365   OpCode op_code = kOpCodeText;
    366   size_t data_length = message.length();
    367 
    368   frame.push_back(kFinalBit | op_code);
    369   char mask_key_bit = masking_key != 0 ? kMaskBit : 0;
    370   if (data_length <= kMaxSingleBytePayloadLength)
    371     frame.push_back(data_length | mask_key_bit);
    372   else if (data_length <= 0xFFFF) {
    373     frame.push_back(kTwoBytePayloadLengthField | mask_key_bit);
    374     frame.push_back((data_length & 0xFF00) >> 8);
    375     frame.push_back(data_length & 0xFF);
    376   } else {
    377     frame.push_back(kEightBytePayloadLengthField | mask_key_bit);
    378     char extended_payload_length[8];
    379     size_t remaining = data_length;
    380     // Fill the length into extended_payload_length in the network byte order.
    381     for (int i = 0; i < 8; ++i) {
    382       extended_payload_length[7 - i] = remaining & 0xFF;
    383       remaining >>= 8;
    384     }
    385     frame.insert(frame.end(),
    386                  extended_payload_length,
    387                  extended_payload_length + 8);
    388     DCHECK(!remaining);
    389   }
    390 
    391   const char* data = const_cast<char*>(message.data());
    392   if (masking_key != 0) {
    393     const char* mask_bytes = reinterpret_cast<char*>(&masking_key);
    394     frame.insert(frame.end(), mask_bytes, mask_bytes + 4);
    395     for (size_t i = 0; i < data_length; ++i)  // Mask the payload.
    396       frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]);
    397   } else {
    398     frame.insert(frame.end(), data, data + data_length);
    399   }
    400   return std::string(&frame[0], frame.size());
    401 }
    402 
    403 WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) {
    404 }
    405 
    406 }  // namespace net
    407