Home | History | Annotate | Download | only in web-page-replay
      1 # Copyright 2014 Google Inc. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #      http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 
     15 """Test routines to generate dummy certificates."""
     16 
     17 import BaseHTTPServer
     18 import shutil
     19 import signal
     20 import socket
     21 import tempfile
     22 import threading
     23 import time
     24 import unittest
     25 
     26 import certutils
     27 import sslproxy
     28 
     29 
     30 class Client(object):
     31 
     32   def __init__(self, ca_cert_path, verify_cb, port, host_name='foo.com',
     33                host='localhost'):
     34     self.host_name = host_name
     35     self.verify_cb = verify_cb
     36     self.ca_cert_path = ca_cert_path
     37     self.port = port
     38     self.host_name = host_name
     39     self.host = host
     40     self.connection = None
     41 
     42   def run_request(self):
     43     context = certutils.get_ssl_context()
     44     context.set_verify(certutils.VERIFY_PEER, self.verify_cb)  # Demand a cert
     45     context.use_certificate_file(self.ca_cert_path)
     46     context.load_verify_locations(self.ca_cert_path)
     47 
     48     s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     49     self.connection = certutils.get_ssl_connection(context, s)
     50     self.connection.connect((self.host, self.port))
     51     self.connection.set_tlsext_host_name(self.host_name)
     52 
     53     try:
     54       self.connection.send('\r\n\r\n')
     55     finally:
     56       self.connection.shutdown()
     57       self.connection.close()
     58 
     59 
     60 class Handler(BaseHTTPServer.BaseHTTPRequestHandler):
     61   protocol_version = 'HTTP/1.1'  # override BaseHTTPServer setting
     62 
     63   def handle_one_request(self):
     64     """Handle a single HTTP request."""
     65     self.raw_requestline = self.rfile.readline(65537)
     66 
     67 
     68 class WrappedErrorHandler(Handler):
     69   """Wraps handler to verify expected sslproxy errors are being raised."""
     70 
     71   def setup(self):
     72     Handler.setup(self)
     73     try:
     74       sslproxy._SetUpUsingDummyCert(self)
     75     except certutils.Error:
     76       self.server.error_function = certutils.Error
     77 
     78   def finish(self):
     79     Handler.finish(self)
     80     self.connection.shutdown()
     81     self.connection.close()
     82 
     83 
     84 class DummyArchive(object):
     85 
     86   def __init__(self):
     87     pass
     88 
     89 
     90 class DummyFetch(object):
     91 
     92   def __init__(self):
     93     self.http_archive = DummyArchive()
     94 
     95 
     96 class Server(BaseHTTPServer.HTTPServer):
     97   """SSL server."""
     98 
     99   def __init__(self, ca_cert_path, use_error_handler=False, port=0,
    100                host='localhost'):
    101     self.ca_cert_path = ca_cert_path
    102     with open(ca_cert_path, 'r') as ca_file:
    103       self.ca_cert_str = ca_file.read()
    104     self.http_archive_fetch = DummyFetch()
    105     if use_error_handler:
    106       self.HANDLER = WrappedErrorHandler
    107     else:
    108       self.HANDLER = sslproxy.wrap_handler(Handler)
    109     try:
    110       BaseHTTPServer.HTTPServer.__init__(self, (host, port), self.HANDLER)
    111     except Exception, e:
    112       raise RuntimeError('Could not start HTTPSServer on port %d: %s'
    113                          % (port, e))
    114 
    115   def __enter__(self):
    116     thread = threading.Thread(target=self.serve_forever)
    117     thread.daemon = True
    118     thread.start()
    119     return self
    120 
    121   def cleanup(self):
    122     try:
    123       self.shutdown()
    124     except KeyboardInterrupt:
    125       pass
    126 
    127   def __exit__(self, type_, value_, traceback_):
    128     self.cleanup()
    129 
    130   def get_certificate(self, host):
    131     return certutils.generate_cert(self.ca_cert_str, '', host)
    132 
    133 
    134 class TestClient(unittest.TestCase):
    135   _temp_dir = None
    136 
    137   def setUp(self):
    138     self._temp_dir = tempfile.mkdtemp(prefix='sslproxy_', dir='/tmp')
    139     self.ca_cert_path = self._temp_dir + 'testCA.pem'
    140     self.cert_path = self._temp_dir + 'testCA-cert.cer'
    141     self.wrong_ca_cert_path = self._temp_dir + 'wrong.pem'
    142     self.wrong_cert_path = self._temp_dir + 'wrong-cert.cer'
    143 
    144     # Write both pem and cer files for certificates
    145     certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(),
    146                                   cert_path=self.ca_cert_path)
    147     certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(),
    148                                   cert_path=self.ca_cert_path)
    149 
    150   def tearDown(self):
    151     if self._temp_dir:
    152       shutil.rmtree(self._temp_dir)
    153 
    154   def verify_cb(self, conn, cert, errnum, depth, ok):
    155     """A callback that verifies the certificate authentication worked.
    156 
    157     Args:
    158       conn: Connection object
    159       cert: x509 object
    160       errnum: possible error number
    161       depth: error depth
    162       ok: 1 if the authentication worked 0 if it didnt.
    163     Returns:
    164       1 or 0 depending on if the verification worked
    165     """
    166     self.assertFalse(cert.has_expired())
    167     self.assertGreater(time.strftime('%Y%m%d%H%M%SZ', time.gmtime()),
    168                        cert.get_notBefore())
    169     return ok
    170 
    171   def test_no_host(self):
    172     with Server(self.ca_cert_path) as server:
    173       c = Client(self.cert_path, self.verify_cb, server.server_port, '')
    174       self.assertRaises(certutils.Error, c.run_request)
    175 
    176   def test_client_connection(self):
    177     with Server(self.ca_cert_path) as server:
    178       c = Client(self.cert_path, self.verify_cb, server.server_port, 'foo.com')
    179       c.run_request()
    180 
    181       c = Client(self.cert_path, self.verify_cb, server.server_port,
    182                  'random.host')
    183       c.run_request()
    184 
    185   def test_wrong_cert(self):
    186     with Server(self.ca_cert_path, True) as server:
    187       c = Client(self.wrong_cert_path, self.verify_cb, server.server_port,
    188                  'foo.com')
    189       self.assertRaises(certutils.Error, c.run_request)
    190 
    191 
    192 if __name__ == '__main__':
    193   signal.signal(signal.SIGINT, signal.SIG_DFL)  # Exit on Ctrl-C
    194   unittest.main()
    195