1 #!/usr/bin/env python 2 """Extra types understood by apitools. 3 4 This file will be replaced by a .proto file when we switch to proto2 5 from protorpc. 6 """ 7 8 import collections 9 import datetime 10 import json 11 import numbers 12 13 from protorpc import message_types 14 from protorpc import messages 15 from protorpc import protojson 16 import six 17 18 from apitools.base.py import encoding 19 from apitools.base.py import exceptions 20 from apitools.base.py import util 21 22 __all__ = [ 23 'DateField', 24 'DateTimeMessage', 25 'JsonArray', 26 'JsonObject', 27 'JsonValue', 28 'JsonProtoEncoder', 29 'JsonProtoDecoder', 30 ] 31 32 # We import from protorpc. 33 # pylint:disable=invalid-name 34 DateTimeMessage = message_types.DateTimeMessage 35 # pylint:enable=invalid-name 36 37 38 class DateField(messages.Field): 39 40 """Field definition for Date values.""" 41 42 # We insert our own metaclass here to avoid letting ProtoRPC 43 # register this as the default field type for strings. 44 # * since ProtoRPC does this via metaclasses, we don't have any 45 # choice but to use one ourselves 46 # * since a subclass's metaclass must inherit from its superclass's 47 # metaclass, we're forced to have this hard-to-read inheritance. 48 # 49 # pylint: disable=invalid-name 50 class __metaclass__(messages.Field.__metaclass__): 51 52 def __init__(cls, name, bases, dct): 53 super(messages.Field.__metaclass__, cls).__init__(name, bases, dct) 54 # pylint: enable=invalid-name 55 56 VARIANTS = frozenset([messages.Variant.STRING]) 57 DEFAULT_VARIANT = messages.Variant.STRING 58 type = datetime.date 59 60 61 def _ValidateJsonValue(json_value): 62 entries = [(f, json_value.get_assigned_value(f.name)) 63 for f in json_value.all_fields()] 64 assigned_entries = [(f, value) 65 for f, value in entries if value is not None] 66 if len(assigned_entries) != 1: 67 raise exceptions.InvalidDataError( 68 'Malformed JsonValue: %s' % json_value) 69 70 71 def _JsonValueToPythonValue(json_value): 72 """Convert the given JsonValue to a json string.""" 73 util.Typecheck(json_value, JsonValue) 74 _ValidateJsonValue(json_value) 75 if json_value.is_null: 76 return None 77 entries = [(f, json_value.get_assigned_value(f.name)) 78 for f in json_value.all_fields()] 79 assigned_entries = [(f, value) 80 for f, value in entries if value is not None] 81 field, value = assigned_entries[0] 82 if not isinstance(field, messages.MessageField): 83 return value 84 elif field.message_type is JsonObject: 85 return _JsonObjectToPythonValue(value) 86 elif field.message_type is JsonArray: 87 return _JsonArrayToPythonValue(value) 88 89 90 def _JsonObjectToPythonValue(json_value): 91 util.Typecheck(json_value, JsonObject) 92 return dict([(prop.key, _JsonValueToPythonValue(prop.value)) for prop 93 in json_value.properties]) 94 95 96 def _JsonArrayToPythonValue(json_value): 97 util.Typecheck(json_value, JsonArray) 98 return [_JsonValueToPythonValue(e) for e in json_value.entries] 99 100 101 _MAXINT64 = 2 << 63 - 1 102 _MININT64 = -(2 << 63) 103 104 105 def _PythonValueToJsonValue(py_value): 106 """Convert the given python value to a JsonValue.""" 107 if py_value is None: 108 return JsonValue(is_null=True) 109 if isinstance(py_value, bool): 110 return JsonValue(boolean_value=py_value) 111 if isinstance(py_value, six.string_types): 112 return JsonValue(string_value=py_value) 113 if isinstance(py_value, numbers.Number): 114 if isinstance(py_value, six.integer_types): 115 if _MININT64 < py_value < _MAXINT64: 116 return JsonValue(integer_value=py_value) 117 return JsonValue(double_value=float(py_value)) 118 if isinstance(py_value, dict): 119 return JsonValue(object_value=_PythonValueToJsonObject(py_value)) 120 if isinstance(py_value, collections.Iterable): 121 return JsonValue(array_value=_PythonValueToJsonArray(py_value)) 122 raise exceptions.InvalidDataError( 123 'Cannot convert "%s" to JsonValue' % py_value) 124 125 126 def _PythonValueToJsonObject(py_value): 127 util.Typecheck(py_value, dict) 128 return JsonObject( 129 properties=[ 130 JsonObject.Property(key=key, value=_PythonValueToJsonValue(value)) 131 for key, value in py_value.items()]) 132 133 134 def _PythonValueToJsonArray(py_value): 135 return JsonArray(entries=list(map(_PythonValueToJsonValue, py_value))) 136 137 138 class JsonValue(messages.Message): 139 140 """Any valid JSON value.""" 141 # Is this JSON object `null`? 142 is_null = messages.BooleanField(1, default=False) 143 144 # Exactly one of the following is provided if is_null is False; none 145 # should be provided if is_null is True. 146 boolean_value = messages.BooleanField(2) 147 string_value = messages.StringField(3) 148 # We keep two numeric fields to keep int64 round-trips exact. 149 double_value = messages.FloatField(4, variant=messages.Variant.DOUBLE) 150 integer_value = messages.IntegerField(5, variant=messages.Variant.INT64) 151 # Compound types 152 object_value = messages.MessageField('JsonObject', 6) 153 array_value = messages.MessageField('JsonArray', 7) 154 155 156 class JsonObject(messages.Message): 157 158 """A JSON object value. 159 160 Messages: 161 Property: A property of a JsonObject. 162 163 Fields: 164 properties: A list of properties of a JsonObject. 165 """ 166 167 class Property(messages.Message): 168 169 """A property of a JSON object. 170 171 Fields: 172 key: Name of the property. 173 value: A JsonValue attribute. 174 """ 175 key = messages.StringField(1) 176 value = messages.MessageField(JsonValue, 2) 177 178 properties = messages.MessageField(Property, 1, repeated=True) 179 180 181 class JsonArray(messages.Message): 182 183 """A JSON array value.""" 184 entries = messages.MessageField(JsonValue, 1, repeated=True) 185 186 187 _JSON_PROTO_TO_PYTHON_MAP = { 188 JsonArray: _JsonArrayToPythonValue, 189 JsonObject: _JsonObjectToPythonValue, 190 JsonValue: _JsonValueToPythonValue, 191 } 192 _JSON_PROTO_TYPES = tuple(_JSON_PROTO_TO_PYTHON_MAP.keys()) 193 194 195 def _JsonProtoToPythonValue(json_proto): 196 util.Typecheck(json_proto, _JSON_PROTO_TYPES) 197 return _JSON_PROTO_TO_PYTHON_MAP[type(json_proto)](json_proto) 198 199 200 def _PythonValueToJsonProto(py_value): 201 if isinstance(py_value, dict): 202 return _PythonValueToJsonObject(py_value) 203 if (isinstance(py_value, collections.Iterable) and 204 not isinstance(py_value, six.string_types)): 205 return _PythonValueToJsonArray(py_value) 206 return _PythonValueToJsonValue(py_value) 207 208 209 def _JsonProtoToJson(json_proto, unused_encoder=None): 210 return json.dumps(_JsonProtoToPythonValue(json_proto)) 211 212 213 def _JsonToJsonProto(json_data, unused_decoder=None): 214 return _PythonValueToJsonProto(json.loads(json_data)) 215 216 217 def _JsonToJsonValue(json_data, unused_decoder=None): 218 result = _PythonValueToJsonProto(json.loads(json_data)) 219 if isinstance(result, JsonValue): 220 return result 221 elif isinstance(result, JsonObject): 222 return JsonValue(object_value=result) 223 elif isinstance(result, JsonArray): 224 return JsonValue(array_value=result) 225 else: 226 raise exceptions.InvalidDataError( 227 'Malformed JsonValue: %s' % json_data) 228 229 230 # pylint:disable=invalid-name 231 JsonProtoEncoder = _JsonProtoToJson 232 JsonProtoDecoder = _JsonToJsonProto 233 # pylint:enable=invalid-name 234 encoding.RegisterCustomMessageCodec( 235 encoder=JsonProtoEncoder, decoder=_JsonToJsonValue)(JsonValue) 236 encoding.RegisterCustomMessageCodec( 237 encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonObject) 238 encoding.RegisterCustomMessageCodec( 239 encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonArray) 240 241 242 def _EncodeDateTimeField(field, value): 243 result = protojson.ProtoJson().encode_field(field, value) 244 return encoding.CodecResult(value=result, complete=True) 245 246 247 def _DecodeDateTimeField(unused_field, value): 248 result = protojson.ProtoJson().decode_field( 249 message_types.DateTimeField(1), value) 250 return encoding.CodecResult(value=result, complete=True) 251 252 253 encoding.RegisterFieldTypeCodec(_EncodeDateTimeField, _DecodeDateTimeField)( 254 message_types.DateTimeField) 255 256 257 def _EncodeInt64Field(field, value): 258 """Handle the special case of int64 as a string.""" 259 capabilities = [ 260 messages.Variant.INT64, 261 messages.Variant.UINT64, 262 ] 263 if field.variant not in capabilities: 264 return encoding.CodecResult(value=value, complete=False) 265 266 if field.repeated: 267 result = [str(x) for x in value] 268 else: 269 result = str(value) 270 return encoding.CodecResult(value=result, complete=True) 271 272 273 def _DecodeInt64Field(unused_field, value): 274 # Don't need to do anything special, they're decoded just fine 275 return encoding.CodecResult(value=value, complete=False) 276 277 encoding.RegisterFieldTypeCodec(_EncodeInt64Field, _DecodeInt64Field)( 278 messages.IntegerField) 279 280 281 def _EncodeDateField(field, value): 282 """Encoder for datetime.date objects.""" 283 if field.repeated: 284 result = [d.isoformat() for d in value] 285 else: 286 result = value.isoformat() 287 return encoding.CodecResult(value=result, complete=True) 288 289 290 def _DecodeDateField(unused_field, value): 291 date = datetime.datetime.strptime(value, '%Y-%m-%d').date() 292 return encoding.CodecResult(value=date, complete=True) 293 294 encoding.RegisterFieldTypeCodec(_EncodeDateField, _DecodeDateField)(DateField) 295