Home | History | Annotate | Download | only in atest
      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2017, The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 #     http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 """Utility functions for unit tests."""
     18 
     19 import os
     20 
     21 import constants
     22 import unittest_constants as uc
     23 
     24 def assert_strict_equal(test_class, first, second):
     25     """Check for strict equality and strict equality of nametuple elements.
     26 
     27     assertEqual considers types equal to their subtypes, but we want to
     28     not consider set() and frozenset() equal for testing.
     29     """
     30     test_class.assertEqual(first, second)
     31     # allow byte and unicode string equality.
     32     if not (isinstance(first, basestring) and
     33             isinstance(second, basestring)):
     34         test_class.assertIsInstance(first, type(second))
     35         test_class.assertIsInstance(second, type(first))
     36     # Recursively check elements of namedtuples for strict equals.
     37     if isinstance(first, tuple) and hasattr(first, '_fields'):
     38         # pylint: disable=invalid-name
     39         for f in first._fields:
     40             assert_strict_equal(test_class, getattr(first, f),
     41                                 getattr(second, f))
     42 
     43 def assert_equal_testinfos(test_class, test_info_a, test_info_b):
     44     """Check that the passed in TestInfos are equal."""
     45     # Use unittest.assertEqual to do checks when None is involved.
     46     if test_info_a is None or test_info_b is None:
     47         test_class.assertEqual(test_info_a, test_info_b)
     48         return
     49 
     50     for attr in test_info_a.__dict__:
     51         test_info_a_attr = getattr(test_info_a, attr)
     52         test_info_b_attr = getattr(test_info_b, attr)
     53         test_class.assertEqual(test_info_a_attr, test_info_b_attr,
     54                                msg=('TestInfo.%s mismatch: %s != %s' %
     55                                     (attr, test_info_a_attr, test_info_b_attr)))
     56 
     57 def assert_equal_testinfo_sets(test_class, test_info_set_a, test_info_set_b):
     58     """Check that the sets of TestInfos are equal."""
     59     test_class.assertEqual(len(test_info_set_a), len(test_info_set_b),
     60                            msg=('mismatch # of TestInfos: %d != %d' %
     61                                 (len(test_info_set_a), len(test_info_set_b))))
     62     # Iterate over a set and pop them out as you compare them.
     63     while test_info_set_a:
     64         test_info_a = test_info_set_a.pop()
     65         test_info_b_to_remove = None
     66         for test_info_b in test_info_set_b:
     67             try:
     68                 assert_equal_testinfos(test_class, test_info_a, test_info_b)
     69                 test_info_b_to_remove = test_info_b
     70                 break
     71             except AssertionError:
     72                 pass
     73         if test_info_b_to_remove:
     74             test_info_set_b.remove(test_info_b_to_remove)
     75         else:
     76             # We haven't found a match, raise an assertion error.
     77             raise AssertionError('No matching TestInfo (%s) in [%s]' %
     78                                  (test_info_a, ';'.join([str(t) for t in test_info_set_b])))
     79 
     80 
     81 def isfile_side_effect(value):
     82     """Mock return values for os.path.isfile."""
     83     if value == '/%s/%s' % (uc.MODULE_DIR, constants.MODULE_CONFIG):
     84         return True
     85     if value.endswith('.java'):
     86         return True
     87     if value.endswith(uc.INT_NAME + '.xml'):
     88         return True
     89     if value.endswith(uc.GTF_INT_NAME + '.xml'):
     90         return True
     91     return False
     92 
     93 
     94 def realpath_side_effect(path):
     95     """Mock return values for os.path.realpath."""
     96     return os.path.join(uc.ROOT, path)
     97