Home | History | Annotate | Download | only in unittest
      1 # Copyright 2012 The Chromium Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 """A very very simple mock object harness."""
      5 from types import ModuleType
      6 
      7 DONT_CARE = ''
      8 
      9 class MockFunctionCall(object):
     10   def __init__(self, name):
     11     self.name = name
     12     self.args = tuple()
     13     self.return_value = None
     14     self.when_called_handlers = []
     15 
     16   def WithArgs(self, *args):
     17     self.args = args
     18     return self
     19 
     20   def WillReturn(self, value):
     21     self.return_value = value
     22     return self
     23 
     24   def WhenCalled(self, handler):
     25     self.when_called_handlers.append(handler)
     26 
     27   def VerifyEquals(self, got):
     28     if self.name != got.name:
     29       raise Exception('Self %s, got %s' % (repr(self), repr(got)))
     30     if len(self.args) != len(got.args):
     31       raise Exception('Self %s, got %s' % (repr(self), repr(got)))
     32     for i in range(len(self.args)):
     33       self_a = self.args[i]
     34       got_a = got.args[i]
     35       if self_a == DONT_CARE:
     36         continue
     37       if self_a != got_a:
     38         raise Exception('Self %s, got %s' % (repr(self), repr(got)))
     39 
     40   def __repr__(self):
     41     def arg_to_text(a):
     42       if a == DONT_CARE:
     43         return '_'
     44       return repr(a)
     45     args_text = ', '.join([arg_to_text(a) for a in self.args])
     46     if self.return_value in (None, DONT_CARE):
     47       return '%s(%s)' % (self.name, args_text)
     48     return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value))
     49 
     50 class MockTrace(object):
     51   def __init__(self):
     52     self.expected_calls = []
     53     self.next_call_index = 0
     54 
     55 class MockObject(object):
     56   def __init__(self, parent_mock = None):
     57     if parent_mock:
     58       self._trace = parent_mock._trace # pylint: disable=W0212
     59     else:
     60       self._trace = MockTrace()
     61 
     62   def __setattr__(self, name, value):
     63     if (not hasattr(self, '_trace') or
     64         hasattr(value, 'is_hook')):
     65       object.__setattr__(self, name, value)
     66       return
     67     assert isinstance(value, MockObject)
     68     object.__setattr__(self, name, value)
     69 
     70   def SetAttribute(self, name, value):
     71     setattr(self, name, value)
     72 
     73   def ExpectCall(self, func_name, *args):
     74     assert self._trace.next_call_index == 0
     75     if not hasattr(self, func_name):
     76       self._install_hook(func_name)
     77 
     78     call = MockFunctionCall(func_name)
     79     self._trace.expected_calls.append(call)
     80     call.WithArgs(*args)
     81     return call
     82 
     83   def _install_hook(self, func_name):
     84     def handler(*args, **_):
     85       got_call = MockFunctionCall(
     86         func_name).WithArgs(*args).WillReturn(DONT_CARE)
     87       if self._trace.next_call_index >= len(self._trace.expected_calls):
     88         raise Exception(
     89           'Call to %s was not expected, at end of programmed trace.' %
     90           repr(got_call))
     91       expected_call = self._trace.expected_calls[
     92         self._trace.next_call_index]
     93       expected_call.VerifyEquals(got_call)
     94       self._trace.next_call_index += 1
     95       for h in expected_call.when_called_handlers:
     96         h(*args)
     97       return expected_call.return_value
     98     handler.is_hook = True
     99     setattr(self, func_name, handler)
    100 
    101 
    102 class MockTimer(object):
    103   """ A mock timer to fake out the timing for a module.
    104     Args:
    105       module: module to fake out the time
    106   """
    107   def __init__(self, module=None):
    108     self._elapsed_time = 0
    109     self._module = module
    110     self._actual_time = None
    111     if module:
    112       assert isinstance(module, ModuleType)
    113       self._actual_time = module.time
    114       self._module.time = self
    115 
    116   def sleep(self, time):
    117     self._elapsed_time += time
    118 
    119   def time(self):
    120     return self._elapsed_time
    121 
    122   def SetTime(self, time):
    123     self._elapsed_time = time
    124 
    125   def __del__(self):
    126     self.Release()
    127 
    128   def Restore(self):
    129     if self._module:
    130       self._module.time = self._actual_time
    131       self._module = None
    132       self._actual_time = None
    133