Home | History | Annotate | Download | only in cros
      1 # Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 """Spins up a trivial HTTP cgi form listener in a thread.
      6 
      7    This HTTPThread class is a utility for use with test cases that
      8    need to call back to the Autotest test case with some form value, e.g.
      9    http://localhost:nnnn/?status="Browser started!"
     10 """
     11 
     12 import cgi, errno, logging, os, posixpath, SimpleHTTPServer, socket, ssl, sys
     13 import threading, urllib, urlparse
     14 from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
     15 from SocketServer import BaseServer, ThreadingMixIn
     16 
     17 
     18 def _handle_http_errors(func):
     19     """Decorator function for cleaner presentation of certain exceptions."""
     20     def wrapper(self):
     21         try:
     22             func(self)
     23         except IOError, e:
     24             if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET:
     25                 # Instead of dumping a stack trace, a single line is sufficient.
     26                 self.log_error(str(e))
     27             else:
     28                 raise
     29 
     30     return wrapper
     31 
     32 
     33 class FormHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
     34     """Implements a form handler (for POST requests only) which simply
     35     echoes the key=value parameters back in the response.
     36 
     37     If the form submission is a file upload, the file will be written
     38     to disk with the name contained in the 'filename' field.
     39     """
     40 
     41     SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({
     42         '.webm': 'video/webm',
     43     })
     44 
     45     # Override the default logging methods to use the logging module directly.
     46     def log_error(self, format, *args):
     47         logging.warning("(httpd error) %s - - [%s] %s\n" %
     48                      (self.address_string(), self.log_date_time_string(),
     49                       format%args))
     50 
     51     def log_message(self, format, *args):
     52         logging.debug("%s - - [%s] %s\n" %
     53                      (self.address_string(), self.log_date_time_string(),
     54                       format%args))
     55 
     56     @_handle_http_errors
     57     def do_POST(self):
     58         form = cgi.FieldStorage(
     59             fp=self.rfile,
     60             headers=self.headers,
     61             environ={'REQUEST_METHOD': 'POST',
     62                      'CONTENT_TYPE': self.headers['Content-Type']})
     63         # You'd think form.keys() would just return [], like it does for empty
     64         # python dicts; you'd be wrong. It raises TypeError if called when it
     65         # has no keys.
     66         if form:
     67             for field in form.keys():
     68                 field_item = form[field]
     69                 self.server._form_entries[field] = field_item.value
     70         path = urlparse.urlparse(self.path)[2]
     71         if path in self.server._url_handlers:
     72             self.server._url_handlers[path](self, form)
     73         else:
     74             # Echo back information about what was posted in the form.
     75             self.write_post_response(form)
     76         self._fire_event()
     77 
     78 
     79     def write_post_response(self, form):
     80         """Called to fill out the response to an HTTP POST.
     81 
     82         Override this class to give custom responses.
     83         """
     84         # Send response boilerplate
     85         self.send_response(200)
     86         self.end_headers()
     87         self.wfile.write('Hello from Autotest!\nClient: %s\n' %
     88                          str(self.client_address))
     89         self.wfile.write('Request for path: %s\n' % self.path)
     90         self.wfile.write('Got form data:\n')
     91 
     92         # See the note in do_POST about form.keys().
     93         if form:
     94             for field in form.keys():
     95                 field_item = form[field]
     96                 if field_item.filename:
     97                     # The field contains an uploaded file
     98                     upload = field_item.file.read()
     99                     self.wfile.write('\tUploaded %s (%d bytes)<br>' %
    100                                      (field, len(upload)))
    101                     # Write submitted file to specified filename.
    102                     file(field_item.filename, 'w').write(upload)
    103                     del upload
    104                 else:
    105                     self.wfile.write('\t%s=%s<br>' % (field, form[field].value))
    106 
    107 
    108     def translate_path(self, path):
    109         """Override SimpleHTTPRequestHandler's translate_path to serve
    110         from arbitrary docroot
    111         """
    112         # abandon query parameters
    113         path = urlparse.urlparse(path)[2]
    114         path = posixpath.normpath(urllib.unquote(path))
    115         words = path.split('/')
    116         words = filter(None, words)
    117         path = self.server.docroot
    118         for word in words:
    119             drive, word = os.path.splitdrive(word)
    120             head, word = os.path.split(word)
    121             if word in (os.curdir, os.pardir): continue
    122             path = os.path.join(path, word)
    123         logging.debug('Translated path: %s', path)
    124         return path
    125 
    126 
    127     def _fire_event(self):
    128         wait_urls = self.server._wait_urls
    129         if self.path in wait_urls:
    130             _, e = wait_urls[self.path]
    131             e.set()
    132             del wait_urls[self.path]
    133         else:
    134           if self.path not in self.server._urls:
    135               # if the url is not in _urls, this means it was neither setup
    136               # as a permanent, or event url.
    137               logging.debug('URL %s not in watch list' % self.path)
    138 
    139 
    140     @_handle_http_errors
    141     def do_GET(self):
    142         form = cgi.FieldStorage(
    143             fp=self.rfile,
    144             headers=self.headers,
    145             environ={'REQUEST_METHOD': 'GET'})
    146         split_url = urlparse.urlsplit(self.path)
    147         path = split_url[2]
    148         # Strip off query parameters to ensure that the url path
    149         # matches any registered events.
    150         self.path = path
    151         args = urlparse.parse_qs(split_url[3])
    152         if path in self.server._url_handlers:
    153             self.server._url_handlers[path](self, args)
    154         else:
    155             SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
    156         self._fire_event()
    157 
    158 
    159     @_handle_http_errors
    160     def do_HEAD(self):
    161         SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self)
    162 
    163 
    164 class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
    165     def __init__(self, server_address, HandlerClass):
    166         HTTPServer.__init__(self, server_address, HandlerClass)
    167 
    168 
    169 class HTTPListener(object):
    170     # Point default docroot to a non-existent directory (instead of None) to
    171     # avoid exceptions when page content is served through handlers only.
    172     def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}):
    173         self._server = ThreadedHTTPServer(('', port), FormHandler)
    174         self.config_server(self._server, docroot, wait_urls, url_handlers)
    175 
    176     def config_server(self, server, docroot, wait_urls, url_handlers):
    177         # Stuff some convenient data fields into the server object.
    178         self._server.docroot = docroot
    179         self._server._urls = set()
    180         self._server._wait_urls = wait_urls
    181         self._server._url_handlers = url_handlers
    182         self._server._form_entries = {}
    183         self._server_thread = threading.Thread(
    184             target=self._server.serve_forever)
    185 
    186     def add_url(self, url):
    187         """
    188           Add a url to the urls that the http server is actively watching for.
    189 
    190           Not adding a url via add_url or add_wait_url, and only installing a
    191           handler will still result in that handler being executed, but this
    192           server will warn in the debug logs that it does not expect that url.
    193 
    194           Args:
    195             url (string): url suffix to listen to
    196         """
    197         self._server._urls.add(url)
    198 
    199     def add_wait_url(self, url='/', matchParams={}):
    200         """
    201           Add a wait url to the urls that the http server is aware of.
    202 
    203           Not adding a url via add_url or add_wait_url, and only installing a
    204           handler will still result in that handler being executed, but this
    205           server will warn in the debug logs that it does not expect that url.
    206 
    207           Args:
    208             url (string): url suffix to listen to
    209             matchParams (dictionary): an unused dictionary
    210 
    211           Returns:
    212             e, and event object. Call e.wait() on the object to wait (block)
    213             until the server receives the first request for the wait url.
    214 
    215         """
    216         e = threading.Event()
    217         self._server._wait_urls[url] = (matchParams, e)
    218         self._server._urls.add(url)
    219         return e
    220 
    221     def add_url_handler(self, url, handler_func):
    222         self._server._url_handlers[url] = handler_func
    223 
    224     def clear_form_entries(self):
    225         self._server._form_entries = {}
    226 
    227 
    228     def get_form_entries(self):
    229         """Returns a dictionary of all field=values recieved by the server.
    230         """
    231         return self._server._form_entries
    232 
    233 
    234     def run(self):
    235         logging.debug('http server on %s:%d' %
    236                       (self._server.server_name, self._server.server_port))
    237         self._server_thread.start()
    238 
    239 
    240     def stop(self):
    241         self._server.shutdown()
    242         self._server.socket.close()
    243         self._server_thread.join()
    244 
    245 
    246 class SecureHTTPServer(ThreadingMixIn, HTTPServer):
    247     def __init__(self, server_address, HandlerClass, cert_path, key_path):
    248         _socket = socket.socket(self.address_family, self.socket_type)
    249         self.socket = ssl.wrap_socket(_socket,
    250                                       server_side=True,
    251                                       ssl_version=ssl.PROTOCOL_TLSv1,
    252                                       certfile=cert_path,
    253                                       keyfile=key_path)
    254         BaseServer.__init__(self, server_address, HandlerClass)
    255         self.server_bind()
    256         self.server_activate()
    257 
    258 
    259 class SecureHTTPRequestHandler(FormHandler):
    260     def setup(self):
    261         self.connection = self.request
    262         self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize)
    263         self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize)
    264 
    265     # Override the default logging methods to use the logging module directly.
    266     def log_error(self, format, *args):
    267         logging.warning("(httpd error) %s - - [%s] %s\n" %
    268                      (self.address_string(), self.log_date_time_string(),
    269                       format%args))
    270 
    271     def log_message(self, format, *args):
    272         logging.debug("%s - - [%s] %s\n" %
    273                      (self.address_string(), self.log_date_time_string(),
    274                       format%args))
    275 
    276 
    277 class SecureHTTPListener(HTTPListener):
    278     def __init__(self,
    279                  cert_path='/etc/login_trust_root.pem',
    280                  key_path='/etc/mock_server.key',
    281                  port=0,
    282                  docroot='/_',
    283                  wait_urls={},
    284                  url_handlers={}):
    285         self._server = SecureHTTPServer(('', port),
    286                                         SecureHTTPRequestHandler,
    287                                         cert_path,
    288                                         key_path)
    289         self.config_server(self._server, docroot, wait_urls, url_handlers)
    290 
    291 
    292     def getsockname(self):
    293         return self._server.socket.getsockname()
    294 
    295