Home | History | Annotate | Download | only in protorpclite
      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2010 Google Inc.
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 #     http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 #
     17 
     18 """Test utilities for message testing.
     19 
     20 Includes module interface test to ensure that public parts of module are
     21 correctly declared in __all__.
     22 
     23 Includes message types that correspond to those defined in
     24 services_test.proto.
     25 
     26 Includes additional test utilities to make sure encoding/decoding libraries
     27 conform.
     28 """
     29 import cgi
     30 import datetime
     31 import inspect
     32 import os
     33 import re
     34 import socket
     35 import types
     36 
     37 import six
     38 from six.moves import range  # pylint: disable=redefined-builtin
     39 import unittest2 as unittest
     40 
     41 from apitools.base.protorpclite import message_types
     42 from apitools.base.protorpclite import messages
     43 from apitools.base.protorpclite import util
     44 
     45 # Unicode of the word "Russian" in cyrillic.
     46 RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439'
     47 
     48 # All characters binary value interspersed with nulls.
     49 BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256))
     50 
     51 
     52 class TestCase(unittest.TestCase):
     53 
     54     def assertRaisesWithRegexpMatch(self,
     55                                     exception,
     56                                     regexp,
     57                                     function,
     58                                     *params,
     59                                     **kwargs):
     60         """Check that exception is raised and text matches regular expression.
     61 
     62         Args:
     63           exception: Exception type that is expected.
     64           regexp: String regular expression that is expected in error message.
     65           function: Callable to test.
     66           params: Parameters to forward to function.
     67           kwargs: Keyword arguments to forward to function.
     68         """
     69         try:
     70             function(*params, **kwargs)
     71             self.fail('Expected exception %s was not raised' %
     72                       exception.__name__)
     73         except exception as err:
     74             match = bool(re.match(regexp, str(err)))
     75             self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp,
     76                                                                         err))
     77 
     78     def assertHeaderSame(self, header1, header2):
     79         """Check that two HTTP headers are the same.
     80 
     81         Args:
     82           header1: Header value string 1.
     83           header2: header value string 2.
     84         """
     85         value1, params1 = cgi.parse_header(header1)
     86         value2, params2 = cgi.parse_header(header2)
     87         self.assertEqual(value1, value2)
     88         self.assertEqual(params1, params2)
     89 
     90     def assertIterEqual(self, iter1, iter2):
     91         """Check that two iterators or iterables are equal independent of order.
     92 
     93         Similar to Python 2.7 assertItemsEqual.  Named differently in order to
     94         avoid potential conflict.
     95 
     96         Args:
     97           iter1: An iterator or iterable.
     98           iter2: An iterator or iterable.
     99         """
    100         list1 = list(iter1)
    101         list2 = list(iter2)
    102 
    103         unmatched1 = list()
    104 
    105         while list1:
    106             item1 = list1[0]
    107             del list1[0]
    108             for index in range(len(list2)):
    109                 if item1 == list2[index]:
    110                     del list2[index]
    111                     break
    112             else:
    113                 unmatched1.append(item1)
    114 
    115         error_message = []
    116         for item in unmatched1:
    117             error_message.append(
    118                 '  Item from iter1 not found in iter2: %r' % item)
    119         for item in list2:
    120             error_message.append(
    121                 '  Item from iter2 not found in iter1: %r' % item)
    122         if error_message:
    123             self.fail('Collections not equivalent:\n' +
    124                       '\n'.join(error_message))
    125 
    126 
    127 class ModuleInterfaceTest(object):
    128     """Test to ensure module interface is carefully constructed.
    129 
    130     A module interface is the set of public objects listed in the
    131     module __all__ attribute. Modules that that are considered public
    132     should have this interface carefully declared. At all times, the
    133     __all__ attribute should have objects intended to be publically
    134     used and all other objects in the module should be considered
    135     unused.
    136 
    137     Protected attributes (those beginning with '_') and other imported
    138     modules should not be part of this set of variables. An exception
    139     is for variables that begin and end with '__' which are implicitly
    140     part of the interface (eg. __name__, __file__, __all__ itself,
    141     etc.).
    142 
    143     Modules that are imported in to the tested modules are an
    144     exception and may be left out of the __all__ definition. The test
    145     is done by checking the value of what would otherwise be a public
    146     name and not allowing it to be exported if it is an instance of a
    147     module. Modules that are explicitly exported are for the time
    148     being not permitted.
    149 
    150     To use this test class a module should define a new class that
    151     inherits first from ModuleInterfaceTest and then from
    152     test_util.TestCase. No other tests should be added to this test
    153     case, making the order of inheritance less important, but if setUp
    154     for some reason is overidden, it is important that
    155     ModuleInterfaceTest is first in the list so that its setUp method
    156     is invoked.
    157 
    158     Multiple inheritance is required so that ModuleInterfaceTest is
    159     not itself a test, and is not itself executed as one.
    160 
    161     The test class is expected to have the following class attributes
    162     defined:
    163 
    164       MODULE: A reference to the module that is being validated for interface
    165         correctness.
    166 
    167     Example:
    168       Module definition (hello.py):
    169 
    170         import sys
    171 
    172         __all__ = ['hello']
    173 
    174         def _get_outputter():
    175           return sys.stdout
    176 
    177         def hello():
    178           _get_outputter().write('Hello\n')
    179 
    180       Test definition:
    181 
    182         import unittest
    183         from protorpc import test_util
    184 
    185         import hello
    186 
    187         class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
    188                                   test_util.TestCase):
    189 
    190           MODULE = hello
    191 
    192 
    193         class HelloTest(test_util.TestCase):
    194           ... Test 'hello' module ...
    195 
    196 
    197         if __name__ == '__main__':
    198           unittest.main()
    199 
    200     """
    201 
    202     def setUp(self):
    203         """Set up makes sure that MODULE and IMPORTED_MODULES is defined.
    204 
    205         This is a basic configuration test for the test itself so does not
    206         get it's own test case.
    207         """
    208         if not hasattr(self, 'MODULE'):
    209             self.fail(
    210                 "You must define 'MODULE' on ModuleInterfaceTest sub-class "
    211                 "%s." % type(self).__name__)
    212 
    213     def testAllExist(self):
    214         """Test that all attributes defined in __all__ exist."""
    215         missing_attributes = []
    216         for attribute in self.MODULE.__all__:
    217             if not hasattr(self.MODULE, attribute):
    218                 missing_attributes.append(attribute)
    219         if missing_attributes:
    220             self.fail('%s of __all__ are not defined in module.' %
    221                       missing_attributes)
    222 
    223     def testAllExported(self):
    224         """Test that all public attributes not imported are in __all__."""
    225         missing_attributes = []
    226         for attribute in dir(self.MODULE):
    227             if not attribute.startswith('_'):
    228                 if (attribute not in self.MODULE.__all__ and
    229                         not isinstance(getattr(self.MODULE, attribute),
    230                                        types.ModuleType) and
    231                         attribute != 'with_statement'):
    232                     missing_attributes.append(attribute)
    233         if missing_attributes:
    234             self.fail('%s are not modules and not defined in __all__.' %
    235                       missing_attributes)
    236 
    237     def testNoExportedProtectedVariables(self):
    238         """Test that there are no protected variables listed in __all__."""
    239         protected_variables = []
    240         for attribute in self.MODULE.__all__:
    241             if attribute.startswith('_'):
    242                 protected_variables.append(attribute)
    243         if protected_variables:
    244             self.fail('%s are protected variables and may not be exported.' %
    245                       protected_variables)
    246 
    247     def testNoExportedModules(self):
    248         """Test that no modules exist in __all__."""
    249         exported_modules = []
    250         for attribute in self.MODULE.__all__:
    251             try:
    252                 value = getattr(self.MODULE, attribute)
    253             except AttributeError:
    254                 # This is a different error case tested for in testAllExist.
    255                 pass
    256             else:
    257                 if isinstance(value, types.ModuleType):
    258                     exported_modules.append(attribute)
    259         if exported_modules:
    260             self.fail('%s are modules and may not be exported.' %
    261                       exported_modules)
    262 
    263 
    264 class NestedMessage(messages.Message):
    265     """Simple message that gets nested in another message."""
    266 
    267     a_value = messages.StringField(1, required=True)
    268 
    269 
    270 class HasNestedMessage(messages.Message):
    271     """Message that has another message nested in it."""
    272 
    273     nested = messages.MessageField(NestedMessage, 1)
    274     repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True)
    275 
    276 
    277 class HasDefault(messages.Message):
    278     """Has a default value."""
    279 
    280     a_value = messages.StringField(1, default=u'a default')
    281 
    282 
    283 class OptionalMessage(messages.Message):
    284     """Contains all message types."""
    285 
    286     class SimpleEnum(messages.Enum):
    287         """Simple enumeration type."""
    288         VAL1 = 1
    289         VAL2 = 2
    290 
    291     double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE)
    292     float_value = messages.FloatField(2, variant=messages.Variant.FLOAT)
    293     int64_value = messages.IntegerField(3, variant=messages.Variant.INT64)
    294     uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64)
    295     int32_value = messages.IntegerField(5, variant=messages.Variant.INT32)
    296     bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL)
    297     string_value = messages.StringField(7, variant=messages.Variant.STRING)
    298     bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES)
    299     enum_value = messages.EnumField(SimpleEnum, 10)
    300 
    301 
    302 class RepeatedMessage(messages.Message):
    303     """Contains all message types as repeated fields."""
    304 
    305     class SimpleEnum(messages.Enum):
    306         """Simple enumeration type."""
    307         VAL1 = 1
    308         VAL2 = 2
    309 
    310     double_value = messages.FloatField(1,
    311                                        variant=messages.Variant.DOUBLE,
    312                                        repeated=True)
    313     float_value = messages.FloatField(2,
    314                                       variant=messages.Variant.FLOAT,
    315                                       repeated=True)
    316     int64_value = messages.IntegerField(3,
    317                                         variant=messages.Variant.INT64,
    318                                         repeated=True)
    319     uint64_value = messages.IntegerField(4,
    320                                          variant=messages.Variant.UINT64,
    321                                          repeated=True)
    322     int32_value = messages.IntegerField(5,
    323                                         variant=messages.Variant.INT32,
    324                                         repeated=True)
    325     bool_value = messages.BooleanField(6,
    326                                        variant=messages.Variant.BOOL,
    327                                        repeated=True)
    328     string_value = messages.StringField(7,
    329                                         variant=messages.Variant.STRING,
    330                                         repeated=True)
    331     bytes_value = messages.BytesField(8,
    332                                       variant=messages.Variant.BYTES,
    333                                       repeated=True)
    334     enum_value = messages.EnumField(SimpleEnum,
    335                                     10,
    336                                     repeated=True)
    337 
    338 
    339 class HasOptionalNestedMessage(messages.Message):
    340 
    341     nested = messages.MessageField(OptionalMessage, 1)
    342     repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True)
    343 
    344 
    345 # pylint:disable=anomalous-unicode-escape-in-string
    346 class ProtoConformanceTestBase(object):
    347     """Protocol conformance test base class.
    348 
    349     Each supported protocol should implement two methods that support encoding
    350     and decoding of Message objects in that format:
    351 
    352       encode_message(message) - Serialize to encoding.
    353       encode_message(message, encoded_message) - Deserialize from encoding.
    354 
    355     Tests for the modules where these functions are implemented should extend
    356     this class in order to support basic behavioral expectations.  This ensures
    357     that protocols correctly encode and decode message transparently to the
    358     caller.
    359 
    360     In order to support these test, the base class should also extend
    361     the TestCase class and implement the following class attributes
    362     which define the encoded version of certain protocol buffers:
    363 
    364       encoded_partial:
    365         <OptionalMessage
    366           double_value: 1.23
    367           int64_value: -100000000000
    368           string_value: u"a string"
    369           enum_value: OptionalMessage.SimpleEnum.VAL2
    370           >
    371 
    372       encoded_full:
    373         <OptionalMessage
    374           double_value: 1.23
    375           float_value: -2.5
    376           int64_value: -100000000000
    377           uint64_value: 102020202020
    378           int32_value: 1020
    379           bool_value: true
    380           string_value: u"a string\u044f"
    381           bytes_value: b"a bytes\xff\xfe"
    382           enum_value: OptionalMessage.SimpleEnum.VAL2
    383           >
    384 
    385       encoded_repeated:
    386         <RepeatedMessage
    387           double_value: [1.23, 2.3]
    388           float_value: [-2.5, 0.5]
    389           int64_value: [-100000000000, 20]
    390           uint64_value: [102020202020, 10]
    391           int32_value: [1020, 718]
    392           bool_value: [true, false]
    393           string_value: [u"a string\u044f", u"another string"]
    394           bytes_value: [b"a bytes\xff\xfe", b"another bytes"]
    395           enum_value: [OptionalMessage.SimpleEnum.VAL2,
    396                        OptionalMessage.SimpleEnum.VAL 1]
    397           >
    398 
    399       encoded_nested:
    400         <HasNestedMessage
    401           nested: <NestedMessage
    402             a_value: "a string"
    403             >
    404           >
    405 
    406       encoded_repeated_nested:
    407         <HasNestedMessage
    408           repeated_nested: [
    409               <NestedMessage a_value: "a string">,
    410               <NestedMessage a_value: "another string">
    411             ]
    412           >
    413 
    414       unexpected_tag_message:
    415         An encoded message that has an undefined tag or number in the stream.
    416 
    417       encoded_default_assigned:
    418         <HasDefault
    419           a_value: "a default"
    420           >
    421 
    422       encoded_nested_empty:
    423         <HasOptionalNestedMessage
    424           nested: <OptionalMessage>
    425           >
    426 
    427       encoded_invalid_enum:
    428         <OptionalMessage
    429           enum_value: (invalid value for serialization type)
    430           >
    431     """
    432 
    433     encoded_empty_message = ''
    434 
    435     def testEncodeInvalidMessage(self):
    436         message = NestedMessage()
    437         self.assertRaises(messages.ValidationError,
    438                           self.PROTOLIB.encode_message, message)
    439 
    440     def CompareEncoded(self, expected_encoded, actual_encoded):
    441         """Compare two encoded protocol values.
    442 
    443         Can be overridden by sub-classes to special case comparison.
    444         For example, to eliminate white space from output that is not
    445         relevant to encoding.
    446 
    447         Args:
    448           expected_encoded: Expected string encoded value.
    449           actual_encoded: Actual string encoded value.
    450         """
    451         self.assertEquals(expected_encoded, actual_encoded)
    452 
    453     def EncodeDecode(self, encoded, expected_message):
    454         message = self.PROTOLIB.decode_message(type(expected_message), encoded)
    455         self.assertEquals(expected_message, message)
    456         self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message))
    457 
    458     def testEmptyMessage(self):
    459         self.EncodeDecode(self.encoded_empty_message, OptionalMessage())
    460 
    461     def testPartial(self):
    462         """Test message with a few values set."""
    463         message = OptionalMessage()
    464         message.double_value = 1.23
    465         message.int64_value = -100000000000
    466         message.int32_value = 1020
    467         message.string_value = u'a string'
    468         message.enum_value = OptionalMessage.SimpleEnum.VAL2
    469 
    470         self.EncodeDecode(self.encoded_partial, message)
    471 
    472     def testFull(self):
    473         """Test all types."""
    474         message = OptionalMessage()
    475         message.double_value = 1.23
    476         message.float_value = -2.5
    477         message.int64_value = -100000000000
    478         message.uint64_value = 102020202020
    479         message.int32_value = 1020
    480         message.bool_value = True
    481         message.string_value = u'a string\u044f'
    482         message.bytes_value = b'a bytes\xff\xfe'
    483         message.enum_value = OptionalMessage.SimpleEnum.VAL2
    484 
    485         self.EncodeDecode(self.encoded_full, message)
    486 
    487     def testRepeated(self):
    488         """Test repeated fields."""
    489         message = RepeatedMessage()
    490         message.double_value = [1.23, 2.3]
    491         message.float_value = [-2.5, 0.5]
    492         message.int64_value = [-100000000000, 20]
    493         message.uint64_value = [102020202020, 10]
    494         message.int32_value = [1020, 718]
    495         message.bool_value = [True, False]
    496         message.string_value = [u'a string\u044f', u'another string']
    497         message.bytes_value = [b'a bytes\xff\xfe', b'another bytes']
    498         message.enum_value = [RepeatedMessage.SimpleEnum.VAL2,
    499                               RepeatedMessage.SimpleEnum.VAL1]
    500 
    501         self.EncodeDecode(self.encoded_repeated, message)
    502 
    503     def testNested(self):
    504         """Test nested messages."""
    505         nested_message = NestedMessage()
    506         nested_message.a_value = u'a string'
    507 
    508         message = HasNestedMessage()
    509         message.nested = nested_message
    510 
    511         self.EncodeDecode(self.encoded_nested, message)
    512 
    513     def testRepeatedNested(self):
    514         """Test repeated nested messages."""
    515         nested_message1 = NestedMessage()
    516         nested_message1.a_value = u'a string'
    517         nested_message2 = NestedMessage()
    518         nested_message2.a_value = u'another string'
    519 
    520         message = HasNestedMessage()
    521         message.repeated_nested = [nested_message1, nested_message2]
    522 
    523         self.EncodeDecode(self.encoded_repeated_nested, message)
    524 
    525     def testStringTypes(self):
    526         """Test that encoding str on StringField works."""
    527         message = OptionalMessage()
    528         message.string_value = 'Latin'
    529         self.EncodeDecode(self.encoded_string_types, message)
    530 
    531     def testEncodeUninitialized(self):
    532         """Test that cannot encode uninitialized message."""
    533         required = NestedMessage()
    534         self.assertRaisesWithRegexpMatch(messages.ValidationError,
    535                                          "Message NestedMessage is missing "
    536                                          "required field a_value",
    537                                          self.PROTOLIB.encode_message,
    538                                          required)
    539 
    540     def testUnexpectedField(self):
    541         """Test decoding and encoding unexpected fields."""
    542         loaded_message = self.PROTOLIB.decode_message(
    543             OptionalMessage, self.unexpected_tag_message)
    544         # Message should be equal to an empty message, since unknown
    545         # values aren't included in equality.
    546         self.assertEquals(OptionalMessage(), loaded_message)
    547         # Verify that the encoded message matches the source, including the
    548         # unknown value.
    549         self.assertEquals(self.unexpected_tag_message,
    550                           self.PROTOLIB.encode_message(loaded_message))
    551 
    552     def testDoNotSendDefault(self):
    553         """Test that default is not sent when nothing is assigned."""
    554         self.EncodeDecode(self.encoded_empty_message, HasDefault())
    555 
    556     def testSendDefaultExplicitlyAssigned(self):
    557         """Test that default is sent when explcitly assigned."""
    558         message = HasDefault()
    559 
    560         message.a_value = HasDefault.a_value.default
    561 
    562         self.EncodeDecode(self.encoded_default_assigned, message)
    563 
    564     def testEncodingNestedEmptyMessage(self):
    565         """Test encoding a nested empty message."""
    566         message = HasOptionalNestedMessage()
    567         message.nested = OptionalMessage()
    568 
    569         self.EncodeDecode(self.encoded_nested_empty, message)
    570 
    571     def testEncodingRepeatedNestedEmptyMessage(self):
    572         """Test encoding a nested empty message."""
    573         message = HasOptionalNestedMessage()
    574         message.repeated_nested = [OptionalMessage(), OptionalMessage()]
    575 
    576         self.EncodeDecode(self.encoded_repeated_nested_empty, message)
    577 
    578     def testContentType(self):
    579         self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str))
    580 
    581     def testDecodeInvalidEnumType(self):
    582         self.assertRaisesWithRegexpMatch(messages.DecodeError,
    583                                          'Invalid enum value ',
    584                                          self.PROTOLIB.decode_message,
    585                                          OptionalMessage,
    586                                          self.encoded_invalid_enum)
    587 
    588     def testDateTimeNoTimeZone(self):
    589         """Test that DateTimeFields are encoded/decoded correctly."""
    590 
    591         class MyMessage(messages.Message):
    592             value = message_types.DateTimeField(1)
    593 
    594         value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000)
    595         message = MyMessage(value=value)
    596         decoded = self.PROTOLIB.decode_message(
    597             MyMessage, self.PROTOLIB.encode_message(message))
    598         self.assertEquals(decoded.value, value)
    599 
    600     def testDateTimeWithTimeZone(self):
    601         """Test DateTimeFields with time zones."""
    602 
    603         class MyMessage(messages.Message):
    604             value = message_types.DateTimeField(1)
    605 
    606         value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000,
    607                                   util.TimeZoneOffset(8 * 60))
    608         message = MyMessage(value=value)
    609         decoded = self.PROTOLIB.decode_message(
    610             MyMessage, self.PROTOLIB.encode_message(message))
    611         self.assertEquals(decoded.value, value)
    612 
    613 
    614 def pick_unused_port():
    615     """Find an unused port to use in tests.
    616 
    617       Derived from Damon Kohlers example:
    618 
    619         http://code.activestate.com/recipes/531822-pick-unused-port
    620     """
    621     temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    622     try:
    623         temp.bind(('localhost', 0))
    624         port = temp.getsockname()[1]
    625     finally:
    626         temp.close()
    627     return port
    628 
    629 
    630 def get_module_name(module_attribute):
    631     """Get the module name.
    632 
    633     Args:
    634       module_attribute: An attribute of the module.
    635 
    636     Returns:
    637       The fully qualified module name or simple module name where
    638       'module_attribute' is defined if the module name is "__main__".
    639     """
    640     if module_attribute.__module__ == '__main__':
    641         module_file = inspect.getfile(module_attribute)
    642         default = os.path.basename(module_file).split('.')[0]
    643         return default
    644     return module_attribute.__module__
    645