1 # Copyright 2017 gRPC authors. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 """Tests exposure of SSL auth context""" 15 16 import pickle 17 import unittest 18 19 import grpc 20 from grpc import _channel 21 from grpc.experimental import session_cache 22 import six 23 24 from tests.unit import test_common 25 from tests.unit import resources 26 27 _REQUEST = b'\x00\x00\x00' 28 _RESPONSE = b'\x00\x00\x00' 29 30 _UNARY_UNARY = '/test/UnaryUnary' 31 32 _SERVER_HOST_OVERRIDE = 'foo.test.google.fr' 33 _CLIENT_IDS = ( 34 b'*.test.google.fr', 35 b'waterzooi.test.google.be', 36 b'*.test.youtube.com', 37 b'192.168.1.3', 38 ) 39 _ID = 'id' 40 _ID_KEY = 'id_key' 41 _AUTH_CTX = 'auth_ctx' 42 43 _PRIVATE_KEY = resources.private_key() 44 _CERTIFICATE_CHAIN = resources.certificate_chain() 45 _TEST_ROOT_CERTIFICATES = resources.test_root_certificates() 46 _SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) 47 _PROPERTY_OPTIONS = (( 48 'grpc.ssl_target_name_override', 49 _SERVER_HOST_OVERRIDE, 50 ),) 51 52 53 def handle_unary_unary(request, servicer_context): 54 return pickle.dumps({ 55 _ID: servicer_context.peer_identities(), 56 _ID_KEY: servicer_context.peer_identity_key(), 57 _AUTH_CTX: servicer_context.auth_context() 58 }) 59 60 61 class AuthContextTest(unittest.TestCase): 62 63 def testInsecure(self): 64 handler = grpc.method_handlers_generic_handler('test', { 65 'UnaryUnary': 66 grpc.unary_unary_rpc_method_handler(handle_unary_unary) 67 }) 68 server = test_common.test_server() 69 server.add_generic_rpc_handlers((handler,)) 70 port = server.add_insecure_port('[::]:0') 71 server.start() 72 73 channel = grpc.insecure_channel('localhost:%d' % port) 74 response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) 75 server.stop(None) 76 77 auth_data = pickle.loads(response) 78 self.assertIsNone(auth_data[_ID]) 79 self.assertIsNone(auth_data[_ID_KEY]) 80 self.assertDictEqual({}, auth_data[_AUTH_CTX]) 81 82 def testSecureNoCert(self): 83 handler = grpc.method_handlers_generic_handler('test', { 84 'UnaryUnary': 85 grpc.unary_unary_rpc_method_handler(handle_unary_unary) 86 }) 87 server = test_common.test_server() 88 server.add_generic_rpc_handlers((handler,)) 89 server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) 90 port = server.add_secure_port('[::]:0', server_cred) 91 server.start() 92 93 channel_creds = grpc.ssl_channel_credentials( 94 root_certificates=_TEST_ROOT_CERTIFICATES) 95 channel = grpc.secure_channel( 96 'localhost:{}'.format(port), 97 channel_creds, 98 options=_PROPERTY_OPTIONS) 99 response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) 100 server.stop(None) 101 102 auth_data = pickle.loads(response) 103 self.assertIsNone(auth_data[_ID]) 104 self.assertIsNone(auth_data[_ID_KEY]) 105 self.assertDictEqual({ 106 'transport_security_type': [b'ssl'], 107 'ssl_session_reused': [b'false'], 108 }, auth_data[_AUTH_CTX]) 109 110 def testSecureClientCert(self): 111 handler = grpc.method_handlers_generic_handler('test', { 112 'UnaryUnary': 113 grpc.unary_unary_rpc_method_handler(handle_unary_unary) 114 }) 115 server = test_common.test_server() 116 server.add_generic_rpc_handlers((handler,)) 117 server_cred = grpc.ssl_server_credentials( 118 _SERVER_CERTS, 119 root_certificates=_TEST_ROOT_CERTIFICATES, 120 require_client_auth=True) 121 port = server.add_secure_port('[::]:0', server_cred) 122 server.start() 123 124 channel_creds = grpc.ssl_channel_credentials( 125 root_certificates=_TEST_ROOT_CERTIFICATES, 126 private_key=_PRIVATE_KEY, 127 certificate_chain=_CERTIFICATE_CHAIN) 128 channel = grpc.secure_channel( 129 'localhost:{}'.format(port), 130 channel_creds, 131 options=_PROPERTY_OPTIONS) 132 133 response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) 134 server.stop(None) 135 136 auth_data = pickle.loads(response) 137 auth_ctx = auth_data[_AUTH_CTX] 138 six.assertCountEqual(self, _CLIENT_IDS, auth_data[_ID]) 139 self.assertEqual('x509_subject_alternative_name', auth_data[_ID_KEY]) 140 self.assertSequenceEqual([b'ssl'], auth_ctx['transport_security_type']) 141 self.assertSequenceEqual([b'*.test.google.com'], 142 auth_ctx['x509_common_name']) 143 144 def _do_one_shot_client_rpc(self, channel_creds, channel_options, port, 145 expect_ssl_session_reused): 146 channel = grpc.secure_channel( 147 'localhost:{}'.format(port), channel_creds, options=channel_options) 148 response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) 149 auth_data = pickle.loads(response) 150 self.assertEqual(expect_ssl_session_reused, 151 auth_data[_AUTH_CTX]['ssl_session_reused']) 152 channel.close() 153 154 def testSessionResumption(self): 155 # Set up a secure server 156 handler = grpc.method_handlers_generic_handler('test', { 157 'UnaryUnary': 158 grpc.unary_unary_rpc_method_handler(handle_unary_unary) 159 }) 160 server = test_common.test_server() 161 server.add_generic_rpc_handlers((handler,)) 162 server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) 163 port = server.add_secure_port('[::]:0', server_cred) 164 server.start() 165 166 # Create a cache for TLS session tickets 167 cache = session_cache.ssl_session_cache_lru(1) 168 channel_creds = grpc.ssl_channel_credentials( 169 root_certificates=_TEST_ROOT_CERTIFICATES) 170 channel_options = _PROPERTY_OPTIONS + ( 171 ('grpc.ssl_session_cache', cache),) 172 173 # Initial connection has no session to resume 174 self._do_one_shot_client_rpc( 175 channel_creds, 176 channel_options, 177 port, 178 expect_ssl_session_reused=[b'false']) 179 180 # Subsequent connections resume sessions 181 self._do_one_shot_client_rpc( 182 channel_creds, 183 channel_options, 184 port, 185 expect_ssl_session_reused=[b'true']) 186 server.stop(None) 187 188 189 if __name__ == '__main__': 190 unittest.main(verbosity=2) 191