Home | History | Annotate | Download | only in webdriver
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/test/webdriver/webdriver_dispatch.h"
      6 
      7 #include <sstream>
      8 #include <string>
      9 #include <vector>
     10 
     11 #include "base/command_line.h"
     12 #include "base/format_macros.h"
     13 #include "base/json/json_reader.h"
     14 #include "base/logging.h"
     15 #include "base/memory/scoped_ptr.h"
     16 #include "base/message_loop/message_loop_proxy.h"
     17 #include "base/strings/string_split.h"
     18 #include "base/strings/string_util.h"
     19 #include "base/strings/stringprintf.h"
     20 #include "base/synchronization/waitable_event.h"
     21 #include "base/sys_info.h"
     22 #include "base/threading/platform_thread.h"
     23 #include "base/threading/thread.h"
     24 #include "chrome/common/chrome_version_info.h"
     25 #include "chrome/test/webdriver/commands/command.h"
     26 #include "chrome/test/webdriver/http_response.h"
     27 #include "chrome/test/webdriver/webdriver_logging.h"
     28 #include "chrome/test/webdriver/webdriver_session_manager.h"
     29 #include "chrome/test/webdriver/webdriver_switches.h"
     30 #include "chrome/test/webdriver/webdriver_util.h"
     31 
     32 namespace webdriver {
     33 
     34 namespace {
     35 
     36 // Maximum safe size of HTTP response message. Any larger than this,
     37 // the message may not be transferred at all.
     38 const size_t kMaxHttpMessageSize = 1024 * 1024 * 16;  // 16MB
     39 
     40 void ReadRequestBody(const struct mg_request_info* const request_info,
     41                      struct mg_connection* const connection,
     42                      std::string* request_body) {
     43   int content_length = 0;
     44   // 64 maximum header count hard-coded in mongoose.h
     45   for (int header_index = 0; header_index < 64; ++header_index) {
     46     if (request_info->http_headers[header_index].name == NULL) {
     47       break;
     48     }
     49     if (strcmp(request_info->http_headers[header_index].name,
     50                "Content-Length") == 0) {
     51       content_length = atoi(request_info->http_headers[header_index].value);
     52       break;
     53     }
     54   }
     55   if (content_length > 0) {
     56     request_body->resize(content_length);
     57     int bytes_read = 0;
     58     while (bytes_read < content_length) {
     59       bytes_read += mg_read(connection,
     60                             &(*request_body)[bytes_read],
     61                             content_length - bytes_read);
     62     }
     63   }
     64 }
     65 
     66 void WriteHttpResponse(struct mg_connection* connection,
     67                        const HttpResponse& response) {
     68   HttpResponse modified_response(response);
     69   if (!CommandLine::ForCurrentProcess()->HasSwitch(kEnableKeepAlive))
     70     modified_response.AddHeader("connection", "close");
     71   std::string data;
     72   modified_response.GetData(&data);
     73   mg_write(connection, data.data(), data.length());
     74 }
     75 
     76 void DispatchCommand(Command* const command,
     77                      const std::string& method,
     78                      Response* response) {
     79   if (!command->Init(response))
     80     return;
     81 
     82   if (method == "POST") {
     83     command->ExecutePost(response);
     84   } else if (method == "GET") {
     85     command->ExecuteGet(response);
     86   } else if (method == "DELETE") {
     87     command->ExecuteDelete(response);
     88   } else {
     89     NOTREACHED();
     90   }
     91   command->Finish(response);
     92 }
     93 
     94 void SendOkWithBody(struct mg_connection* connection,
     95                     const std::string& content) {
     96   HttpResponse response;
     97   response.set_body(content);
     98   WriteHttpResponse(connection, response);
     99 }
    100 
    101 void Shutdown(struct mg_connection* connection,
    102               const struct mg_request_info* request_info,
    103               void* user_data) {
    104   base::WaitableEvent* shutdown_event =
    105       reinterpret_cast<base::WaitableEvent*>(user_data);
    106   WriteHttpResponse(connection, HttpResponse());
    107   shutdown_event->Signal();
    108 }
    109 
    110 void SendStatus(struct mg_connection* connection,
    111                 const struct mg_request_info* request_info,
    112                 void* user_data) {
    113   chrome::VersionInfo version_info;
    114   base::DictionaryValue* build_info = new base::DictionaryValue;
    115   build_info->SetString("time",
    116                         base::StringPrintf("%s %s PST", __DATE__, __TIME__));
    117   build_info->SetString("version", version_info.Version());
    118   build_info->SetString("revision", version_info.LastChange());
    119 
    120   base::DictionaryValue* os_info = new base::DictionaryValue;
    121   os_info->SetString("name", base::SysInfo::OperatingSystemName());
    122   os_info->SetString("version", base::SysInfo::OperatingSystemVersion());
    123   os_info->SetString("arch", base::SysInfo::OperatingSystemArchitecture());
    124 
    125   base::DictionaryValue* status = new base::DictionaryValue;
    126   status->Set("build", build_info);
    127   status->Set("os", os_info);
    128 
    129   Response response;
    130   response.SetStatus(kSuccess);
    131   response.SetValue(status);  // Assumes ownership of |status|.
    132 
    133   internal::SendResponse(connection,
    134                          request_info->request_method,
    135                          response);
    136 }
    137 
    138 void SendLog(struct mg_connection* connection,
    139              const struct mg_request_info* request_info,
    140              void* user_data) {
    141   std::string content, log;
    142   if (FileLog::Get()->GetLogContents(&log)) {
    143     content = "START ChromeDriver log";
    144     const size_t kMaxSizeWithoutHeaders = kMaxHttpMessageSize - 10000;
    145     if (log.size() > kMaxSizeWithoutHeaders) {
    146       log = log.substr(log.size() - kMaxSizeWithoutHeaders);
    147       content += " (only last several MB)";
    148     }
    149     content += ":\n" + log + "END ChromeDriver log";
    150   } else {
    151     content = "No ChromeDriver log found";
    152   }
    153   SendOkWithBody(connection, content);
    154 }
    155 
    156 void SimulateHang(struct mg_connection* connection,
    157                   const struct mg_request_info* request_info,
    158                   void* user_data) {
    159   base::PlatformThread::Sleep(base::TimeDelta::FromMinutes(5));
    160 }
    161 
    162 void SendNoContentResponse(struct mg_connection* connection,
    163                            const struct mg_request_info* request_info,
    164                            void* user_data) {
    165   WriteHttpResponse(connection, HttpResponse(HttpResponse::kNoContent));
    166 }
    167 
    168 void SendForbidden(struct mg_connection* connection,
    169                    const struct mg_request_info* request_info,
    170                    void* user_data) {
    171   WriteHttpResponse(connection, HttpResponse(HttpResponse::kForbidden));
    172 }
    173 
    174 void SendNotImplementedError(struct mg_connection* connection,
    175                              const struct mg_request_info* request_info,
    176                              void* user_data) {
    177   // Send a well-formed WebDriver JSON error response to ensure clients
    178   // handle it correctly.
    179   std::string body = base::StringPrintf(
    180       "{\"status\":%d,\"value\":{\"message\":"
    181       "\"Command has not been implemented yet: %s %s\"}}",
    182       kUnknownCommand, request_info->request_method, request_info->uri);
    183 
    184   HttpResponse response(HttpResponse::kNotImplemented);
    185   response.AddHeader("Content-Type", "application/json");
    186   response.set_body(body);
    187   WriteHttpResponse(connection, response);
    188 }
    189 
    190 }  // namespace
    191 
    192 namespace internal {
    193 
    194 void PrepareHttpResponse(const Response& command_response,
    195                          HttpResponse* const http_response) {
    196   ErrorCode status = command_response.GetStatus();
    197   switch (status) {
    198     case kSuccess:
    199       http_response->set_status(HttpResponse::kOk);
    200       break;
    201 
    202     // TODO(jleyba): kSeeOther, kBadRequest, kSessionNotFound,
    203     // and kMethodNotAllowed should be detected before creating
    204     // a command_response, and should thus not need conversion.
    205     case kSeeOther: {
    206       const base::Value* const value = command_response.GetValue();
    207       std::string location;
    208       if (!value->GetAsString(&location)) {
    209         // This should never happen.
    210         http_response->set_status(HttpResponse::kInternalServerError);
    211         http_response->set_body("Unable to set 'Location' header: response "
    212                                "value is not a string: " +
    213                                command_response.ToJSON());
    214         return;
    215       }
    216       http_response->AddHeader("Location", location);
    217       http_response->set_status(HttpResponse::kSeeOther);
    218       break;
    219     }
    220 
    221     case kBadRequest:
    222     case kSessionNotFound:
    223       http_response->set_status(status);
    224       break;
    225 
    226     case kMethodNotAllowed: {
    227       const base::Value* const value = command_response.GetValue();
    228       if (!value->IsType(base::Value::TYPE_LIST)) {
    229         // This should never happen.
    230         http_response->set_status(HttpResponse::kInternalServerError);
    231         http_response->set_body(
    232             "Unable to set 'Allow' header: response value was "
    233             "not a list of strings: " + command_response.ToJSON());
    234         return;
    235       }
    236 
    237       const base::ListValue* const list_value =
    238           static_cast<const base::ListValue* const>(value);
    239       std::vector<std::string> allowed_methods;
    240       for (size_t i = 0; i < list_value->GetSize(); ++i) {
    241         std::string method;
    242         if (list_value->GetString(i, &method)) {
    243           allowed_methods.push_back(method);
    244         } else {
    245           // This should never happen.
    246           http_response->set_status(HttpResponse::kInternalServerError);
    247           http_response->set_body(
    248               "Unable to set 'Allow' header: response value was "
    249               "not a list of strings: " + command_response.ToJSON());
    250           return;
    251         }
    252       }
    253       http_response->AddHeader("Allow", JoinString(allowed_methods, ','));
    254       http_response->set_status(HttpResponse::kMethodNotAllowed);
    255       break;
    256     }
    257 
    258     // All other errors should be treated as generic 500s. The client
    259     // will be responsible for inspecting the message body for details.
    260     case kInternalServerError:
    261     default:
    262       http_response->set_status(HttpResponse::kInternalServerError);
    263       break;
    264   }
    265 
    266   http_response->SetMimeType("application/json; charset=utf-8");
    267   http_response->set_body(command_response.ToJSON());
    268 }
    269 
    270 void SendResponse(struct mg_connection* const connection,
    271                   const std::string& request_method,
    272                   const Response& response) {
    273   HttpResponse http_response;
    274   PrepareHttpResponse(response, &http_response);
    275   WriteHttpResponse(connection, http_response);
    276 }
    277 
    278 bool ParseRequestInfo(const struct mg_request_info* const request_info,
    279                       struct mg_connection* const connection,
    280                       std::string* method,
    281                       std::vector<std::string>* path_segments,
    282                       base::DictionaryValue** parameters,
    283                       Response* const response) {
    284   *method = request_info->request_method;
    285   if (*method == "HEAD")
    286     *method = "GET";
    287   else if (*method == "PUT")
    288     *method = "POST";
    289 
    290   std::string uri(request_info->uri);
    291   SessionManager* manager = SessionManager::GetInstance();
    292   uri = uri.substr(manager->url_base().length());
    293 
    294   base::SplitString(uri, '/', path_segments);
    295 
    296   if (*method == "POST") {
    297     std::string json;
    298     ReadRequestBody(request_info, connection, &json);
    299     if (json.length() > 0) {
    300       std::string error_msg;
    301       scoped_ptr<base::Value> params(base::JSONReader::ReadAndReturnError(
    302           json, base::JSON_ALLOW_TRAILING_COMMAS, NULL, &error_msg));
    303       if (!params.get()) {
    304         response->SetError(new Error(
    305             kBadRequest,
    306             "Failed to parse command data: " + error_msg +
    307                 "\n  Data: " + json));
    308         return false;
    309       }
    310       if (!params->IsType(base::Value::TYPE_DICTIONARY)) {
    311         response->SetError(new Error(
    312             kBadRequest,
    313             "Data passed in URL must be a dictionary. Data: " + json));
    314         return false;
    315       }
    316       *parameters = static_cast<base::DictionaryValue*>(params.release());
    317     }
    318   }
    319   return true;
    320 }
    321 
    322 void DispatchHelper(Command* command_ptr,
    323                     const std::string& method,
    324                     Response* response) {
    325   CHECK(method == "GET" || method == "POST" || method == "DELETE");
    326   scoped_ptr<Command> command(command_ptr);
    327 
    328   if ((method == "GET" && !command->DoesGet()) ||
    329       (method == "POST" && !command->DoesPost()) ||
    330       (method == "DELETE" && !command->DoesDelete())) {
    331     base::ListValue* methods = new base::ListValue;
    332     if (command->DoesPost())
    333       methods->Append(new base::StringValue("POST"));
    334     if (command->DoesGet()) {
    335       methods->Append(new base::StringValue("GET"));
    336       methods->Append(new base::StringValue("HEAD"));
    337     }
    338     if (command->DoesDelete())
    339       methods->Append(new base::StringValue("DELETE"));
    340     response->SetStatus(kMethodNotAllowed);
    341     response->SetValue(methods);
    342     return;
    343   }
    344 
    345   DispatchCommand(command.get(), method, response);
    346 }
    347 
    348 }  // namespace internal
    349 
    350 Dispatcher::Dispatcher(const std::string& url_base)
    351     : url_base_(url_base) {
    352   // Overwrite mongoose's default handler for /favicon.ico to always return a
    353   // 204 response so we don't spam the logs with 404s.
    354   AddCallback("/favicon.ico", &SendNoContentResponse, NULL);
    355   AddCallback("/hang", &SimulateHang, NULL);
    356 }
    357 
    358 Dispatcher::~Dispatcher() {}
    359 
    360 void Dispatcher::AddShutdown(const std::string& pattern,
    361                              base::WaitableEvent* shutdown_event) {
    362   AddCallback(url_base_ + pattern, &Shutdown, shutdown_event);
    363 }
    364 
    365 void Dispatcher::AddStatus(const std::string& pattern) {
    366   AddCallback(url_base_ + pattern, &SendStatus, NULL);
    367 }
    368 
    369 void Dispatcher::AddLog(const std::string& pattern) {
    370   AddCallback(url_base_ + pattern, &SendLog, NULL);
    371 }
    372 
    373 void Dispatcher::SetNotImplemented(const std::string& pattern) {
    374   AddCallback(url_base_ + pattern, &SendNotImplementedError, NULL);
    375 }
    376 
    377 void Dispatcher::ForbidAllOtherRequests() {
    378   AddCallback("*", &SendForbidden, NULL);
    379 }
    380 
    381 void Dispatcher::AddCallback(const std::string& uri_pattern,
    382                              webdriver::mongoose::HttpCallback callback,
    383                              void* user_data) {
    384   callbacks_.push_back(webdriver::mongoose::CallbackDetails(
    385     uri_pattern,
    386     callback,
    387     user_data));
    388 }
    389 
    390 
    391 bool Dispatcher::ProcessHttpRequest(
    392     struct mg_connection* connection,
    393     const struct mg_request_info* request_info) {
    394   std::vector<webdriver::mongoose::CallbackDetails>::const_iterator callback;
    395   for (callback = callbacks_.begin();
    396        callback < callbacks_.end();
    397        ++callback) {
    398     if (MatchPattern(request_info->uri, callback->uri_regex_)) {
    399       callback->func_(connection, request_info, callback->user_data_);
    400       return true;
    401     }
    402   }
    403   return false;
    404 }
    405 
    406 }  // namespace webdriver
    407