1 # Copyright (c) 2012 The Chromium OS Authors. All rights reserved. 2 # Use of this source code is governed by a BSD-style license that can be 3 # found in the LICENSE file. 4 5 """Spins up a trivial HTTP cgi form listener in a thread. 6 7 This HTTPThread class is a utility for use with test cases that 8 need to call back to the Autotest test case with some form value, e.g. 9 http://localhost:nnnn/?status="Browser started!" 10 """ 11 12 import cgi, errno, logging, os, posixpath, SimpleHTTPServer, socket, ssl, sys 13 import threading, urllib, urlparse 14 from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer 15 from SocketServer import BaseServer, ThreadingMixIn 16 17 18 def _handle_http_errors(func): 19 """Decorator function for cleaner presentation of certain exceptions.""" 20 def wrapper(self): 21 try: 22 func(self) 23 except IOError, e: 24 if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET: 25 # Instead of dumping a stack trace, a single line is sufficient. 26 self.log_error(str(e)) 27 else: 28 raise 29 30 return wrapper 31 32 33 class FormHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): 34 """Implements a form handler (for POST requests only) which simply 35 echoes the key=value parameters back in the response. 36 37 If the form submission is a file upload, the file will be written 38 to disk with the name contained in the 'filename' field. 39 """ 40 41 SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({ 42 '.webm': 'video/webm', 43 }) 44 45 # Override the default logging methods to use the logging module directly. 46 def log_error(self, format, *args): 47 logging.warning("(httpd error) %s - - [%s] %s\n" % 48 (self.address_string(), self.log_date_time_string(), 49 format%args)) 50 51 def log_message(self, format, *args): 52 logging.debug("%s - - [%s] %s\n" % 53 (self.address_string(), self.log_date_time_string(), 54 format%args)) 55 56 @_handle_http_errors 57 def do_POST(self): 58 form = cgi.FieldStorage( 59 fp=self.rfile, 60 headers=self.headers, 61 environ={'REQUEST_METHOD': 'POST', 62 'CONTENT_TYPE': self.headers['Content-Type']}) 63 # You'd think form.keys() would just return [], like it does for empty 64 # python dicts; you'd be wrong. It raises TypeError if called when it 65 # has no keys. 66 if form: 67 for field in form.keys(): 68 field_item = form[field] 69 self.server._form_entries[field] = field_item.value 70 path = urlparse.urlparse(self.path)[2] 71 if path in self.server._url_handlers: 72 self.server._url_handlers[path](self, form) 73 else: 74 # Echo back information about what was posted in the form. 75 self.write_post_response(form) 76 self._fire_event() 77 78 79 def write_post_response(self, form): 80 """Called to fill out the response to an HTTP POST. 81 82 Override this class to give custom responses. 83 """ 84 # Send response boilerplate 85 self.send_response(200) 86 self.end_headers() 87 self.wfile.write('Hello from Autotest!\nClient: %s\n' % 88 str(self.client_address)) 89 self.wfile.write('Request for path: %s\n' % self.path) 90 self.wfile.write('Got form data:\n') 91 92 # See the note in do_POST about form.keys(). 93 if form: 94 for field in form.keys(): 95 field_item = form[field] 96 if field_item.filename: 97 # The field contains an uploaded file 98 upload = field_item.file.read() 99 self.wfile.write('\tUploaded %s (%d bytes)<br>' % 100 (field, len(upload))) 101 # Write submitted file to specified filename. 102 file(field_item.filename, 'w').write(upload) 103 del upload 104 else: 105 self.wfile.write('\t%s=%s<br>' % (field, form[field].value)) 106 107 108 def translate_path(self, path): 109 """Override SimpleHTTPRequestHandler's translate_path to serve 110 from arbitrary docroot 111 """ 112 # abandon query parameters 113 path = urlparse.urlparse(path)[2] 114 path = posixpath.normpath(urllib.unquote(path)) 115 words = path.split('/') 116 words = filter(None, words) 117 path = self.server.docroot 118 for word in words: 119 drive, word = os.path.splitdrive(word) 120 head, word = os.path.split(word) 121 if word in (os.curdir, os.pardir): continue 122 path = os.path.join(path, word) 123 logging.debug('Translated path: %s', path) 124 return path 125 126 127 def _fire_event(self): 128 wait_urls = self.server._wait_urls 129 if self.path in wait_urls: 130 _, e = wait_urls[self.path] 131 e.set() 132 del wait_urls[self.path] 133 else: 134 logging.debug('URL %s not in watch list' % self.path) 135 136 137 @_handle_http_errors 138 def do_GET(self): 139 form = cgi.FieldStorage( 140 fp=self.rfile, 141 headers=self.headers, 142 environ={'REQUEST_METHOD': 'GET'}) 143 split_url = urlparse.urlsplit(self.path) 144 path = split_url[2] 145 # Strip off query parameters to ensure that the url path 146 # matches any registered events. 147 self.path = path 148 args = urlparse.parse_qs(split_url[3]) 149 if path in self.server._url_handlers: 150 self.server._url_handlers[path](self, args) 151 else: 152 SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self) 153 self._fire_event() 154 155 156 @_handle_http_errors 157 def do_HEAD(self): 158 SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self) 159 160 161 class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): 162 def __init__(self, server_address, HandlerClass): 163 HTTPServer.__init__(self, server_address, HandlerClass) 164 165 166 class HTTPListener(object): 167 # Point default docroot to a non-existent directory (instead of None) to 168 # avoid exceptions when page content is served through handlers only. 169 def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}): 170 self._server = ThreadedHTTPServer(('', port), FormHandler) 171 self.config_server(self._server, docroot, wait_urls, url_handlers) 172 173 def config_server(self, server, docroot, wait_urls, url_handlers): 174 # Stuff some convenient data fields into the server object. 175 self._server.docroot = docroot 176 self._server._wait_urls = wait_urls 177 self._server._url_handlers = url_handlers 178 self._server._form_entries = {} 179 self._server_thread = threading.Thread( 180 target=self._server.serve_forever) 181 182 183 def add_wait_url(self, url='/', matchParams={}): 184 e = threading.Event() 185 self._server._wait_urls[url] = (matchParams, e) 186 return e 187 188 189 def add_url_handler(self, url, handler_func): 190 self._server._url_handlers[url] = handler_func 191 192 193 def clear_form_entries(self): 194 self._server._form_entries = {} 195 196 197 def get_form_entries(self): 198 """Returns a dictionary of all field=values recieved by the server. 199 """ 200 return self._server._form_entries 201 202 203 def run(self): 204 logging.debug('http server on %s:%d' % 205 (self._server.server_name, self._server.server_port)) 206 self._server_thread.start() 207 208 209 def stop(self): 210 self._server.shutdown() 211 self._server.socket.close() 212 self._server_thread.join() 213 214 215 class SecureHTTPServer(ThreadingMixIn, HTTPServer): 216 def __init__(self, server_address, HandlerClass, cert_path, key_path): 217 _socket = socket.socket(self.address_family, self.socket_type) 218 self.socket = ssl.wrap_socket(_socket, 219 server_side=True, 220 ssl_version=ssl.PROTOCOL_TLSv1, 221 certfile=cert_path, 222 keyfile=key_path) 223 BaseServer.__init__(self, server_address, HandlerClass) 224 self.server_bind() 225 self.server_activate() 226 227 228 class SecureHTTPRequestHandler(FormHandler): 229 def setup(self): 230 self.connection = self.request 231 self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize) 232 self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize) 233 234 # Override the default logging methods to use the logging module directly. 235 def log_error(self, format, *args): 236 logging.warning("(httpd error) %s - - [%s] %s\n" % 237 (self.address_string(), self.log_date_time_string(), 238 format%args)) 239 240 def log_message(self, format, *args): 241 logging.debug("%s - - [%s] %s\n" % 242 (self.address_string(), self.log_date_time_string(), 243 format%args)) 244 245 246 class SecureHTTPListener(HTTPListener): 247 def __init__(self, 248 cert_path='/etc/login_trust_root.pem', 249 key_path='/etc/mock_server.key', 250 port=0, 251 docroot='/_', 252 wait_urls={}, 253 url_handlers={}): 254 self._server = SecureHTTPServer(('', port), 255 SecureHTTPRequestHandler, 256 cert_path, 257 key_path) 258 self.config_server(self._server, docroot, wait_urls, url_handlers) 259 260 261 def getsockname(self): 262 return self._server.socket.getsockname() 263 264