Home | History | Annotate | Download | only in hosts
      1 # Copyright 2016 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import cStringIO
      6 import inspect
      7 import json
      8 import unittest
      9 
     10 import common
     11 from autotest_lib.server.hosts import host_info
     12 
     13 
     14 class HostInfoTest(unittest.TestCase):
     15     """Tests the non-trivial attributes of HostInfo."""
     16 
     17     def setUp(self):
     18         self.info = host_info.HostInfo()
     19 
     20     def test_info_comparison_to_wrong_type(self):
     21         """Comparing HostInfo to a different type always returns False."""
     22         self.assertNotEqual(host_info.HostInfo(), 42)
     23         self.assertNotEqual(host_info.HostInfo(), None)
     24         # equality and non-equality are unrelated by the data model.
     25         self.assertFalse(host_info.HostInfo() == 42)
     26         self.assertFalse(host_info.HostInfo() == None)
     27 
     28 
     29     def test_empty_infos_are_equal(self):
     30         """Tests that empty HostInfo objects are considered equal."""
     31         self.assertEqual(host_info.HostInfo(), host_info.HostInfo())
     32         # equality and non-equality are unrelated by the data model.
     33         self.assertFalse(host_info.HostInfo() != host_info.HostInfo())
     34 
     35 
     36     def test_non_trivial_infos_are_equal(self):
     37         """Tests that the most complicated infos are correctly stated equal."""
     38         info1 = host_info.HostInfo(
     39                 labels=['label1', 'label2', 'label1'],
     40                 attributes={'attrib1': None, 'attrib2': 'val2'})
     41         info2 = host_info.HostInfo(
     42                 labels=['label1', 'label2', 'label1'],
     43                 attributes={'attrib1': None, 'attrib2': 'val2'})
     44         self.assertEqual(info1, info2)
     45         # equality and non-equality are unrelated by the data model.
     46         self.assertFalse(info1 != info2)
     47 
     48 
     49     def test_non_equal_infos(self):
     50         """Tests that HostInfo objects with different information are unequal"""
     51         info1 = host_info.HostInfo(labels=['label'])
     52         info2 = host_info.HostInfo(attributes={'attrib': 'value'})
     53         self.assertNotEqual(info1, info2)
     54         # equality and non-equality are unrelated by the data model.
     55         self.assertFalse(info1 == info2)
     56 
     57 
     58     def test_build_needs_prefix(self):
     59         """The build prefix is of the form '<type>-version:'"""
     60         self.info.labels = ['cros-version', 'ab-version', 'testbed-version',
     61                             'fwrw-version', 'fwro-version']
     62         self.assertIsNone(self.info.build)
     63 
     64 
     65     def test_build_prefix_must_be_anchored(self):
     66         """Ensure that build ignores prefixes occuring mid-string."""
     67         self.info.labels = ['not-at-start-cros-version:cros1',
     68                             'not-at-start-ab-version:ab1',
     69                             'not-at-start-testbed-version:testbed1']
     70         self.assertIsNone(self.info.build)
     71 
     72 
     73     def test_build_ignores_firmware(self):
     74         """build attribute should ignore firmware versions."""
     75         self.info.labels = ['fwrw-version:fwrw1', 'fwro-version:fwro1']
     76         self.assertIsNone(self.info.build)
     77 
     78 
     79     def test_build_returns_first_match(self):
     80         """When multiple labels match, first one should be used as build."""
     81         self.info.labels = ['cros-version:cros1', 'cros-version:cros2']
     82         self.assertEqual(self.info.build, 'cros1')
     83         self.info.labels = ['ab-version:ab1', 'ab-version:ab2']
     84         self.assertEqual(self.info.build, 'ab1')
     85         self.info.labels = ['testbed-version:tb1', 'testbed-version:tb2']
     86         self.assertEqual(self.info.build, 'tb1')
     87 
     88 
     89     def test_build_prefer_cros_over_others(self):
     90         """When multiple versions are available, prefer cros."""
     91         self.info.labels = ['testbed-version:tb1', 'ab-version:ab1',
     92                             'cros-version:cros1']
     93         self.assertEqual(self.info.build, 'cros1')
     94         self.info.labels = ['cros-version:cros1', 'ab-version:ab1',
     95                             'testbed-version:tb1']
     96         self.assertEqual(self.info.build, 'cros1')
     97 
     98 
     99     def test_build_prefer_ab_over_testbed(self):
    100         """When multiple versions are available, prefer ab over testbed."""
    101         self.info.labels = ['testbed-version:tb1', 'ab-version:ab1']
    102         self.assertEqual(self.info.build, 'ab1')
    103         self.info.labels = ['ab-version:ab1', 'testbed-version:tb1']
    104         self.assertEqual(self.info.build, 'ab1')
    105 
    106 
    107     def test_os_no_match(self):
    108         """Use proper prefix to search for os information."""
    109         self.info.labels = ['something_else', 'cros-version:hana',
    110                             'os_without_colon']
    111         self.assertEqual(self.info.os, '')
    112 
    113 
    114     def test_os_returns_first_match(self):
    115         """Return the first matching os label."""
    116         self.info.labels = ['os:linux', 'os:windows', 'os_corrupted_label']
    117         self.assertEqual(self.info.os, 'linux')
    118 
    119 
    120     def test_board_no_match(self):
    121         """Use proper prefix to search for board information."""
    122         self.info.labels = ['something_else', 'cros-version:hana', 'os:blah',
    123                             'board_my_board_no_colon']
    124         self.assertEqual(self.info.board, '')
    125 
    126 
    127     def test_board_returns_first_match(self):
    128         """Return the first matching board label."""
    129         self.info.labels = ['board_corrupted', 'board:walk', 'board:bored']
    130         self.assertEqual(self.info.board, 'walk')
    131 
    132 
    133     def test_pools_no_match(self):
    134         """Use proper prefix to search for pool information."""
    135         self.info.labels = ['something_else', 'cros-version:hana', 'os:blah',
    136                             'board_my_board_no_colon', 'board:my_board']
    137         self.assertEqual(self.info.pools, set())
    138 
    139 
    140     def test_pools_returns_all_matches(self):
    141         """Return all matching pool labels."""
    142         self.info.labels = ['board_corrupted', 'board:walk', 'board:bored',
    143                             'pool:first_pool', 'pool:second_pool']
    144         self.assertEqual(self.info.pools, {'second_pool', 'first_pool'})
    145 
    146 
    147     def test_str(self):
    148         """Sanity checks the __str__ implementation."""
    149         info = host_info.HostInfo(labels=['a'], attributes={'b': 2})
    150         self.assertEqual(str(info),
    151                          "HostInfo[Labels: ['a'], Attributes: {'b': 2}]")
    152 
    153 
    154     def test_clear_version_labels_no_labels(self):
    155         """When no version labels exit, do nothing for clear_version_labels."""
    156         original_labels = ['board:something', 'os:something_else',
    157                            'pool:mypool', 'ab-version-corrupted:blah',
    158                            'cros-version']
    159         self.info.labels = list(original_labels)
    160         self.info.clear_version_labels()
    161         self.assertListEqual(self.info.labels, original_labels)
    162 
    163 
    164     def test_clear_all_version_labels(self):
    165         """Clear each recognized type of version label."""
    166         original_labels = ['extra_label', 'cros-version:cr1', 'ab-version:ab1',
    167                            'testbed-version:tb1']
    168         self.info.labels = list(original_labels)
    169         self.info.clear_version_labels()
    170         self.assertListEqual(self.info.labels, ['extra_label'])
    171 
    172     def test_clear_all_version_label_prefixes(self):
    173         """Clear each recognized type of version label with empty value."""
    174         original_labels = ['extra_label', 'cros-version:', 'ab-version:',
    175                            'testbed-version:']
    176         self.info.labels = list(original_labels)
    177         self.info.clear_version_labels()
    178         self.assertListEqual(self.info.labels, ['extra_label'])
    179 
    180 
    181     def test_set_version_labels_updates_in_place(self):
    182         """Update version label in place if prefix already exists."""
    183         self.info.labels = ['extra', 'cros-version:X', 'ab-version:Y']
    184         self.info.set_version_label('cros-version', 'Z')
    185         self.assertListEqual(self.info.labels, ['extra', 'cros-version:Z',
    186                                                 'ab-version:Y'])
    187 
    188     def test_set_version_labels_appends(self):
    189         """Append a new version label if the prefix doesn't exist."""
    190         self.info.labels = ['extra', 'ab-version:Y']
    191         self.info.set_version_label('cros-version', 'Z')
    192         self.assertListEqual(self.info.labels, ['extra', 'ab-version:Y',
    193                                                 'cros-version:Z'])
    194 
    195 
    196 class InMemoryHostInfoStoreTest(unittest.TestCase):
    197     """Basic tests for CachingHostInfoStore using InMemoryHostInfoStore."""
    198 
    199     def setUp(self):
    200         self.store = host_info.InMemoryHostInfoStore()
    201 
    202 
    203     def _verify_host_info_data(self, host_info, labels, attributes):
    204         """Verifies the data in the given host_info."""
    205         self.assertListEqual(host_info.labels, labels)
    206         self.assertDictEqual(host_info.attributes, attributes)
    207 
    208 
    209     def test_first_get_refreshes_cache(self):
    210         """Test that the first call to get gets the data from store."""
    211         self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
    212         got = self.store.get()
    213         self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
    214 
    215 
    216     def test_repeated_get_returns_from_cache(self):
    217         """Tests that repeated calls to get do not refresh cache."""
    218         self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
    219         got = self.store.get()
    220         self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
    221 
    222         self.store.info = host_info.HostInfo(['label1', 'label2'], {})
    223         got = self.store.get()
    224         self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
    225 
    226 
    227     def test_get_uncached_always_refreshes_cache(self):
    228         """Tests that calling get_uncached always refreshes the cache."""
    229         self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
    230         got = self.store.get(force_refresh=True)
    231         self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
    232 
    233         self.store.info = host_info.HostInfo(['label1', 'label2'], {})
    234         got = self.store.get(force_refresh=True)
    235         self._verify_host_info_data(got, ['label1', 'label2'], {})
    236 
    237 
    238     def test_commit(self):
    239         """Test that commit sends data to store."""
    240         info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
    241         self._verify_host_info_data(self.store.info, [], {})
    242         self.store.commit(info)
    243         self._verify_host_info_data(self.store.info, ['label1'],
    244                                     {'attrib1': 'val1'})
    245 
    246 
    247     def test_commit_then_get(self):
    248         """Test a commit-get roundtrip."""
    249         got = self.store.get()
    250         self._verify_host_info_data(got, [], {})
    251 
    252         info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
    253         self.store.commit(info)
    254         got = self.store.get()
    255         self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
    256 
    257 
    258     def test_commit_then_get_uncached(self):
    259         """Test a commit-get_uncached roundtrip."""
    260         got = self.store.get()
    261         self._verify_host_info_data(got, [], {})
    262 
    263         info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
    264         self.store.commit(info)
    265         got = self.store.get(force_refresh=True)
    266         self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
    267 
    268 
    269     def test_commit_deepcopies_data(self):
    270         """Once commited, changes to HostInfo don't corrupt the store."""
    271         info = host_info.HostInfo(['label1'], {'attrib1': {'key1': 'data1'}})
    272         self.store.commit(info)
    273         info.labels.append('label2')
    274         info.attributes['attrib1']['key1'] = 'data2'
    275         self._verify_host_info_data(self.store.info,
    276                                     ['label1'], {'attrib1': {'key1': 'data1'}})
    277 
    278 
    279     def test_get_returns_deepcopy(self):
    280         """The cached object is protected from |get| caller modifications."""
    281         self.store.info = host_info.HostInfo(['label1'],
    282                                              {'attrib1': {'key1': 'data1'}})
    283         got = self.store.get()
    284         self._verify_host_info_data(got,
    285                                     ['label1'], {'attrib1': {'key1': 'data1'}})
    286         got.labels.append('label2')
    287         got.attributes['attrib1']['key1'] = 'data2'
    288         got = self.store.get()
    289         self._verify_host_info_data(got,
    290                                     ['label1'], {'attrib1': {'key1': 'data1'}})
    291 
    292 
    293     def test_str(self):
    294         """Sanity tests __str__ implementation."""
    295         self.store.info = host_info.HostInfo(['label1'],
    296                                              {'attrib1': {'key1': 'data1'}})
    297         self.assertEqual(str(self.store),
    298                          'InMemoryHostInfoStore[%s]' % self.store.info)
    299 
    300 
    301 class ExceptionRaisingStore(host_info.CachingHostInfoStore):
    302     """A test class that always raises on refresh / commit."""
    303 
    304     def __init__(self):
    305         super(ExceptionRaisingStore, self).__init__()
    306         self.refresh_raises = True
    307         self.commit_raises = True
    308 
    309 
    310     def _refresh_impl(self):
    311         if self.refresh_raises:
    312             raise host_info.StoreError('no can do')
    313         return host_info.HostInfo()
    314 
    315     def _commit_impl(self, _):
    316         if self.commit_raises:
    317             raise host_info.StoreError('wont wont wont')
    318 
    319 
    320 class CachingHostInfoStoreErrorTest(unittest.TestCase):
    321     """Tests error behaviours of CachingHostInfoStore."""
    322 
    323     def setUp(self):
    324         self.store = ExceptionRaisingStore()
    325 
    326 
    327     def test_failed_refresh_cleans_cache(self):
    328         """Sanity checks return values when refresh raises."""
    329         with self.assertRaises(host_info.StoreError):
    330             self.store.get()
    331         # Since |get| hit an error, a subsequent get should again hit the store.
    332         with self.assertRaises(host_info.StoreError):
    333             self.store.get()
    334 
    335 
    336     def test_failed_commit_cleans_cache(self):
    337         """Check that a failed commit cleanes cache."""
    338         # Let's initialize the store without errors.
    339         self.store.refresh_raises = False
    340         self.store.get(force_refresh=True)
    341         self.store.refresh_raises = True
    342 
    343         with self.assertRaises(host_info.StoreError):
    344             self.store.commit(host_info.HostInfo())
    345         # Since |commit| hit an error, a subsequent get should again hit the
    346         # store.
    347         with self.assertRaises(host_info.StoreError):
    348             self.store.get()
    349 
    350 
    351 class GetStoreFromMachineTest(unittest.TestCase):
    352     """Tests the get_store_from_machine function."""
    353 
    354     def test_machine_is_dict(self):
    355         """We extract the store when machine is a dict."""
    356         machine = {
    357                 'something': 'else',
    358                 'host_info_store': 5
    359         }
    360         self.assertEqual(host_info.get_store_from_machine(machine), 5)
    361 
    362 
    363     def test_machine_is_string(self):
    364         """We return a trivial store when machine is a string."""
    365         machine = 'hostname'
    366         self.assertTrue(isinstance(host_info.get_store_from_machine(machine),
    367                                    host_info.InMemoryHostInfoStore))
    368 
    369 
    370 class HostInfoJsonSerializationTestCase(unittest.TestCase):
    371     """Tests the json_serialize and json_deserialize functions."""
    372 
    373     CURRENT_SERIALIZATION_VERSION = host_info._CURRENT_SERIALIZATION_VERSION
    374 
    375     def test_serialize_empty(self):
    376         """Serializing empty HostInfo results in the expected json."""
    377         info = host_info.HostInfo()
    378         file_obj = cStringIO.StringIO()
    379         host_info.json_serialize(info, file_obj)
    380         file_obj.seek(0)
    381         expected_dict = {
    382                 'serializer_version': self.CURRENT_SERIALIZATION_VERSION,
    383                 'attributes' : {},
    384                 'labels': [],
    385         }
    386         self.assertEqual(json.load(file_obj), expected_dict)
    387 
    388 
    389     def test_serialize_non_empty(self):
    390         """Serializing a populated HostInfo results in expected json."""
    391         info = host_info.HostInfo(labels=['label1'],
    392                                   attributes={'attrib': 'val'})
    393         file_obj = cStringIO.StringIO()
    394         host_info.json_serialize(info, file_obj)
    395         file_obj.seek(0)
    396         expected_dict = {
    397                 'serializer_version': self.CURRENT_SERIALIZATION_VERSION,
    398                 'attributes' : {'attrib': 'val'},
    399                 'labels': ['label1'],
    400         }
    401         self.assertEqual(json.load(file_obj), expected_dict)
    402 
    403 
    404     def test_round_trip_empty(self):
    405         """Serializing - deserializing empty HostInfo keeps it unchanged."""
    406         info = host_info.HostInfo()
    407         serialized_fp = cStringIO.StringIO()
    408         host_info.json_serialize(info, serialized_fp)
    409         serialized_fp.seek(0)
    410         got = host_info.json_deserialize(serialized_fp)
    411         self.assertEqual(got, info)
    412 
    413 
    414     def test_round_trip_non_empty(self):
    415         """Serializing - deserializing non-empty HostInfo keeps it unchanged."""
    416         info = host_info.HostInfo(
    417                 labels=['label1'],
    418                 attributes = {'attrib': 'val'})
    419         serialized_fp = cStringIO.StringIO()
    420         host_info.json_serialize(info, serialized_fp)
    421         serialized_fp.seek(0)
    422         got = host_info.json_deserialize(serialized_fp)
    423         self.assertEqual(got, info)
    424 
    425 
    426     def test_deserialize_malformed_json_raises(self):
    427         """Deserializing a malformed string raises."""
    428         with self.assertRaises(host_info.DeserializationError):
    429             host_info.json_deserialize(cStringIO.StringIO('{labels:['))
    430 
    431 
    432     def test_deserialize_no_version_raises(self):
    433         """Deserializing a string with no serializer version raises."""
    434         info = host_info.HostInfo()
    435         serialized_fp = cStringIO.StringIO()
    436         host_info.json_serialize(info, serialized_fp)
    437         serialized_fp.seek(0)
    438 
    439         serialized_dict = json.load(serialized_fp)
    440         del serialized_dict['serializer_version']
    441         serialized_no_version_str = json.dumps(serialized_dict)
    442 
    443         with self.assertRaises(host_info.DeserializationError):
    444             host_info.json_deserialize(
    445                     cStringIO.StringIO(serialized_no_version_str))
    446 
    447 
    448     def test_deserialize_malformed_host_info_raises(self):
    449         """Deserializing a malformed host_info raises."""
    450         info = host_info.HostInfo()
    451         serialized_fp = cStringIO.StringIO()
    452         host_info.json_serialize(info, serialized_fp)
    453         serialized_fp.seek(0)
    454 
    455         serialized_dict = json.load(serialized_fp)
    456         del serialized_dict['labels']
    457         serialized_no_version_str = json.dumps(serialized_dict)
    458 
    459         with self.assertRaises(host_info.DeserializationError):
    460             host_info.json_deserialize(
    461                     cStringIO.StringIO(serialized_no_version_str))
    462 
    463 
    464     def test_enforce_compatibility_version_1(self):
    465         """Tests that required fields are never dropped.
    466 
    467         Never change this test. If you must break compatibility, uprev the
    468         serializer version and add a new test for the newer version.
    469 
    470         Adding a field to compat_info_str means we're making the new field
    471         mandatory. This breaks backwards compatibility.
    472         Removing a field from compat_info_str means we're no longer requiring a
    473         field to be mandatory. This breaks forwards compatibility.
    474         """
    475         compat_dict = {
    476                 'serializer_version': 1,
    477                 'attributes': {},
    478                 'labels': []
    479         }
    480         serialized_str = json.dumps(compat_dict)
    481         serialized_fp = cStringIO.StringIO(serialized_str)
    482         host_info.json_deserialize(serialized_fp)
    483 
    484 
    485     def test_serialize_pretty_print(self):
    486         """Serializing a host_info dumps the json in human-friendly format"""
    487         info = host_info.HostInfo(labels=['label1'],
    488                                   attributes={'attrib': 'val'})
    489         serialized_fp = cStringIO.StringIO()
    490         host_info.json_serialize(info, serialized_fp)
    491         expected = """{
    492             "attributes": {
    493                 "attrib": "val"
    494             },
    495             "labels": [
    496                 "label1"
    497             ],
    498             "serializer_version": %d
    499         }""" % self.CURRENT_SERIALIZATION_VERSION
    500         self.assertEqual(serialized_fp.getvalue(), inspect.cleandoc(expected))
    501 
    502 
    503 if __name__ == '__main__':
    504     unittest.main()
    505