Home | History | Annotate | Download | only in protobuf
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Utility functions for comparing proto2 messages in Python.
     17 
     18 ProtoEq() compares two proto2 messages for equality.
     19 
     20 ClearDefaultValuedFields() recursively clears the fields that are set to their
     21 default values. This is useful for comparing protocol buffers where the
     22 semantics of unset fields and default valued fields are the same.
     23 
     24 assertProtoEqual() is useful for unit tests.  It produces much more helpful
     25 output than assertEqual() for proto2 messages, e.g. this:
     26 
     27   outer {
     28     inner {
     29 -     strings: "x"
     30 ?               ^
     31 +     strings: "y"
     32 ?               ^
     33     }
     34   }
     35 
     36 ...compared to the default output from assertEqual() that looks like this:
     37 
     38 AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc>
     39 
     40 Call it inside your unit test's googletest.TestCase subclasses like this:
     41 
     42   from tensorflow.python.util.protobuf import compare
     43 
     44   class MyTest(googletest.TestCase):
     45     ...
     46     def testXXX(self):
     47       ...
     48       compare.assertProtoEqual(self, a, b)
     49 
     50 Alternatively:
     51 
     52   from tensorflow.python.util.protobuf import compare
     53 
     54   class MyTest(compare.ProtoAssertions, googletest.TestCase):
     55     ...
     56     def testXXX(self):
     57       ...
     58       self.assertProtoEqual(a, b)
     59 """
     60 
     61 from __future__ import absolute_import
     62 from __future__ import division
     63 from __future__ import print_function
     64 
     65 import collections
     66 import difflib
     67 
     68 import six
     69 
     70 from google.protobuf import descriptor
     71 from google.protobuf import descriptor_pool
     72 from google.protobuf import message
     73 from google.protobuf import text_format
     74 
     75 
     76 def assertProtoEqual(self, a, b, check_initialized=True,  # pylint: disable=invalid-name
     77                      normalize_numbers=False, msg=None):
     78   """Fails with a useful error if a and b aren't equal.
     79 
     80   Comparison of repeated fields matches the semantics of
     81   unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
     82 
     83   Args:
     84     self: googletest.TestCase
     85     a: proto2 PB instance, or text string representing one.
     86     b: proto2 PB instance -- message.Message or subclass thereof.
     87     check_initialized: boolean, whether to fail if either a or b isn't
     88       initialized.
     89     normalize_numbers: boolean, whether to normalize types and precision of
     90       numbers before comparison.
     91     msg: if specified, is used as the error message on failure.
     92   """
     93   pool = descriptor_pool.Default()
     94   if isinstance(a, six.string_types):
     95     a = text_format.Merge(a, b.__class__(), descriptor_pool=pool)
     96 
     97   for pb in a, b:
     98     if check_initialized:
     99       errors = pb.FindInitializationErrors()
    100       if errors:
    101         self.fail('Initialization errors: %s\n%s' % (errors, pb))
    102     if normalize_numbers:
    103       NormalizeNumberFields(pb)
    104 
    105   a_str = text_format.MessageToString(a, descriptor_pool=pool)
    106   b_str = text_format.MessageToString(b, descriptor_pool=pool)
    107 
    108   # Some Python versions would perform regular diff instead of multi-line
    109   # diff if string is longer than 2**16. We substitute this behavior
    110   # with a call to unified_diff instead to have easier-to-read diffs.
    111   # For context, see: https://bugs.python.org/issue11763.
    112   if len(a_str) < 2**16 and len(b_str) < 2**16:
    113     self.assertMultiLineEqual(a_str, b_str, msg=msg)
    114   else:
    115     diff = '\n' + ''.join(difflib.unified_diff(a_str.splitlines(True),
    116                                                b_str.splitlines(True)))
    117     self.fail('%s : %s' % (msg, diff))
    118 
    119 
    120 def NormalizeNumberFields(pb):
    121   """Normalizes types and precisions of number fields in a protocol buffer.
    122 
    123   Due to subtleties in the python protocol buffer implementation, it is possible
    124   for values to have different types and precision depending on whether they
    125   were set and retrieved directly or deserialized from a protobuf. This function
    126   normalizes integer values to ints and longs based on width, 32-bit floats to
    127   five digits of precision to account for python always storing them as 64-bit,
    128   and ensures doubles are floating point for when they're set to integers.
    129 
    130   Modifies pb in place. Recurses into nested objects.
    131 
    132   Args:
    133     pb: proto2 message.
    134 
    135   Returns:
    136     the given pb, modified in place.
    137   """
    138   for desc, values in pb.ListFields():
    139     is_repeated = True
    140     if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED:
    141       is_repeated = False
    142       values = [values]
    143 
    144     normalized_values = None
    145 
    146     # We force 32-bit values to int and 64-bit values to long to make
    147     # alternate implementations where the distinction is more significant
    148     # (e.g. the C++ implementation) simpler.
    149     if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
    150                      descriptor.FieldDescriptor.TYPE_UINT64,
    151                      descriptor.FieldDescriptor.TYPE_SINT64):
    152       normalized_values = [int(x) for x in values]
    153     elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
    154                        descriptor.FieldDescriptor.TYPE_UINT32,
    155                        descriptor.FieldDescriptor.TYPE_SINT32,
    156                        descriptor.FieldDescriptor.TYPE_ENUM):
    157       normalized_values = [int(x) for x in values]
    158     elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
    159       normalized_values = [round(x, 6) for x in values]
    160     elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
    161       normalized_values = [round(float(x), 7) for x in values]
    162 
    163     if normalized_values is not None:
    164       if is_repeated:
    165         pb.ClearField(desc.name)
    166         getattr(pb, desc.name).extend(normalized_values)
    167       else:
    168         setattr(pb, desc.name, normalized_values[0])
    169 
    170     if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
    171         desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
    172       if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
    173           desc.message_type.has_options and
    174           desc.message_type.GetOptions().map_entry):
    175         # This is a map, only recurse if the values have a message type.
    176         if (desc.message_type.fields_by_number[2].type ==
    177             descriptor.FieldDescriptor.TYPE_MESSAGE):
    178           for v in six.itervalues(values):
    179             NormalizeNumberFields(v)
    180       else:
    181         for v in values:
    182           # recursive step
    183           NormalizeNumberFields(v)
    184 
    185   return pb
    186 
    187 
    188 def _IsMap(value):
    189   return isinstance(value, collections.Mapping)
    190 
    191 
    192 def _IsRepeatedContainer(value):
    193   if isinstance(value, six.string_types):
    194     return False
    195   try:
    196     iter(value)
    197     return True
    198   except TypeError:
    199     return False
    200 
    201 
    202 def ProtoEq(a, b):
    203   """Compares two proto2 objects for equality.
    204 
    205   Recurses into nested messages. Uses list (not set) semantics for comparing
    206   repeated fields, ie duplicates and order matter.
    207 
    208   Args:
    209     a: A proto2 message or a primitive.
    210     b: A proto2 message or a primitive.
    211 
    212   Returns:
    213     `True` if the messages are equal.
    214   """
    215   def Format(pb):
    216     """Returns a dictionary or unchanged pb bases on its type.
    217 
    218     Specifically, this function returns a dictionary that maps tag
    219     number (for messages) or element index (for repeated fields) to
    220     value, or just pb unchanged if it's neither.
    221 
    222     Args:
    223       pb: A proto2 message or a primitive.
    224     Returns:
    225       A dict or unchanged pb.
    226     """
    227     if isinstance(pb, message.Message):
    228       return dict((desc.number, value) for desc, value in pb.ListFields())
    229     elif _IsMap(pb):
    230       return dict(pb.items())
    231     elif _IsRepeatedContainer(pb):
    232       return dict(enumerate(list(pb)))
    233     else:
    234       return pb
    235 
    236   a, b = Format(a), Format(b)
    237 
    238   # Base case
    239   if not isinstance(a, dict) or not isinstance(b, dict):
    240     return a == b
    241 
    242   # This list performs double duty: it compares two messages by tag value *or*
    243   # two repeated fields by element, in order. the magic is in the format()
    244   # function, which converts them both to the same easily comparable format.
    245   for tag in sorted(set(a.keys()) | set(b.keys())):
    246     if tag not in a or tag not in b:
    247       return False
    248     else:
    249       # Recursive step
    250       if not ProtoEq(a[tag], b[tag]):
    251         return False
    252 
    253   # Didn't find any values that differed, so they're equal!
    254   return True
    255 
    256 
    257 class ProtoAssertions(object):
    258   """Mix this into a googletest.TestCase class to get proto2 assertions.
    259 
    260   Usage:
    261 
    262   class SomeTestCase(compare.ProtoAssertions, googletest.TestCase):
    263     ...
    264     def testSomething(self):
    265       ...
    266       self.assertProtoEqual(a, b)
    267 
    268   See module-level definitions for method documentation.
    269   """
    270 
    271   # pylint: disable=invalid-name
    272   def assertProtoEqual(self, *args, **kwargs):
    273     return assertProtoEqual(self, *args, **kwargs)
    274