Home | History | Annotate | Download | only in _cython
      1 # Copyright 2016 gRPC authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """Test a corner-case at the level of the Cython API."""
     15 
     16 import threading
     17 import unittest
     18 
     19 from grpc._cython import cygrpc
     20 from tests.unit._cython import test_utilities
     21 
     22 _EMPTY_FLAGS = 0
     23 _EMPTY_METADATA = ()
     24 
     25 
     26 class _ServerDriver(object):
     27 
     28     def __init__(self, completion_queue, shutdown_tag):
     29         self._condition = threading.Condition()
     30         self._completion_queue = completion_queue
     31         self._shutdown_tag = shutdown_tag
     32         self._events = []
     33         self._saw_shutdown_tag = False
     34 
     35     def start(self):
     36 
     37         def in_thread():
     38             while True:
     39                 event = self._completion_queue.poll()
     40                 with self._condition:
     41                     self._events.append(event)
     42                     self._condition.notify()
     43                     if event.tag is self._shutdown_tag:
     44                         self._saw_shutdown_tag = True
     45                         break
     46 
     47         thread = threading.Thread(target=in_thread)
     48         thread.start()
     49 
     50     def done(self):
     51         with self._condition:
     52             return self._saw_shutdown_tag
     53 
     54     def first_event(self):
     55         with self._condition:
     56             while not self._events:
     57                 self._condition.wait()
     58             return self._events[0]
     59 
     60     def events(self):
     61         with self._condition:
     62             while not self._saw_shutdown_tag:
     63                 self._condition.wait()
     64             return tuple(self._events)
     65 
     66 
     67 class _QueueDriver(object):
     68 
     69     def __init__(self, condition, completion_queue, due):
     70         self._condition = condition
     71         self._completion_queue = completion_queue
     72         self._due = due
     73         self._events = []
     74         self._returned = False
     75 
     76     def start(self):
     77 
     78         def in_thread():
     79             while True:
     80                 event = self._completion_queue.poll()
     81                 with self._condition:
     82                     self._events.append(event)
     83                     self._due.remove(event.tag)
     84                     self._condition.notify_all()
     85                     if not self._due:
     86                         self._returned = True
     87                         return
     88 
     89         thread = threading.Thread(target=in_thread)
     90         thread.start()
     91 
     92     def done(self):
     93         with self._condition:
     94             return self._returned
     95 
     96     def event_with_tag(self, tag):
     97         with self._condition:
     98             while True:
     99                 for event in self._events:
    100                     if event.tag is tag:
    101                         return event
    102                 self._condition.wait()
    103 
    104     def events(self):
    105         with self._condition:
    106             while not self._returned:
    107                 self._condition.wait()
    108             return tuple(self._events)
    109 
    110 
    111 class ReadSomeButNotAllResponsesTest(unittest.TestCase):
    112 
    113     def testReadSomeButNotAllResponses(self):
    114         server_completion_queue = cygrpc.CompletionQueue()
    115         server = cygrpc.Server([(
    116             b'grpc.so_reuseport',
    117             0,
    118         )])
    119         server.register_completion_queue(server_completion_queue)
    120         port = server.add_http2_port(b'[::]:0')
    121         server.start()
    122         channel = cygrpc.Channel('localhost:{}'.format(port).encode(), set(),
    123                                  None)
    124 
    125         server_shutdown_tag = 'server_shutdown_tag'
    126         server_driver = _ServerDriver(server_completion_queue,
    127                                       server_shutdown_tag)
    128         server_driver.start()
    129 
    130         client_condition = threading.Condition()
    131         client_due = set()
    132 
    133         server_call_condition = threading.Condition()
    134         server_send_initial_metadata_tag = 'server_send_initial_metadata_tag'
    135         server_send_first_message_tag = 'server_send_first_message_tag'
    136         server_send_second_message_tag = 'server_send_second_message_tag'
    137         server_complete_rpc_tag = 'server_complete_rpc_tag'
    138         server_call_due = set((
    139             server_send_initial_metadata_tag,
    140             server_send_first_message_tag,
    141             server_send_second_message_tag,
    142             server_complete_rpc_tag,
    143         ))
    144         server_call_completion_queue = cygrpc.CompletionQueue()
    145         server_call_driver = _QueueDriver(server_call_condition,
    146                                           server_call_completion_queue,
    147                                           server_call_due)
    148         server_call_driver.start()
    149 
    150         server_rpc_tag = 'server_rpc_tag'
    151         request_call_result = server.request_call(server_call_completion_queue,
    152                                                   server_completion_queue,
    153                                                   server_rpc_tag)
    154 
    155         client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
    156         client_complete_rpc_tag = 'client_complete_rpc_tag'
    157         client_call = channel.segregated_call(
    158             _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA, None, (
    159                 (
    160                     [
    161                         cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
    162                     ],
    163                     client_receive_initial_metadata_tag,
    164                 ),
    165                 (
    166                     [
    167                         cygrpc.SendInitialMetadataOperation(
    168                             _EMPTY_METADATA, _EMPTY_FLAGS),
    169                         cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
    170                         cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
    171                     ],
    172                     client_complete_rpc_tag,
    173                 ),
    174             ))
    175         client_receive_initial_metadata_event_future = test_utilities.SimpleFuture(
    176             client_call.next_event)
    177 
    178         server_rpc_event = server_driver.first_event()
    179 
    180         with server_call_condition:
    181             server_send_initial_metadata_start_batch_result = (
    182                 server_rpc_event.call.start_server_batch([
    183                     cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA,
    184                                                         _EMPTY_FLAGS),
    185                 ], server_send_initial_metadata_tag))
    186             server_send_first_message_start_batch_result = (
    187                 server_rpc_event.call.start_server_batch([
    188                     cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS),
    189                 ], server_send_first_message_tag))
    190         server_send_initial_metadata_event = server_call_driver.event_with_tag(
    191             server_send_initial_metadata_tag)
    192         server_send_first_message_event = server_call_driver.event_with_tag(
    193             server_send_first_message_tag)
    194         with server_call_condition:
    195             server_send_second_message_start_batch_result = (
    196                 server_rpc_event.call.start_server_batch([
    197                     cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS),
    198                 ], server_send_second_message_tag))
    199             server_complete_rpc_start_batch_result = (
    200                 server_rpc_event.call.start_server_batch([
    201                     cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
    202                     cygrpc.SendStatusFromServerOperation(
    203                         (), cygrpc.StatusCode.ok, b'test details',
    204                         _EMPTY_FLAGS),
    205                 ], server_complete_rpc_tag))
    206         server_send_second_message_event = server_call_driver.event_with_tag(
    207             server_send_second_message_tag)
    208         server_complete_rpc_event = server_call_driver.event_with_tag(
    209             server_complete_rpc_tag)
    210         server_call_driver.events()
    211 
    212         client_recieve_initial_metadata_event = client_receive_initial_metadata_event_future.result(
    213         )
    214 
    215         client_receive_first_message_tag = 'client_receive_first_message_tag'
    216         client_call.operate([
    217             cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
    218         ], client_receive_first_message_tag)
    219         client_receive_first_message_event = client_call.next_event()
    220 
    221         client_call_cancel_result = client_call.cancel(
    222             cygrpc.StatusCode.cancelled, 'Cancelled during test!')
    223         client_complete_rpc_event = client_call.next_event()
    224 
    225         channel.close(cygrpc.StatusCode.unknown, 'Channel closed!')
    226         server.shutdown(server_completion_queue, server_shutdown_tag)
    227         server.cancel_all_calls()
    228         server_driver.events()
    229 
    230         self.assertEqual(cygrpc.CallError.ok, request_call_result)
    231         self.assertEqual(cygrpc.CallError.ok,
    232                          server_send_initial_metadata_start_batch_result)
    233         self.assertIs(server_rpc_tag, server_rpc_event.tag)
    234         self.assertEqual(cygrpc.CompletionType.operation_complete,
    235                          server_rpc_event.completion_type)
    236         self.assertIsInstance(server_rpc_event.call, cygrpc.Call)
    237 
    238 
    239 if __name__ == '__main__':
    240     unittest.main(verbosity=2)
    241