Home | History | Annotate | Download | only in embedded_test_server
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/test/embedded_test_server/embedded_test_server.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/files/file_path.h"
      9 #include "base/file_util.h"
     10 #include "base/path_service.h"
     11 #include "base/run_loop.h"
     12 #include "base/stl_util.h"
     13 #include "base/strings/string_util.h"
     14 #include "base/strings/stringprintf.h"
     15 #include "base/threading/thread_restrictions.h"
     16 #include "net/base/ip_endpoint.h"
     17 #include "net/base/net_errors.h"
     18 #include "net/test/embedded_test_server/http_connection.h"
     19 #include "net/test/embedded_test_server/http_request.h"
     20 #include "net/test/embedded_test_server/http_response.h"
     21 #include "net/tools/fetch/http_listen_socket.h"
     22 
     23 namespace net {
     24 namespace test_server {
     25 
     26 namespace {
     27 
     28 class CustomHttpResponse : public HttpResponse {
     29  public:
     30   CustomHttpResponse(const std::string& headers, const std::string& contents)
     31       : headers_(headers), contents_(contents) {
     32   }
     33 
     34   virtual std::string ToResponseString() const OVERRIDE {
     35     return headers_ + "\r\n" + contents_;
     36   }
     37 
     38  private:
     39   std::string headers_;
     40   std::string contents_;
     41 
     42   DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
     43 };
     44 
     45 // Handles |request| by serving a file from under |server_root|.
     46 scoped_ptr<HttpResponse> HandleFileRequest(
     47     const base::FilePath& server_root,
     48     const HttpRequest& request) {
     49   // This is a test-only server. Ignore I/O thread restrictions.
     50   base::ThreadRestrictions::ScopedAllowIO allow_io;
     51 
     52   // Trim the first byte ('/').
     53   std::string request_path(request.relative_url.substr(1));
     54 
     55   // Remove the query string if present.
     56   size_t query_pos = request_path.find('?');
     57   if (query_pos != std::string::npos)
     58     request_path = request_path.substr(0, query_pos);
     59 
     60   base::FilePath file_path(server_root.AppendASCII(request_path));
     61   std::string file_contents;
     62   if (!file_util::ReadFileToString(file_path, &file_contents))
     63     return scoped_ptr<HttpResponse>();
     64 
     65   base::FilePath headers_path(
     66       file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
     67 
     68   if (base::PathExists(headers_path)) {
     69     std::string headers_contents;
     70     if (!file_util::ReadFileToString(headers_path, &headers_contents))
     71       return scoped_ptr<HttpResponse>();
     72 
     73     scoped_ptr<CustomHttpResponse> http_response(
     74         new CustomHttpResponse(headers_contents, file_contents));
     75     return http_response.PassAs<HttpResponse>();
     76   }
     77 
     78   scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
     79   http_response->set_code(HTTP_OK);
     80   http_response->set_content(file_contents);
     81   return http_response.PassAs<HttpResponse>();
     82 }
     83 
     84 }  // namespace
     85 
     86 HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor,
     87                                    StreamListenSocket::Delegate* delegate)
     88     : TCPListenSocket(socket_descriptor, delegate) {
     89   DCHECK(thread_checker_.CalledOnValidThread());
     90 }
     91 
     92 void HttpListenSocket::Listen() {
     93   DCHECK(thread_checker_.CalledOnValidThread());
     94   TCPListenSocket::Listen();
     95 }
     96 
     97 HttpListenSocket::~HttpListenSocket() {
     98   DCHECK(thread_checker_.CalledOnValidThread());
     99 }
    100 
    101 EmbeddedTestServer::EmbeddedTestServer(
    102     const scoped_refptr<base::SingleThreadTaskRunner>& io_thread)
    103     : io_thread_(io_thread),
    104       port_(-1),
    105       weak_factory_(this) {
    106   DCHECK(io_thread_.get());
    107   DCHECK(thread_checker_.CalledOnValidThread());
    108 }
    109 
    110 EmbeddedTestServer::~EmbeddedTestServer() {
    111   DCHECK(thread_checker_.CalledOnValidThread());
    112 
    113   if (Started() && !ShutdownAndWaitUntilComplete()) {
    114     LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
    115   }
    116 }
    117 
    118 bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
    119   DCHECK(thread_checker_.CalledOnValidThread());
    120 
    121   base::RunLoop run_loop;
    122   if (!io_thread_->PostTaskAndReply(
    123           FROM_HERE,
    124           base::Bind(&EmbeddedTestServer::InitializeOnIOThread,
    125                      base::Unretained(this)),
    126           run_loop.QuitClosure())) {
    127     return false;
    128   }
    129   run_loop.Run();
    130 
    131   return Started() && base_url_.is_valid();
    132 }
    133 
    134 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
    135   DCHECK(thread_checker_.CalledOnValidThread());
    136 
    137   base::RunLoop run_loop;
    138   if (!io_thread_->PostTaskAndReply(
    139           FROM_HERE,
    140           base::Bind(&EmbeddedTestServer::ShutdownOnIOThread,
    141                      base::Unretained(this)),
    142           run_loop.QuitClosure())) {
    143     return false;
    144   }
    145   run_loop.Run();
    146 
    147   return true;
    148 }
    149 
    150 void EmbeddedTestServer::InitializeOnIOThread() {
    151   DCHECK(io_thread_->BelongsToCurrentThread());
    152   DCHECK(!Started());
    153 
    154   SocketDescriptor socket_descriptor =
    155       TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_);
    156   if (socket_descriptor == TCPListenSocket::kInvalidSocket)
    157     return;
    158 
    159   listen_socket_ = new HttpListenSocket(socket_descriptor, this);
    160   listen_socket_->Listen();
    161 
    162   IPEndPoint address;
    163   int result = listen_socket_->GetLocalAddress(&address);
    164   if (result == OK) {
    165     base_url_ = GURL(std::string("http://") + address.ToString());
    166   } else {
    167     LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
    168   }
    169 }
    170 
    171 void EmbeddedTestServer::ShutdownOnIOThread() {
    172   DCHECK(io_thread_->BelongsToCurrentThread());
    173 
    174   listen_socket_ = NULL;  // Release the listen socket.
    175   STLDeleteContainerPairSecondPointers(connections_.begin(),
    176                                        connections_.end());
    177   connections_.clear();
    178 }
    179 
    180 void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
    181                                scoped_ptr<HttpRequest> request) {
    182   DCHECK(io_thread_->BelongsToCurrentThread());
    183 
    184   bool request_handled = false;
    185 
    186   for (size_t i = 0; i < request_handlers_.size(); ++i) {
    187     scoped_ptr<HttpResponse> response =
    188         request_handlers_[i].Run(*request.get());
    189     if (response.get()) {
    190       connection->SendResponse(response.Pass());
    191       request_handled = true;
    192       break;
    193     }
    194   }
    195 
    196   if (!request_handled) {
    197     LOG(WARNING) << "Request not handled. Returning 404: "
    198                  << request->relative_url;
    199     scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
    200     not_found_response->set_code(HTTP_NOT_FOUND);
    201     connection->SendResponse(
    202         not_found_response.PassAs<HttpResponse>());
    203   }
    204 
    205   // Drop the connection, since we do not support multiple requests per
    206   // connection.
    207   connections_.erase(connection->socket_.get());
    208   delete connection;
    209 }
    210 
    211 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
    212   DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */))
    213       << relative_url;
    214   return base_url_.Resolve(relative_url);
    215 }
    216 
    217 void EmbeddedTestServer::ServeFilesFromDirectory(
    218     const base::FilePath& directory) {
    219   RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
    220 }
    221 
    222 void EmbeddedTestServer::RegisterRequestHandler(
    223     const HandleRequestCallback& callback) {
    224   request_handlers_.push_back(callback);
    225 }
    226 
    227 void EmbeddedTestServer::DidAccept(StreamListenSocket* server,
    228                            StreamListenSocket* connection) {
    229   DCHECK(io_thread_->BelongsToCurrentThread());
    230 
    231   HttpConnection* http_connection = new HttpConnection(
    232       connection,
    233       base::Bind(&EmbeddedTestServer::HandleRequest,
    234                  weak_factory_.GetWeakPtr()));
    235   connections_[connection] = http_connection;
    236 }
    237 
    238 void EmbeddedTestServer::DidRead(StreamListenSocket* connection,
    239                          const char* data,
    240                          int length) {
    241   DCHECK(io_thread_->BelongsToCurrentThread());
    242 
    243   HttpConnection* http_connection = FindConnection(connection);
    244   if (http_connection == NULL) {
    245     LOG(WARNING) << "Unknown connection.";
    246     return;
    247   }
    248   http_connection->ReceiveData(std::string(data, length));
    249 }
    250 
    251 void EmbeddedTestServer::DidClose(StreamListenSocket* connection) {
    252   DCHECK(io_thread_->BelongsToCurrentThread());
    253 
    254   HttpConnection* http_connection = FindConnection(connection);
    255   if (http_connection == NULL) {
    256     LOG(WARNING) << "Unknown connection.";
    257     return;
    258   }
    259   delete http_connection;
    260   connections_.erase(connection);
    261 }
    262 
    263 HttpConnection* EmbeddedTestServer::FindConnection(
    264     StreamListenSocket* socket) {
    265   DCHECK(io_thread_->BelongsToCurrentThread());
    266 
    267   std::map<StreamListenSocket*, HttpConnection*>::iterator it =
    268       connections_.find(socket);
    269   if (it == connections_.end()) {
    270     return NULL;
    271   }
    272   return it->second;
    273 }
    274 
    275 }  // namespace test_server
    276 }  // namespace net
    277