Home | History | Annotate | Download | only in websockets
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <string>
      6 #include <vector>
      7 
      8 #include "base/memory/scoped_ptr.h"
      9 #include "base/string_split.h"
     10 #include "base/string_util.h"
     11 #include "base/stringprintf.h"
     12 #include "net/websockets/websocket_handshake.h"
     13 #include "testing/gmock/include/gmock/gmock.h"
     14 #include "testing/gtest/include/gtest/gtest.h"
     15 #include "testing/platform_test.h"
     16 
     17 namespace net {
     18 
     19 class WebSocketHandshakeTest : public testing::Test {
     20  public:
     21   static void SetUpParameter(WebSocketHandshake* handshake,
     22                              uint32 number_1, uint32 number_2,
     23                              const std::string& key_1, const std::string& key_2,
     24                              const std::string& key_3) {
     25     WebSocketHandshake::Parameter* parameter =
     26         new WebSocketHandshake::Parameter;
     27     parameter->number_1_ = number_1;
     28     parameter->number_2_ = number_2;
     29     parameter->key_1_ = key_1;
     30     parameter->key_2_ = key_2;
     31     parameter->key_3_ = key_3;
     32     handshake->parameter_.reset(parameter);
     33   }
     34 
     35   static void ExpectHeaderEquals(const std::string& expected,
     36                           const std::string& actual) {
     37     std::vector<std::string> expected_lines;
     38     Tokenize(expected, "\r\n", &expected_lines);
     39     std::vector<std::string> actual_lines;
     40     Tokenize(actual, "\r\n", &actual_lines);
     41     // Request lines.
     42     EXPECT_EQ(expected_lines[0], actual_lines[0]);
     43 
     44     std::vector<std::string> expected_headers;
     45     for (size_t i = 1; i < expected_lines.size(); i++) {
     46       // Finish at first CRLF CRLF.  Note that /key_3/ might include CRLF.
     47       if (expected_lines[i] == "")
     48         break;
     49       expected_headers.push_back(expected_lines[i]);
     50     }
     51     sort(expected_headers.begin(), expected_headers.end());
     52 
     53     std::vector<std::string> actual_headers;
     54     for (size_t i = 1; i < actual_lines.size(); i++) {
     55       // Finish at first CRLF CRLF.  Note that /key_3/ might include CRLF.
     56       if (actual_lines[i] == "")
     57         break;
     58       actual_headers.push_back(actual_lines[i]);
     59     }
     60     sort(actual_headers.begin(), actual_headers.end());
     61 
     62     EXPECT_EQ(expected_headers.size(), actual_headers.size())
     63         << "expected:" << expected
     64         << "\nactual:" << actual;
     65     for (size_t i = 0; i < expected_headers.size(); i++) {
     66       EXPECT_EQ(expected_headers[i], actual_headers[i]);
     67     }
     68   }
     69 
     70   static void ExpectHandshakeMessageEquals(const std::string& expected,
     71                                            const std::string& actual) {
     72     // Headers.
     73     ExpectHeaderEquals(expected, actual);
     74     // Compare tailing \r\n\r\n<key3> (4 + 8 bytes).
     75     ASSERT_GT(expected.size(), 12U);
     76     const char* expected_key3 = expected.data() + expected.size() - 12;
     77     EXPECT_GT(actual.size(), 12U);
     78     if (actual.size() <= 12U)
     79       return;
     80     const char* actual_key3 = actual.data() + actual.size() - 12;
     81     EXPECT_TRUE(memcmp(expected_key3, actual_key3, 12) == 0)
     82         << "expected_key3:" << DumpKey(expected_key3, 12)
     83         << ", actual_key3:" << DumpKey(actual_key3, 12);
     84   }
     85 
     86   static std::string DumpKey(const char* buf, int len) {
     87     std::string s;
     88     for (int i = 0; i < len; i++) {
     89       if (isprint(buf[i]))
     90         s += base::StringPrintf("%c", buf[i]);
     91       else
     92         s += base::StringPrintf("\\x%02x", buf[i]);
     93     }
     94     return s;
     95   }
     96 
     97   static std::string GetResourceName(WebSocketHandshake* handshake) {
     98     return handshake->GetResourceName();
     99   }
    100   static std::string GetHostFieldValue(WebSocketHandshake* handshake) {
    101     return handshake->GetHostFieldValue();
    102   }
    103   static std::string GetOriginFieldValue(WebSocketHandshake* handshake) {
    104     return handshake->GetOriginFieldValue();
    105   }
    106 };
    107 
    108 
    109 TEST_F(WebSocketHandshakeTest, Connect) {
    110   const std::string kExpectedClientHandshakeMessage =
    111       "GET /demo HTTP/1.1\r\n"
    112       "Upgrade: WebSocket\r\n"
    113       "Connection: Upgrade\r\n"
    114       "Host: example.com\r\n"
    115       "Origin: http://example.com\r\n"
    116       "Sec-WebSocket-Protocol: sample\r\n"
    117       "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7  15\r\n"
    118       "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o  4 I97N 5-S3O 31\r\n"
    119       "\r\n"
    120       "\x47\x30\x22\x2D\x5A\x3F\x47\x58";
    121 
    122   scoped_ptr<WebSocketHandshake> handshake(
    123       new WebSocketHandshake(GURL("ws://example.com/demo"),
    124                              "http://example.com",
    125                              "ws://example.com/demo",
    126                              "sample"));
    127   SetUpParameter(handshake.get(), 777007543U, 114997259U,
    128                  "388P O503D&ul7 {K%gX( %7  15",
    129                  "1 N ?|k UT0or 3o  4 I97N 5-S3O 31",
    130                  std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8));
    131   EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
    132   ExpectHandshakeMessageEquals(
    133       kExpectedClientHandshakeMessage,
    134       handshake->CreateClientHandshakeMessage());
    135 
    136   const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    137       "Upgrade: WebSocket\r\n"
    138       "Connection: Upgrade\r\n"
    139       "Sec-WebSocket-Origin: http://example.com\r\n"
    140       "Sec-WebSocket-Location: ws://example.com/demo\r\n"
    141       "Sec-WebSocket-Protocol: sample\r\n"
    142       "\r\n"
    143       "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75";
    144   std::vector<std::string> response_lines;
    145   base::SplitStringDontTrim(kResponse, '\n', &response_lines);
    146 
    147   EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
    148   // too short
    149   EXPECT_EQ(-1, handshake->ReadServerHandshake(kResponse, 16));
    150   EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
    151 
    152   // only status line
    153   std::string response = response_lines[0];
    154   EXPECT_EQ(-1, handshake->ReadServerHandshake(
    155       response.data(), response.size()));
    156   EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
    157   // by upgrade header
    158   response += response_lines[1];
    159   EXPECT_EQ(-1, handshake->ReadServerHandshake(
    160       response.data(), response.size()));
    161   EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
    162   // by connection header
    163   response += response_lines[2];
    164   EXPECT_EQ(-1, handshake->ReadServerHandshake(
    165       response.data(), response.size()));
    166   EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
    167 
    168   response += response_lines[3];  // Sec-WebSocket-Origin
    169   response += response_lines[4];  // Sec-WebSocket-Location
    170   response += response_lines[5];  // Sec-WebSocket-Protocol
    171   EXPECT_EQ(-1, handshake->ReadServerHandshake(
    172       response.data(), response.size()));
    173   EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
    174 
    175   response += response_lines[6];  // \r\n
    176   EXPECT_EQ(-1, handshake->ReadServerHandshake(
    177       response.data(), response.size()));
    178   EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
    179 
    180   int handshake_length = sizeof(kResponse) - 1;  // -1 for terminating \0
    181   EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
    182       kResponse, handshake_length));  // -1 for terminating \0
    183   EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
    184 }
    185 
    186 TEST_F(WebSocketHandshakeTest, ServerSentData) {
    187   const std::string kExpectedClientHandshakeMessage =
    188       "GET /demo HTTP/1.1\r\n"
    189       "Upgrade: WebSocket\r\n"
    190       "Connection: Upgrade\r\n"
    191       "Host: example.com\r\n"
    192       "Origin: http://example.com\r\n"
    193       "Sec-WebSocket-Protocol: sample\r\n"
    194       "Sec-WebSocket-Key1: 388P O503D&ul7 {K%gX( %7  15\r\n"
    195       "Sec-WebSocket-Key2: 1 N ?|k UT0or 3o  4 I97N 5-S3O 31\r\n"
    196       "\r\n"
    197       "\x47\x30\x22\x2D\x5A\x3F\x47\x58";
    198   scoped_ptr<WebSocketHandshake> handshake(
    199       new WebSocketHandshake(GURL("ws://example.com/demo"),
    200                              "http://example.com",
    201                              "ws://example.com/demo",
    202                              "sample"));
    203   SetUpParameter(handshake.get(), 777007543U, 114997259U,
    204                  "388P O503D&ul7 {K%gX( %7  15",
    205                  "1 N ?|k UT0or 3o  4 I97N 5-S3O 31",
    206                  std::string("\x47\x30\x22\x2D\x5A\x3F\x47\x58", 8));
    207   EXPECT_EQ(WebSocketHandshake::MODE_INCOMPLETE, handshake->mode());
    208   ExpectHandshakeMessageEquals(
    209       kExpectedClientHandshakeMessage,
    210       handshake->CreateClientHandshakeMessage());
    211 
    212   const char kResponse[] = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    213       "Upgrade: WebSocket\r\n"
    214       "Connection: Upgrade\r\n"
    215       "Sec-WebSocket-Origin: http://example.com\r\n"
    216       "Sec-WebSocket-Location: ws://example.com/demo\r\n"
    217       "Sec-WebSocket-Protocol: sample\r\n"
    218       "\r\n"
    219       "\x30\x73\x74\x33\x52\x6C\x26\x71\x2D\x32\x5A\x55\x5E\x77\x65\x75"
    220       "\0Hello\xff";
    221 
    222   int handshake_length = strlen(kResponse);  // key3 doesn't contain \0.
    223   EXPECT_EQ(handshake_length, handshake->ReadServerHandshake(
    224       kResponse, sizeof(kResponse) - 1));  // -1 for terminating \0
    225   EXPECT_EQ(WebSocketHandshake::MODE_CONNECTED, handshake->mode());
    226 }
    227 
    228 TEST_F(WebSocketHandshakeTest, is_secure_false) {
    229   scoped_ptr<WebSocketHandshake> handshake(
    230       new WebSocketHandshake(GURL("ws://example.com/demo"),
    231                              "http://example.com",
    232                              "ws://example.com/demo",
    233                              "sample"));
    234   EXPECT_FALSE(handshake->is_secure());
    235 }
    236 
    237 TEST_F(WebSocketHandshakeTest, is_secure_true) {
    238   // wss:// is secure.
    239   scoped_ptr<WebSocketHandshake> handshake(
    240       new WebSocketHandshake(GURL("wss://example.com/demo"),
    241                              "http://example.com",
    242                              "wss://example.com/demo",
    243                              "sample"));
    244   EXPECT_TRUE(handshake->is_secure());
    245 }
    246 
    247 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_ResourceName) {
    248   scoped_ptr<WebSocketHandshake> handshake(
    249       new WebSocketHandshake(GURL("ws://example.com/Test?q=xxx&p=%20"),
    250                              "http://example.com",
    251                              "ws://example.com/demo",
    252                              "sample"));
    253   // Path and query should be preserved as-is.
    254   EXPECT_EQ("/Test?q=xxx&p=%20", GetResourceName(handshake.get()));
    255 }
    256 
    257 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_Host) {
    258   scoped_ptr<WebSocketHandshake> handshake(
    259       new WebSocketHandshake(GURL("ws://Example.Com/demo"),
    260                              "http://Example.Com",
    261                              "ws://Example.Com/demo",
    262                              "sample"));
    263   // Host should be lowercased
    264   EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
    265   EXPECT_EQ("http://example.com", GetOriginFieldValue(handshake.get()));
    266 }
    267 
    268 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort80) {
    269   scoped_ptr<WebSocketHandshake> handshake(
    270       new WebSocketHandshake(GURL("ws://example.com:80/demo"),
    271                              "http://example.com",
    272                              "ws://example.com/demo",
    273                              "sample"));
    274   // :80 should be trimmed as it's the default port for ws://.
    275   EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
    276 }
    277 
    278 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_TrimPort443) {
    279   scoped_ptr<WebSocketHandshake> handshake(
    280       new WebSocketHandshake(GURL("wss://example.com:443/demo"),
    281                              "http://example.com",
    282                              "wss://example.com/demo",
    283                              "sample"));
    284   // :443 should be trimmed as it's the default port for wss://.
    285   EXPECT_EQ("example.com", GetHostFieldValue(handshake.get()));
    286 }
    287 
    288 TEST_F(WebSocketHandshakeTest,
    289        CreateClientHandshakeMessage_NonDefaultPortForWs) {
    290   scoped_ptr<WebSocketHandshake> handshake(
    291       new WebSocketHandshake(GURL("ws://example.com:8080/demo"),
    292                              "http://example.com",
    293                              "wss://example.com/demo",
    294                              "sample"));
    295   // :8080 should be preserved as it's not the default port for ws://.
    296   EXPECT_EQ("example.com:8080", GetHostFieldValue(handshake.get()));
    297 }
    298 
    299 TEST_F(WebSocketHandshakeTest,
    300      CreateClientHandshakeMessage_NonDefaultPortForWss) {
    301   scoped_ptr<WebSocketHandshake> handshake(
    302       new WebSocketHandshake(GURL("wss://example.com:4443/demo"),
    303                              "http://example.com",
    304                              "wss://example.com/demo",
    305                              "sample"));
    306   // :4443 should be preserved as it's not the default port for wss://.
    307   EXPECT_EQ("example.com:4443", GetHostFieldValue(handshake.get()));
    308 }
    309 
    310 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WsBut443) {
    311   scoped_ptr<WebSocketHandshake> handshake(
    312       new WebSocketHandshake(GURL("ws://example.com:443/demo"),
    313                              "http://example.com",
    314                              "ws://example.com/demo",
    315                              "sample"));
    316   // :443 should be preserved as it's not the default port for ws://.
    317   EXPECT_EQ("example.com:443", GetHostFieldValue(handshake.get()));
    318 }
    319 
    320 TEST_F(WebSocketHandshakeTest, CreateClientHandshakeMessage_WssBut80) {
    321   scoped_ptr<WebSocketHandshake> handshake(
    322       new WebSocketHandshake(GURL("wss://example.com:80/demo"),
    323                              "http://example.com",
    324                              "wss://example.com/demo",
    325                              "sample"));
    326   // :80 should be preserved as it's not the default port for wss://.
    327   EXPECT_EQ("example.com:80", GetHostFieldValue(handshake.get()));
    328 }
    329 
    330 }  // namespace net
    331