1 # Copyright 2013 The Chromium Authors. All rights reserved. 2 # Use of this source code is governed by a BSD-style license that can be 3 # found in the LICENSE file. 4 5 import BaseHTTPServer 6 import os 7 import threading 8 9 10 class Responder(object): 11 """Sends a HTTP response. Used with TestWebServer.""" 12 13 def __init__(self, handler): 14 self._handler = handler 15 16 def SendResponse(self, body): 17 """Sends OK response with body.""" 18 self.SendHeaders(len(body)) 19 self.SendBody(body) 20 21 def SendResponseFromFile(self, path): 22 """Sends OK response with the given file as the body.""" 23 with open(path, 'r') as f: 24 self.SendResponse(f.read()) 25 26 def SendHeaders(self, content_length=None): 27 """Sends headers for OK response.""" 28 self._handler.send_response(200) 29 if content_length: 30 self._handler.send_header('Content-Length', content_length) 31 self._handler.end_headers() 32 33 def SendError(self, code): 34 """Sends response for the given HTTP error code.""" 35 self._handler.send_error(code) 36 37 def SendBody(self, body): 38 """Just sends the body, no headers.""" 39 self._handler.wfile.write(body) 40 41 42 class Request(object): 43 """An HTTP request.""" 44 45 def __init__(self, handler): 46 self._handler = handler 47 48 def GetPath(self): 49 return self._handler.path 50 51 def GetHeader(self, name): 52 return self._handler.headers.getheader(name) 53 54 55 class _BaseServer(BaseHTTPServer.HTTPServer): 56 """Internal server that throws if timed out waiting for a request.""" 57 58 def __init__(self, on_request, server_cert_and_key_path=None): 59 """Starts the server. 60 61 It is an HTTP server if parameter server_cert_and_key_path is not provided. 62 Otherwise, it is an HTTPS server. 63 64 Args: 65 server_cert_and_key_path: path to a PEM file containing the cert and key. 66 if it is None, start the server as an HTTP one. 67 """ 68 class _Handler(BaseHTTPServer.BaseHTTPRequestHandler): 69 """Internal handler that just asks the server to handle the request.""" 70 71 def do_GET(self): 72 if self.path.endswith('favicon.ico'): 73 self.send_error(404) 74 return 75 on_request(Request(self), Responder(self)) 76 77 def log_message(self, *args, **kwargs): 78 """Overriddes base class method to disable logging.""" 79 pass 80 81 BaseHTTPServer.HTTPServer.__init__(self, ('127.0.0.1', 0), _Handler) 82 83 if server_cert_and_key_path is not None: 84 self._is_https_enabled = True 85 self._server.socket = ssl.wrap_socket( 86 self._server.socket, certfile=server_cert_and_key_path, 87 server_side=True) 88 else: 89 self._is_https_enabled = False 90 91 def handle_timeout(self): 92 """Overridden from SocketServer.""" 93 raise RuntimeError('Timed out waiting for http request') 94 95 def GetUrl(self): 96 """Returns the base URL of the server.""" 97 postfix = '://127.0.0.1:%s' % self.server_port 98 if self._is_https_enabled: 99 return 'https' + postfix 100 return 'http' + postfix 101 102 103 class WebServer(object): 104 """An HTTP or HTTPS server that serves on its own thread. 105 106 Serves files from given directory but may use custom data for specific paths. 107 """ 108 109 def __init__(self, root_dir, server_cert_and_key_path=None): 110 """Starts the server. 111 112 It is an HTTP server if parameter server_cert_and_key_path is not provided. 113 Otherwise, it is an HTTPS server. 114 115 Args: 116 root_dir: root path to serve files from. This parameter is required. 117 server_cert_and_key_path: path to a PEM file containing the cert and key. 118 if it is None, start the server as an HTTP one. 119 """ 120 self._root_dir = os.path.abspath(root_dir) 121 self._server = _BaseServer(self._OnRequest, server_cert_and_key_path) 122 self._thread = threading.Thread(target=self._server.serve_forever) 123 self._thread.daemon = True 124 self._thread.start() 125 self._path_data_map = {} 126 self._path_callback_map = {} 127 self._path_maps_lock = threading.Lock() 128 129 def _OnRequest(self, request, responder): 130 path = request.GetPath().split('?')[0] 131 132 # Serve from path -> callback and data maps. 133 self._path_maps_lock.acquire() 134 try: 135 if path in self._path_callback_map: 136 body = self._path_callback_map[path](request) 137 if body: 138 responder.SendResponse(body) 139 else: 140 responder.SendError(503) 141 return 142 143 if path in self._path_data_map: 144 responder.SendResponse(self._path_data_map[path]) 145 return 146 finally: 147 self._path_maps_lock.release() 148 149 # Serve from file. 150 path = os.path.normpath( 151 os.path.join(self._root_dir, *path.split('/'))) 152 if not path.startswith(self._root_dir): 153 responder.SendError(403) 154 return 155 if not os.path.exists(path): 156 responder.SendError(404) 157 return 158 responder.SendResponseFromFile(path) 159 160 def SetDataForPath(self, path, data): 161 self._path_maps_lock.acquire() 162 try: 163 self._path_data_map[path] = data 164 finally: 165 self._path_maps_lock.release() 166 167 def SetCallbackForPath(self, path, func): 168 self._path_maps_lock.acquire() 169 try: 170 self._path_callback_map[path] = func 171 finally: 172 self._path_maps_lock.release() 173 174 175 def GetUrl(self): 176 """Returns the base URL of the server.""" 177 return self._server.GetUrl() 178 179 def Shutdown(self): 180 """Shuts down the server synchronously.""" 181 self._server.shutdown() 182 self._thread.join() 183 184 185 class SyncWebServer(object): 186 """WebServer for testing. 187 188 Incoming requests are blocked until explicitly handled. 189 This was designed for single thread use. All requests should be handled on 190 the same thread. 191 """ 192 193 def __init__(self): 194 self._server = _BaseServer(self._OnRequest) 195 # Recognized by SocketServer. 196 self._server.timeout = 10 197 self._on_request = None 198 199 def _OnRequest(self, request, responder): 200 self._on_request(responder) 201 self._on_request = None 202 203 def Respond(self, on_request): 204 """Blocks until request comes in, then calls given handler function. 205 206 Args: 207 on_request: Function that handles the request. Invoked with single 208 parameter, an instance of Responder. 209 """ 210 if self._on_request: 211 raise RuntimeError('Must handle 1 request at a time.') 212 213 self._on_request = on_request 214 while self._on_request: 215 # Don't use handle_one_request, because it won't work with the timeout. 216 self._server.handle_request() 217 218 def RespondWithContent(self, content): 219 """Blocks until request comes in, then handles it with the given content.""" 220 def SendContent(responder): 221 responder.SendResponse(content) 222 self.Respond(SendContent) 223 224 def GetUrl(self): 225 return self._server.GetUrl() 226