Home | History | Annotate | Download | only in cros
      1 # Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 """Spins up a trivial HTTP cgi form listener in a thread.
      6 
      7    This HTTPThread class is a utility for use with test cases that
      8    need to call back to the Autotest test case with some form value, e.g.
      9    http://localhost:nnnn/?status="Browser started!"
     10 """
     11 
     12 import cgi, errno, logging, os, posixpath, SimpleHTTPServer, socket, ssl, sys
     13 import threading, urllib, urlparse
     14 from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
     15 from SocketServer import BaseServer, ThreadingMixIn
     16 
     17 
     18 def _handle_http_errors(func):
     19     """Decorator function for cleaner presentation of certain exceptions."""
     20     def wrapper(self):
     21         try:
     22             func(self)
     23         except IOError, e:
     24             if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET:
     25                 # Instead of dumping a stack trace, a single line is sufficient.
     26                 self.log_error(str(e))
     27             else:
     28                 raise
     29 
     30     return wrapper
     31 
     32 
     33 class FormHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
     34     """Implements a form handler (for POST requests only) which simply
     35     echoes the key=value parameters back in the response.
     36 
     37     If the form submission is a file upload, the file will be written
     38     to disk with the name contained in the 'filename' field.
     39     """
     40 
     41     SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({
     42         '.webm': 'video/webm',
     43     })
     44 
     45     # Override the default logging methods to use the logging module directly.
     46     def log_error(self, format, *args):
     47         logging.warning("(httpd error) %s - - [%s] %s\n" %
     48                      (self.address_string(), self.log_date_time_string(),
     49                       format%args))
     50 
     51     def log_message(self, format, *args):
     52         logging.debug("%s - - [%s] %s\n" %
     53                      (self.address_string(), self.log_date_time_string(),
     54                       format%args))
     55 
     56     @_handle_http_errors
     57     def do_POST(self):
     58         form = cgi.FieldStorage(
     59             fp=self.rfile,
     60             headers=self.headers,
     61             environ={'REQUEST_METHOD': 'POST',
     62                      'CONTENT_TYPE': self.headers['Content-Type']})
     63         # You'd think form.keys() would just return [], like it does for empty
     64         # python dicts; you'd be wrong. It raises TypeError if called when it
     65         # has no keys.
     66         if form:
     67             for field in form.keys():
     68                 field_item = form[field]
     69                 self.server._form_entries[field] = field_item.value
     70         path = urlparse.urlparse(self.path)[2]
     71         if path in self.server._url_handlers:
     72             self.server._url_handlers[path](self, form)
     73         else:
     74             # Echo back information about what was posted in the form.
     75             self.write_post_response(form)
     76         self._fire_event()
     77 
     78 
     79     def write_post_response(self, form):
     80         """Called to fill out the response to an HTTP POST.
     81 
     82         Override this class to give custom responses.
     83         """
     84         # Send response boilerplate
     85         self.send_response(200)
     86         self.end_headers()
     87         self.wfile.write('Hello from Autotest!\nClient: %s\n' %
     88                          str(self.client_address))
     89         self.wfile.write('Request for path: %s\n' % self.path)
     90         self.wfile.write('Got form data:\n')
     91 
     92         # See the note in do_POST about form.keys().
     93         if form:
     94             for field in form.keys():
     95                 field_item = form[field]
     96                 if field_item.filename:
     97                     # The field contains an uploaded file
     98                     upload = field_item.file.read()
     99                     self.wfile.write('\tUploaded %s (%d bytes)<br>' %
    100                                      (field, len(upload)))
    101                     # Write submitted file to specified filename.
    102                     file(field_item.filename, 'w').write(upload)
    103                     del upload
    104                 else:
    105                     self.wfile.write('\t%s=%s<br>' % (field, form[field].value))
    106 
    107 
    108     def translate_path(self, path):
    109         """Override SimpleHTTPRequestHandler's translate_path to serve
    110         from arbitrary docroot
    111         """
    112         # abandon query parameters
    113         path = urlparse.urlparse(path)[2]
    114         path = posixpath.normpath(urllib.unquote(path))
    115         words = path.split('/')
    116         words = filter(None, words)
    117         path = self.server.docroot
    118         for word in words:
    119             drive, word = os.path.splitdrive(word)
    120             head, word = os.path.split(word)
    121             if word in (os.curdir, os.pardir): continue
    122             path = os.path.join(path, word)
    123         logging.debug('Translated path: %s', path)
    124         return path
    125 
    126 
    127     def _fire_event(self):
    128         wait_urls = self.server._wait_urls
    129         if self.path in wait_urls:
    130             _, e = wait_urls[self.path]
    131             e.set()
    132             del wait_urls[self.path]
    133         else:
    134             logging.debug('URL %s not in watch list' % self.path)
    135 
    136 
    137     @_handle_http_errors
    138     def do_GET(self):
    139         form = cgi.FieldStorage(
    140             fp=self.rfile,
    141             headers=self.headers,
    142             environ={'REQUEST_METHOD': 'GET'})
    143         split_url = urlparse.urlsplit(self.path)
    144         path = split_url[2]
    145         # Strip off query parameters to ensure that the url path
    146         # matches any registered events.
    147         self.path = path
    148         args = urlparse.parse_qs(split_url[3])
    149         if path in self.server._url_handlers:
    150             self.server._url_handlers[path](self, args)
    151         else:
    152             SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
    153         self._fire_event()
    154 
    155 
    156     @_handle_http_errors
    157     def do_HEAD(self):
    158         SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self)
    159 
    160 
    161 class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
    162     def __init__(self, server_address, HandlerClass):
    163         HTTPServer.__init__(self, server_address, HandlerClass)
    164 
    165 
    166 class HTTPListener(object):
    167     # Point default docroot to a non-existent directory (instead of None) to
    168     # avoid exceptions when page content is served through handlers only.
    169     def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}):
    170         self._server = ThreadedHTTPServer(('', port), FormHandler)
    171         self.config_server(self._server, docroot, wait_urls, url_handlers)
    172 
    173     def config_server(self, server, docroot, wait_urls, url_handlers):
    174         # Stuff some convenient data fields into the server object.
    175         self._server.docroot = docroot
    176         self._server._wait_urls = wait_urls
    177         self._server._url_handlers = url_handlers
    178         self._server._form_entries = {}
    179         self._server_thread = threading.Thread(
    180             target=self._server.serve_forever)
    181 
    182 
    183     def add_wait_url(self, url='/', matchParams={}):
    184         e = threading.Event()
    185         self._server._wait_urls[url] = (matchParams, e)
    186         return e
    187 
    188 
    189     def add_url_handler(self, url, handler_func):
    190         self._server._url_handlers[url] = handler_func
    191 
    192 
    193     def clear_form_entries(self):
    194         self._server._form_entries = {}
    195 
    196 
    197     def get_form_entries(self):
    198         """Returns a dictionary of all field=values recieved by the server.
    199         """
    200         return self._server._form_entries
    201 
    202 
    203     def run(self):
    204         logging.debug('http server on %s:%d' %
    205                       (self._server.server_name, self._server.server_port))
    206         self._server_thread.start()
    207 
    208 
    209     def stop(self):
    210         self._server.shutdown()
    211         self._server.socket.close()
    212         self._server_thread.join()
    213 
    214 
    215 class SecureHTTPServer(ThreadingMixIn, HTTPServer):
    216     def __init__(self, server_address, HandlerClass, cert_path, key_path):
    217         _socket = socket.socket(self.address_family, self.socket_type)
    218         self.socket = ssl.wrap_socket(_socket,
    219                                       server_side=True,
    220                                       ssl_version=ssl.PROTOCOL_TLSv1,
    221                                       certfile=cert_path,
    222                                       keyfile=key_path)
    223         BaseServer.__init__(self, server_address, HandlerClass)
    224         self.server_bind()
    225         self.server_activate()
    226 
    227 
    228 class SecureHTTPRequestHandler(FormHandler):
    229     def setup(self):
    230         self.connection = self.request
    231         self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize)
    232         self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize)
    233 
    234     # Override the default logging methods to use the logging module directly.
    235     def log_error(self, format, *args):
    236         logging.warning("(httpd error) %s - - [%s] %s\n" %
    237                      (self.address_string(), self.log_date_time_string(),
    238                       format%args))
    239 
    240     def log_message(self, format, *args):
    241         logging.debug("%s - - [%s] %s\n" %
    242                      (self.address_string(), self.log_date_time_string(),
    243                       format%args))
    244 
    245 
    246 class SecureHTTPListener(HTTPListener):
    247     def __init__(self,
    248                  cert_path='/etc/login_trust_root.pem',
    249                  key_path='/etc/mock_server.key',
    250                  port=0,
    251                  docroot='/_',
    252                  wait_urls={},
    253                  url_handlers={}):
    254         self._server = SecureHTTPServer(('', port),
    255                                         SecureHTTPRequestHandler,
    256                                         cert_path,
    257                                         key_path)
    258         self.config_server(self._server, docroot, wait_urls, url_handlers)
    259 
    260 
    261     def getsockname(self):
    262         return self._server.socket.getsockname()
    263 
    264