1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <string> 6 #include <vector> 7 8 #include "base/callback.h" 9 #include "net/base/completion_callback.h" 10 #include "net/base/io_buffer.h" 11 #include "net/base/mock_host_resolver.h" 12 #include "net/base/test_completion_callback.h" 13 #include "net/socket/socket_test_util.h" 14 #include "net/url_request/url_request_test_util.h" 15 #include "net/websockets/websocket.h" 16 #include "testing/gtest/include/gtest/gtest.h" 17 #include "testing/gmock/include/gmock/gmock.h" 18 #include "testing/platform_test.h" 19 20 struct WebSocketEvent { 21 enum EventType { 22 EVENT_OPEN, EVENT_MESSAGE, EVENT_ERROR, EVENT_CLOSE, 23 }; 24 25 WebSocketEvent(EventType type, net::WebSocket* websocket, 26 const std::string& websocket_msg, bool websocket_flag) 27 : event_type(type), socket(websocket), msg(websocket_msg), 28 flag(websocket_flag) {} 29 30 EventType event_type; 31 net::WebSocket* socket; 32 std::string msg; 33 bool flag; 34 }; 35 36 class WebSocketEventRecorder : public net::WebSocketDelegate { 37 public: 38 explicit WebSocketEventRecorder(net::CompletionCallback* callback) 39 : onopen_(NULL), 40 onmessage_(NULL), 41 onerror_(NULL), 42 onclose_(NULL), 43 callback_(callback) {} 44 virtual ~WebSocketEventRecorder() { 45 delete onopen_; 46 delete onmessage_; 47 delete onerror_; 48 delete onclose_; 49 } 50 51 void SetOnOpen(Callback1<WebSocketEvent*>::Type* callback) { 52 onopen_ = callback; 53 } 54 void SetOnMessage(Callback1<WebSocketEvent*>::Type* callback) { 55 onmessage_ = callback; 56 } 57 void SetOnClose(Callback1<WebSocketEvent*>::Type* callback) { 58 onclose_ = callback; 59 } 60 61 virtual void OnOpen(net::WebSocket* socket) { 62 events_.push_back( 63 WebSocketEvent(WebSocketEvent::EVENT_OPEN, socket, 64 std::string(), false)); 65 if (onopen_) 66 onopen_->Run(&events_.back()); 67 } 68 69 virtual void OnMessage(net::WebSocket* socket, const std::string& msg) { 70 events_.push_back( 71 WebSocketEvent(WebSocketEvent::EVENT_MESSAGE, socket, msg, false)); 72 if (onmessage_) 73 onmessage_->Run(&events_.back()); 74 } 75 virtual void OnError(net::WebSocket* socket) { 76 events_.push_back( 77 WebSocketEvent(WebSocketEvent::EVENT_ERROR, socket, 78 std::string(), false)); 79 if (onerror_) 80 onerror_->Run(&events_.back()); 81 } 82 virtual void OnClose(net::WebSocket* socket, bool was_clean) { 83 events_.push_back( 84 WebSocketEvent(WebSocketEvent::EVENT_CLOSE, socket, 85 std::string(), was_clean)); 86 if (onclose_) 87 onclose_->Run(&events_.back()); 88 if (callback_) 89 callback_->Run(net::OK); 90 } 91 92 void DoClose(WebSocketEvent* event) { 93 event->socket->Close(); 94 } 95 96 const std::vector<WebSocketEvent>& GetSeenEvents() const { 97 return events_; 98 } 99 100 private: 101 std::vector<WebSocketEvent> events_; 102 Callback1<WebSocketEvent*>::Type* onopen_; 103 Callback1<WebSocketEvent*>::Type* onmessage_; 104 Callback1<WebSocketEvent*>::Type* onerror_; 105 Callback1<WebSocketEvent*>::Type* onclose_; 106 net::CompletionCallback* callback_; 107 108 DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder); 109 }; 110 111 namespace net { 112 113 class WebSocketTest : public PlatformTest { 114 protected: 115 void InitReadBuf(WebSocket* websocket) { 116 // Set up |current_read_buf_|. 117 websocket->current_read_buf_ = new GrowableIOBuffer(); 118 } 119 void SetReadConsumed(WebSocket* websocket, int consumed) { 120 websocket->read_consumed_len_ = consumed; 121 } 122 void AddToReadBuf(WebSocket* websocket, const char* data, int len) { 123 websocket->AddToReadBuffer(data, len); 124 } 125 126 void TestProcessFrameData(WebSocket* websocket, 127 const char* expected_remaining_data, 128 int expected_remaining_len) { 129 websocket->ProcessFrameData(); 130 131 const char* actual_remaining_data = 132 websocket->current_read_buf_->StartOfBuffer() 133 + websocket->read_consumed_len_; 134 int actual_remaining_len = 135 websocket->current_read_buf_->offset() - websocket->read_consumed_len_; 136 137 EXPECT_EQ(expected_remaining_len, actual_remaining_len); 138 EXPECT_TRUE(!memcmp(expected_remaining_data, actual_remaining_data, 139 expected_remaining_len)); 140 } 141 }; 142 143 TEST_F(WebSocketTest, Connect) { 144 MockClientSocketFactory mock_socket_factory; 145 MockRead data_reads[] = { 146 MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" 147 "Upgrade: WebSocket\r\n" 148 "Connection: Upgrade\r\n" 149 "WebSocket-Origin: http://example.com\r\n" 150 "WebSocket-Location: ws://example.com/demo\r\n" 151 "WebSocket-Protocol: sample\r\n" 152 "\r\n"), 153 // Server doesn't close the connection after handshake. 154 MockRead(true, ERR_IO_PENDING), 155 }; 156 MockWrite data_writes[] = { 157 MockWrite("GET /demo HTTP/1.1\r\n" 158 "Upgrade: WebSocket\r\n" 159 "Connection: Upgrade\r\n" 160 "Host: example.com\r\n" 161 "Origin: http://example.com\r\n" 162 "WebSocket-Protocol: sample\r\n" 163 "\r\n"), 164 }; 165 StaticSocketDataProvider data(data_reads, arraysize(data_reads), 166 data_writes, arraysize(data_writes)); 167 mock_socket_factory.AddSocketDataProvider(&data); 168 MockHostResolver host_resolver; 169 170 WebSocket::Request* request( 171 new WebSocket::Request(GURL("ws://example.com/demo"), 172 "sample", 173 "http://example.com", 174 "ws://example.com/demo", 175 WebSocket::DRAFT75, 176 new TestURLRequestContext())); 177 request->SetHostResolver(&host_resolver); 178 request->SetClientSocketFactory(&mock_socket_factory); 179 180 TestCompletionCallback callback; 181 182 scoped_ptr<WebSocketEventRecorder> delegate( 183 new WebSocketEventRecorder(&callback)); 184 delegate->SetOnOpen(NewCallback(delegate.get(), 185 &WebSocketEventRecorder::DoClose)); 186 187 scoped_refptr<WebSocket> websocket( 188 new WebSocket(request, delegate.get())); 189 190 EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state()); 191 websocket->Connect(); 192 193 callback.WaitForResult(); 194 195 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents(); 196 EXPECT_EQ(2U, events.size()); 197 198 EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type); 199 EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[1].event_type); 200 } 201 202 TEST_F(WebSocketTest, ServerSentData) { 203 MockClientSocketFactory mock_socket_factory; 204 static const char kMessage[] = "Hello"; 205 static const char kFrame[] = "\x00Hello\xff"; 206 static const int kFrameLen = sizeof(kFrame) - 1; 207 MockRead data_reads[] = { 208 MockRead("HTTP/1.1 101 Web Socket Protocol Handshake\r\n" 209 "Upgrade: WebSocket\r\n" 210 "Connection: Upgrade\r\n" 211 "WebSocket-Origin: http://example.com\r\n" 212 "WebSocket-Location: ws://example.com/demo\r\n" 213 "WebSocket-Protocol: sample\r\n" 214 "\r\n"), 215 MockRead(true, kFrame, kFrameLen), 216 // Server doesn't close the connection after handshake. 217 MockRead(true, ERR_IO_PENDING), 218 }; 219 MockWrite data_writes[] = { 220 MockWrite("GET /demo HTTP/1.1\r\n" 221 "Upgrade: WebSocket\r\n" 222 "Connection: Upgrade\r\n" 223 "Host: example.com\r\n" 224 "Origin: http://example.com\r\n" 225 "WebSocket-Protocol: sample\r\n" 226 "\r\n"), 227 }; 228 StaticSocketDataProvider data(data_reads, arraysize(data_reads), 229 data_writes, arraysize(data_writes)); 230 mock_socket_factory.AddSocketDataProvider(&data); 231 MockHostResolver host_resolver; 232 233 WebSocket::Request* request( 234 new WebSocket::Request(GURL("ws://example.com/demo"), 235 "sample", 236 "http://example.com", 237 "ws://example.com/demo", 238 WebSocket::DRAFT75, 239 new TestURLRequestContext())); 240 request->SetHostResolver(&host_resolver); 241 request->SetClientSocketFactory(&mock_socket_factory); 242 243 TestCompletionCallback callback; 244 245 scoped_ptr<WebSocketEventRecorder> delegate( 246 new WebSocketEventRecorder(&callback)); 247 delegate->SetOnMessage(NewCallback(delegate.get(), 248 &WebSocketEventRecorder::DoClose)); 249 250 scoped_refptr<WebSocket> websocket( 251 new WebSocket(request, delegate.get())); 252 253 EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state()); 254 websocket->Connect(); 255 256 callback.WaitForResult(); 257 258 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents(); 259 EXPECT_EQ(3U, events.size()); 260 261 EXPECT_EQ(WebSocketEvent::EVENT_OPEN, events[0].event_type); 262 EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[1].event_type); 263 EXPECT_EQ(kMessage, events[1].msg); 264 EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type); 265 } 266 267 TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) { 268 WebSocket::Request* request( 269 new WebSocket::Request(GURL("ws://example.com/demo"), 270 "sample", 271 "http://example.com", 272 "ws://example.com/demo", 273 WebSocket::DRAFT75, 274 new TestURLRequestContext())); 275 TestCompletionCallback callback; 276 scoped_ptr<WebSocketEventRecorder> delegate( 277 new WebSocketEventRecorder(&callback)); 278 279 scoped_refptr<WebSocket> websocket( 280 new WebSocket(request, delegate.get())); 281 282 // Frame data: skip length 1 ('x'), and try to skip length 129 283 // (1 * 128 + 1) bytes after \x81\x01, but buffer is too short to skip. 284 static const char kTestLengthFrame[] = 285 "\x80\x01x\x80\x81\x01\x01\x00unexpected data\xFF"; 286 const int kTestLengthFrameLength = sizeof(kTestLengthFrame) - 1; 287 InitReadBuf(websocket.get()); 288 AddToReadBuf(websocket.get(), kTestLengthFrame, kTestLengthFrameLength); 289 SetReadConsumed(websocket.get(), 0); 290 291 static const char kExpectedRemainingFrame[] = 292 "\x80\x81\x01\x01\x00unexpected data\xFF"; 293 const int kExpectedRemainingLength = sizeof(kExpectedRemainingFrame) - 1; 294 TestProcessFrameData(websocket.get(), 295 kExpectedRemainingFrame, kExpectedRemainingLength); 296 // No onmessage event expected. 297 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents(); 298 EXPECT_EQ(1U, events.size()); 299 300 EXPECT_EQ(WebSocketEvent::EVENT_ERROR, events[0].event_type); 301 302 websocket->DetachDelegate(); 303 } 304 305 TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) { 306 WebSocket::Request* request( 307 new WebSocket::Request(GURL("ws://example.com/demo"), 308 "sample", 309 "http://example.com", 310 "ws://example.com/demo", 311 WebSocket::DRAFT75, 312 new TestURLRequestContext())); 313 TestCompletionCallback callback; 314 scoped_ptr<WebSocketEventRecorder> delegate( 315 new WebSocketEventRecorder(&callback)); 316 317 scoped_refptr<WebSocket> websocket( 318 new WebSocket(request, delegate.get())); 319 320 static const char kTestUnterminatedFrame[] = 321 "\x00unterminated frame"; 322 const int kTestUnterminatedFrameLength = sizeof(kTestUnterminatedFrame) - 1; 323 InitReadBuf(websocket.get()); 324 AddToReadBuf(websocket.get(), kTestUnterminatedFrame, 325 kTestUnterminatedFrameLength); 326 SetReadConsumed(websocket.get(), 0); 327 TestProcessFrameData(websocket.get(), 328 kTestUnterminatedFrame, kTestUnterminatedFrameLength); 329 { 330 // No onmessage event expected. 331 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents(); 332 EXPECT_EQ(0U, events.size()); 333 } 334 335 static const char kTestTerminateFrame[] = " is terminated in next read\xff"; 336 const int kTestTerminateFrameLength = sizeof(kTestTerminateFrame) - 1; 337 AddToReadBuf(websocket.get(), kTestTerminateFrame, 338 kTestTerminateFrameLength); 339 TestProcessFrameData(websocket.get(), "", 0); 340 341 static const char kExpectedMsg[] = 342 "unterminated frame is terminated in next read"; 343 { 344 const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents(); 345 EXPECT_EQ(1U, events.size()); 346 347 EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[0].event_type); 348 EXPECT_EQ(kExpectedMsg, events[0].msg); 349 } 350 351 websocket->DetachDelegate(); 352 } 353 354 } // namespace net 355