Home | History | Annotate | Download | only in py
      1 #!/usr/bin/env python
      2 """Extra types understood by apitools.
      3 
      4 This file will be replaced by a .proto file when we switch to proto2
      5 from protorpc.
      6 """
      7 
      8 import collections
      9 import datetime
     10 import json
     11 import numbers
     12 
     13 from protorpc import message_types
     14 from protorpc import messages
     15 from protorpc import protojson
     16 import six
     17 
     18 from apitools.base.py import encoding
     19 from apitools.base.py import exceptions
     20 from apitools.base.py import util
     21 
     22 __all__ = [
     23     'DateField',
     24     'DateTimeMessage',
     25     'JsonArray',
     26     'JsonObject',
     27     'JsonValue',
     28     'JsonProtoEncoder',
     29     'JsonProtoDecoder',
     30 ]
     31 
     32 # We import from protorpc.
     33 # pylint:disable=invalid-name
     34 DateTimeMessage = message_types.DateTimeMessage
     35 # pylint:enable=invalid-name
     36 
     37 
     38 class DateField(messages.Field):
     39 
     40     """Field definition for Date values."""
     41 
     42     # We insert our own metaclass here to avoid letting ProtoRPC
     43     # register this as the default field type for strings.
     44     #  * since ProtoRPC does this via metaclasses, we don't have any
     45     #    choice but to use one ourselves
     46     #  * since a subclass's metaclass must inherit from its superclass's
     47     #    metaclass, we're forced to have this hard-to-read inheritance.
     48     #
     49     # pylint: disable=invalid-name
     50     class __metaclass__(messages.Field.__metaclass__):
     51 
     52         def __init__(cls, name, bases, dct):
     53             super(messages.Field.__metaclass__, cls).__init__(name, bases, dct)
     54     # pylint: enable=invalid-name
     55 
     56     VARIANTS = frozenset([messages.Variant.STRING])
     57     DEFAULT_VARIANT = messages.Variant.STRING
     58     type = datetime.date
     59 
     60 
     61 def _ValidateJsonValue(json_value):
     62     entries = [(f, json_value.get_assigned_value(f.name))
     63                for f in json_value.all_fields()]
     64     assigned_entries = [(f, value)
     65                         for f, value in entries if value is not None]
     66     if len(assigned_entries) != 1:
     67         raise exceptions.InvalidDataError(
     68             'Malformed JsonValue: %s' % json_value)
     69 
     70 
     71 def _JsonValueToPythonValue(json_value):
     72     """Convert the given JsonValue to a json string."""
     73     util.Typecheck(json_value, JsonValue)
     74     _ValidateJsonValue(json_value)
     75     if json_value.is_null:
     76         return None
     77     entries = [(f, json_value.get_assigned_value(f.name))
     78                for f in json_value.all_fields()]
     79     assigned_entries = [(f, value)
     80                         for f, value in entries if value is not None]
     81     field, value = assigned_entries[0]
     82     if not isinstance(field, messages.MessageField):
     83         return value
     84     elif field.message_type is JsonObject:
     85         return _JsonObjectToPythonValue(value)
     86     elif field.message_type is JsonArray:
     87         return _JsonArrayToPythonValue(value)
     88 
     89 
     90 def _JsonObjectToPythonValue(json_value):
     91     util.Typecheck(json_value, JsonObject)
     92     return dict([(prop.key, _JsonValueToPythonValue(prop.value)) for prop
     93                  in json_value.properties])
     94 
     95 
     96 def _JsonArrayToPythonValue(json_value):
     97     util.Typecheck(json_value, JsonArray)
     98     return [_JsonValueToPythonValue(e) for e in json_value.entries]
     99 
    100 
    101 _MAXINT64 = 2 << 63 - 1
    102 _MININT64 = -(2 << 63)
    103 
    104 
    105 def _PythonValueToJsonValue(py_value):
    106     """Convert the given python value to a JsonValue."""
    107     if py_value is None:
    108         return JsonValue(is_null=True)
    109     if isinstance(py_value, bool):
    110         return JsonValue(boolean_value=py_value)
    111     if isinstance(py_value, six.string_types):
    112         return JsonValue(string_value=py_value)
    113     if isinstance(py_value, numbers.Number):
    114         if isinstance(py_value, six.integer_types):
    115             if _MININT64 < py_value < _MAXINT64:
    116                 return JsonValue(integer_value=py_value)
    117         return JsonValue(double_value=float(py_value))
    118     if isinstance(py_value, dict):
    119         return JsonValue(object_value=_PythonValueToJsonObject(py_value))
    120     if isinstance(py_value, collections.Iterable):
    121         return JsonValue(array_value=_PythonValueToJsonArray(py_value))
    122     raise exceptions.InvalidDataError(
    123         'Cannot convert "%s" to JsonValue' % py_value)
    124 
    125 
    126 def _PythonValueToJsonObject(py_value):
    127     util.Typecheck(py_value, dict)
    128     return JsonObject(
    129         properties=[
    130             JsonObject.Property(key=key, value=_PythonValueToJsonValue(value))
    131             for key, value in py_value.items()])
    132 
    133 
    134 def _PythonValueToJsonArray(py_value):
    135     return JsonArray(entries=list(map(_PythonValueToJsonValue, py_value)))
    136 
    137 
    138 class JsonValue(messages.Message):
    139 
    140     """Any valid JSON value."""
    141     # Is this JSON object `null`?
    142     is_null = messages.BooleanField(1, default=False)
    143 
    144     # Exactly one of the following is provided if is_null is False; none
    145     # should be provided if is_null is True.
    146     boolean_value = messages.BooleanField(2)
    147     string_value = messages.StringField(3)
    148     # We keep two numeric fields to keep int64 round-trips exact.
    149     double_value = messages.FloatField(4, variant=messages.Variant.DOUBLE)
    150     integer_value = messages.IntegerField(5, variant=messages.Variant.INT64)
    151     # Compound types
    152     object_value = messages.MessageField('JsonObject', 6)
    153     array_value = messages.MessageField('JsonArray', 7)
    154 
    155 
    156 class JsonObject(messages.Message):
    157 
    158     """A JSON object value.
    159 
    160     Messages:
    161       Property: A property of a JsonObject.
    162 
    163     Fields:
    164       properties: A list of properties of a JsonObject.
    165     """
    166 
    167     class Property(messages.Message):
    168 
    169         """A property of a JSON object.
    170 
    171         Fields:
    172           key: Name of the property.
    173           value: A JsonValue attribute.
    174         """
    175         key = messages.StringField(1)
    176         value = messages.MessageField(JsonValue, 2)
    177 
    178     properties = messages.MessageField(Property, 1, repeated=True)
    179 
    180 
    181 class JsonArray(messages.Message):
    182 
    183     """A JSON array value."""
    184     entries = messages.MessageField(JsonValue, 1, repeated=True)
    185 
    186 
    187 _JSON_PROTO_TO_PYTHON_MAP = {
    188     JsonArray: _JsonArrayToPythonValue,
    189     JsonObject: _JsonObjectToPythonValue,
    190     JsonValue: _JsonValueToPythonValue,
    191 }
    192 _JSON_PROTO_TYPES = tuple(_JSON_PROTO_TO_PYTHON_MAP.keys())
    193 
    194 
    195 def _JsonProtoToPythonValue(json_proto):
    196     util.Typecheck(json_proto, _JSON_PROTO_TYPES)
    197     return _JSON_PROTO_TO_PYTHON_MAP[type(json_proto)](json_proto)
    198 
    199 
    200 def _PythonValueToJsonProto(py_value):
    201     if isinstance(py_value, dict):
    202         return _PythonValueToJsonObject(py_value)
    203     if (isinstance(py_value, collections.Iterable) and
    204             not isinstance(py_value, six.string_types)):
    205         return _PythonValueToJsonArray(py_value)
    206     return _PythonValueToJsonValue(py_value)
    207 
    208 
    209 def _JsonProtoToJson(json_proto, unused_encoder=None):
    210     return json.dumps(_JsonProtoToPythonValue(json_proto))
    211 
    212 
    213 def _JsonToJsonProto(json_data, unused_decoder=None):
    214     return _PythonValueToJsonProto(json.loads(json_data))
    215 
    216 
    217 def _JsonToJsonValue(json_data, unused_decoder=None):
    218     result = _PythonValueToJsonProto(json.loads(json_data))
    219     if isinstance(result, JsonValue):
    220         return result
    221     elif isinstance(result, JsonObject):
    222         return JsonValue(object_value=result)
    223     elif isinstance(result, JsonArray):
    224         return JsonValue(array_value=result)
    225     else:
    226         raise exceptions.InvalidDataError(
    227             'Malformed JsonValue: %s' % json_data)
    228 
    229 
    230 # pylint:disable=invalid-name
    231 JsonProtoEncoder = _JsonProtoToJson
    232 JsonProtoDecoder = _JsonToJsonProto
    233 # pylint:enable=invalid-name
    234 encoding.RegisterCustomMessageCodec(
    235     encoder=JsonProtoEncoder, decoder=_JsonToJsonValue)(JsonValue)
    236 encoding.RegisterCustomMessageCodec(
    237     encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonObject)
    238 encoding.RegisterCustomMessageCodec(
    239     encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonArray)
    240 
    241 
    242 def _EncodeDateTimeField(field, value):
    243     result = protojson.ProtoJson().encode_field(field, value)
    244     return encoding.CodecResult(value=result, complete=True)
    245 
    246 
    247 def _DecodeDateTimeField(unused_field, value):
    248     result = protojson.ProtoJson().decode_field(
    249         message_types.DateTimeField(1), value)
    250     return encoding.CodecResult(value=result, complete=True)
    251 
    252 
    253 encoding.RegisterFieldTypeCodec(_EncodeDateTimeField, _DecodeDateTimeField)(
    254     message_types.DateTimeField)
    255 
    256 
    257 def _EncodeInt64Field(field, value):
    258     """Handle the special case of int64 as a string."""
    259     capabilities = [
    260         messages.Variant.INT64,
    261         messages.Variant.UINT64,
    262     ]
    263     if field.variant not in capabilities:
    264         return encoding.CodecResult(value=value, complete=False)
    265 
    266     if field.repeated:
    267         result = [str(x) for x in value]
    268     else:
    269         result = str(value)
    270     return encoding.CodecResult(value=result, complete=True)
    271 
    272 
    273 def _DecodeInt64Field(unused_field, value):
    274     # Don't need to do anything special, they're decoded just fine
    275     return encoding.CodecResult(value=value, complete=False)
    276 
    277 encoding.RegisterFieldTypeCodec(_EncodeInt64Field, _DecodeInt64Field)(
    278     messages.IntegerField)
    279 
    280 
    281 def _EncodeDateField(field, value):
    282     """Encoder for datetime.date objects."""
    283     if field.repeated:
    284         result = [d.isoformat() for d in value]
    285     else:
    286         result = value.isoformat()
    287     return encoding.CodecResult(value=result, complete=True)
    288 
    289 
    290 def _DecodeDateField(unused_field, value):
    291     date = datetime.datetime.strptime(value, '%Y-%m-%d').date()
    292     return encoding.CodecResult(value=date, complete=True)
    293 
    294 encoding.RegisterFieldTypeCodec(_EncodeDateField, _DecodeDateField)(DateField)
    295