Home | History | Annotate | Download | only in beta
      1 # Copyright 2016 gRPC authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """Translates gRPC's client-side API into gRPC's client-side Beta API."""
     15 
     16 import grpc
     17 from grpc import _common
     18 from grpc.beta import _metadata
     19 from grpc.beta import interfaces
     20 from grpc.framework.common import cardinality
     21 from grpc.framework.foundation import future
     22 from grpc.framework.interfaces.face import face
     23 
     24 # pylint: disable=too-many-arguments,too-many-locals,unused-argument
     25 
     26 _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = {
     27     grpc.StatusCode.CANCELLED: (face.Abortion.Kind.CANCELLED,
     28                                 face.CancellationError),
     29     grpc.StatusCode.UNKNOWN: (face.Abortion.Kind.REMOTE_FAILURE,
     30                               face.RemoteError),
     31     grpc.StatusCode.DEADLINE_EXCEEDED: (face.Abortion.Kind.EXPIRED,
     32                                         face.ExpirationError),
     33     grpc.StatusCode.UNIMPLEMENTED: (face.Abortion.Kind.LOCAL_FAILURE,
     34                                     face.LocalError),
     35 }
     36 
     37 
     38 def _effective_metadata(metadata, metadata_transformer):
     39     non_none_metadata = () if metadata is None else metadata
     40     if metadata_transformer is None:
     41         return non_none_metadata
     42     else:
     43         return metadata_transformer(non_none_metadata)
     44 
     45 
     46 def _credentials(grpc_call_options):
     47     return None if grpc_call_options is None else grpc_call_options.credentials
     48 
     49 
     50 def _abortion(rpc_error_call):
     51     code = rpc_error_call.code()
     52     pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
     53     error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
     54     return face.Abortion(error_kind, rpc_error_call.initial_metadata(),
     55                          rpc_error_call.trailing_metadata(), code,
     56                          rpc_error_call.details())
     57 
     58 
     59 def _abortion_error(rpc_error_call):
     60     code = rpc_error_call.code()
     61     pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
     62     exception_class = face.AbortionError if pair is None else pair[1]
     63     return exception_class(rpc_error_call.initial_metadata(),
     64                            rpc_error_call.trailing_metadata(), code,
     65                            rpc_error_call.details())
     66 
     67 
     68 class _InvocationProtocolContext(interfaces.GRPCInvocationContext):
     69 
     70     def disable_next_request_compression(self):
     71         pass  # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
     72 
     73 
     74 class _Rendezvous(future.Future, face.Call):
     75 
     76     def __init__(self, response_future, response_iterator, call):
     77         self._future = response_future
     78         self._iterator = response_iterator
     79         self._call = call
     80 
     81     def cancel(self):
     82         return self._call.cancel()
     83 
     84     def cancelled(self):
     85         return self._future.cancelled()
     86 
     87     def running(self):
     88         return self._future.running()
     89 
     90     def done(self):
     91         return self._future.done()
     92 
     93     def result(self, timeout=None):
     94         try:
     95             return self._future.result(timeout=timeout)
     96         except grpc.RpcError as rpc_error_call:
     97             raise _abortion_error(rpc_error_call)
     98         except grpc.FutureTimeoutError:
     99             raise future.TimeoutError()
    100         except grpc.FutureCancelledError:
    101             raise future.CancelledError()
    102 
    103     def exception(self, timeout=None):
    104         try:
    105             rpc_error_call = self._future.exception(timeout=timeout)
    106             if rpc_error_call is None:
    107                 return None
    108             else:
    109                 return _abortion_error(rpc_error_call)
    110         except grpc.FutureTimeoutError:
    111             raise future.TimeoutError()
    112         except grpc.FutureCancelledError:
    113             raise future.CancelledError()
    114 
    115     def traceback(self, timeout=None):
    116         try:
    117             return self._future.traceback(timeout=timeout)
    118         except grpc.FutureTimeoutError:
    119             raise future.TimeoutError()
    120         except grpc.FutureCancelledError:
    121             raise future.CancelledError()
    122 
    123     def add_done_callback(self, fn):
    124         self._future.add_done_callback(lambda ignored_callback: fn(self))
    125 
    126     def __iter__(self):
    127         return self
    128 
    129     def _next(self):
    130         try:
    131             return next(self._iterator)
    132         except grpc.RpcError as rpc_error_call:
    133             raise _abortion_error(rpc_error_call)
    134 
    135     def __next__(self):
    136         return self._next()
    137 
    138     def next(self):
    139         return self._next()
    140 
    141     def is_active(self):
    142         return self._call.is_active()
    143 
    144     def time_remaining(self):
    145         return self._call.time_remaining()
    146 
    147     def add_abortion_callback(self, abortion_callback):
    148 
    149         def done_callback():
    150             if self.code() is not grpc.StatusCode.OK:
    151                 abortion_callback(_abortion(self._call))
    152 
    153         registered = self._call.add_callback(done_callback)
    154         return None if registered else done_callback()
    155 
    156     def protocol_context(self):
    157         return _InvocationProtocolContext()
    158 
    159     def initial_metadata(self):
    160         return _metadata.beta(self._call.initial_metadata())
    161 
    162     def terminal_metadata(self):
    163         return _metadata.beta(self._call.terminal_metadata())
    164 
    165     def code(self):
    166         return self._call.code()
    167 
    168     def details(self):
    169         return self._call.details()
    170 
    171 
    172 def _blocking_unary_unary(channel, group, method, timeout, with_call,
    173                           protocol_options, metadata, metadata_transformer,
    174                           request, request_serializer, response_deserializer):
    175     try:
    176         multi_callable = channel.unary_unary(
    177             _common.fully_qualified_method(group, method),
    178             request_serializer=request_serializer,
    179             response_deserializer=response_deserializer)
    180         effective_metadata = _effective_metadata(metadata, metadata_transformer)
    181         if with_call:
    182             response, call = multi_callable.with_call(
    183                 request,
    184                 timeout=timeout,
    185                 metadata=_metadata.unbeta(effective_metadata),
    186                 credentials=_credentials(protocol_options))
    187             return response, _Rendezvous(None, None, call)
    188         else:
    189             return multi_callable(
    190                 request,
    191                 timeout=timeout,
    192                 metadata=_metadata.unbeta(effective_metadata),
    193                 credentials=_credentials(protocol_options))
    194     except grpc.RpcError as rpc_error_call:
    195         raise _abortion_error(rpc_error_call)
    196 
    197 
    198 def _future_unary_unary(channel, group, method, timeout, protocol_options,
    199                         metadata, metadata_transformer, request,
    200                         request_serializer, response_deserializer):
    201     multi_callable = channel.unary_unary(
    202         _common.fully_qualified_method(group, method),
    203         request_serializer=request_serializer,
    204         response_deserializer=response_deserializer)
    205     effective_metadata = _effective_metadata(metadata, metadata_transformer)
    206     response_future = multi_callable.future(
    207         request,
    208         timeout=timeout,
    209         metadata=_metadata.unbeta(effective_metadata),
    210         credentials=_credentials(protocol_options))
    211     return _Rendezvous(response_future, None, response_future)
    212 
    213 
    214 def _unary_stream(channel, group, method, timeout, protocol_options, metadata,
    215                   metadata_transformer, request, request_serializer,
    216                   response_deserializer):
    217     multi_callable = channel.unary_stream(
    218         _common.fully_qualified_method(group, method),
    219         request_serializer=request_serializer,
    220         response_deserializer=response_deserializer)
    221     effective_metadata = _effective_metadata(metadata, metadata_transformer)
    222     response_iterator = multi_callable(
    223         request,
    224         timeout=timeout,
    225         metadata=_metadata.unbeta(effective_metadata),
    226         credentials=_credentials(protocol_options))
    227     return _Rendezvous(None, response_iterator, response_iterator)
    228 
    229 
    230 def _blocking_stream_unary(channel, group, method, timeout, with_call,
    231                            protocol_options, metadata, metadata_transformer,
    232                            request_iterator, request_serializer,
    233                            response_deserializer):
    234     try:
    235         multi_callable = channel.stream_unary(
    236             _common.fully_qualified_method(group, method),
    237             request_serializer=request_serializer,
    238             response_deserializer=response_deserializer)
    239         effective_metadata = _effective_metadata(metadata, metadata_transformer)
    240         if with_call:
    241             response, call = multi_callable.with_call(
    242                 request_iterator,
    243                 timeout=timeout,
    244                 metadata=_metadata.unbeta(effective_metadata),
    245                 credentials=_credentials(protocol_options))
    246             return response, _Rendezvous(None, None, call)
    247         else:
    248             return multi_callable(
    249                 request_iterator,
    250                 timeout=timeout,
    251                 metadata=_metadata.unbeta(effective_metadata),
    252                 credentials=_credentials(protocol_options))
    253     except grpc.RpcError as rpc_error_call:
    254         raise _abortion_error(rpc_error_call)
    255 
    256 
    257 def _future_stream_unary(channel, group, method, timeout, protocol_options,
    258                          metadata, metadata_transformer, request_iterator,
    259                          request_serializer, response_deserializer):
    260     multi_callable = channel.stream_unary(
    261         _common.fully_qualified_method(group, method),
    262         request_serializer=request_serializer,
    263         response_deserializer=response_deserializer)
    264     effective_metadata = _effective_metadata(metadata, metadata_transformer)
    265     response_future = multi_callable.future(
    266         request_iterator,
    267         timeout=timeout,
    268         metadata=_metadata.unbeta(effective_metadata),
    269         credentials=_credentials(protocol_options))
    270     return _Rendezvous(response_future, None, response_future)
    271 
    272 
    273 def _stream_stream(channel, group, method, timeout, protocol_options, metadata,
    274                    metadata_transformer, request_iterator, request_serializer,
    275                    response_deserializer):
    276     multi_callable = channel.stream_stream(
    277         _common.fully_qualified_method(group, method),
    278         request_serializer=request_serializer,
    279         response_deserializer=response_deserializer)
    280     effective_metadata = _effective_metadata(metadata, metadata_transformer)
    281     response_iterator = multi_callable(
    282         request_iterator,
    283         timeout=timeout,
    284         metadata=_metadata.unbeta(effective_metadata),
    285         credentials=_credentials(protocol_options))
    286     return _Rendezvous(None, response_iterator, response_iterator)
    287 
    288 
    289 class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
    290 
    291     def __init__(self, channel, group, method, metadata_transformer,
    292                  request_serializer, response_deserializer):
    293         self._channel = channel
    294         self._group = group
    295         self._method = method
    296         self._metadata_transformer = metadata_transformer
    297         self._request_serializer = request_serializer
    298         self._response_deserializer = response_deserializer
    299 
    300     def __call__(self,
    301                  request,
    302                  timeout,
    303                  metadata=None,
    304                  with_call=False,
    305                  protocol_options=None):
    306         return _blocking_unary_unary(
    307             self._channel, self._group, self._method, timeout, with_call,
    308             protocol_options, metadata, self._metadata_transformer, request,
    309             self._request_serializer, self._response_deserializer)
    310 
    311     def future(self, request, timeout, metadata=None, protocol_options=None):
    312         return _future_unary_unary(
    313             self._channel, self._group, self._method, timeout, protocol_options,
    314             metadata, self._metadata_transformer, request,
    315             self._request_serializer, self._response_deserializer)
    316 
    317     def event(self,
    318               request,
    319               receiver,
    320               abortion_callback,
    321               timeout,
    322               metadata=None,
    323               protocol_options=None):
    324         raise NotImplementedError()
    325 
    326 
    327 class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
    328 
    329     def __init__(self, channel, group, method, metadata_transformer,
    330                  request_serializer, response_deserializer):
    331         self._channel = channel
    332         self._group = group
    333         self._method = method
    334         self._metadata_transformer = metadata_transformer
    335         self._request_serializer = request_serializer
    336         self._response_deserializer = response_deserializer
    337 
    338     def __call__(self, request, timeout, metadata=None, protocol_options=None):
    339         return _unary_stream(
    340             self._channel, self._group, self._method, timeout, protocol_options,
    341             metadata, self._metadata_transformer, request,
    342             self._request_serializer, self._response_deserializer)
    343 
    344     def event(self,
    345               request,
    346               receiver,
    347               abortion_callback,
    348               timeout,
    349               metadata=None,
    350               protocol_options=None):
    351         raise NotImplementedError()
    352 
    353 
    354 class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
    355 
    356     def __init__(self, channel, group, method, metadata_transformer,
    357                  request_serializer, response_deserializer):
    358         self._channel = channel
    359         self._group = group
    360         self._method = method
    361         self._metadata_transformer = metadata_transformer
    362         self._request_serializer = request_serializer
    363         self._response_deserializer = response_deserializer
    364 
    365     def __call__(self,
    366                  request_iterator,
    367                  timeout,
    368                  metadata=None,
    369                  with_call=False,
    370                  protocol_options=None):
    371         return _blocking_stream_unary(
    372             self._channel, self._group, self._method, timeout, with_call,
    373             protocol_options, metadata, self._metadata_transformer,
    374             request_iterator, self._request_serializer,
    375             self._response_deserializer)
    376 
    377     def future(self,
    378                request_iterator,
    379                timeout,
    380                metadata=None,
    381                protocol_options=None):
    382         return _future_stream_unary(
    383             self._channel, self._group, self._method, timeout, protocol_options,
    384             metadata, self._metadata_transformer, request_iterator,
    385             self._request_serializer, self._response_deserializer)
    386 
    387     def event(self,
    388               receiver,
    389               abortion_callback,
    390               timeout,
    391               metadata=None,
    392               protocol_options=None):
    393         raise NotImplementedError()
    394 
    395 
    396 class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
    397 
    398     def __init__(self, channel, group, method, metadata_transformer,
    399                  request_serializer, response_deserializer):
    400         self._channel = channel
    401         self._group = group
    402         self._method = method
    403         self._metadata_transformer = metadata_transformer
    404         self._request_serializer = request_serializer
    405         self._response_deserializer = response_deserializer
    406 
    407     def __call__(self,
    408                  request_iterator,
    409                  timeout,
    410                  metadata=None,
    411                  protocol_options=None):
    412         return _stream_stream(
    413             self._channel, self._group, self._method, timeout, protocol_options,
    414             metadata, self._metadata_transformer, request_iterator,
    415             self._request_serializer, self._response_deserializer)
    416 
    417     def event(self,
    418               receiver,
    419               abortion_callback,
    420               timeout,
    421               metadata=None,
    422               protocol_options=None):
    423         raise NotImplementedError()
    424 
    425 
    426 class _GenericStub(face.GenericStub):
    427 
    428     def __init__(self, channel, metadata_transformer, request_serializers,
    429                  response_deserializers):
    430         self._channel = channel
    431         self._metadata_transformer = metadata_transformer
    432         self._request_serializers = request_serializers or {}
    433         self._response_deserializers = response_deserializers or {}
    434 
    435     def blocking_unary_unary(self,
    436                              group,
    437                              method,
    438                              request,
    439                              timeout,
    440                              metadata=None,
    441                              with_call=None,
    442                              protocol_options=None):
    443         request_serializer = self._request_serializers.get((
    444             group,
    445             method,
    446         ))
    447         response_deserializer = self._response_deserializers.get((
    448             group,
    449             method,
    450         ))
    451         return _blocking_unary_unary(self._channel, group, method, timeout,
    452                                      with_call, protocol_options, metadata,
    453                                      self._metadata_transformer, request,
    454                                      request_serializer, response_deserializer)
    455 
    456     def future_unary_unary(self,
    457                            group,
    458                            method,
    459                            request,
    460                            timeout,
    461                            metadata=None,
    462                            protocol_options=None):
    463         request_serializer = self._request_serializers.get((
    464             group,
    465             method,
    466         ))
    467         response_deserializer = self._response_deserializers.get((
    468             group,
    469             method,
    470         ))
    471         return _future_unary_unary(self._channel, group, method, timeout,
    472                                    protocol_options, metadata,
    473                                    self._metadata_transformer, request,
    474                                    request_serializer, response_deserializer)
    475 
    476     def inline_unary_stream(self,
    477                             group,
    478                             method,
    479                             request,
    480                             timeout,
    481                             metadata=None,
    482                             protocol_options=None):
    483         request_serializer = self._request_serializers.get((
    484             group,
    485             method,
    486         ))
    487         response_deserializer = self._response_deserializers.get((
    488             group,
    489             method,
    490         ))
    491         return _unary_stream(self._channel, group, method, timeout,
    492                              protocol_options, metadata,
    493                              self._metadata_transformer, request,
    494                              request_serializer, response_deserializer)
    495 
    496     def blocking_stream_unary(self,
    497                               group,
    498                               method,
    499                               request_iterator,
    500                               timeout,
    501                               metadata=None,
    502                               with_call=None,
    503                               protocol_options=None):
    504         request_serializer = self._request_serializers.get((
    505             group,
    506             method,
    507         ))
    508         response_deserializer = self._response_deserializers.get((
    509             group,
    510             method,
    511         ))
    512         return _blocking_stream_unary(
    513             self._channel, group, method, timeout, with_call, protocol_options,
    514             metadata, self._metadata_transformer, request_iterator,
    515             request_serializer, response_deserializer)
    516 
    517     def future_stream_unary(self,
    518                             group,
    519                             method,
    520                             request_iterator,
    521                             timeout,
    522                             metadata=None,
    523                             protocol_options=None):
    524         request_serializer = self._request_serializers.get((
    525             group,
    526             method,
    527         ))
    528         response_deserializer = self._response_deserializers.get((
    529             group,
    530             method,
    531         ))
    532         return _future_stream_unary(
    533             self._channel, group, method, timeout, protocol_options, metadata,
    534             self._metadata_transformer, request_iterator, request_serializer,
    535             response_deserializer)
    536 
    537     def inline_stream_stream(self,
    538                              group,
    539                              method,
    540                              request_iterator,
    541                              timeout,
    542                              metadata=None,
    543                              protocol_options=None):
    544         request_serializer = self._request_serializers.get((
    545             group,
    546             method,
    547         ))
    548         response_deserializer = self._response_deserializers.get((
    549             group,
    550             method,
    551         ))
    552         return _stream_stream(self._channel, group, method, timeout,
    553                               protocol_options, metadata,
    554                               self._metadata_transformer, request_iterator,
    555                               request_serializer, response_deserializer)
    556 
    557     def event_unary_unary(self,
    558                           group,
    559                           method,
    560                           request,
    561                           receiver,
    562                           abortion_callback,
    563                           timeout,
    564                           metadata=None,
    565                           protocol_options=None):
    566         raise NotImplementedError()
    567 
    568     def event_unary_stream(self,
    569                            group,
    570                            method,
    571                            request,
    572                            receiver,
    573                            abortion_callback,
    574                            timeout,
    575                            metadata=None,
    576                            protocol_options=None):
    577         raise NotImplementedError()
    578 
    579     def event_stream_unary(self,
    580                            group,
    581                            method,
    582                            receiver,
    583                            abortion_callback,
    584                            timeout,
    585                            metadata=None,
    586                            protocol_options=None):
    587         raise NotImplementedError()
    588 
    589     def event_stream_stream(self,
    590                             group,
    591                             method,
    592                             receiver,
    593                             abortion_callback,
    594                             timeout,
    595                             metadata=None,
    596                             protocol_options=None):
    597         raise NotImplementedError()
    598 
    599     def unary_unary(self, group, method):
    600         request_serializer = self._request_serializers.get((
    601             group,
    602             method,
    603         ))
    604         response_deserializer = self._response_deserializers.get((
    605             group,
    606             method,
    607         ))
    608         return _UnaryUnaryMultiCallable(
    609             self._channel, group, method, self._metadata_transformer,
    610             request_serializer, response_deserializer)
    611 
    612     def unary_stream(self, group, method):
    613         request_serializer = self._request_serializers.get((
    614             group,
    615             method,
    616         ))
    617         response_deserializer = self._response_deserializers.get((
    618             group,
    619             method,
    620         ))
    621         return _UnaryStreamMultiCallable(
    622             self._channel, group, method, self._metadata_transformer,
    623             request_serializer, response_deserializer)
    624 
    625     def stream_unary(self, group, method):
    626         request_serializer = self._request_serializers.get((
    627             group,
    628             method,
    629         ))
    630         response_deserializer = self._response_deserializers.get((
    631             group,
    632             method,
    633         ))
    634         return _StreamUnaryMultiCallable(
    635             self._channel, group, method, self._metadata_transformer,
    636             request_serializer, response_deserializer)
    637 
    638     def stream_stream(self, group, method):
    639         request_serializer = self._request_serializers.get((
    640             group,
    641             method,
    642         ))
    643         response_deserializer = self._response_deserializers.get((
    644             group,
    645             method,
    646         ))
    647         return _StreamStreamMultiCallable(
    648             self._channel, group, method, self._metadata_transformer,
    649             request_serializer, response_deserializer)
    650 
    651     def __enter__(self):
    652         return self
    653 
    654     def __exit__(self, exc_type, exc_val, exc_tb):
    655         return False
    656 
    657 
    658 class _DynamicStub(face.DynamicStub):
    659 
    660     def __init__(self, backing_generic_stub, group, cardinalities):
    661         self._generic_stub = backing_generic_stub
    662         self._group = group
    663         self._cardinalities = cardinalities
    664 
    665     def __getattr__(self, attr):
    666         method_cardinality = self._cardinalities.get(attr)
    667         if method_cardinality is cardinality.Cardinality.UNARY_UNARY:
    668             return self._generic_stub.unary_unary(self._group, attr)
    669         elif method_cardinality is cardinality.Cardinality.UNARY_STREAM:
    670             return self._generic_stub.unary_stream(self._group, attr)
    671         elif method_cardinality is cardinality.Cardinality.STREAM_UNARY:
    672             return self._generic_stub.stream_unary(self._group, attr)
    673         elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
    674             return self._generic_stub.stream_stream(self._group, attr)
    675         else:
    676             raise AttributeError(
    677                 '_DynamicStub object has no attribute "%s"!' % attr)
    678 
    679     def __enter__(self):
    680         return self
    681 
    682     def __exit__(self, exc_type, exc_val, exc_tb):
    683         return False
    684 
    685 
    686 def generic_stub(channel, host, metadata_transformer, request_serializers,
    687                  response_deserializers):
    688     return _GenericStub(channel, metadata_transformer, request_serializers,
    689                         response_deserializers)
    690 
    691 
    692 def dynamic_stub(channel, service, cardinalities, host, metadata_transformer,
    693                  request_serializers, response_deserializers):
    694     return _DynamicStub(
    695         _GenericStub(channel, metadata_transformer, request_serializers,
    696                      response_deserializers), service, cardinalities)
    697