1 // Copyright (c) 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <stdio.h> 6 #include <locale> 7 #include <string> 8 #include <vector> 9 10 #include "base/at_exit.h" 11 #include "base/bind.h" 12 #include "base/callback.h" 13 #include "base/command_line.h" 14 #include "base/files/file_path.h" 15 #include "base/lazy_instance.h" 16 #include "base/logging.h" 17 #include "base/memory/scoped_ptr.h" 18 #include "base/message_loop/message_loop.h" 19 #include "base/run_loop.h" 20 #include "base/strings/string_number_conversions.h" 21 #include "base/strings/string_split.h" 22 #include "base/strings/string_util.h" 23 #include "base/strings/stringprintf.h" 24 #include "base/synchronization/waitable_event.h" 25 #include "base/threading/thread.h" 26 #include "base/threading/thread_local.h" 27 #include "chrome/test/chromedriver/logging.h" 28 #include "chrome/test/chromedriver/net/port_server.h" 29 #include "chrome/test/chromedriver/server/http_handler.h" 30 #include "chrome/test/chromedriver/version.h" 31 #include "net/base/ip_endpoint.h" 32 #include "net/base/net_errors.h" 33 #include "net/server/http_server.h" 34 #include "net/server/http_server_request_info.h" 35 #include "net/server/http_server_response_info.h" 36 #include "net/socket/tcp_server_socket.h" 37 38 namespace { 39 40 const char* kLocalHostAddress = "127.0.0.1"; 41 const int kBufferSize = 100 * 1024 * 1024; // 100 MB 42 43 typedef base::Callback< 44 void(const net::HttpServerRequestInfo&, const HttpResponseSenderFunc&)> 45 HttpRequestHandlerFunc; 46 47 class HttpServer : public net::HttpServer::Delegate { 48 public: 49 explicit HttpServer(const HttpRequestHandlerFunc& handle_request_func) 50 : handle_request_func_(handle_request_func), 51 weak_factory_(this) {} 52 53 virtual ~HttpServer() {} 54 55 bool Start(int port, bool allow_remote) { 56 std::string binding_ip = kLocalHostAddress; 57 if (allow_remote) 58 binding_ip = "0.0.0.0"; 59 scoped_ptr<net::ServerSocket> server_socket( 60 new net::TCPServerSocket(NULL, net::NetLog::Source())); 61 server_socket->ListenWithAddressAndPort(binding_ip, port, 1); 62 server_.reset(new net::HttpServer(server_socket.Pass(), this)); 63 net::IPEndPoint address; 64 return server_->GetLocalAddress(&address) == net::OK; 65 } 66 67 // Overridden from net::HttpServer::Delegate: 68 virtual void OnConnect(int connection_id) OVERRIDE { 69 server_->SetSendBufferSize(connection_id, kBufferSize); 70 server_->SetReceiveBufferSize(connection_id, kBufferSize); 71 } 72 virtual void OnHttpRequest(int connection_id, 73 const net::HttpServerRequestInfo& info) OVERRIDE { 74 handle_request_func_.Run( 75 info, 76 base::Bind(&HttpServer::OnResponse, 77 weak_factory_.GetWeakPtr(), 78 connection_id)); 79 } 80 virtual void OnWebSocketRequest( 81 int connection_id, 82 const net::HttpServerRequestInfo& info) OVERRIDE {} 83 virtual void OnWebSocketMessage(int connection_id, 84 const std::string& data) OVERRIDE {} 85 virtual void OnClose(int connection_id) OVERRIDE {} 86 87 private: 88 void OnResponse(int connection_id, 89 scoped_ptr<net::HttpServerResponseInfo> response) { 90 // Don't support keep-alive, since there's no way to detect if the 91 // client is HTTP/1.0. In such cases, the client may hang waiting for 92 // the connection to close (e.g., python 2.7 urllib). 93 response->AddHeader("Connection", "close"); 94 server_->SendResponse(connection_id, *response); 95 // Don't need to call server_->Close(), since SendResponse() will handle 96 // this for us. 97 } 98 99 HttpRequestHandlerFunc handle_request_func_; 100 scoped_ptr<net::HttpServer> server_; 101 base::WeakPtrFactory<HttpServer> weak_factory_; // Should be last. 102 }; 103 104 void SendResponseOnCmdThread( 105 const scoped_refptr<base::SingleThreadTaskRunner>& io_task_runner, 106 const HttpResponseSenderFunc& send_response_on_io_func, 107 scoped_ptr<net::HttpServerResponseInfo> response) { 108 io_task_runner->PostTask( 109 FROM_HERE, base::Bind(send_response_on_io_func, base::Passed(&response))); 110 } 111 112 void HandleRequestOnCmdThread( 113 HttpHandler* handler, 114 const std::vector<std::string>& whitelisted_ips, 115 const net::HttpServerRequestInfo& request, 116 const HttpResponseSenderFunc& send_response_func) { 117 if (!whitelisted_ips.empty()) { 118 std::string peer_address = request.peer.ToStringWithoutPort(); 119 if (peer_address != kLocalHostAddress && 120 std::find(whitelisted_ips.begin(), whitelisted_ips.end(), 121 peer_address) == whitelisted_ips.end()) { 122 LOG(WARNING) << "unauthorized access from " << request.peer.ToString(); 123 scoped_ptr<net::HttpServerResponseInfo> response( 124 new net::HttpServerResponseInfo(net::HTTP_UNAUTHORIZED)); 125 response->SetBody("Unauthorized access", "text/plain"); 126 send_response_func.Run(response.Pass()); 127 return; 128 } 129 } 130 131 handler->Handle(request, send_response_func); 132 } 133 134 void HandleRequestOnIOThread( 135 const scoped_refptr<base::SingleThreadTaskRunner>& cmd_task_runner, 136 const HttpRequestHandlerFunc& handle_request_on_cmd_func, 137 const net::HttpServerRequestInfo& request, 138 const HttpResponseSenderFunc& send_response_func) { 139 cmd_task_runner->PostTask( 140 FROM_HERE, 141 base::Bind(handle_request_on_cmd_func, 142 request, 143 base::Bind(&SendResponseOnCmdThread, 144 base::MessageLoopProxy::current(), 145 send_response_func))); 146 } 147 148 base::LazyInstance<base::ThreadLocalPointer<HttpServer> > 149 lazy_tls_server = LAZY_INSTANCE_INITIALIZER; 150 151 void StopServerOnIOThread() { 152 // Note, |server| may be NULL. 153 HttpServer* server = lazy_tls_server.Pointer()->Get(); 154 lazy_tls_server.Pointer()->Set(NULL); 155 delete server; 156 } 157 158 void StartServerOnIOThread(int port, 159 bool allow_remote, 160 const HttpRequestHandlerFunc& handle_request_func) { 161 scoped_ptr<HttpServer> temp_server(new HttpServer(handle_request_func)); 162 if (!temp_server->Start(port, allow_remote)) { 163 printf("Port not available. Exiting...\n"); 164 exit(1); 165 } 166 lazy_tls_server.Pointer()->Set(temp_server.release()); 167 } 168 169 void RunServer(int port, 170 bool allow_remote, 171 const std::vector<std::string>& whitelisted_ips, 172 const std::string& url_base, 173 int adb_port, 174 scoped_ptr<PortServer> port_server) { 175 base::Thread io_thread("ChromeDriver IO"); 176 CHECK(io_thread.StartWithOptions( 177 base::Thread::Options(base::MessageLoop::TYPE_IO, 0))); 178 179 base::MessageLoop cmd_loop; 180 base::RunLoop cmd_run_loop; 181 HttpHandler handler(cmd_run_loop.QuitClosure(), 182 io_thread.message_loop_proxy(), 183 url_base, 184 adb_port, 185 port_server.Pass()); 186 HttpRequestHandlerFunc handle_request_func = 187 base::Bind(&HandleRequestOnCmdThread, &handler, whitelisted_ips); 188 189 io_thread.message_loop() 190 ->PostTask(FROM_HERE, 191 base::Bind(&StartServerOnIOThread, 192 port, 193 allow_remote, 194 base::Bind(&HandleRequestOnIOThread, 195 cmd_loop.message_loop_proxy(), 196 handle_request_func))); 197 // Run the command loop. This loop is quit after the response for a shutdown 198 // request is posted to the IO loop. After the command loop quits, a task 199 // is posted to the IO loop to stop the server. Lastly, the IO thread is 200 // destroyed, which waits until all pending tasks have been completed. 201 // This assumes the response is sent synchronously as part of the IO task. 202 cmd_run_loop.Run(); 203 io_thread.message_loop() 204 ->PostTask(FROM_HERE, base::Bind(&StopServerOnIOThread)); 205 } 206 207 } // namespace 208 209 int main(int argc, char *argv[]) { 210 CommandLine::Init(argc, argv); 211 212 base::AtExitManager at_exit; 213 CommandLine* cmd_line = CommandLine::ForCurrentProcess(); 214 215 #if defined(OS_LINUX) 216 // Select the locale from the environment by passing an empty string instead 217 // of the default "C" locale. This is particularly needed for the keycode 218 // conversion code to work. 219 setlocale(LC_ALL, ""); 220 #endif 221 222 // Parse command line flags. 223 int port = 9515; 224 int adb_port = 5037; 225 bool allow_remote = false; 226 std::vector<std::string> whitelisted_ips; 227 std::string url_base; 228 scoped_ptr<PortServer> port_server; 229 if (cmd_line->HasSwitch("h") || cmd_line->HasSwitch("help")) { 230 std::string options; 231 const char* kOptionAndDescriptions[] = { 232 "port=PORT", "port to listen on", 233 "adb-port=PORT", "adb server port", 234 "log-path=FILE", "write server log to file instead of stderr, " 235 "increases log level to INFO", 236 "verbose", "log verbosely", 237 "version", "print the version number and exit", 238 "silent", "log nothing", 239 "url-base", "base URL path prefix for commands, e.g. wd/url", 240 "port-server", "address of server to contact for reserving a port", 241 "whitelisted-ips", "comma-separated whitelist of remote IPv4 addresses " 242 "which are allowed to connect to ChromeDriver", 243 }; 244 for (size_t i = 0; i < arraysize(kOptionAndDescriptions) - 1; i += 2) { 245 options += base::StringPrintf( 246 " --%-30s%s\n", 247 kOptionAndDescriptions[i], kOptionAndDescriptions[i + 1]); 248 } 249 printf("Usage: %s [OPTIONS]\n\nOptions\n%s", argv[0], options.c_str()); 250 return 0; 251 } 252 if (cmd_line->HasSwitch("v") || cmd_line->HasSwitch("version")) { 253 printf("ChromeDriver %s\n", kChromeDriverVersion); 254 return 0; 255 } 256 if (cmd_line->HasSwitch("port")) { 257 if (!base::StringToInt(cmd_line->GetSwitchValueASCII("port"), &port)) { 258 printf("Invalid port. Exiting...\n"); 259 return 1; 260 } 261 } 262 if (cmd_line->HasSwitch("adb-port")) { 263 if (!base::StringToInt(cmd_line->GetSwitchValueASCII("adb-port"), 264 &adb_port)) { 265 printf("Invalid adb-port. Exiting...\n"); 266 return 1; 267 } 268 } 269 if (cmd_line->HasSwitch("port-server")) { 270 #if defined(OS_LINUX) 271 std::string address = cmd_line->GetSwitchValueASCII("port-server"); 272 if (address.empty() || address[0] != '@') { 273 printf("Invalid port-server. Exiting...\n"); 274 return 1; 275 } 276 std::string path; 277 // First character of path is \0 to use Linux's abstract namespace. 278 path.push_back(0); 279 path += address.substr(1); 280 port_server.reset(new PortServer(path)); 281 #else 282 printf("Warning: port-server not implemented for this platform.\n"); 283 #endif 284 } 285 if (cmd_line->HasSwitch("url-base")) 286 url_base = cmd_line->GetSwitchValueASCII("url-base"); 287 if (url_base.empty() || url_base[0] != '/') 288 url_base = "/" + url_base; 289 if (url_base[url_base.length() - 1] != '/') 290 url_base = url_base + "/"; 291 if (cmd_line->HasSwitch("whitelisted-ips")) { 292 allow_remote = true; 293 std::string whitelist = cmd_line->GetSwitchValueASCII("whitelisted-ips"); 294 base::SplitString(whitelist, ',', &whitelisted_ips); 295 } 296 if (!cmd_line->HasSwitch("silent")) { 297 printf( 298 "Starting ChromeDriver (v%s) on port %d\n", kChromeDriverVersion, port); 299 if (!allow_remote) { 300 printf("Only local connections are allowed.\n"); 301 } else if (!whitelisted_ips.empty()) { 302 printf("Remote connections are allowed by a whitelist (%s).\n", 303 cmd_line->GetSwitchValueASCII("whitelisted-ips").c_str()); 304 } else { 305 printf("All remote connections are allowed. Use a whitelist instead!\n"); 306 } 307 fflush(stdout); 308 } 309 310 if (!InitLogging()) { 311 printf("Unable to initialize logging. Exiting...\n"); 312 return 1; 313 } 314 RunServer(port, allow_remote, whitelisted_ips, 315 url_base, adb_port, port_server.Pass()); 316 return 0; 317 } 318