Home | History | Annotate | Download | only in utils
      1 # Copyright 2014 The Chromium Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 """
      6 A test facility to assert call sequences while mocking their behavior.
      7 """
      8 
      9 import unittest
     10 
     11 from devil import devil_env
     12 
     13 with devil_env.SysPath(devil_env.PYMOCK_PATH):
     14   import mock  # pylint: disable=import-error
     15 
     16 
     17 class TestCase(unittest.TestCase):
     18   """Adds assertCalls to TestCase objects."""
     19   class _AssertCalls(object):
     20 
     21     def __init__(self, test_case, expected_calls, watched):
     22       def call_action(pair):
     23         if isinstance(pair, type(mock.call)):
     24           return (pair, None)
     25         else:
     26           return pair
     27 
     28       def do_check(call):
     29         def side_effect(*args, **kwargs):
     30           received_call = call(*args, **kwargs)
     31           self._test_case.assertTrue(
     32               self._expected_calls,
     33               msg=('Unexpected call: %s' % str(received_call)))
     34           expected_call, action = self._expected_calls.pop(0)
     35           self._test_case.assertTrue(
     36               received_call == expected_call,
     37               msg=('Expected call mismatch:\n'
     38                    '  expected: %s\n'
     39                    '  received: %s\n'
     40                    % (str(expected_call), str(received_call))))
     41           if callable(action):
     42             return action(*args, **kwargs)
     43           else:
     44             return action
     45         return side_effect
     46 
     47       self._test_case = test_case
     48       self._expected_calls = [call_action(pair) for pair in expected_calls]
     49       watched = watched.copy()  # do not pollute the caller's dict
     50       watched.update((call.parent.name, call.parent)
     51                      for call, _ in self._expected_calls)
     52       self._patched = [test_case.patch_call(call, side_effect=do_check(call))
     53                        for call in watched.itervalues()]
     54 
     55     def __enter__(self):
     56       for patch in self._patched:
     57         patch.__enter__()
     58       return self
     59 
     60     def __exit__(self, exc_type, exc_val, exc_tb):
     61       for patch in self._patched:
     62         patch.__exit__(exc_type, exc_val, exc_tb)
     63       if exc_type is None:
     64         missing = ''.join('  expected: %s\n' % str(call)
     65                           for call, _ in self._expected_calls)
     66         self._test_case.assertFalse(
     67             missing,
     68             msg='Expected calls not found:\n' + missing)
     69 
     70   def __init__(self, *args, **kwargs):
     71     super(TestCase, self).__init__(*args, **kwargs)
     72     self.call = mock.call.self
     73     self._watched = {}
     74 
     75   def call_target(self, call):
     76     """Resolve a self.call instance to the target it represents.
     77 
     78     Args:
     79       call: a self.call instance, e.g. self.call.adb.Shell
     80 
     81     Returns:
     82       The target object represented by the call, e.g. self.adb.Shell
     83 
     84     Raises:
     85       ValueError if the path of the call does not start with "self", i.e. the
     86           target of the call is external to the self object.
     87       AttributeError if the path of the call does not specify a valid
     88           chain of attributes (without any calls) starting from "self".
     89     """
     90     path = call.name.split('.')
     91     if path.pop(0) != 'self':
     92       raise ValueError("Target %r outside of 'self' object" % call.name)
     93     target = self
     94     for attr in path:
     95       target = getattr(target, attr)
     96     return target
     97 
     98   def patch_call(self, call, **kwargs):
     99     """Patch the target of a mock.call instance.
    100 
    101     Args:
    102       call: a mock.call instance identifying a target to patch
    103       Extra keyword arguments are processed by mock.patch
    104 
    105     Returns:
    106       A context manager to mock/unmock the target of the call
    107     """
    108     if call.name.startswith('self.'):
    109       target = self.call_target(call.parent)
    110       _, attribute = call.name.rsplit('.', 1)
    111       if (hasattr(type(target), attribute)
    112           and isinstance(getattr(type(target), attribute), property)):
    113         return mock.patch.object(
    114             type(target), attribute, new_callable=mock.PropertyMock, **kwargs)
    115       else:
    116         return mock.patch.object(target, attribute, **kwargs)
    117     else:
    118       return mock.patch(call.name, **kwargs)
    119 
    120   def watchCalls(self, calls):
    121     """Add calls to the set of watched calls.
    122 
    123     Args:
    124       calls: a sequence of mock.call instances identifying targets to watch
    125     """
    126     self._watched.update((call.name, call) for call in calls)
    127 
    128   def watchMethodCalls(self, call, ignore=None):
    129     """Watch all public methods of the target identified by a self.call.
    130 
    131     Args:
    132       call: a self.call instance indetifying an object
    133       ignore: a list of public methods to ignore when watching for calls
    134     """
    135     target = self.call_target(call)
    136     if ignore is None:
    137       ignore = []
    138     self.watchCalls(getattr(call, method)
    139                     for method in dir(target.__class__)
    140                     if not method.startswith('_') and not method in ignore)
    141 
    142   def clearWatched(self):
    143     """Clear the set of watched calls."""
    144     self._watched = {}
    145 
    146   def assertCalls(self, *calls):
    147     """A context manager to assert that a sequence of calls is made.
    148 
    149     During the assertion, a number of functions and methods will be "watched",
    150     and any calls made to them is expected to appear---in the exact same order,
    151     and with the exact same arguments---as specified by the argument |calls|.
    152 
    153     By default, the targets of all expected calls are watched. Further targets
    154     to watch may be added using watchCalls and watchMethodCalls.
    155 
    156     Optionaly, each call may be accompanied by an action. If the action is a
    157     (non-callable) value, this value will be used as the return value given to
    158     the caller when the matching call is found. Alternatively, if the action is
    159     a callable, the action will be then called with the same arguments as the
    160     intercepted call, so that it can provide a return value or perform other
    161     side effects. If the action is missing, a return value of None is assumed.
    162 
    163     Note that mock.Mock objects are often convenient to use as a callable
    164     action, e.g. to raise exceptions or return other objects which are
    165     themselves callable.
    166 
    167     Args:
    168       calls: each argument is either a pair (expected_call, action) or just an
    169           expected_call, where expected_call is a mock.call instance.
    170 
    171     Raises:
    172       AssertionError if the watched targets do not receive the exact sequence
    173           of calls specified. Missing calls, extra calls, and calls with
    174           mismatching arguments, all cause the assertion to fail.
    175     """
    176     return self._AssertCalls(self, calls, self._watched)
    177 
    178   def assertCall(self, call, action=None):
    179     return self.assertCalls((call, action))
    180 
    181