Home | History | Annotate | Download | only in py
      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2015 Google Inc.
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 #     http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 """Extra types understood by apitools."""
     18 
     19 import collections
     20 import datetime
     21 import json
     22 import numbers
     23 
     24 import six
     25 
     26 from apitools.base.protorpclite import message_types
     27 from apitools.base.protorpclite import messages
     28 from apitools.base.protorpclite import protojson
     29 from apitools.base.py import encoding
     30 from apitools.base.py import exceptions
     31 from apitools.base.py import util
     32 
     33 __all__ = [
     34     'DateField',
     35     'DateTimeMessage',
     36     'JsonArray',
     37     'JsonObject',
     38     'JsonValue',
     39     'JsonProtoEncoder',
     40     'JsonProtoDecoder',
     41 ]
     42 
     43 # pylint:disable=invalid-name
     44 DateTimeMessage = message_types.DateTimeMessage
     45 # pylint:enable=invalid-name
     46 
     47 
     48 # We insert our own metaclass here to avoid letting ProtoRPC
     49 # register this as the default field type for strings.
     50 #  * since ProtoRPC does this via metaclasses, we don't have any
     51 #    choice but to use one ourselves
     52 #  * since a subclass's metaclass must inherit from its superclass's
     53 #    metaclass, we're forced to have this hard-to-read inheritance.
     54 #
     55 # pylint: disable=protected-access
     56 class _FieldMeta(messages._FieldMeta):
     57 
     58     def __init__(cls, name, bases, dct):  # pylint: disable=no-self-argument
     59         # pylint: disable=super-init-not-called,non-parent-init-called
     60         type.__init__(cls, name, bases, dct)
     61 # pylint: enable=protected-access
     62 
     63 
     64 class DateField(six.with_metaclass(_FieldMeta, messages.Field)):
     65 
     66     """Field definition for Date values."""
     67 
     68     VARIANTS = frozenset([messages.Variant.STRING])
     69     DEFAULT_VARIANT = messages.Variant.STRING
     70     type = datetime.date
     71 
     72 
     73 def _ValidateJsonValue(json_value):
     74     entries = [(f, json_value.get_assigned_value(f.name))
     75                for f in json_value.all_fields()]
     76     assigned_entries = [(f, value)
     77                         for f, value in entries if value is not None]
     78     if len(assigned_entries) != 1:
     79         raise exceptions.InvalidDataError(
     80             'Malformed JsonValue: %s' % json_value)
     81 
     82 
     83 def _JsonValueToPythonValue(json_value):
     84     """Convert the given JsonValue to a json string."""
     85     util.Typecheck(json_value, JsonValue)
     86     _ValidateJsonValue(json_value)
     87     if json_value.is_null:
     88         return None
     89     entries = [(f, json_value.get_assigned_value(f.name))
     90                for f in json_value.all_fields()]
     91     assigned_entries = [(f, value)
     92                         for f, value in entries if value is not None]
     93     field, value = assigned_entries[0]
     94     if not isinstance(field, messages.MessageField):
     95         return value
     96     elif field.message_type is JsonObject:
     97         return _JsonObjectToPythonValue(value)
     98     elif field.message_type is JsonArray:
     99         return _JsonArrayToPythonValue(value)
    100 
    101 
    102 def _JsonObjectToPythonValue(json_value):
    103     util.Typecheck(json_value, JsonObject)
    104     return dict([(prop.key, _JsonValueToPythonValue(prop.value)) for prop
    105                  in json_value.properties])
    106 
    107 
    108 def _JsonArrayToPythonValue(json_value):
    109     util.Typecheck(json_value, JsonArray)
    110     return [_JsonValueToPythonValue(e) for e in json_value.entries]
    111 
    112 
    113 _MAXINT64 = 2 << 63 - 1
    114 _MININT64 = -(2 << 63)
    115 
    116 
    117 def _PythonValueToJsonValue(py_value):
    118     """Convert the given python value to a JsonValue."""
    119     if py_value is None:
    120         return JsonValue(is_null=True)
    121     if isinstance(py_value, bool):
    122         return JsonValue(boolean_value=py_value)
    123     if isinstance(py_value, six.string_types):
    124         return JsonValue(string_value=py_value)
    125     if isinstance(py_value, numbers.Number):
    126         if isinstance(py_value, six.integer_types):
    127             if _MININT64 < py_value < _MAXINT64:
    128                 return JsonValue(integer_value=py_value)
    129         return JsonValue(double_value=float(py_value))
    130     if isinstance(py_value, dict):
    131         return JsonValue(object_value=_PythonValueToJsonObject(py_value))
    132     if isinstance(py_value, collections.Iterable):
    133         return JsonValue(array_value=_PythonValueToJsonArray(py_value))
    134     raise exceptions.InvalidDataError(
    135         'Cannot convert "%s" to JsonValue' % py_value)
    136 
    137 
    138 def _PythonValueToJsonObject(py_value):
    139     util.Typecheck(py_value, dict)
    140     return JsonObject(
    141         properties=[
    142             JsonObject.Property(key=key, value=_PythonValueToJsonValue(value))
    143             for key, value in py_value.items()])
    144 
    145 
    146 def _PythonValueToJsonArray(py_value):
    147     return JsonArray(entries=list(map(_PythonValueToJsonValue, py_value)))
    148 
    149 
    150 class JsonValue(messages.Message):
    151 
    152     """Any valid JSON value."""
    153     # Is this JSON object `null`?
    154     is_null = messages.BooleanField(1, default=False)
    155 
    156     # Exactly one of the following is provided if is_null is False; none
    157     # should be provided if is_null is True.
    158     boolean_value = messages.BooleanField(2)
    159     string_value = messages.StringField(3)
    160     # We keep two numeric fields to keep int64 round-trips exact.
    161     double_value = messages.FloatField(4, variant=messages.Variant.DOUBLE)
    162     integer_value = messages.IntegerField(5, variant=messages.Variant.INT64)
    163     # Compound types
    164     object_value = messages.MessageField('JsonObject', 6)
    165     array_value = messages.MessageField('JsonArray', 7)
    166 
    167 
    168 class JsonObject(messages.Message):
    169 
    170     """A JSON object value.
    171 
    172     Messages:
    173       Property: A property of a JsonObject.
    174 
    175     Fields:
    176       properties: A list of properties of a JsonObject.
    177     """
    178 
    179     class Property(messages.Message):
    180 
    181         """A property of a JSON object.
    182 
    183         Fields:
    184           key: Name of the property.
    185           value: A JsonValue attribute.
    186         """
    187         key = messages.StringField(1)
    188         value = messages.MessageField(JsonValue, 2)
    189 
    190     properties = messages.MessageField(Property, 1, repeated=True)
    191 
    192 
    193 class JsonArray(messages.Message):
    194 
    195     """A JSON array value."""
    196     entries = messages.MessageField(JsonValue, 1, repeated=True)
    197 
    198 
    199 _JSON_PROTO_TO_PYTHON_MAP = {
    200     JsonArray: _JsonArrayToPythonValue,
    201     JsonObject: _JsonObjectToPythonValue,
    202     JsonValue: _JsonValueToPythonValue,
    203 }
    204 _JSON_PROTO_TYPES = tuple(_JSON_PROTO_TO_PYTHON_MAP.keys())
    205 
    206 
    207 def _JsonProtoToPythonValue(json_proto):
    208     util.Typecheck(json_proto, _JSON_PROTO_TYPES)
    209     return _JSON_PROTO_TO_PYTHON_MAP[type(json_proto)](json_proto)
    210 
    211 
    212 def _PythonValueToJsonProto(py_value):
    213     if isinstance(py_value, dict):
    214         return _PythonValueToJsonObject(py_value)
    215     if (isinstance(py_value, collections.Iterable) and
    216             not isinstance(py_value, six.string_types)):
    217         return _PythonValueToJsonArray(py_value)
    218     return _PythonValueToJsonValue(py_value)
    219 
    220 
    221 def _JsonProtoToJson(json_proto, unused_encoder=None):
    222     return json.dumps(_JsonProtoToPythonValue(json_proto))
    223 
    224 
    225 def _JsonToJsonProto(json_data, unused_decoder=None):
    226     return _PythonValueToJsonProto(json.loads(json_data))
    227 
    228 
    229 def _JsonToJsonValue(json_data, unused_decoder=None):
    230     result = _PythonValueToJsonProto(json.loads(json_data))
    231     if isinstance(result, JsonValue):
    232         return result
    233     elif isinstance(result, JsonObject):
    234         return JsonValue(object_value=result)
    235     elif isinstance(result, JsonArray):
    236         return JsonValue(array_value=result)
    237     else:
    238         raise exceptions.InvalidDataError(
    239             'Malformed JsonValue: %s' % json_data)
    240 
    241 
    242 # pylint:disable=invalid-name
    243 JsonProtoEncoder = _JsonProtoToJson
    244 JsonProtoDecoder = _JsonToJsonProto
    245 # pylint:enable=invalid-name
    246 encoding.RegisterCustomMessageCodec(
    247     encoder=JsonProtoEncoder, decoder=_JsonToJsonValue)(JsonValue)
    248 encoding.RegisterCustomMessageCodec(
    249     encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonObject)
    250 encoding.RegisterCustomMessageCodec(
    251     encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonArray)
    252 
    253 
    254 def _EncodeDateTimeField(field, value):
    255     result = protojson.ProtoJson().encode_field(field, value)
    256     return encoding.CodecResult(value=result, complete=True)
    257 
    258 
    259 def _DecodeDateTimeField(unused_field, value):
    260     result = protojson.ProtoJson().decode_field(
    261         message_types.DateTimeField(1), value)
    262     return encoding.CodecResult(value=result, complete=True)
    263 
    264 
    265 encoding.RegisterFieldTypeCodec(_EncodeDateTimeField, _DecodeDateTimeField)(
    266     message_types.DateTimeField)
    267 
    268 
    269 def _EncodeInt64Field(field, value):
    270     """Handle the special case of int64 as a string."""
    271     capabilities = [
    272         messages.Variant.INT64,
    273         messages.Variant.UINT64,
    274     ]
    275     if field.variant not in capabilities:
    276         return encoding.CodecResult(value=value, complete=False)
    277 
    278     if field.repeated:
    279         result = [str(x) for x in value]
    280     else:
    281         result = str(value)
    282     return encoding.CodecResult(value=result, complete=True)
    283 
    284 
    285 def _DecodeInt64Field(unused_field, value):
    286     # Don't need to do anything special, they're decoded just fine
    287     return encoding.CodecResult(value=value, complete=False)
    288 
    289 encoding.RegisterFieldTypeCodec(_EncodeInt64Field, _DecodeInt64Field)(
    290     messages.IntegerField)
    291 
    292 
    293 def _EncodeDateField(field, value):
    294     """Encoder for datetime.date objects."""
    295     if field.repeated:
    296         result = [d.isoformat() for d in value]
    297     else:
    298         result = value.isoformat()
    299     return encoding.CodecResult(value=result, complete=True)
    300 
    301 
    302 def _DecodeDateField(unused_field, value):
    303     date = datetime.datetime.strptime(value, '%Y-%m-%d').date()
    304     return encoding.CodecResult(value=date, complete=True)
    305 
    306 encoding.RegisterFieldTypeCodec(_EncodeDateField, _DecodeDateField)(DateField)
    307