Home | History | Annotate | Download | only in beta
      1 # Copyright 2015 gRPC authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """Tests Face interface compliance of the gRPC Python Beta API."""
     15 
     16 import threading
     17 import unittest
     18 
     19 from grpc.beta import implementations
     20 from grpc.beta import interfaces
     21 from grpc.framework.common import cardinality
     22 from grpc.framework.interfaces.face import utilities
     23 from tests.unit import resources
     24 from tests.unit.beta import test_utilities
     25 from tests.unit.framework.common import test_constants
     26 
     27 _SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
     28 
     29 _PER_RPC_CREDENTIALS_METADATA_KEY = b'my-call-credentials-metadata-key'
     30 _PER_RPC_CREDENTIALS_METADATA_VALUE = b'my-call-credentials-metadata-value'
     31 
     32 _GROUP = 'group'
     33 _UNARY_UNARY = 'unary-unary'
     34 _UNARY_STREAM = 'unary-stream'
     35 _STREAM_UNARY = 'stream-unary'
     36 _STREAM_STREAM = 'stream-stream'
     37 
     38 _REQUEST = b'abc'
     39 _RESPONSE = b'123'
     40 
     41 
     42 class _Servicer(object):
     43 
     44     def __init__(self):
     45         self._condition = threading.Condition()
     46         self._peer = None
     47         self._serviced = False
     48 
     49     def unary_unary(self, request, context):
     50         with self._condition:
     51             self._request = request
     52             self._peer = context.protocol_context().peer()
     53             self._invocation_metadata = context.invocation_metadata()
     54             context.protocol_context().disable_next_response_compression()
     55             self._serviced = True
     56             self._condition.notify_all()
     57             return _RESPONSE
     58 
     59     def unary_stream(self, request, context):
     60         with self._condition:
     61             self._request = request
     62             self._peer = context.protocol_context().peer()
     63             self._invocation_metadata = context.invocation_metadata()
     64             context.protocol_context().disable_next_response_compression()
     65             self._serviced = True
     66             self._condition.notify_all()
     67             return
     68             yield  # pylint: disable=unreachable
     69 
     70     def stream_unary(self, request_iterator, context):
     71         for request in request_iterator:
     72             self._request = request
     73         with self._condition:
     74             self._peer = context.protocol_context().peer()
     75             self._invocation_metadata = context.invocation_metadata()
     76             context.protocol_context().disable_next_response_compression()
     77             self._serviced = True
     78             self._condition.notify_all()
     79             return _RESPONSE
     80 
     81     def stream_stream(self, request_iterator, context):
     82         for request in request_iterator:
     83             with self._condition:
     84                 self._peer = context.protocol_context().peer()
     85                 context.protocol_context().disable_next_response_compression()
     86                 yield _RESPONSE
     87         with self._condition:
     88             self._invocation_metadata = context.invocation_metadata()
     89             self._serviced = True
     90             self._condition.notify_all()
     91 
     92     def peer(self):
     93         with self._condition:
     94             return self._peer
     95 
     96     def block_until_serviced(self):
     97         with self._condition:
     98             while not self._serviced:
     99                 self._condition.wait()
    100 
    101 
    102 class _BlockingIterator(object):
    103 
    104     def __init__(self, upstream):
    105         self._condition = threading.Condition()
    106         self._upstream = upstream
    107         self._allowed = []
    108 
    109     def __iter__(self):
    110         return self
    111 
    112     def __next__(self):
    113         return self.next()
    114 
    115     def next(self):
    116         with self._condition:
    117             while True:
    118                 if self._allowed is None:
    119                     raise StopIteration()
    120                 elif self._allowed:
    121                     return self._allowed.pop(0)
    122                 else:
    123                     self._condition.wait()
    124 
    125     def allow(self):
    126         with self._condition:
    127             try:
    128                 self._allowed.append(next(self._upstream))
    129             except StopIteration:
    130                 self._allowed = None
    131             self._condition.notify_all()
    132 
    133 
    134 def _metadata_plugin(context, callback):
    135     callback([(_PER_RPC_CREDENTIALS_METADATA_KEY,
    136                _PER_RPC_CREDENTIALS_METADATA_VALUE)], None)
    137 
    138 
    139 class BetaFeaturesTest(unittest.TestCase):
    140 
    141     def setUp(self):
    142         self._servicer = _Servicer()
    143         method_implementations = {
    144             (_GROUP, _UNARY_UNARY):
    145             utilities.unary_unary_inline(self._servicer.unary_unary),
    146             (_GROUP, _UNARY_STREAM):
    147             utilities.unary_stream_inline(self._servicer.unary_stream),
    148             (_GROUP, _STREAM_UNARY):
    149             utilities.stream_unary_inline(self._servicer.stream_unary),
    150             (_GROUP, _STREAM_STREAM):
    151             utilities.stream_stream_inline(self._servicer.stream_stream),
    152         }
    153 
    154         cardinalities = {
    155             _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY,
    156             _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM,
    157             _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY,
    158             _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM,
    159         }
    160 
    161         server_options = implementations.server_options(
    162             thread_pool_size=test_constants.POOL_SIZE)
    163         self._server = implementations.server(
    164             method_implementations, options=server_options)
    165         server_credentials = implementations.ssl_server_credentials([
    166             (
    167                 resources.private_key(),
    168                 resources.certificate_chain(),
    169             ),
    170         ])
    171         port = self._server.add_secure_port('[::]:0', server_credentials)
    172         self._server.start()
    173         self._channel_credentials = implementations.ssl_channel_credentials(
    174             resources.test_root_certificates())
    175         self._call_credentials = implementations.metadata_call_credentials(
    176             _metadata_plugin)
    177         channel = test_utilities.not_really_secure_channel(
    178             'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
    179         stub_options = implementations.stub_options(
    180             thread_pool_size=test_constants.POOL_SIZE)
    181         self._dynamic_stub = implementations.dynamic_stub(
    182             channel, _GROUP, cardinalities, options=stub_options)
    183 
    184     def tearDown(self):
    185         self._dynamic_stub = None
    186         self._server.stop(test_constants.SHORT_TIMEOUT).wait()
    187 
    188     def test_unary_unary(self):
    189         call_options = interfaces.grpc_call_options(
    190             disable_compression=True, credentials=self._call_credentials)
    191         response = getattr(self._dynamic_stub, _UNARY_UNARY)(
    192             _REQUEST,
    193             test_constants.LONG_TIMEOUT,
    194             protocol_options=call_options)
    195         self.assertEqual(_RESPONSE, response)
    196         self.assertIsNotNone(self._servicer.peer())
    197         invocation_metadata = [
    198             (metadatum.key, metadatum.value)
    199             for metadatum in self._servicer._invocation_metadata
    200         ]
    201         self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY,
    202                        _PER_RPC_CREDENTIALS_METADATA_VALUE),
    203                       invocation_metadata)
    204 
    205     def test_unary_stream(self):
    206         call_options = interfaces.grpc_call_options(
    207             disable_compression=True, credentials=self._call_credentials)
    208         response_iterator = getattr(self._dynamic_stub, _UNARY_STREAM)(
    209             _REQUEST,
    210             test_constants.LONG_TIMEOUT,
    211             protocol_options=call_options)
    212         self._servicer.block_until_serviced()
    213         self.assertIsNotNone(self._servicer.peer())
    214         invocation_metadata = [
    215             (metadatum.key, metadatum.value)
    216             for metadatum in self._servicer._invocation_metadata
    217         ]
    218         self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY,
    219                        _PER_RPC_CREDENTIALS_METADATA_VALUE),
    220                       invocation_metadata)
    221 
    222     def test_stream_unary(self):
    223         call_options = interfaces.grpc_call_options(
    224             credentials=self._call_credentials)
    225         request_iterator = _BlockingIterator(iter((_REQUEST,)))
    226         response_future = getattr(self._dynamic_stub, _STREAM_UNARY).future(
    227             request_iterator,
    228             test_constants.LONG_TIMEOUT,
    229             protocol_options=call_options)
    230         response_future.protocol_context().disable_next_request_compression()
    231         request_iterator.allow()
    232         response_future.protocol_context().disable_next_request_compression()
    233         request_iterator.allow()
    234         self._servicer.block_until_serviced()
    235         self.assertIsNotNone(self._servicer.peer())
    236         self.assertEqual(_RESPONSE, response_future.result())
    237         invocation_metadata = [
    238             (metadatum.key, metadatum.value)
    239             for metadatum in self._servicer._invocation_metadata
    240         ]
    241         self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY,
    242                        _PER_RPC_CREDENTIALS_METADATA_VALUE),
    243                       invocation_metadata)
    244 
    245     def test_stream_stream(self):
    246         call_options = interfaces.grpc_call_options(
    247             credentials=self._call_credentials)
    248         request_iterator = _BlockingIterator(iter((_REQUEST,)))
    249         response_iterator = getattr(self._dynamic_stub, _STREAM_STREAM)(
    250             request_iterator,
    251             test_constants.SHORT_TIMEOUT,
    252             protocol_options=call_options)
    253         response_iterator.protocol_context().disable_next_request_compression()
    254         request_iterator.allow()
    255         response = next(response_iterator)
    256         response_iterator.protocol_context().disable_next_request_compression()
    257         request_iterator.allow()
    258         self._servicer.block_until_serviced()
    259         self.assertIsNotNone(self._servicer.peer())
    260         self.assertEqual(_RESPONSE, response)
    261         invocation_metadata = [
    262             (metadatum.key, metadatum.value)
    263             for metadatum in self._servicer._invocation_metadata
    264         ]
    265         self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY,
    266                        _PER_RPC_CREDENTIALS_METADATA_VALUE),
    267                       invocation_metadata)
    268 
    269 
    270 class ContextManagementAndLifecycleTest(unittest.TestCase):
    271 
    272     def setUp(self):
    273         self._servicer = _Servicer()
    274         self._method_implementations = {
    275             (_GROUP, _UNARY_UNARY):
    276             utilities.unary_unary_inline(self._servicer.unary_unary),
    277             (_GROUP, _UNARY_STREAM):
    278             utilities.unary_stream_inline(self._servicer.unary_stream),
    279             (_GROUP, _STREAM_UNARY):
    280             utilities.stream_unary_inline(self._servicer.stream_unary),
    281             (_GROUP, _STREAM_STREAM):
    282             utilities.stream_stream_inline(self._servicer.stream_stream),
    283         }
    284 
    285         self._cardinalities = {
    286             _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY,
    287             _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM,
    288             _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY,
    289             _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM,
    290         }
    291 
    292         self._server_options = implementations.server_options(
    293             thread_pool_size=test_constants.POOL_SIZE)
    294         self._server_credentials = implementations.ssl_server_credentials([
    295             (
    296                 resources.private_key(),
    297                 resources.certificate_chain(),
    298             ),
    299         ])
    300         self._channel_credentials = implementations.ssl_channel_credentials(
    301             resources.test_root_certificates())
    302         self._stub_options = implementations.stub_options(
    303             thread_pool_size=test_constants.POOL_SIZE)
    304 
    305     def test_stub_context(self):
    306         server = implementations.server(
    307             self._method_implementations, options=self._server_options)
    308         port = server.add_secure_port('[::]:0', self._server_credentials)
    309         server.start()
    310 
    311         channel = test_utilities.not_really_secure_channel(
    312             'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
    313         dynamic_stub = implementations.dynamic_stub(
    314             channel, _GROUP, self._cardinalities, options=self._stub_options)
    315         for _ in range(100):
    316             with dynamic_stub:
    317                 pass
    318         for _ in range(10):
    319             with dynamic_stub:
    320                 call_options = interfaces.grpc_call_options(
    321                     disable_compression=True)
    322                 response = getattr(dynamic_stub, _UNARY_UNARY)(
    323                     _REQUEST,
    324                     test_constants.LONG_TIMEOUT,
    325                     protocol_options=call_options)
    326                 self.assertEqual(_RESPONSE, response)
    327                 self.assertIsNotNone(self._servicer.peer())
    328 
    329         server.stop(test_constants.SHORT_TIMEOUT).wait()
    330 
    331     def test_server_lifecycle(self):
    332         for _ in range(100):
    333             server = implementations.server(
    334                 self._method_implementations, options=self._server_options)
    335             port = server.add_secure_port('[::]:0', self._server_credentials)
    336             server.start()
    337             server.stop(test_constants.SHORT_TIMEOUT).wait()
    338         for _ in range(100):
    339             server = implementations.server(
    340                 self._method_implementations, options=self._server_options)
    341             server.add_secure_port('[::]:0', self._server_credentials)
    342             server.add_insecure_port('[::]:0')
    343             with server:
    344                 server.stop(test_constants.SHORT_TIMEOUT)
    345             server.stop(test_constants.SHORT_TIMEOUT)
    346 
    347 
    348 if __name__ == '__main__':
    349     unittest.main(verbosity=2)
    350