Home | History | Annotate | Download | only in unit
      1 # Copyright 2017 gRPC authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """Tests exposure of SSL auth context"""
     15 
     16 import pickle
     17 import unittest
     18 
     19 import grpc
     20 from grpc import _channel
     21 from grpc.experimental import session_cache
     22 import six
     23 
     24 from tests.unit import test_common
     25 from tests.unit import resources
     26 
     27 _REQUEST = b'\x00\x00\x00'
     28 _RESPONSE = b'\x00\x00\x00'
     29 
     30 _UNARY_UNARY = '/test/UnaryUnary'
     31 
     32 _SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
     33 _CLIENT_IDS = (
     34     b'*.test.google.fr',
     35     b'waterzooi.test.google.be',
     36     b'*.test.youtube.com',
     37     b'192.168.1.3',
     38 )
     39 _ID = 'id'
     40 _ID_KEY = 'id_key'
     41 _AUTH_CTX = 'auth_ctx'
     42 
     43 _PRIVATE_KEY = resources.private_key()
     44 _CERTIFICATE_CHAIN = resources.certificate_chain()
     45 _TEST_ROOT_CERTIFICATES = resources.test_root_certificates()
     46 _SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),)
     47 _PROPERTY_OPTIONS = ((
     48     'grpc.ssl_target_name_override',
     49     _SERVER_HOST_OVERRIDE,
     50 ),)
     51 
     52 
     53 def handle_unary_unary(request, servicer_context):
     54     return pickle.dumps({
     55         _ID: servicer_context.peer_identities(),
     56         _ID_KEY: servicer_context.peer_identity_key(),
     57         _AUTH_CTX: servicer_context.auth_context()
     58     })
     59 
     60 
     61 class AuthContextTest(unittest.TestCase):
     62 
     63     def testInsecure(self):
     64         handler = grpc.method_handlers_generic_handler('test', {
     65             'UnaryUnary':
     66             grpc.unary_unary_rpc_method_handler(handle_unary_unary)
     67         })
     68         server = test_common.test_server()
     69         server.add_generic_rpc_handlers((handler,))
     70         port = server.add_insecure_port('[::]:0')
     71         server.start()
     72 
     73         channel = grpc.insecure_channel('localhost:%d' % port)
     74         response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
     75         server.stop(None)
     76 
     77         auth_data = pickle.loads(response)
     78         self.assertIsNone(auth_data[_ID])
     79         self.assertIsNone(auth_data[_ID_KEY])
     80         self.assertDictEqual({}, auth_data[_AUTH_CTX])
     81 
     82     def testSecureNoCert(self):
     83         handler = grpc.method_handlers_generic_handler('test', {
     84             'UnaryUnary':
     85             grpc.unary_unary_rpc_method_handler(handle_unary_unary)
     86         })
     87         server = test_common.test_server()
     88         server.add_generic_rpc_handlers((handler,))
     89         server_cred = grpc.ssl_server_credentials(_SERVER_CERTS)
     90         port = server.add_secure_port('[::]:0', server_cred)
     91         server.start()
     92 
     93         channel_creds = grpc.ssl_channel_credentials(
     94             root_certificates=_TEST_ROOT_CERTIFICATES)
     95         channel = grpc.secure_channel(
     96             'localhost:{}'.format(port),
     97             channel_creds,
     98             options=_PROPERTY_OPTIONS)
     99         response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
    100         server.stop(None)
    101 
    102         auth_data = pickle.loads(response)
    103         self.assertIsNone(auth_data[_ID])
    104         self.assertIsNone(auth_data[_ID_KEY])
    105         self.assertDictEqual({
    106             'transport_security_type': [b'ssl'],
    107             'ssl_session_reused': [b'false'],
    108         }, auth_data[_AUTH_CTX])
    109 
    110     def testSecureClientCert(self):
    111         handler = grpc.method_handlers_generic_handler('test', {
    112             'UnaryUnary':
    113             grpc.unary_unary_rpc_method_handler(handle_unary_unary)
    114         })
    115         server = test_common.test_server()
    116         server.add_generic_rpc_handlers((handler,))
    117         server_cred = grpc.ssl_server_credentials(
    118             _SERVER_CERTS,
    119             root_certificates=_TEST_ROOT_CERTIFICATES,
    120             require_client_auth=True)
    121         port = server.add_secure_port('[::]:0', server_cred)
    122         server.start()
    123 
    124         channel_creds = grpc.ssl_channel_credentials(
    125             root_certificates=_TEST_ROOT_CERTIFICATES,
    126             private_key=_PRIVATE_KEY,
    127             certificate_chain=_CERTIFICATE_CHAIN)
    128         channel = grpc.secure_channel(
    129             'localhost:{}'.format(port),
    130             channel_creds,
    131             options=_PROPERTY_OPTIONS)
    132 
    133         response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
    134         server.stop(None)
    135 
    136         auth_data = pickle.loads(response)
    137         auth_ctx = auth_data[_AUTH_CTX]
    138         six.assertCountEqual(self, _CLIENT_IDS, auth_data[_ID])
    139         self.assertEqual('x509_subject_alternative_name', auth_data[_ID_KEY])
    140         self.assertSequenceEqual([b'ssl'], auth_ctx['transport_security_type'])
    141         self.assertSequenceEqual([b'*.test.google.com'],
    142                                  auth_ctx['x509_common_name'])
    143 
    144     def _do_one_shot_client_rpc(self, channel_creds, channel_options, port,
    145                                 expect_ssl_session_reused):
    146         channel = grpc.secure_channel(
    147             'localhost:{}'.format(port), channel_creds, options=channel_options)
    148         response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
    149         auth_data = pickle.loads(response)
    150         self.assertEqual(expect_ssl_session_reused,
    151                          auth_data[_AUTH_CTX]['ssl_session_reused'])
    152         channel.close()
    153 
    154     def testSessionResumption(self):
    155         # Set up a secure server
    156         handler = grpc.method_handlers_generic_handler('test', {
    157             'UnaryUnary':
    158             grpc.unary_unary_rpc_method_handler(handle_unary_unary)
    159         })
    160         server = test_common.test_server()
    161         server.add_generic_rpc_handlers((handler,))
    162         server_cred = grpc.ssl_server_credentials(_SERVER_CERTS)
    163         port = server.add_secure_port('[::]:0', server_cred)
    164         server.start()
    165 
    166         # Create a cache for TLS session tickets
    167         cache = session_cache.ssl_session_cache_lru(1)
    168         channel_creds = grpc.ssl_channel_credentials(
    169             root_certificates=_TEST_ROOT_CERTIFICATES)
    170         channel_options = _PROPERTY_OPTIONS + (
    171             ('grpc.ssl_session_cache', cache),)
    172 
    173         # Initial connection has no session to resume
    174         self._do_one_shot_client_rpc(
    175             channel_creds,
    176             channel_options,
    177             port,
    178             expect_ssl_session_reused=[b'false'])
    179 
    180         # Subsequent connections resume sessions
    181         self._do_one_shot_client_rpc(
    182             channel_creds,
    183             channel_options,
    184             port,
    185             expect_ssl_session_reused=[b'true'])
    186         server.stop(None)
    187 
    188 
    189 if __name__ == '__main__':
    190     unittest.main(verbosity=2)
    191