Home | History | Annotate | Download | only in unittest
      1 # Copyright 2012 The Chromium Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 """A very very simple mock object harness."""
      5 
      6 DONT_CARE = ''
      7 
      8 class MockFunctionCall(object):
      9   def __init__(self, name):
     10     self.name = name
     11     self.args = tuple()
     12     self.return_value = None
     13     self.when_called_handlers = []
     14 
     15   def WithArgs(self, *args):
     16     self.args = args
     17     return self
     18 
     19   def WillReturn(self, value):
     20     self.return_value = value
     21     return self
     22 
     23   def WhenCalled(self, handler):
     24     self.when_called_handlers.append(handler)
     25 
     26   def VerifyEquals(self, got):
     27     if self.name != got.name:
     28       raise Exception('Self %s, got %s' % (repr(self), repr(got)))
     29     if len(self.args) != len(got.args):
     30       raise Exception('Self %s, got %s' % (repr(self), repr(got)))
     31     for i in range(len(self.args)):
     32       self_a = self.args[i]
     33       got_a = got.args[i]
     34       if self_a == DONT_CARE:
     35         continue
     36       if self_a != got_a:
     37         raise Exception('Self %s, got %s' % (repr(self), repr(got)))
     38 
     39   def __repr__(self):
     40     def arg_to_text(a):
     41       if a == DONT_CARE:
     42         return '_'
     43       return repr(a)
     44     args_text = ', '.join([arg_to_text(a) for a in self.args])
     45     if self.return_value in (None, DONT_CARE):
     46       return '%s(%s)' % (self.name, args_text)
     47     return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value))
     48 
     49 class MockTrace(object):
     50   def __init__(self):
     51     self.expected_calls = []
     52     self.next_call_index = 0
     53 
     54 class MockObject(object):
     55   def __init__(self, parent_mock = None):
     56     if parent_mock:
     57       self._trace = parent_mock._trace # pylint: disable=W0212
     58     else:
     59       self._trace = MockTrace()
     60 
     61   def __setattr__(self, name, value):
     62     if (not hasattr(self, '_trace') or
     63         hasattr(value, 'is_hook')):
     64       object.__setattr__(self, name, value)
     65       return
     66     assert isinstance(value, MockObject)
     67     object.__setattr__(self, name, value)
     68 
     69   def SetAttribute(self, name, value):
     70     setattr(self, name, value)
     71 
     72   def ExpectCall(self, func_name, *args):
     73     assert self._trace.next_call_index == 0
     74     if not hasattr(self, func_name):
     75       self._install_hook(func_name)
     76 
     77     call = MockFunctionCall(func_name)
     78     self._trace.expected_calls.append(call)
     79     call.WithArgs(*args)
     80     return call
     81 
     82   def _install_hook(self, func_name):
     83     def handler(*args, **_):
     84       got_call = MockFunctionCall(
     85         func_name).WithArgs(*args).WillReturn(DONT_CARE)
     86       if self._trace.next_call_index >= len(self._trace.expected_calls):
     87         raise Exception(
     88           'Call to %s was not expected, at end of programmed trace.' %
     89           repr(got_call))
     90       expected_call = self._trace.expected_calls[
     91         self._trace.next_call_index]
     92       expected_call.VerifyEquals(got_call)
     93       self._trace.next_call_index += 1
     94       for h in expected_call.when_called_handlers:
     95         h(*args)
     96       return expected_call.return_value
     97     handler.is_hook = True
     98     setattr(self, func_name, handler)
     99 
    100 
    101 class MockTimer(object):
    102   def __init__(self):
    103     self._elapsed_time = 0
    104 
    105   def Sleep(self, time):
    106     self._elapsed_time += time
    107 
    108   def GetTime(self):
    109     return self._elapsed_time
    110 
    111   def SetTime(self, time):
    112     self._elapsed_time = time
    113