1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <string> 6 #include <vector> 7 8 #include "base/bind.h" 9 #include "base/compiler_specific.h" 10 #include "base/location.h" 11 #include "base/memory/ref_counted.h" 12 #include "base/memory/scoped_ptr.h" 13 #include "base/message_loop/message_loop.h" 14 #include "base/message_loop/message_loop_proxy.h" 15 #include "base/run_loop.h" 16 #include "base/single_thread_task_runner.h" 17 #include "base/threading/thread.h" 18 #include "base/time/time.h" 19 #include "chrome/test/chromedriver/net/test_http_server.h" 20 #include "chrome/test/chromedriver/net/websocket.h" 21 #include "net/url_request/url_request_test_util.h" 22 #include "testing/gtest/include/gtest/gtest.h" 23 #include "url/gurl.h" 24 25 namespace { 26 27 void OnConnectFinished(base::RunLoop* run_loop, int* save_error, int error) { 28 *save_error = error; 29 run_loop->Quit(); 30 } 31 32 void RunPending(base::MessageLoop* loop) { 33 base::RunLoop run_loop; 34 loop->PostTask(FROM_HERE, run_loop.QuitClosure()); 35 run_loop.Run(); 36 } 37 38 class Listener : public WebSocketListener { 39 public: 40 explicit Listener(const std::vector<std::string>& messages) 41 : messages_(messages) {} 42 43 virtual ~Listener() { 44 EXPECT_TRUE(messages_.empty()); 45 } 46 47 virtual void OnMessageReceived(const std::string& message) OVERRIDE { 48 ASSERT_TRUE(messages_.size()); 49 EXPECT_EQ(messages_[0], message); 50 messages_.erase(messages_.begin()); 51 if (messages_.empty()) 52 base::MessageLoop::current()->Quit(); 53 } 54 55 virtual void OnClose() OVERRIDE { 56 EXPECT_TRUE(false); 57 } 58 59 private: 60 std::vector<std::string> messages_; 61 }; 62 63 class CloseListener : public WebSocketListener { 64 public: 65 explicit CloseListener(base::RunLoop* run_loop) 66 : run_loop_(run_loop) {} 67 68 virtual ~CloseListener() { 69 EXPECT_FALSE(run_loop_); 70 } 71 72 virtual void OnMessageReceived(const std::string& message) OVERRIDE {} 73 74 virtual void OnClose() OVERRIDE { 75 EXPECT_TRUE(run_loop_); 76 if (run_loop_) 77 run_loop_->Quit(); 78 run_loop_ = NULL; 79 } 80 81 private: 82 base::RunLoop* run_loop_; 83 }; 84 85 class WebSocketTest : public testing::Test { 86 public: 87 WebSocketTest() {} 88 virtual ~WebSocketTest() {} 89 90 virtual void SetUp() OVERRIDE { 91 ASSERT_TRUE(server_.Start()); 92 } 93 94 virtual void TearDown() OVERRIDE { 95 server_.Stop(); 96 } 97 98 protected: 99 scoped_ptr<WebSocket> CreateWebSocket(const GURL& url, 100 WebSocketListener* listener) { 101 int error; 102 scoped_ptr<WebSocket> sock(new WebSocket(url, listener)); 103 base::RunLoop run_loop; 104 sock->Connect(base::Bind(&OnConnectFinished, &run_loop, &error)); 105 loop_.PostDelayedTask( 106 FROM_HERE, run_loop.QuitClosure(), 107 base::TimeDelta::FromSeconds(10)); 108 run_loop.Run(); 109 if (error == net::OK) 110 return sock.Pass(); 111 return scoped_ptr<WebSocket>(); 112 } 113 114 scoped_ptr<WebSocket> CreateConnectedWebSocket(WebSocketListener* listener) { 115 return CreateWebSocket(server_.web_socket_url(), listener); 116 } 117 118 void SendReceive(const std::vector<std::string>& messages) { 119 Listener listener(messages); 120 scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener)); 121 ASSERT_TRUE(sock); 122 for (size_t i = 0; i < messages.size(); ++i) { 123 ASSERT_TRUE(sock->Send(messages[i])); 124 } 125 base::RunLoop run_loop; 126 loop_.PostDelayedTask( 127 FROM_HERE, run_loop.QuitClosure(), 128 base::TimeDelta::FromSeconds(10)); 129 run_loop.Run(); 130 } 131 132 base::MessageLoopForIO loop_; 133 TestHttpServer server_; 134 }; 135 136 } // namespace 137 138 TEST_F(WebSocketTest, CreateDestroy) { 139 CloseListener listener(NULL); 140 WebSocket sock(GURL("ws://127.0.0.1:2222"), &listener); 141 } 142 143 TEST_F(WebSocketTest, Connect) { 144 CloseListener listener(NULL); 145 ASSERT_TRUE(CreateWebSocket(server_.web_socket_url(), &listener)); 146 RunPending(&loop_); 147 ASSERT_TRUE(server_.WaitForConnectionsToClose()); 148 } 149 150 TEST_F(WebSocketTest, ConnectNoServer) { 151 CloseListener listener(NULL); 152 ASSERT_FALSE(CreateWebSocket(GURL("ws://127.0.0.1:33333"), NULL)); 153 } 154 155 TEST_F(WebSocketTest, Connect404) { 156 server_.SetRequestAction(TestHttpServer::kNotFound); 157 CloseListener listener(NULL); 158 ASSERT_FALSE(CreateWebSocket(server_.web_socket_url(), NULL)); 159 RunPending(&loop_); 160 ASSERT_TRUE(server_.WaitForConnectionsToClose()); 161 } 162 163 TEST_F(WebSocketTest, ConnectServerClosesConn) { 164 server_.SetRequestAction(TestHttpServer::kClose); 165 CloseListener listener(NULL); 166 ASSERT_FALSE(CreateWebSocket(server_.web_socket_url(), &listener)); 167 } 168 169 TEST_F(WebSocketTest, CloseOnReceive) { 170 server_.SetMessageAction(TestHttpServer::kCloseOnMessage); 171 base::RunLoop run_loop; 172 CloseListener listener(&run_loop); 173 scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener)); 174 ASSERT_TRUE(sock); 175 ASSERT_TRUE(sock->Send("hi")); 176 loop_.PostDelayedTask( 177 FROM_HERE, run_loop.QuitClosure(), 178 base::TimeDelta::FromSeconds(10)); 179 run_loop.Run(); 180 } 181 182 TEST_F(WebSocketTest, CloseOnSend) { 183 base::RunLoop run_loop; 184 CloseListener listener(&run_loop); 185 scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener)); 186 ASSERT_TRUE(sock); 187 server_.Stop(); 188 189 sock->Send("hi"); 190 loop_.PostDelayedTask( 191 FROM_HERE, run_loop.QuitClosure(), 192 base::TimeDelta::FromSeconds(10)); 193 run_loop.Run(); 194 ASSERT_FALSE(sock->Send("hi")); 195 } 196 197 TEST_F(WebSocketTest, SendReceive) { 198 std::vector<std::string> messages; 199 messages.push_back("hello"); 200 SendReceive(messages); 201 } 202 203 TEST_F(WebSocketTest, SendReceiveLarge) { 204 std::vector<std::string> messages; 205 messages.push_back(std::string(10 << 20, 'a')); 206 SendReceive(messages); 207 } 208 209 TEST_F(WebSocketTest, SendReceiveMultiple) { 210 std::vector<std::string> messages; 211 messages.push_back("1"); 212 messages.push_back("2"); 213 messages.push_back("3"); 214 SendReceive(messages); 215 } 216