Home | History | Annotate | Download | only in embedded_test_server
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/test/embedded_test_server/embedded_test_server.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/file_util.h"
      9 #include "base/files/file_path.h"
     10 #include "base/message_loop/message_loop.h"
     11 #include "base/path_service.h"
     12 #include "base/process/process_metrics.h"
     13 #include "base/run_loop.h"
     14 #include "base/stl_util.h"
     15 #include "base/strings/string_util.h"
     16 #include "base/strings/stringprintf.h"
     17 #include "base/threading/thread_restrictions.h"
     18 #include "net/base/ip_endpoint.h"
     19 #include "net/base/net_errors.h"
     20 #include "net/test/embedded_test_server/http_connection.h"
     21 #include "net/test/embedded_test_server/http_request.h"
     22 #include "net/test/embedded_test_server/http_response.h"
     23 #include "net/tools/fetch/http_listen_socket.h"
     24 
     25 namespace net {
     26 namespace test_server {
     27 
     28 namespace {
     29 
     30 class CustomHttpResponse : public HttpResponse {
     31  public:
     32   CustomHttpResponse(const std::string& headers, const std::string& contents)
     33       : headers_(headers), contents_(contents) {
     34   }
     35 
     36   virtual std::string ToResponseString() const OVERRIDE {
     37     return headers_ + "\r\n" + contents_;
     38   }
     39 
     40  private:
     41   std::string headers_;
     42   std::string contents_;
     43 
     44   DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
     45 };
     46 
     47 // Handles |request| by serving a file from under |server_root|.
     48 scoped_ptr<HttpResponse> HandleFileRequest(
     49     const base::FilePath& server_root,
     50     const HttpRequest& request) {
     51   // This is a test-only server. Ignore I/O thread restrictions.
     52   base::ThreadRestrictions::ScopedAllowIO allow_io;
     53 
     54   // Trim the first byte ('/').
     55   std::string request_path(request.relative_url.substr(1));
     56 
     57   // Remove the query string if present.
     58   size_t query_pos = request_path.find('?');
     59   if (query_pos != std::string::npos)
     60     request_path = request_path.substr(0, query_pos);
     61 
     62   base::FilePath file_path(server_root.AppendASCII(request_path));
     63   std::string file_contents;
     64   if (!base::ReadFileToString(file_path, &file_contents))
     65     return scoped_ptr<HttpResponse>();
     66 
     67   base::FilePath headers_path(
     68       file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
     69 
     70   if (base::PathExists(headers_path)) {
     71     std::string headers_contents;
     72     if (!base::ReadFileToString(headers_path, &headers_contents))
     73       return scoped_ptr<HttpResponse>();
     74 
     75     scoped_ptr<CustomHttpResponse> http_response(
     76         new CustomHttpResponse(headers_contents, file_contents));
     77     return http_response.PassAs<HttpResponse>();
     78   }
     79 
     80   scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
     81   http_response->set_code(HTTP_OK);
     82   http_response->set_content(file_contents);
     83   return http_response.PassAs<HttpResponse>();
     84 }
     85 
     86 }  // namespace
     87 
     88 HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor,
     89                                    StreamListenSocket::Delegate* delegate)
     90     : TCPListenSocket(socket_descriptor, delegate) {
     91   DCHECK(thread_checker_.CalledOnValidThread());
     92 }
     93 
     94 void HttpListenSocket::Listen() {
     95   DCHECK(thread_checker_.CalledOnValidThread());
     96   TCPListenSocket::Listen();
     97 }
     98 
     99 HttpListenSocket::~HttpListenSocket() {
    100   DCHECK(thread_checker_.CalledOnValidThread());
    101 }
    102 
    103 void HttpListenSocket::DetachFromThread() {
    104   thread_checker_.DetachFromThread();
    105 }
    106 
    107 EmbeddedTestServer::EmbeddedTestServer()
    108     : port_(-1),
    109       weak_factory_(this) {
    110   DCHECK(thread_checker_.CalledOnValidThread());
    111 }
    112 
    113 EmbeddedTestServer::~EmbeddedTestServer() {
    114   DCHECK(thread_checker_.CalledOnValidThread());
    115 
    116   if (Started() && !ShutdownAndWaitUntilComplete()) {
    117     LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
    118   }
    119 }
    120 
    121 bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
    122   StartThread();
    123   DCHECK(thread_checker_.CalledOnValidThread());
    124   if (!PostTaskToIOThreadAndWait(base::Bind(
    125           &EmbeddedTestServer::InitializeOnIOThread, base::Unretained(this)))) {
    126     return false;
    127   }
    128   return Started() && base_url_.is_valid();
    129 }
    130 
    131 void EmbeddedTestServer::StopThread() {
    132   DCHECK(io_thread_ && io_thread_->IsRunning());
    133 
    134 #if defined(OS_LINUX)
    135   const int thread_count =
    136       base::GetNumberOfThreads(base::GetCurrentProcessHandle());
    137 #endif
    138 
    139   io_thread_->Stop();
    140   io_thread_.reset();
    141   thread_checker_.DetachFromThread();
    142   listen_socket_->DetachFromThread();
    143 
    144 #if defined(OS_LINUX)
    145   // Busy loop to wait for thread count to decrease. This is needed because
    146   // pthread_join does not guarantee that kernel stat is updated when it
    147   // returns. Thus, GetNumberOfThreads does not immediately reflect the stopped
    148   // thread and hits the thread number DCHECK in render_sandbox_host_linux.cc
    149   // in browser_tests.
    150   while (thread_count ==
    151          base::GetNumberOfThreads(base::GetCurrentProcessHandle())) {
    152     base::PlatformThread::YieldCurrentThread();
    153   }
    154 #endif
    155 }
    156 
    157 void EmbeddedTestServer::RestartThreadAndListen() {
    158   StartThread();
    159   CHECK(PostTaskToIOThreadAndWait(base::Bind(
    160       &EmbeddedTestServer::ListenOnIOThread, base::Unretained(this))));
    161 }
    162 
    163 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
    164   DCHECK(thread_checker_.CalledOnValidThread());
    165 
    166   return PostTaskToIOThreadAndWait(base::Bind(
    167       &EmbeddedTestServer::ShutdownOnIOThread, base::Unretained(this)));
    168 }
    169 
    170 void EmbeddedTestServer::StartThread() {
    171   DCHECK(!io_thread_.get());
    172   base::Thread::Options thread_options;
    173   thread_options.message_loop_type = base::MessageLoop::TYPE_IO;
    174   io_thread_.reset(new base::Thread("EmbeddedTestServer io thread"));
    175   CHECK(io_thread_->StartWithOptions(thread_options));
    176 }
    177 
    178 void EmbeddedTestServer::InitializeOnIOThread() {
    179   DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
    180   DCHECK(!Started());
    181 
    182   SocketDescriptor socket_descriptor =
    183       TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_);
    184   if (socket_descriptor == kInvalidSocket)
    185     return;
    186 
    187   listen_socket_.reset(new HttpListenSocket(socket_descriptor, this));
    188   listen_socket_->Listen();
    189 
    190   IPEndPoint address;
    191   int result = listen_socket_->GetLocalAddress(&address);
    192   if (result == OK) {
    193     base_url_ = GURL(std::string("http://") + address.ToString());
    194   } else {
    195     LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
    196   }
    197 }
    198 
    199 void EmbeddedTestServer::ListenOnIOThread() {
    200   DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
    201   DCHECK(Started());
    202   listen_socket_->Listen();
    203 }
    204 
    205 void EmbeddedTestServer::ShutdownOnIOThread() {
    206   DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
    207 
    208   listen_socket_.reset();
    209   STLDeleteContainerPairSecondPointers(connections_.begin(),
    210                                        connections_.end());
    211   connections_.clear();
    212 }
    213 
    214 void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
    215                                scoped_ptr<HttpRequest> request) {
    216   DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
    217 
    218   bool request_handled = false;
    219 
    220   for (size_t i = 0; i < request_handlers_.size(); ++i) {
    221     scoped_ptr<HttpResponse> response =
    222         request_handlers_[i].Run(*request.get());
    223     if (response.get()) {
    224       connection->SendResponse(response.Pass());
    225       request_handled = true;
    226       break;
    227     }
    228   }
    229 
    230   if (!request_handled) {
    231     LOG(WARNING) << "Request not handled. Returning 404: "
    232                  << request->relative_url;
    233     scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
    234     not_found_response->set_code(HTTP_NOT_FOUND);
    235     connection->SendResponse(
    236         not_found_response.PassAs<HttpResponse>());
    237   }
    238 
    239   // Drop the connection, since we do not support multiple requests per
    240   // connection.
    241   connections_.erase(connection->socket_.get());
    242   delete connection;
    243 }
    244 
    245 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
    246   DCHECK(Started()) << "You must start the server first.";
    247   DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */))
    248       << relative_url;
    249   return base_url_.Resolve(relative_url);
    250 }
    251 
    252 void EmbeddedTestServer::ServeFilesFromDirectory(
    253     const base::FilePath& directory) {
    254   RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
    255 }
    256 
    257 void EmbeddedTestServer::RegisterRequestHandler(
    258     const HandleRequestCallback& callback) {
    259   request_handlers_.push_back(callback);
    260 }
    261 
    262 void EmbeddedTestServer::DidAccept(
    263     StreamListenSocket* server,
    264     scoped_ptr<StreamListenSocket> connection) {
    265   DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
    266 
    267   HttpConnection* http_connection = new HttpConnection(
    268       connection.Pass(),
    269       base::Bind(&EmbeddedTestServer::HandleRequest,
    270                  weak_factory_.GetWeakPtr()));
    271   // TODO(szym): Make HttpConnection the StreamListenSocket delegate.
    272   connections_[http_connection->socket_.get()] = http_connection;
    273 }
    274 
    275 void EmbeddedTestServer::DidRead(StreamListenSocket* connection,
    276                          const char* data,
    277                          int length) {
    278   DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
    279 
    280   HttpConnection* http_connection = FindConnection(connection);
    281   if (http_connection == NULL) {
    282     LOG(WARNING) << "Unknown connection.";
    283     return;
    284   }
    285   http_connection->ReceiveData(std::string(data, length));
    286 }
    287 
    288 void EmbeddedTestServer::DidClose(StreamListenSocket* connection) {
    289   DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
    290 
    291   HttpConnection* http_connection = FindConnection(connection);
    292   if (http_connection == NULL) {
    293     LOG(WARNING) << "Unknown connection.";
    294     return;
    295   }
    296   delete http_connection;
    297   connections_.erase(connection);
    298 }
    299 
    300 HttpConnection* EmbeddedTestServer::FindConnection(
    301     StreamListenSocket* socket) {
    302   DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
    303 
    304   std::map<StreamListenSocket*, HttpConnection*>::iterator it =
    305       connections_.find(socket);
    306   if (it == connections_.end()) {
    307     return NULL;
    308   }
    309   return it->second;
    310 }
    311 
    312 bool EmbeddedTestServer::PostTaskToIOThreadAndWait(
    313     const base::Closure& closure) {
    314   // Note that PostTaskAndReply below requires base::MessageLoopProxy::current()
    315   // to return a loop for posting the reply task. However, in order to make
    316   // EmbeddedTestServer universally usable, it needs to cope with the situation
    317   // where it's running on a thread on which a message loop is not (yet)
    318   // available or as has been destroyed already.
    319   //
    320   // To handle this situation, create temporary message loop to support the
    321   // PostTaskAndReply operation if the current thread as no message loop.
    322   scoped_ptr<base::MessageLoop> temporary_loop;
    323   if (!base::MessageLoop::current())
    324     temporary_loop.reset(new base::MessageLoop());
    325 
    326   base::RunLoop run_loop;
    327   if (!io_thread_->message_loop_proxy()->PostTaskAndReply(
    328           FROM_HERE, closure, run_loop.QuitClosure())) {
    329     return false;
    330   }
    331   run_loop.Run();
    332 
    333   return true;
    334 }
    335 
    336 }  // namespace test_server
    337 }  // namespace net
    338