1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/test/embedded_test_server/embedded_test_server.h" 6 7 #include "base/bind.h" 8 #include "base/files/file_path.h" 9 #include "base/files/file_util.h" 10 #include "base/message_loop/message_loop.h" 11 #include "base/path_service.h" 12 #include "base/process/process_metrics.h" 13 #include "base/run_loop.h" 14 #include "base/stl_util.h" 15 #include "base/strings/string_util.h" 16 #include "base/strings/stringprintf.h" 17 #include "base/threading/thread_restrictions.h" 18 #include "net/base/ip_endpoint.h" 19 #include "net/base/net_errors.h" 20 #include "net/test/embedded_test_server/http_connection.h" 21 #include "net/test/embedded_test_server/http_request.h" 22 #include "net/test/embedded_test_server/http_response.h" 23 24 namespace net { 25 namespace test_server { 26 27 namespace { 28 29 class CustomHttpResponse : public HttpResponse { 30 public: 31 CustomHttpResponse(const std::string& headers, const std::string& contents) 32 : headers_(headers), contents_(contents) { 33 } 34 35 virtual std::string ToResponseString() const OVERRIDE { 36 return headers_ + "\r\n" + contents_; 37 } 38 39 private: 40 std::string headers_; 41 std::string contents_; 42 43 DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse); 44 }; 45 46 // Handles |request| by serving a file from under |server_root|. 47 scoped_ptr<HttpResponse> HandleFileRequest( 48 const base::FilePath& server_root, 49 const HttpRequest& request) { 50 // This is a test-only server. Ignore I/O thread restrictions. 51 base::ThreadRestrictions::ScopedAllowIO allow_io; 52 53 // Trim the first byte ('/'). 54 std::string request_path(request.relative_url.substr(1)); 55 56 // Remove the query string if present. 57 size_t query_pos = request_path.find('?'); 58 if (query_pos != std::string::npos) 59 request_path = request_path.substr(0, query_pos); 60 61 base::FilePath file_path(server_root.AppendASCII(request_path)); 62 std::string file_contents; 63 if (!base::ReadFileToString(file_path, &file_contents)) 64 return scoped_ptr<HttpResponse>(); 65 66 base::FilePath headers_path( 67 file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers"))); 68 69 if (base::PathExists(headers_path)) { 70 std::string headers_contents; 71 if (!base::ReadFileToString(headers_path, &headers_contents)) 72 return scoped_ptr<HttpResponse>(); 73 74 scoped_ptr<CustomHttpResponse> http_response( 75 new CustomHttpResponse(headers_contents, file_contents)); 76 return http_response.PassAs<HttpResponse>(); 77 } 78 79 scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse); 80 http_response->set_code(HTTP_OK); 81 http_response->set_content(file_contents); 82 return http_response.PassAs<HttpResponse>(); 83 } 84 85 } // namespace 86 87 HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor, 88 StreamListenSocket::Delegate* delegate) 89 : TCPListenSocket(socket_descriptor, delegate) { 90 DCHECK(thread_checker_.CalledOnValidThread()); 91 } 92 93 void HttpListenSocket::Listen() { 94 DCHECK(thread_checker_.CalledOnValidThread()); 95 TCPListenSocket::Listen(); 96 } 97 98 void HttpListenSocket::ListenOnIOThread() { 99 DCHECK(thread_checker_.CalledOnValidThread()); 100 #if !defined(OS_POSIX) 101 // This method may be called after the IO thread is changed, thus we need to 102 // call |WatchSocket| again to make sure it listens on the current IO thread. 103 // Only needed for non POSIX platforms, since on POSIX platforms 104 // StreamListenSocket::Listen already calls WatchSocket inside the function. 105 WatchSocket(WAITING_ACCEPT); 106 #endif 107 Listen(); 108 } 109 110 HttpListenSocket::~HttpListenSocket() { 111 DCHECK(thread_checker_.CalledOnValidThread()); 112 } 113 114 void HttpListenSocket::DetachFromThread() { 115 thread_checker_.DetachFromThread(); 116 } 117 118 EmbeddedTestServer::EmbeddedTestServer() 119 : port_(-1), 120 weak_factory_(this) { 121 DCHECK(thread_checker_.CalledOnValidThread()); 122 } 123 124 EmbeddedTestServer::~EmbeddedTestServer() { 125 DCHECK(thread_checker_.CalledOnValidThread()); 126 127 if (Started() && !ShutdownAndWaitUntilComplete()) { 128 LOG(ERROR) << "EmbeddedTestServer failed to shut down."; 129 } 130 } 131 132 bool EmbeddedTestServer::InitializeAndWaitUntilReady() { 133 StartThread(); 134 DCHECK(thread_checker_.CalledOnValidThread()); 135 if (!PostTaskToIOThreadAndWait(base::Bind( 136 &EmbeddedTestServer::InitializeOnIOThread, base::Unretained(this)))) { 137 return false; 138 } 139 return Started() && base_url_.is_valid(); 140 } 141 142 void EmbeddedTestServer::StopThread() { 143 DCHECK(io_thread_ && io_thread_->IsRunning()); 144 145 #if defined(OS_LINUX) 146 const int thread_count = 147 base::GetNumberOfThreads(base::GetCurrentProcessHandle()); 148 #endif 149 150 io_thread_->Stop(); 151 io_thread_.reset(); 152 thread_checker_.DetachFromThread(); 153 listen_socket_->DetachFromThread(); 154 155 #if defined(OS_LINUX) 156 // Busy loop to wait for thread count to decrease. This is needed because 157 // pthread_join does not guarantee that kernel stat is updated when it 158 // returns. Thus, GetNumberOfThreads does not immediately reflect the stopped 159 // thread and hits the thread number DCHECK in render_sandbox_host_linux.cc 160 // in browser_tests. 161 while (thread_count == 162 base::GetNumberOfThreads(base::GetCurrentProcessHandle())) { 163 base::PlatformThread::YieldCurrentThread(); 164 } 165 #endif 166 } 167 168 void EmbeddedTestServer::RestartThreadAndListen() { 169 StartThread(); 170 CHECK(PostTaskToIOThreadAndWait(base::Bind( 171 &EmbeddedTestServer::ListenOnIOThread, base::Unretained(this)))); 172 } 173 174 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() { 175 DCHECK(thread_checker_.CalledOnValidThread()); 176 177 return PostTaskToIOThreadAndWait(base::Bind( 178 &EmbeddedTestServer::ShutdownOnIOThread, base::Unretained(this))); 179 } 180 181 void EmbeddedTestServer::StartThread() { 182 DCHECK(!io_thread_.get()); 183 base::Thread::Options thread_options; 184 thread_options.message_loop_type = base::MessageLoop::TYPE_IO; 185 io_thread_.reset(new base::Thread("EmbeddedTestServer io thread")); 186 CHECK(io_thread_->StartWithOptions(thread_options)); 187 } 188 189 void EmbeddedTestServer::InitializeOnIOThread() { 190 DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread()); 191 DCHECK(!Started()); 192 193 SocketDescriptor socket_descriptor = 194 TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_); 195 if (socket_descriptor == kInvalidSocket) 196 return; 197 198 listen_socket_.reset(new HttpListenSocket(socket_descriptor, this)); 199 listen_socket_->Listen(); 200 201 IPEndPoint address; 202 int result = listen_socket_->GetLocalAddress(&address); 203 if (result == OK) { 204 base_url_ = GURL(std::string("http://") + address.ToString()); 205 } else { 206 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result); 207 } 208 } 209 210 void EmbeddedTestServer::ListenOnIOThread() { 211 DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread()); 212 DCHECK(Started()); 213 listen_socket_->ListenOnIOThread(); 214 } 215 216 void EmbeddedTestServer::ShutdownOnIOThread() { 217 DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread()); 218 219 listen_socket_.reset(); 220 STLDeleteContainerPairSecondPointers(connections_.begin(), 221 connections_.end()); 222 connections_.clear(); 223 } 224 225 void EmbeddedTestServer::HandleRequest(HttpConnection* connection, 226 scoped_ptr<HttpRequest> request) { 227 DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread()); 228 229 bool request_handled = false; 230 231 for (size_t i = 0; i < request_handlers_.size(); ++i) { 232 scoped_ptr<HttpResponse> response = 233 request_handlers_[i].Run(*request.get()); 234 if (response.get()) { 235 connection->SendResponse(response.Pass()); 236 request_handled = true; 237 break; 238 } 239 } 240 241 if (!request_handled) { 242 LOG(WARNING) << "Request not handled. Returning 404: " 243 << request->relative_url; 244 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse); 245 not_found_response->set_code(HTTP_NOT_FOUND); 246 connection->SendResponse( 247 not_found_response.PassAs<HttpResponse>()); 248 } 249 250 // Drop the connection, since we do not support multiple requests per 251 // connection. 252 connections_.erase(connection->socket_.get()); 253 delete connection; 254 } 255 256 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { 257 DCHECK(Started()) << "You must start the server first."; 258 DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */)) 259 << relative_url; 260 return base_url_.Resolve(relative_url); 261 } 262 263 void EmbeddedTestServer::ServeFilesFromDirectory( 264 const base::FilePath& directory) { 265 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory)); 266 } 267 268 void EmbeddedTestServer::RegisterRequestHandler( 269 const HandleRequestCallback& callback) { 270 request_handlers_.push_back(callback); 271 } 272 273 void EmbeddedTestServer::DidAccept( 274 StreamListenSocket* server, 275 scoped_ptr<StreamListenSocket> connection) { 276 DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread()); 277 278 HttpConnection* http_connection = new HttpConnection( 279 connection.Pass(), 280 base::Bind(&EmbeddedTestServer::HandleRequest, 281 weak_factory_.GetWeakPtr())); 282 // TODO(szym): Make HttpConnection the StreamListenSocket delegate. 283 connections_[http_connection->socket_.get()] = http_connection; 284 } 285 286 void EmbeddedTestServer::DidRead(StreamListenSocket* connection, 287 const char* data, 288 int length) { 289 DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread()); 290 291 HttpConnection* http_connection = FindConnection(connection); 292 if (http_connection == NULL) { 293 LOG(WARNING) << "Unknown connection."; 294 return; 295 } 296 http_connection->ReceiveData(std::string(data, length)); 297 } 298 299 void EmbeddedTestServer::DidClose(StreamListenSocket* connection) { 300 DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread()); 301 302 HttpConnection* http_connection = FindConnection(connection); 303 if (http_connection == NULL) { 304 LOG(WARNING) << "Unknown connection."; 305 return; 306 } 307 delete http_connection; 308 connections_.erase(connection); 309 } 310 311 HttpConnection* EmbeddedTestServer::FindConnection( 312 StreamListenSocket* socket) { 313 DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread()); 314 315 std::map<StreamListenSocket*, HttpConnection*>::iterator it = 316 connections_.find(socket); 317 if (it == connections_.end()) { 318 return NULL; 319 } 320 return it->second; 321 } 322 323 bool EmbeddedTestServer::PostTaskToIOThreadAndWait( 324 const base::Closure& closure) { 325 // Note that PostTaskAndReply below requires base::MessageLoopProxy::current() 326 // to return a loop for posting the reply task. However, in order to make 327 // EmbeddedTestServer universally usable, it needs to cope with the situation 328 // where it's running on a thread on which a message loop is not (yet) 329 // available or as has been destroyed already. 330 // 331 // To handle this situation, create temporary message loop to support the 332 // PostTaskAndReply operation if the current thread as no message loop. 333 scoped_ptr<base::MessageLoop> temporary_loop; 334 if (!base::MessageLoop::current()) 335 temporary_loop.reset(new base::MessageLoop()); 336 337 base::RunLoop run_loop; 338 if (!io_thread_->message_loop_proxy()->PostTaskAndReply( 339 FROM_HERE, closure, run_loop.QuitClosure())) { 340 return false; 341 } 342 run_loop.Run(); 343 344 return true; 345 } 346 347 } // namespace test_server 348 } // namespace net 349