Home | History | Annotate | Download | only in server
      1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <stdio.h>
      6 #include <locale>
      7 #include <string>
      8 #include <vector>
      9 
     10 #include "base/at_exit.h"
     11 #include "base/bind.h"
     12 #include "base/callback.h"
     13 #include "base/command_line.h"
     14 #include "base/files/file_path.h"
     15 #include "base/lazy_instance.h"
     16 #include "base/logging.h"
     17 #include "base/memory/scoped_ptr.h"
     18 #include "base/message_loop/message_loop.h"
     19 #include "base/run_loop.h"
     20 #include "base/strings/string_number_conversions.h"
     21 #include "base/strings/string_split.h"
     22 #include "base/strings/string_util.h"
     23 #include "base/strings/stringprintf.h"
     24 #include "base/synchronization/waitable_event.h"
     25 #include "base/threading/thread.h"
     26 #include "base/threading/thread_local.h"
     27 #include "chrome/test/chromedriver/logging.h"
     28 #include "chrome/test/chromedriver/net/port_server.h"
     29 #include "chrome/test/chromedriver/server/http_handler.h"
     30 #include "chrome/test/chromedriver/version.h"
     31 #include "net/base/ip_endpoint.h"
     32 #include "net/base/net_errors.h"
     33 #include "net/server/http_server.h"
     34 #include "net/server/http_server_request_info.h"
     35 #include "net/server/http_server_response_info.h"
     36 #include "net/socket/tcp_server_socket.h"
     37 
     38 namespace {
     39 
     40 const char* kLocalHostAddress = "127.0.0.1";
     41 const int kBufferSize = 100 * 1024 * 1024;  // 100 MB
     42 
     43 typedef base::Callback<
     44     void(const net::HttpServerRequestInfo&, const HttpResponseSenderFunc&)>
     45     HttpRequestHandlerFunc;
     46 
     47 class HttpServer : public net::HttpServer::Delegate {
     48  public:
     49   explicit HttpServer(const HttpRequestHandlerFunc& handle_request_func)
     50       : handle_request_func_(handle_request_func),
     51         weak_factory_(this) {}
     52 
     53   virtual ~HttpServer() {}
     54 
     55   bool Start(int port, bool allow_remote) {
     56     std::string binding_ip = kLocalHostAddress;
     57     if (allow_remote)
     58       binding_ip = "0.0.0.0";
     59     scoped_ptr<net::ServerSocket> server_socket(
     60         new net::TCPServerSocket(NULL, net::NetLog::Source()));
     61     server_socket->ListenWithAddressAndPort(binding_ip, port, 1);
     62     server_.reset(new net::HttpServer(server_socket.Pass(), this));
     63     net::IPEndPoint address;
     64     return server_->GetLocalAddress(&address) == net::OK;
     65   }
     66 
     67   // Overridden from net::HttpServer::Delegate:
     68   virtual void OnConnect(int connection_id) OVERRIDE {
     69     server_->SetSendBufferSize(connection_id, kBufferSize);
     70     server_->SetReceiveBufferSize(connection_id, kBufferSize);
     71   }
     72   virtual void OnHttpRequest(int connection_id,
     73                              const net::HttpServerRequestInfo& info) OVERRIDE {
     74     handle_request_func_.Run(
     75         info,
     76         base::Bind(&HttpServer::OnResponse,
     77                    weak_factory_.GetWeakPtr(),
     78                    connection_id));
     79   }
     80   virtual void OnWebSocketRequest(
     81       int connection_id,
     82       const net::HttpServerRequestInfo& info) OVERRIDE {}
     83   virtual void OnWebSocketMessage(int connection_id,
     84                                   const std::string& data) OVERRIDE {}
     85   virtual void OnClose(int connection_id) OVERRIDE {}
     86 
     87  private:
     88   void OnResponse(int connection_id,
     89                   scoped_ptr<net::HttpServerResponseInfo> response) {
     90     // Don't support keep-alive, since there's no way to detect if the
     91     // client is HTTP/1.0. In such cases, the client may hang waiting for
     92     // the connection to close (e.g., python 2.7 urllib).
     93     response->AddHeader("Connection", "close");
     94     server_->SendResponse(connection_id, *response);
     95     // Don't need to call server_->Close(), since SendResponse() will handle
     96     // this for us.
     97   }
     98 
     99   HttpRequestHandlerFunc handle_request_func_;
    100   scoped_ptr<net::HttpServer> server_;
    101   base::WeakPtrFactory<HttpServer> weak_factory_;  // Should be last.
    102 };
    103 
    104 void SendResponseOnCmdThread(
    105     const scoped_refptr<base::SingleThreadTaskRunner>& io_task_runner,
    106     const HttpResponseSenderFunc& send_response_on_io_func,
    107     scoped_ptr<net::HttpServerResponseInfo> response) {
    108   io_task_runner->PostTask(
    109       FROM_HERE, base::Bind(send_response_on_io_func, base::Passed(&response)));
    110 }
    111 
    112 void HandleRequestOnCmdThread(
    113     HttpHandler* handler,
    114     const std::vector<std::string>& whitelisted_ips,
    115     const net::HttpServerRequestInfo& request,
    116     const HttpResponseSenderFunc& send_response_func) {
    117   if (!whitelisted_ips.empty()) {
    118     std::string peer_address = request.peer.ToStringWithoutPort();
    119     if (peer_address != kLocalHostAddress &&
    120         std::find(whitelisted_ips.begin(), whitelisted_ips.end(),
    121                   peer_address) == whitelisted_ips.end()) {
    122       LOG(WARNING) << "unauthorized access from " << request.peer.ToString();
    123       scoped_ptr<net::HttpServerResponseInfo> response(
    124           new net::HttpServerResponseInfo(net::HTTP_UNAUTHORIZED));
    125       response->SetBody("Unauthorized access", "text/plain");
    126       send_response_func.Run(response.Pass());
    127       return;
    128     }
    129   }
    130 
    131   handler->Handle(request, send_response_func);
    132 }
    133 
    134 void HandleRequestOnIOThread(
    135     const scoped_refptr<base::SingleThreadTaskRunner>& cmd_task_runner,
    136     const HttpRequestHandlerFunc& handle_request_on_cmd_func,
    137     const net::HttpServerRequestInfo& request,
    138     const HttpResponseSenderFunc& send_response_func) {
    139   cmd_task_runner->PostTask(
    140       FROM_HERE,
    141       base::Bind(handle_request_on_cmd_func,
    142                  request,
    143                  base::Bind(&SendResponseOnCmdThread,
    144                             base::MessageLoopProxy::current(),
    145                             send_response_func)));
    146 }
    147 
    148 base::LazyInstance<base::ThreadLocalPointer<HttpServer> >
    149     lazy_tls_server = LAZY_INSTANCE_INITIALIZER;
    150 
    151 void StopServerOnIOThread() {
    152   // Note, |server| may be NULL.
    153   HttpServer* server = lazy_tls_server.Pointer()->Get();
    154   lazy_tls_server.Pointer()->Set(NULL);
    155   delete server;
    156 }
    157 
    158 void StartServerOnIOThread(int port,
    159                            bool allow_remote,
    160                            const HttpRequestHandlerFunc& handle_request_func) {
    161   scoped_ptr<HttpServer> temp_server(new HttpServer(handle_request_func));
    162   if (!temp_server->Start(port, allow_remote)) {
    163     printf("Port not available. Exiting...\n");
    164     exit(1);
    165   }
    166   lazy_tls_server.Pointer()->Set(temp_server.release());
    167 }
    168 
    169 void RunServer(int port,
    170                bool allow_remote,
    171                const std::vector<std::string>& whitelisted_ips,
    172                const std::string& url_base,
    173                int adb_port,
    174                scoped_ptr<PortServer> port_server) {
    175   base::Thread io_thread("ChromeDriver IO");
    176   CHECK(io_thread.StartWithOptions(
    177       base::Thread::Options(base::MessageLoop::TYPE_IO, 0)));
    178 
    179   base::MessageLoop cmd_loop;
    180   base::RunLoop cmd_run_loop;
    181   HttpHandler handler(cmd_run_loop.QuitClosure(),
    182                       io_thread.message_loop_proxy(),
    183                       url_base,
    184                       adb_port,
    185                       port_server.Pass());
    186   HttpRequestHandlerFunc handle_request_func =
    187       base::Bind(&HandleRequestOnCmdThread, &handler, whitelisted_ips);
    188 
    189   io_thread.message_loop()
    190       ->PostTask(FROM_HERE,
    191                  base::Bind(&StartServerOnIOThread,
    192                             port,
    193                             allow_remote,
    194                             base::Bind(&HandleRequestOnIOThread,
    195                                        cmd_loop.message_loop_proxy(),
    196                                        handle_request_func)));
    197   // Run the command loop. This loop is quit after the response for a shutdown
    198   // request is posted to the IO loop. After the command loop quits, a task
    199   // is posted to the IO loop to stop the server. Lastly, the IO thread is
    200   // destroyed, which waits until all pending tasks have been completed.
    201   // This assumes the response is sent synchronously as part of the IO task.
    202   cmd_run_loop.Run();
    203   io_thread.message_loop()
    204       ->PostTask(FROM_HERE, base::Bind(&StopServerOnIOThread));
    205 }
    206 
    207 }  // namespace
    208 
    209 int main(int argc, char *argv[]) {
    210   CommandLine::Init(argc, argv);
    211 
    212   base::AtExitManager at_exit;
    213   CommandLine* cmd_line = CommandLine::ForCurrentProcess();
    214 
    215 #if defined(OS_LINUX)
    216   // Select the locale from the environment by passing an empty string instead
    217   // of the default "C" locale. This is particularly needed for the keycode
    218   // conversion code to work.
    219   setlocale(LC_ALL, "");
    220 #endif
    221 
    222   // Parse command line flags.
    223   int port = 9515;
    224   int adb_port = 5037;
    225   bool allow_remote = false;
    226   std::vector<std::string> whitelisted_ips;
    227   std::string url_base;
    228   scoped_ptr<PortServer> port_server;
    229   if (cmd_line->HasSwitch("h") || cmd_line->HasSwitch("help")) {
    230     std::string options;
    231     const char* kOptionAndDescriptions[] = {
    232         "port=PORT", "port to listen on",
    233         "adb-port=PORT", "adb server port",
    234         "log-path=FILE", "write server log to file instead of stderr, "
    235             "increases log level to INFO",
    236         "verbose", "log verbosely",
    237         "version", "print the version number and exit",
    238         "silent", "log nothing",
    239         "url-base", "base URL path prefix for commands, e.g. wd/url",
    240         "port-server", "address of server to contact for reserving a port",
    241         "whitelisted-ips", "comma-separated whitelist of remote IPv4 addresses "
    242             "which are allowed to connect to ChromeDriver",
    243     };
    244     for (size_t i = 0; i < arraysize(kOptionAndDescriptions) - 1; i += 2) {
    245       options += base::StringPrintf(
    246           "  --%-30s%s\n",
    247           kOptionAndDescriptions[i], kOptionAndDescriptions[i + 1]);
    248     }
    249     printf("Usage: %s [OPTIONS]\n\nOptions\n%s", argv[0], options.c_str());
    250     return 0;
    251   }
    252   if (cmd_line->HasSwitch("v") || cmd_line->HasSwitch("version")) {
    253     printf("ChromeDriver %s\n", kChromeDriverVersion);
    254     return 0;
    255   }
    256   if (cmd_line->HasSwitch("port")) {
    257     if (!base::StringToInt(cmd_line->GetSwitchValueASCII("port"), &port)) {
    258       printf("Invalid port. Exiting...\n");
    259       return 1;
    260     }
    261   }
    262   if (cmd_line->HasSwitch("adb-port")) {
    263     if (!base::StringToInt(cmd_line->GetSwitchValueASCII("adb-port"),
    264                            &adb_port)) {
    265       printf("Invalid adb-port. Exiting...\n");
    266       return 1;
    267     }
    268   }
    269   if (cmd_line->HasSwitch("port-server")) {
    270 #if defined(OS_LINUX)
    271     std::string address = cmd_line->GetSwitchValueASCII("port-server");
    272     if (address.empty() || address[0] != '@') {
    273       printf("Invalid port-server. Exiting...\n");
    274       return 1;
    275     }
    276     std::string path;
    277     // First character of path is \0 to use Linux's abstract namespace.
    278     path.push_back(0);
    279     path += address.substr(1);
    280     port_server.reset(new PortServer(path));
    281 #else
    282     printf("Warning: port-server not implemented for this platform.\n");
    283 #endif
    284   }
    285   if (cmd_line->HasSwitch("url-base"))
    286     url_base = cmd_line->GetSwitchValueASCII("url-base");
    287   if (url_base.empty() || url_base[0] != '/')
    288     url_base = "/" + url_base;
    289   if (url_base[url_base.length() - 1] != '/')
    290     url_base = url_base + "/";
    291   if (cmd_line->HasSwitch("whitelisted-ips")) {
    292     allow_remote = true;
    293     std::string whitelist = cmd_line->GetSwitchValueASCII("whitelisted-ips");
    294     base::SplitString(whitelist, ',', &whitelisted_ips);
    295   }
    296   if (!cmd_line->HasSwitch("silent")) {
    297     printf(
    298         "Starting ChromeDriver (v%s) on port %d\n", kChromeDriverVersion, port);
    299     if (!allow_remote) {
    300       printf("Only local connections are allowed.\n");
    301     } else if (!whitelisted_ips.empty()) {
    302       printf("Remote connections are allowed by a whitelist (%s).\n",
    303              cmd_line->GetSwitchValueASCII("whitelisted-ips").c_str());
    304     } else {
    305       printf("All remote connections are allowed. Use a whitelist instead!\n");
    306     }
    307     fflush(stdout);
    308   }
    309 
    310   if (!InitLogging()) {
    311     printf("Unable to initialize logging. Exiting...\n");
    312     return 1;
    313   }
    314   RunServer(port, allow_remote, whitelisted_ips,
    315             url_base, adb_port, port_server.Pass());
    316   return 0;
    317 }
    318