1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/test/embedded_test_server/embedded_test_server.h" 6 7 #include "base/bind.h" 8 #include "base/files/file_path.h" 9 #include "base/file_util.h" 10 #include "base/path_service.h" 11 #include "base/run_loop.h" 12 #include "base/stl_util.h" 13 #include "base/strings/string_util.h" 14 #include "base/strings/stringprintf.h" 15 #include "base/threading/thread_restrictions.h" 16 #include "net/base/ip_endpoint.h" 17 #include "net/base/net_errors.h" 18 #include "net/test/embedded_test_server/http_connection.h" 19 #include "net/test/embedded_test_server/http_request.h" 20 #include "net/test/embedded_test_server/http_response.h" 21 #include "net/tools/fetch/http_listen_socket.h" 22 23 namespace net { 24 namespace test_server { 25 26 namespace { 27 28 class CustomHttpResponse : public HttpResponse { 29 public: 30 CustomHttpResponse(const std::string& headers, const std::string& contents) 31 : headers_(headers), contents_(contents) { 32 } 33 34 virtual std::string ToResponseString() const OVERRIDE { 35 return headers_ + "\r\n" + contents_; 36 } 37 38 private: 39 std::string headers_; 40 std::string contents_; 41 42 DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse); 43 }; 44 45 // Handles |request| by serving a file from under |server_root|. 46 scoped_ptr<HttpResponse> HandleFileRequest( 47 const base::FilePath& server_root, 48 const HttpRequest& request) { 49 // This is a test-only server. Ignore I/O thread restrictions. 50 base::ThreadRestrictions::ScopedAllowIO allow_io; 51 52 // Trim the first byte ('/'). 53 std::string request_path(request.relative_url.substr(1)); 54 55 // Remove the query string if present. 56 size_t query_pos = request_path.find('?'); 57 if (query_pos != std::string::npos) 58 request_path = request_path.substr(0, query_pos); 59 60 base::FilePath file_path(server_root.AppendASCII(request_path)); 61 std::string file_contents; 62 if (!file_util::ReadFileToString(file_path, &file_contents)) 63 return scoped_ptr<HttpResponse>(); 64 65 base::FilePath headers_path( 66 file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers"))); 67 68 if (base::PathExists(headers_path)) { 69 std::string headers_contents; 70 if (!file_util::ReadFileToString(headers_path, &headers_contents)) 71 return scoped_ptr<HttpResponse>(); 72 73 scoped_ptr<CustomHttpResponse> http_response( 74 new CustomHttpResponse(headers_contents, file_contents)); 75 return http_response.PassAs<HttpResponse>(); 76 } 77 78 scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse); 79 http_response->set_code(HTTP_OK); 80 http_response->set_content(file_contents); 81 return http_response.PassAs<HttpResponse>(); 82 } 83 84 } // namespace 85 86 HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor, 87 StreamListenSocket::Delegate* delegate) 88 : TCPListenSocket(socket_descriptor, delegate) { 89 DCHECK(thread_checker_.CalledOnValidThread()); 90 } 91 92 void HttpListenSocket::Listen() { 93 DCHECK(thread_checker_.CalledOnValidThread()); 94 TCPListenSocket::Listen(); 95 } 96 97 HttpListenSocket::~HttpListenSocket() { 98 DCHECK(thread_checker_.CalledOnValidThread()); 99 } 100 101 EmbeddedTestServer::EmbeddedTestServer( 102 const scoped_refptr<base::SingleThreadTaskRunner>& io_thread) 103 : io_thread_(io_thread), 104 port_(-1), 105 weak_factory_(this) { 106 DCHECK(io_thread_.get()); 107 DCHECK(thread_checker_.CalledOnValidThread()); 108 } 109 110 EmbeddedTestServer::~EmbeddedTestServer() { 111 DCHECK(thread_checker_.CalledOnValidThread()); 112 113 if (Started() && !ShutdownAndWaitUntilComplete()) { 114 LOG(ERROR) << "EmbeddedTestServer failed to shut down."; 115 } 116 } 117 118 bool EmbeddedTestServer::InitializeAndWaitUntilReady() { 119 DCHECK(thread_checker_.CalledOnValidThread()); 120 121 base::RunLoop run_loop; 122 if (!io_thread_->PostTaskAndReply( 123 FROM_HERE, 124 base::Bind(&EmbeddedTestServer::InitializeOnIOThread, 125 base::Unretained(this)), 126 run_loop.QuitClosure())) { 127 return false; 128 } 129 run_loop.Run(); 130 131 return Started() && base_url_.is_valid(); 132 } 133 134 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() { 135 DCHECK(thread_checker_.CalledOnValidThread()); 136 137 base::RunLoop run_loop; 138 if (!io_thread_->PostTaskAndReply( 139 FROM_HERE, 140 base::Bind(&EmbeddedTestServer::ShutdownOnIOThread, 141 base::Unretained(this)), 142 run_loop.QuitClosure())) { 143 return false; 144 } 145 run_loop.Run(); 146 147 return true; 148 } 149 150 void EmbeddedTestServer::InitializeOnIOThread() { 151 DCHECK(io_thread_->BelongsToCurrentThread()); 152 DCHECK(!Started()); 153 154 SocketDescriptor socket_descriptor = 155 TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_); 156 if (socket_descriptor == TCPListenSocket::kInvalidSocket) 157 return; 158 159 listen_socket_ = new HttpListenSocket(socket_descriptor, this); 160 listen_socket_->Listen(); 161 162 IPEndPoint address; 163 int result = listen_socket_->GetLocalAddress(&address); 164 if (result == OK) { 165 base_url_ = GURL(std::string("http://") + address.ToString()); 166 } else { 167 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result); 168 } 169 } 170 171 void EmbeddedTestServer::ShutdownOnIOThread() { 172 DCHECK(io_thread_->BelongsToCurrentThread()); 173 174 listen_socket_ = NULL; // Release the listen socket. 175 STLDeleteContainerPairSecondPointers(connections_.begin(), 176 connections_.end()); 177 connections_.clear(); 178 } 179 180 void EmbeddedTestServer::HandleRequest(HttpConnection* connection, 181 scoped_ptr<HttpRequest> request) { 182 DCHECK(io_thread_->BelongsToCurrentThread()); 183 184 bool request_handled = false; 185 186 for (size_t i = 0; i < request_handlers_.size(); ++i) { 187 scoped_ptr<HttpResponse> response = 188 request_handlers_[i].Run(*request.get()); 189 if (response.get()) { 190 connection->SendResponse(response.Pass()); 191 request_handled = true; 192 break; 193 } 194 } 195 196 if (!request_handled) { 197 LOG(WARNING) << "Request not handled. Returning 404: " 198 << request->relative_url; 199 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse); 200 not_found_response->set_code(HTTP_NOT_FOUND); 201 connection->SendResponse( 202 not_found_response.PassAs<HttpResponse>()); 203 } 204 205 // Drop the connection, since we do not support multiple requests per 206 // connection. 207 connections_.erase(connection->socket_.get()); 208 delete connection; 209 } 210 211 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { 212 DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */)) 213 << relative_url; 214 return base_url_.Resolve(relative_url); 215 } 216 217 void EmbeddedTestServer::ServeFilesFromDirectory( 218 const base::FilePath& directory) { 219 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory)); 220 } 221 222 void EmbeddedTestServer::RegisterRequestHandler( 223 const HandleRequestCallback& callback) { 224 request_handlers_.push_back(callback); 225 } 226 227 void EmbeddedTestServer::DidAccept(StreamListenSocket* server, 228 StreamListenSocket* connection) { 229 DCHECK(io_thread_->BelongsToCurrentThread()); 230 231 HttpConnection* http_connection = new HttpConnection( 232 connection, 233 base::Bind(&EmbeddedTestServer::HandleRequest, 234 weak_factory_.GetWeakPtr())); 235 connections_[connection] = http_connection; 236 } 237 238 void EmbeddedTestServer::DidRead(StreamListenSocket* connection, 239 const char* data, 240 int length) { 241 DCHECK(io_thread_->BelongsToCurrentThread()); 242 243 HttpConnection* http_connection = FindConnection(connection); 244 if (http_connection == NULL) { 245 LOG(WARNING) << "Unknown connection."; 246 return; 247 } 248 http_connection->ReceiveData(std::string(data, length)); 249 } 250 251 void EmbeddedTestServer::DidClose(StreamListenSocket* connection) { 252 DCHECK(io_thread_->BelongsToCurrentThread()); 253 254 HttpConnection* http_connection = FindConnection(connection); 255 if (http_connection == NULL) { 256 LOG(WARNING) << "Unknown connection."; 257 return; 258 } 259 delete http_connection; 260 connections_.erase(connection); 261 } 262 263 HttpConnection* EmbeddedTestServer::FindConnection( 264 StreamListenSocket* socket) { 265 DCHECK(io_thread_->BelongsToCurrentThread()); 266 267 std::map<StreamListenSocket*, HttpConnection*>::iterator it = 268 connections_.find(socket); 269 if (it == connections_.end()) { 270 return NULL; 271 } 272 return it->second; 273 } 274 275 } // namespace test_server 276 } // namespace net 277