Home | History | Annotate | Download | only in websockets
      1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <string>
      6 #include <vector>
      7 
      8 #include "base/callback.h"
      9 #include "net/base/completion_callback.h"
     10 #include "net/base/io_buffer.h"
     11 #include "net/base/mock_host_resolver.h"
     12 #include "net/base/test_completion_callback.h"
     13 #include "net/socket/socket_test_util.h"
     14 #include "net/url_request/url_request_test_util.h"
     15 #include "net/websockets/websocket.h"
     16 #include "testing/gtest/include/gtest/gtest.h"
     17 #include "testing/gmock/include/gmock/gmock.h"
     18 #include "testing/platform_test.h"
     19 
     20 struct WebSocketEvent {
     21   enum EventType {
     22     EVENT_OPEN, EVENT_MESSAGE, EVENT_ERROR, EVENT_CLOSE,
     23   };
     24 
     25   WebSocketEvent(EventType type, net::WebSocket* websocket,
     26                  const std::string& websocket_msg, bool websocket_flag)
     27       : event_type(type), socket(websocket), msg(websocket_msg),
     28         flag(websocket_flag) {}
     29 
     30   EventType event_type;
     31   net::WebSocket* socket;
     32   std::string msg;
     33   bool flag;
     34 };
     35 
     36 class WebSocketEventRecorder : public net::WebSocketDelegate {
     37  public:
     38   explicit WebSocketEventRecorder(net::CompletionCallback* callback)
     39       : onopen_(NULL),
     40         onmessage_(NULL),
     41         onerror_(NULL),
     42         onclose_(NULL),
     43         callback_(callback) {}
     44   virtual ~WebSocketEventRecorder() {
     45     delete onopen_;
     46     delete onmessage_;
     47     delete onerror_;
     48     delete onclose_;
     49   }
     50 
     51   void SetOnOpen(Callback1<WebSocketEvent*>::Type* callback) {
     52     onopen_ = callback;
     53   }
     54   void SetOnMessage(Callback1<WebSocketEvent*>::Type* callback) {
     55     onmessage_ = callback;
     56   }
     57   void SetOnClose(Callback1<WebSocketEvent*>::Type* callback) {
     58     onclose_ = callback;
     59   }
     60 
     61   virtual void OnOpen(net::WebSocket* socket) {
     62     events_.push_back(
     63         WebSocketEvent(WebSocketEvent::EVENT_OPEN, socket,
     64                        std::string(), false));
     65     if (onopen_)
     66       onopen_->Run(&events_.back());
     67   }
     68 
     69   virtual void OnMessage(net::WebSocket* socket, const std::string& msg) {
     70     events_.push_back(
     71         WebSocketEvent(WebSocketEvent::EVENT_MESSAGE, socket, msg, false));
     72     if (onmessage_)
     73       onmessage_->Run(&events_.back());
     74   }
     75   virtual void OnError(net::WebSocket* socket) {
     76     events_.push_back(
     77         WebSocketEvent(WebSocketEvent::EVENT_ERROR, socket,
     78                        std::string(), false));
     79     if (onerror_)
     80       onerror_->Run(&events_.back());
     81   }
     82   virtual void OnClose(net::WebSocket* socket, bool was_clean) {
     83     events_.push_back(
     84         WebSocketEvent(WebSocketEvent::EVENT_CLOSE, socket,
     85                        std::string(), was_clean));
     86     if (onclose_)
     87       onclose_->Run(&events_.back());
     88     if (callback_)
     89       callback_->Run(net::OK);
     90   }
     91 
     92   void DoClose(WebSocketEvent* event) {
     93     event->socket->Close();
     94   }
     95 
     96   const std::vector<WebSocketEvent>& GetSeenEvents() const {
     97     return events_;
     98   }
     99 
    100  private:
    101   std::vector<WebSocketEvent> events_;
    102   Callback1<WebSocketEvent*>::Type* onopen_;
    103   Callback1<WebSocketEvent*>::Type* onmessage_;
    104   Callback1<WebSocketEvent*>::Type* onerror_;
    105   Callback1<WebSocketEvent*>::Type* onclose_;
    106   net::CompletionCallback* callback_;
    107 
    108   DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder);
    109 };
    110 
    111 namespace net {
    112 
    113 class WebSocketTest : public PlatformTest {
    114  protected:
    115   void InitReadBuf(WebSocket* websocket) {
    116     // Set up |current_read_buf_|.
    117     websocket->current_read_buf_ = new GrowableIOBuffer();
    118   }
    119   void SetReadConsumed(WebSocket* websocket, int consumed) {
    120     websocket->read_consumed_len_ = consumed;
    121   }
    122   void AddToReadBuf(WebSocket* websocket, const char* data, int len) {
    123     websocket->AddToReadBuffer(data, len);
    124   }
    125 
    126   void TestProcessFrameData(WebSocket* websocket,
    127                             const char* expected_remaining_data,
    128                             int expected_remaining_len) {
    129     websocket->ProcessFrameData();
    130 
    131     const char* actual_remaining_data =
    132         websocket->current_read_buf_->StartOfBuffer()
    133         + websocket->read_consumed_len_;
    134     int actual_remaining_len =
    135         websocket->current_read_buf_->offset() - websocket->read_consumed_len_;
    136 
    137     EXPECT_EQ(expected_remaining_len, actual_remaining_len);
    138     EXPECT_TRUE(!memcmp(expected_remaining_data, actual_remaining_data,
    139                         expected_remaining_len));
    140   }
    141 };
    142 
    143 TEST_F(WebSocketTest, Connect) {
    144   MockClientSocketFactory mock_socket_factory;
    145   MockRead data_reads[] = {
    146     MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
    147              "Upgrade: WebSocket\r\n"
    148              "Connection: Upgrade\r\n"
    149              "WebSocket-Origin: http://example.com\r\n"
    150              "WebSocket-Location: ws://example.com/demo\r\n"
    151              "WebSocket-Protocol: sample\r\n"
    152              "\r\n"),
    153     // Server doesn't close the connection after handshake.
    154     MockRead(true, ERR_IO_PENDING),
    155   };
    156   MockWrite data_writes[] = {
    157     MockWrite("GET /demo HTTP/1.1\r\n"
    158               "Upgrade: WebSocket\r\n"
    159               "Connection: Upgrade\r\n"
    160               "Host: example.com\r\n"
    161               "Origin: http://example.com\r\n"
    162               "WebSocket-Protocol: sample\r\n"
    163               "\r\n"),
    164   };
    165   StaticSocketDataProvider data(data_reads, arraysize(data_reads),
    166                                 data_writes, arraysize(data_writes));
    167   mock_socket_factory.AddSocketDataProvider(&data);
    168   MockHostResolver host_resolver;
    169 
    170   WebSocket::Request* request(
    171       new WebSocket::Request(GURL("ws://example.com/demo"),
    172                              "sample",
    173                              "http://example.com",
    174                              "ws://example.com/demo",
    175                              WebSocket::DRAFT75,
    176                              new TestURLRequestContext()));
    177   request->SetHostResolver(&host_resolver);
    178   request->SetClientSocketFactory(&mock_socket_factory);
    179 
    180   TestCompletionCallback callback;
    181 
    182   scoped_ptr<WebSocketEventRecorder> delegate(
    183       new WebSocketEventRecorder(&callback));
    184   delegate->SetOnOpen(NewCallback(delegate.get(),
    185                                   &WebSocketEventRecorder::DoClose));
    186 
    187   scoped_refptr<WebSocket> websocket(
    188       new WebSocket(request, delegate.get()));
    189 
    190   EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
    191   websocket->Connect();
    192 
    193   callback.WaitForResult();
    194 
    195   const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
    196   EXPECT_EQ(2U, events.size());
    197 
    198   EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
    199   EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[1].event_type);
    200 }
    201 
    202 TEST_F(WebSocketTest, ServerSentData) {
    203   MockClientSocketFactory mock_socket_factory;
    204   static const char kMessage[] = "Hello";
    205   static const char kFrame[] = "\x00Hello\xff";
    206   static const int kFrameLen = sizeof(kFrame) - 1;
    207   MockRead data_reads[] = {
    208     MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
    209              "Upgrade: WebSocket\r\n"
    210              "Connection: Upgrade\r\n"
    211              "WebSocket-Origin: http://example.com\r\n"
    212              "WebSocket-Location: ws://example.com/demo\r\n"
    213              "WebSocket-Protocol: sample\r\n"
    214              "\r\n"),
    215     MockRead(true, kFrame, kFrameLen),
    216     // Server doesn't close the connection after handshake.
    217     MockRead(true, ERR_IO_PENDING),
    218   };
    219   MockWrite data_writes[] = {
    220     MockWrite("GET /demo HTTP/1.1\r\n"
    221               "Upgrade: WebSocket\r\n"
    222               "Connection: Upgrade\r\n"
    223               "Host: example.com\r\n"
    224               "Origin: http://example.com\r\n"
    225               "WebSocket-Protocol: sample\r\n"
    226               "\r\n"),
    227   };
    228   StaticSocketDataProvider data(data_reads, arraysize(data_reads),
    229                                 data_writes, arraysize(data_writes));
    230   mock_socket_factory.AddSocketDataProvider(&data);
    231   MockHostResolver host_resolver;
    232 
    233   WebSocket::Request* request(
    234       new WebSocket::Request(GURL("ws://example.com/demo"),
    235                              "sample",
    236                              "http://example.com",
    237                              "ws://example.com/demo",
    238                              WebSocket::DRAFT75,
    239                              new TestURLRequestContext()));
    240   request->SetHostResolver(&host_resolver);
    241   request->SetClientSocketFactory(&mock_socket_factory);
    242 
    243   TestCompletionCallback callback;
    244 
    245   scoped_ptr<WebSocketEventRecorder> delegate(
    246       new WebSocketEventRecorder(&callback));
    247   delegate->SetOnMessage(NewCallback(delegate.get(),
    248                                      &WebSocketEventRecorder::DoClose));
    249 
    250   scoped_refptr<WebSocket> websocket(
    251       new WebSocket(request, delegate.get()));
    252 
    253   EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
    254   websocket->Connect();
    255 
    256   callback.WaitForResult();
    257 
    258   const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
    259   EXPECT_EQ(3U, events.size());
    260 
    261   EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type);
    262   EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[1].event_type);
    263   EXPECT_EQ(kMessage, events[1].msg);
    264   EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type);
    265 }
    266 
    267 TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) {
    268   WebSocket::Request* request(
    269       new WebSocket::Request(GURL("ws://example.com/demo"),
    270                              "sample",
    271                              "http://example.com",
    272                              "ws://example.com/demo",
    273                              WebSocket::DRAFT75,
    274                              new TestURLRequestContext()));
    275   TestCompletionCallback callback;
    276   scoped_ptr<WebSocketEventRecorder> delegate(
    277       new WebSocketEventRecorder(&callback));
    278 
    279   scoped_refptr<WebSocket> websocket(
    280       new WebSocket(request, delegate.get()));
    281 
    282   // Frame data: skip length 1 ('x'), and try to skip length 129
    283   // (1 * 128 + 1) bytes after \x81\x01, but buffer is too short to skip.
    284   static const char kTestLengthFrame[] =
    285       "\x80\x01x\x80\x81\x01\x01\x00unexpected data\xFF";
    286   const int kTestLengthFrameLength = sizeof(kTestLengthFrame) - 1;
    287   InitReadBuf(websocket.get());
    288   AddToReadBuf(websocket.get(), kTestLengthFrame, kTestLengthFrameLength);
    289   SetReadConsumed(websocket.get(), 0);
    290 
    291   static const char kExpectedRemainingFrame[] =
    292       "\x80\x81\x01\x01\x00unexpected data\xFF";
    293   const int kExpectedRemainingLength = sizeof(kExpectedRemainingFrame) - 1;
    294   TestProcessFrameData(websocket.get(),
    295                        kExpectedRemainingFrame, kExpectedRemainingLength);
    296   // No onmessage event expected.
    297   const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
    298   EXPECT_EQ(1U, events.size());
    299 
    300   EXPECT_EQ(WebSocketEvent::EVENT_ERROR, events[0].event_type);
    301 
    302   websocket->DetachDelegate();
    303 }
    304 
    305 TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) {
    306   WebSocket::Request* request(
    307       new WebSocket::Request(GURL("ws://example.com/demo"),
    308                              "sample",
    309                              "http://example.com",
    310                              "ws://example.com/demo",
    311                              WebSocket::DRAFT75,
    312                              new TestURLRequestContext()));
    313   TestCompletionCallback callback;
    314   scoped_ptr<WebSocketEventRecorder> delegate(
    315       new WebSocketEventRecorder(&callback));
    316 
    317   scoped_refptr<WebSocket> websocket(
    318       new WebSocket(request, delegate.get()));
    319 
    320   static const char kTestUnterminatedFrame[] =
    321       "\x00unterminated frame";
    322   const int kTestUnterminatedFrameLength = sizeof(kTestUnterminatedFrame) - 1;
    323   InitReadBuf(websocket.get());
    324   AddToReadBuf(websocket.get(), kTestUnterminatedFrame,
    325                kTestUnterminatedFrameLength);
    326   SetReadConsumed(websocket.get(), 0);
    327   TestProcessFrameData(websocket.get(),
    328                        kTestUnterminatedFrame, kTestUnterminatedFrameLength);
    329   {
    330     // No onmessage event expected.
    331     const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
    332     EXPECT_EQ(0U, events.size());
    333   }
    334 
    335   static const char kTestTerminateFrame[] = " is terminated in next read\xff";
    336   const int kTestTerminateFrameLength = sizeof(kTestTerminateFrame) - 1;
    337   AddToReadBuf(websocket.get(), kTestTerminateFrame,
    338                kTestTerminateFrameLength);
    339   TestProcessFrameData(websocket.get(), "", 0);
    340 
    341   static const char kExpectedMsg[] =
    342       "unterminated frame is terminated in next read";
    343   {
    344     const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
    345     EXPECT_EQ(1U, events.size());
    346 
    347     EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[0].event_type);
    348     EXPECT_EQ(kExpectedMsg, events[0].msg);
    349   }
    350 
    351   websocket->DetachDelegate();
    352 }
    353 
    354 }  // namespace net
    355