Home | History | Annotate | Download | only in test
      1 import os
      2 import sys
      3 import ssl
      4 import pprint
      5 import urllib
      6 import urlparse
      7 # Rename HTTPServer to _HTTPServer so as to avoid confusion with HTTPSServer.
      8 from BaseHTTPServer import HTTPServer as _HTTPServer, BaseHTTPRequestHandler
      9 from SimpleHTTPServer import SimpleHTTPRequestHandler
     10 
     11 from test import test_support as support
     12 threading = support.import_module("threading")
     13 
     14 here = os.path.dirname(__file__)
     15 
     16 HOST = support.HOST
     17 CERTFILE = os.path.join(here, 'keycert.pem')
     18 
     19 # This one's based on HTTPServer, which is based on SocketServer
     20 
     21 class HTTPSServer(_HTTPServer):
     22 
     23     def __init__(self, server_address, handler_class, context):
     24         _HTTPServer.__init__(self, server_address, handler_class)
     25         self.context = context
     26 
     27     def __str__(self):
     28         return ('<%s %s:%s>' %
     29                 (self.__class__.__name__,
     30                  self.server_name,
     31                  self.server_port))
     32 
     33     def get_request(self):
     34         # override this to wrap socket with SSL
     35         try:
     36             sock, addr = self.socket.accept()
     37             sslconn = self.context.wrap_socket(sock, server_side=True)
     38         except OSError as e:
     39             # socket errors are silenced by the caller, print them here
     40             if support.verbose:
     41                 sys.stderr.write("Got an error:\n%s\n" % e)
     42             raise
     43         return sslconn, addr
     44 
     45 class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
     46     # need to override translate_path to get a known root,
     47     # instead of using os.curdir, since the test could be
     48     # run from anywhere
     49 
     50     server_version = "TestHTTPS/1.0"
     51     root = here
     52     # Avoid hanging when a request gets interrupted by the client
     53     timeout = 5
     54 
     55     def translate_path(self, path):
     56         """Translate a /-separated PATH to the local filename syntax.
     57 
     58         Components that mean special things to the local file system
     59         (e.g. drive or directory names) are ignored.  (XXX They should
     60         probably be diagnosed.)
     61 
     62         """
     63         # abandon query parameters
     64         path = urlparse.urlparse(path)[2]
     65         path = os.path.normpath(urllib.unquote(path))
     66         words = path.split('/')
     67         words = filter(None, words)
     68         path = self.root
     69         for word in words:
     70             drive, word = os.path.splitdrive(word)
     71             head, word = os.path.split(word)
     72             path = os.path.join(path, word)
     73         return path
     74 
     75     def log_message(self, format, *args):
     76         # we override this to suppress logging unless "verbose"
     77         if support.verbose:
     78             sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" %
     79                              (self.server.server_address,
     80                               self.server.server_port,
     81                               self.request.cipher(),
     82                               self.log_date_time_string(),
     83                               format%args))
     84 
     85 
     86 class StatsRequestHandler(BaseHTTPRequestHandler):
     87     """Example HTTP request handler which returns SSL statistics on GET
     88     requests.
     89     """
     90 
     91     server_version = "StatsHTTPS/1.0"
     92 
     93     def do_GET(self, send_body=True):
     94         """Serve a GET request."""
     95         sock = self.rfile.raw._sock
     96         context = sock.context
     97         stats = {
     98             'session_cache': context.session_stats(),
     99             'cipher': sock.cipher(),
    100             'compression': sock.compression(),
    101             }
    102         body = pprint.pformat(stats)
    103         body = body.encode('utf-8')
    104         self.send_response(200)
    105         self.send_header("Content-type", "text/plain; charset=utf-8")
    106         self.send_header("Content-Length", str(len(body)))
    107         self.end_headers()
    108         if send_body:
    109             self.wfile.write(body)
    110 
    111     def do_HEAD(self):
    112         """Serve a HEAD request."""
    113         self.do_GET(send_body=False)
    114 
    115     def log_request(self, format, *args):
    116         if support.verbose:
    117             BaseHTTPRequestHandler.log_request(self, format, *args)
    118 
    119 
    120 class HTTPSServerThread(threading.Thread):
    121 
    122     def __init__(self, context, host=HOST, handler_class=None):
    123         self.flag = None
    124         self.server = HTTPSServer((host, 0),
    125                                   handler_class or RootedHTTPRequestHandler,
    126                                   context)
    127         self.port = self.server.server_port
    128         threading.Thread.__init__(self)
    129         self.daemon = True
    130 
    131     def __str__(self):
    132         return "<%s %s>" % (self.__class__.__name__, self.server)
    133 
    134     def start(self, flag=None):
    135         self.flag = flag
    136         threading.Thread.start(self)
    137 
    138     def run(self):
    139         if self.flag:
    140             self.flag.set()
    141         try:
    142             self.server.serve_forever(0.05)
    143         finally:
    144             self.server.server_close()
    145 
    146     def stop(self):
    147         self.server.shutdown()
    148 
    149 
    150 def make_https_server(case, context=None, certfile=CERTFILE,
    151                       host=HOST, handler_class=None):
    152     if context is None:
    153         context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
    154     # We assume the certfile contains both private key and certificate
    155     context.load_cert_chain(certfile)
    156     server = HTTPSServerThread(context, host, handler_class)
    157     flag = threading.Event()
    158     server.start(flag)
    159     flag.wait()
    160     def cleanup():
    161         if support.verbose:
    162             sys.stdout.write('stopping HTTPS server\n')
    163         server.stop()
    164         if support.verbose:
    165             sys.stdout.write('joining HTTPS thread\n')
    166         server.join()
    167     case.addCleanup(cleanup)
    168     return server
    169 
    170 
    171 if __name__ == "__main__":
    172     import argparse
    173     parser = argparse.ArgumentParser(
    174         description='Run a test HTTPS server. '
    175                     'By default, the current directory is served.')
    176     parser.add_argument('-p', '--port', type=int, default=4433,
    177                         help='port to listen on (default: %(default)s)')
    178     parser.add_argument('-q', '--quiet', dest='verbose', default=True,
    179                         action='store_false', help='be less verbose')
    180     parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False,
    181                         action='store_true', help='always return stats page')
    182     parser.add_argument('--curve-name', dest='curve_name', type=str,
    183                         action='store',
    184                         help='curve name for EC-based Diffie-Hellman')
    185     parser.add_argument('--ciphers', dest='ciphers', type=str,
    186                         help='allowed cipher list')
    187     parser.add_argument('--dh', dest='dh_file', type=str, action='store',
    188                         help='PEM file containing DH parameters')
    189     args = parser.parse_args()
    190 
    191     support.verbose = args.verbose
    192     if args.use_stats_handler:
    193         handler_class = StatsRequestHandler
    194     else:
    195         handler_class = RootedHTTPRequestHandler
    196         handler_class.root = os.getcwd()
    197     context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
    198     context.load_cert_chain(CERTFILE)
    199     if args.curve_name:
    200         context.set_ecdh_curve(args.curve_name)
    201     if args.dh_file:
    202         context.load_dh_params(args.dh_file)
    203     if args.ciphers:
    204         context.set_ciphers(args.ciphers)
    205 
    206     server = HTTPSServer(("", args.port), handler_class, context)
    207     if args.verbose:
    208         print("Listening on https://localhost:{0.port}".format(args))
    209     server.serve_forever(0.1)
    210