1 // Copyright (c) 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/test/chromedriver/net/test_http_server.h" 6 7 #include "base/bind.h" 8 #include "base/location.h" 9 #include "base/message_loop/message_loop.h" 10 #include "base/message_loop/message_loop_proxy.h" 11 #include "base/strings/stringprintf.h" 12 #include "base/time/time.h" 13 #include "net/base/ip_endpoint.h" 14 #include "net/base/net_errors.h" 15 #include "net/server/http_server_request_info.h" 16 #include "net/socket/tcp_listen_socket.h" 17 #include "testing/gtest/include/gtest/gtest.h" 18 19 TestHttpServer::TestHttpServer() 20 : thread_("ServerThread"), 21 all_closed_event_(false, true), 22 request_action_(kAccept), 23 message_action_(kEchoMessage) { 24 } 25 26 TestHttpServer::~TestHttpServer() { 27 } 28 29 bool TestHttpServer::Start() { 30 base::Thread::Options options(base::MessageLoop::TYPE_IO, 0); 31 bool thread_started = thread_.StartWithOptions(options); 32 EXPECT_TRUE(thread_started); 33 if (!thread_started) 34 return false; 35 bool success; 36 base::WaitableEvent event(false, false); 37 thread_.message_loop_proxy()->PostTask( 38 FROM_HERE, 39 base::Bind(&TestHttpServer::StartOnServerThread, 40 base::Unretained(this), &success, &event)); 41 event.Wait(); 42 return success; 43 } 44 45 void TestHttpServer::Stop() { 46 if (!thread_.IsRunning()) 47 return; 48 base::WaitableEvent event(false, false); 49 thread_.message_loop_proxy()->PostTask( 50 FROM_HERE, 51 base::Bind(&TestHttpServer::StopOnServerThread, 52 base::Unretained(this), &event)); 53 event.Wait(); 54 thread_.Stop(); 55 } 56 57 bool TestHttpServer::WaitForConnectionsToClose() { 58 return all_closed_event_.TimedWait(base::TimeDelta::FromSeconds(10)); 59 } 60 61 void TestHttpServer::SetRequestAction(WebSocketRequestAction action) { 62 base::AutoLock lock(action_lock_); 63 request_action_ = action; 64 } 65 66 void TestHttpServer::SetMessageAction(WebSocketMessageAction action) { 67 base::AutoLock lock(action_lock_); 68 message_action_ = action; 69 } 70 71 GURL TestHttpServer::web_socket_url() const { 72 base::AutoLock lock(url_lock_); 73 return web_socket_url_; 74 } 75 76 void TestHttpServer::OnWebSocketRequest( 77 int connection_id, 78 const net::HttpServerRequestInfo& info) { 79 WebSocketRequestAction action; 80 { 81 base::AutoLock lock(action_lock_); 82 action = request_action_; 83 } 84 connections_.insert(connection_id); 85 all_closed_event_.Reset(); 86 87 switch (action) { 88 case kAccept: 89 server_->AcceptWebSocket(connection_id, info); 90 break; 91 case kNotFound: 92 server_->Send404(connection_id); 93 break; 94 case kClose: 95 // net::HttpServer doesn't allow us to close connection during callback. 96 base::MessageLoop::current()->PostTask( 97 FROM_HERE, 98 base::Bind(&net::HttpServer::Close, server_, connection_id)); 99 break; 100 } 101 } 102 103 void TestHttpServer::OnWebSocketMessage(int connection_id, 104 const std::string& data) { 105 WebSocketMessageAction action; 106 { 107 base::AutoLock lock(action_lock_); 108 action = message_action_; 109 } 110 switch (action) { 111 case kEchoMessage: 112 server_->SendOverWebSocket(connection_id, data); 113 break; 114 case kCloseOnMessage: 115 // net::HttpServer doesn't allow us to close connection during callback. 116 base::MessageLoop::current()->PostTask( 117 FROM_HERE, 118 base::Bind(&net::HttpServer::Close, server_, connection_id)); 119 break; 120 } 121 } 122 123 void TestHttpServer::OnClose(int connection_id) { 124 connections_.erase(connection_id); 125 if (connections_.empty()) 126 all_closed_event_.Signal(); 127 } 128 129 void TestHttpServer::StartOnServerThread(bool* success, 130 base::WaitableEvent* event) { 131 net::TCPListenSocketFactory factory("127.0.0.1", 0); 132 server_ = new net::HttpServer(factory, this); 133 134 net::IPEndPoint address; 135 int error = server_->GetLocalAddress(&address); 136 EXPECT_EQ(net::OK, error); 137 if (error == net::OK) { 138 base::AutoLock lock(url_lock_); 139 web_socket_url_ = GURL(base::StringPrintf("ws://127.0.0.1:%d", 140 address.port())); 141 } else { 142 server_ = NULL; 143 } 144 *success = server_.get(); 145 event->Signal(); 146 } 147 148 void TestHttpServer::StopOnServerThread(base::WaitableEvent* event) { 149 if (server_.get()) 150 server_ = NULL; 151 event->Signal(); 152 } 153