Home | History | Annotate | Download | only in _cython
      1 # Copyright 2016 gRPC authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """Test making many calls and immediately cancelling most of them."""
     15 
     16 import threading
     17 import unittest
     18 
     19 from grpc._cython import cygrpc
     20 from grpc.framework.foundation import logging_pool
     21 from tests.unit.framework.common import test_constants
     22 from tests.unit._cython import test_utilities
     23 
     24 _EMPTY_FLAGS = 0
     25 _EMPTY_METADATA = ()
     26 
     27 _SERVER_SHUTDOWN_TAG = 'server_shutdown'
     28 _REQUEST_CALL_TAG = 'request_call'
     29 _RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server'
     30 _RECEIVE_MESSAGE_TAG = 'receive_message'
     31 _SERVER_COMPLETE_CALL_TAG = 'server_complete_call'
     32 
     33 _SUCCESS_CALL_FRACTION = 1.0 / 8.0
     34 _SUCCESSFUL_CALLS = int(test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
     35 _UNSUCCESSFUL_CALLS = test_constants.RPC_CONCURRENCY - _SUCCESSFUL_CALLS
     36 
     37 
     38 class _State(object):
     39 
     40     def __init__(self):
     41         self.condition = threading.Condition()
     42         self.handlers_released = False
     43         self.parked_handlers = 0
     44         self.handled_rpcs = 0
     45 
     46 
     47 def _is_cancellation_event(event):
     48     return (event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
     49             event.batch_operations[0].cancelled())
     50 
     51 
     52 class _Handler(object):
     53 
     54     def __init__(self, state, completion_queue, rpc_event):
     55         self._state = state
     56         self._lock = threading.Lock()
     57         self._completion_queue = completion_queue
     58         self._call = rpc_event.call
     59 
     60     def __call__(self):
     61         with self._state.condition:
     62             self._state.parked_handlers += 1
     63             if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY:
     64                 self._state.condition.notify_all()
     65             while not self._state.handlers_released:
     66                 self._state.condition.wait()
     67 
     68         with self._lock:
     69             self._call.start_server_batch(
     70                 (cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),),
     71                 _RECEIVE_CLOSE_ON_SERVER_TAG)
     72             self._call.start_server_batch(
     73                 (cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),),
     74                 _RECEIVE_MESSAGE_TAG)
     75         first_event = self._completion_queue.poll()
     76         if _is_cancellation_event(first_event):
     77             self._completion_queue.poll()
     78         else:
     79             with self._lock:
     80                 operations = (
     81                     cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA,
     82                                                         _EMPTY_FLAGS),
     83                     cygrpc.SendMessageOperation(b'\x79\x57', _EMPTY_FLAGS),
     84                     cygrpc.SendStatusFromServerOperation(
     85                         _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
     86                         _EMPTY_FLAGS),
     87                 )
     88                 self._call.start_server_batch(operations,
     89                                               _SERVER_COMPLETE_CALL_TAG)
     90             self._completion_queue.poll()
     91             self._completion_queue.poll()
     92 
     93 
     94 def _serve(state, server, server_completion_queue, thread_pool):
     95     for _ in range(test_constants.RPC_CONCURRENCY):
     96         call_completion_queue = cygrpc.CompletionQueue()
     97         server.request_call(call_completion_queue, server_completion_queue,
     98                             _REQUEST_CALL_TAG)
     99         rpc_event = server_completion_queue.poll()
    100         thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
    101         with state.condition:
    102             state.handled_rpcs += 1
    103             if test_constants.RPC_CONCURRENCY <= state.handled_rpcs:
    104                 state.condition.notify_all()
    105     server_completion_queue.poll()
    106 
    107 
    108 class _QueueDriver(object):
    109 
    110     def __init__(self, condition, completion_queue, due):
    111         self._condition = condition
    112         self._completion_queue = completion_queue
    113         self._due = due
    114         self._events = []
    115         self._returned = False
    116 
    117     def start(self):
    118 
    119         def in_thread():
    120             while True:
    121                 event = self._completion_queue.poll()
    122                 with self._condition:
    123                     self._events.append(event)
    124                     self._due.remove(event.tag)
    125                     self._condition.notify_all()
    126                     if not self._due:
    127                         self._returned = True
    128                         return
    129 
    130         thread = threading.Thread(target=in_thread)
    131         thread.start()
    132 
    133     def events(self, at_least):
    134         with self._condition:
    135             while len(self._events) < at_least:
    136                 self._condition.wait()
    137             return tuple(self._events)
    138 
    139 
    140 class CancelManyCallsTest(unittest.TestCase):
    141 
    142     def testCancelManyCalls(self):
    143         server_thread_pool = logging_pool.pool(
    144             test_constants.THREAD_CONCURRENCY)
    145 
    146         server_completion_queue = cygrpc.CompletionQueue()
    147         server = cygrpc.Server([
    148             (
    149                 b'grpc.so_reuseport',
    150                 0,
    151             ),
    152         ])
    153         server.register_completion_queue(server_completion_queue)
    154         port = server.add_http2_port(b'[::]:0')
    155         server.start()
    156         channel = cygrpc.Channel('localhost:{}'.format(port).encode(), None,
    157                                  None)
    158 
    159         state = _State()
    160 
    161         server_thread_args = (
    162             state,
    163             server,
    164             server_completion_queue,
    165             server_thread_pool,
    166         )
    167         server_thread = threading.Thread(target=_serve, args=server_thread_args)
    168         server_thread.start()
    169 
    170         client_condition = threading.Condition()
    171         client_due = set()
    172 
    173         with client_condition:
    174             client_calls = []
    175             for index in range(test_constants.RPC_CONCURRENCY):
    176                 tag = 'client_complete_call_{0:04d}_tag'.format(index)
    177                 client_call = channel.integrated_call(
    178                     _EMPTY_FLAGS, b'/twinkies', None, None, _EMPTY_METADATA,
    179                     None, ((
    180                         (
    181                             cygrpc.SendInitialMetadataOperation(
    182                                 _EMPTY_METADATA, _EMPTY_FLAGS),
    183                             cygrpc.SendMessageOperation(b'\x45\x56',
    184                                                         _EMPTY_FLAGS),
    185                             cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
    186                             cygrpc.ReceiveInitialMetadataOperation(
    187                                 _EMPTY_FLAGS),
    188                             cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
    189                             cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
    190                         ),
    191                         tag,
    192                     ),))
    193                 client_due.add(tag)
    194                 client_calls.append(client_call)
    195 
    196         client_events_future = test_utilities.SimpleFuture(
    197             lambda: tuple(channel.next_call_event() for _ in range(_SUCCESSFUL_CALLS)))
    198 
    199         with state.condition:
    200             while True:
    201                 if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
    202                     state.condition.wait()
    203                 elif state.handled_rpcs < test_constants.RPC_CONCURRENCY:
    204                     state.condition.wait()
    205                 else:
    206                     state.handlers_released = True
    207                     state.condition.notify_all()
    208                     break
    209 
    210         client_events_future.result()
    211         with client_condition:
    212             for client_call in client_calls:
    213                 client_call.cancel(cygrpc.StatusCode.cancelled, 'Cancelled!')
    214         for _ in range(_UNSUCCESSFUL_CALLS):
    215             channel.next_call_event()
    216 
    217         channel.close(cygrpc.StatusCode.unknown, 'Cancelled on channel close!')
    218         with state.condition:
    219             server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)
    220 
    221 
    222 if __name__ == '__main__':
    223     unittest.main(verbosity=2)
    224