1 # Copyright 2014 Google Inc. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Test routines to generate dummy certificates.""" 16 17 import BaseHTTPServer 18 import shutil 19 import signal 20 import socket 21 import tempfile 22 import threading 23 import time 24 import unittest 25 26 import certutils 27 import sslproxy 28 29 30 class Client(object): 31 32 def __init__(self, ca_cert_path, verify_cb, port, host_name='foo.com', 33 host='localhost'): 34 self.host_name = host_name 35 self.verify_cb = verify_cb 36 self.ca_cert_path = ca_cert_path 37 self.port = port 38 self.host_name = host_name 39 self.host = host 40 self.connection = None 41 42 def run_request(self): 43 context = certutils.get_ssl_context() 44 context.set_verify(certutils.VERIFY_PEER, self.verify_cb) # Demand a cert 45 context.use_certificate_file(self.ca_cert_path) 46 context.load_verify_locations(self.ca_cert_path) 47 48 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 49 self.connection = certutils.get_ssl_connection(context, s) 50 self.connection.connect((self.host, self.port)) 51 self.connection.set_tlsext_host_name(self.host_name) 52 53 try: 54 self.connection.send('\r\n\r\n') 55 finally: 56 self.connection.shutdown() 57 self.connection.close() 58 59 60 class Handler(BaseHTTPServer.BaseHTTPRequestHandler): 61 protocol_version = 'HTTP/1.1' # override BaseHTTPServer setting 62 63 def handle_one_request(self): 64 """Handle a single HTTP request.""" 65 self.raw_requestline = self.rfile.readline(65537) 66 67 68 class WrappedErrorHandler(Handler): 69 """Wraps handler to verify expected sslproxy errors are being raised.""" 70 71 def setup(self): 72 Handler.setup(self) 73 try: 74 sslproxy._SetUpUsingDummyCert(self) 75 except certutils.Error: 76 self.server.error_function = certutils.Error 77 78 def finish(self): 79 Handler.finish(self) 80 self.connection.shutdown() 81 self.connection.close() 82 83 84 class DummyArchive(object): 85 86 def __init__(self): 87 pass 88 89 90 class DummyFetch(object): 91 92 def __init__(self): 93 self.http_archive = DummyArchive() 94 95 96 class Server(BaseHTTPServer.HTTPServer): 97 """SSL server.""" 98 99 def __init__(self, ca_cert_path, use_error_handler=False, port=0, 100 host='localhost'): 101 self.ca_cert_path = ca_cert_path 102 with open(ca_cert_path, 'r') as ca_file: 103 self.ca_cert_str = ca_file.read() 104 self.http_archive_fetch = DummyFetch() 105 if use_error_handler: 106 self.HANDLER = WrappedErrorHandler 107 else: 108 self.HANDLER = sslproxy.wrap_handler(Handler) 109 try: 110 BaseHTTPServer.HTTPServer.__init__(self, (host, port), self.HANDLER) 111 except Exception, e: 112 raise RuntimeError('Could not start HTTPSServer on port %d: %s' 113 % (port, e)) 114 115 def __enter__(self): 116 thread = threading.Thread(target=self.serve_forever) 117 thread.daemon = True 118 thread.start() 119 return self 120 121 def cleanup(self): 122 try: 123 self.shutdown() 124 except KeyboardInterrupt: 125 pass 126 127 def __exit__(self, type_, value_, traceback_): 128 self.cleanup() 129 130 def get_certificate(self, host): 131 return certutils.generate_cert(self.ca_cert_str, '', host) 132 133 134 class TestClient(unittest.TestCase): 135 _temp_dir = None 136 137 def setUp(self): 138 self._temp_dir = tempfile.mkdtemp(prefix='sslproxy_', dir='/tmp') 139 self.ca_cert_path = self._temp_dir + 'testCA.pem' 140 self.cert_path = self._temp_dir + 'testCA-cert.cer' 141 self.wrong_ca_cert_path = self._temp_dir + 'wrong.pem' 142 self.wrong_cert_path = self._temp_dir + 'wrong-cert.cer' 143 144 # Write both pem and cer files for certificates 145 certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(), 146 cert_path=self.ca_cert_path) 147 certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(), 148 cert_path=self.ca_cert_path) 149 150 def tearDown(self): 151 if self._temp_dir: 152 shutil.rmtree(self._temp_dir) 153 154 def verify_cb(self, conn, cert, errnum, depth, ok): 155 """A callback that verifies the certificate authentication worked. 156 157 Args: 158 conn: Connection object 159 cert: x509 object 160 errnum: possible error number 161 depth: error depth 162 ok: 1 if the authentication worked 0 if it didnt. 163 Returns: 164 1 or 0 depending on if the verification worked 165 """ 166 self.assertFalse(cert.has_expired()) 167 self.assertGreater(time.strftime('%Y%m%d%H%M%SZ', time.gmtime()), 168 cert.get_notBefore()) 169 return ok 170 171 def test_no_host(self): 172 with Server(self.ca_cert_path) as server: 173 c = Client(self.cert_path, self.verify_cb, server.server_port, '') 174 self.assertRaises(certutils.Error, c.run_request) 175 176 def test_client_connection(self): 177 with Server(self.ca_cert_path) as server: 178 c = Client(self.cert_path, self.verify_cb, server.server_port, 'foo.com') 179 c.run_request() 180 181 c = Client(self.cert_path, self.verify_cb, server.server_port, 182 'random.host') 183 c.run_request() 184 185 def test_wrong_cert(self): 186 with Server(self.ca_cert_path, True) as server: 187 c = Client(self.wrong_cert_path, self.verify_cb, server.server_port, 188 'foo.com') 189 self.assertRaises(certutils.Error, c.run_request) 190 191 192 if __name__ == '__main__': 193 signal.signal(signal.SIGINT, signal.SIG_DFL) # Exit on Ctrl-C 194 unittest.main() 195