Home | History | Annotate | Download | only in python
      1 #!/usr/bin/python2.4
      2 #
      3 # Copyright 2008 Google Inc.
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 #      http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 # This file is used for testing.  The original is at:
     18 #   http://code.google.com/p/pymox/
     19 
     20 """Mox, an object-mocking framework for Python.
     21 
     22 Mox works in the record-replay-verify paradigm.  When you first create
     23 a mock object, it is in record mode.  You then programmatically set
     24 the expected behavior of the mock object (what methods are to be
     25 called on it, with what parameters, what they should return, and in
     26 what order).
     27 
     28 Once you have set up the expected mock behavior, you put it in replay
     29 mode.  Now the mock responds to method calls just as you told it to.
     30 If an unexpected method (or an expected method with unexpected
     31 parameters) is called, then an exception will be raised.
     32 
     33 Once you are done interacting with the mock, you need to verify that
     34 all the expected interactions occured.  (Maybe your code exited
     35 prematurely without calling some cleanup method!)  The verify phase
     36 ensures that every expected method was called; otherwise, an exception
     37 will be raised.
     38 
     39 Suggested usage / workflow:
     40 
     41   # Create Mox factory
     42   my_mox = Mox()
     43 
     44   # Create a mock data access object
     45   mock_dao = my_mox.CreateMock(DAOClass)
     46 
     47   # Set up expected behavior
     48   mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
     49   mock_dao.DeletePerson(person)
     50 
     51   # Put mocks in replay mode
     52   my_mox.ReplayAll()
     53 
     54   # Inject mock object and run test
     55   controller.SetDao(mock_dao)
     56   controller.DeletePersonById('1')
     57 
     58   # Verify all methods were called as expected
     59   my_mox.VerifyAll()
     60 """
     61 
     62 from collections import deque
     63 import re
     64 import types
     65 import unittest
     66 
     67 import stubout
     68 
     69 class Error(AssertionError):
     70   """Base exception for this module."""
     71 
     72   pass
     73 
     74 
     75 class ExpectedMethodCallsError(Error):
     76   """Raised when Verify() is called before all expected methods have been called
     77   """
     78 
     79   def __init__(self, expected_methods):
     80     """Init exception.
     81 
     82     Args:
     83       # expected_methods: A sequence of MockMethod objects that should have been
     84       #   called.
     85       expected_methods: [MockMethod]
     86 
     87     Raises:
     88       ValueError: if expected_methods contains no methods.
     89     """
     90 
     91     if not expected_methods:
     92       raise ValueError("There must be at least one expected method")
     93     Error.__init__(self)
     94     self._expected_methods = expected_methods
     95 
     96   def __str__(self):
     97     calls = "\n".join(["%3d.  %s" % (i, m)
     98                        for i, m in enumerate(self._expected_methods)])
     99     return "Verify: Expected methods never called:\n%s" % (calls,)
    100 
    101 
    102 class UnexpectedMethodCallError(Error):
    103   """Raised when an unexpected method is called.
    104 
    105   This can occur if a method is called with incorrect parameters, or out of the
    106   specified order.
    107   """
    108 
    109   def __init__(self, unexpected_method, expected):
    110     """Init exception.
    111 
    112     Args:
    113       # unexpected_method: MockMethod that was called but was not at the head of
    114       #   the expected_method queue.
    115       # expected: MockMethod or UnorderedGroup the method should have
    116       #   been in.
    117       unexpected_method: MockMethod
    118       expected: MockMethod or UnorderedGroup
    119     """
    120 
    121     Error.__init__(self)
    122     self._unexpected_method = unexpected_method
    123     self._expected = expected
    124 
    125   def __str__(self):
    126     return "Unexpected method call: %s.  Expecting: %s" % \
    127       (self._unexpected_method, self._expected)
    128 
    129 
    130 class UnknownMethodCallError(Error):
    131   """Raised if an unknown method is requested of the mock object."""
    132 
    133   def __init__(self, unknown_method_name):
    134     """Init exception.
    135 
    136     Args:
    137       # unknown_method_name: Method call that is not part of the mocked class's
    138       #   public interface.
    139       unknown_method_name: str
    140     """
    141 
    142     Error.__init__(self)
    143     self._unknown_method_name = unknown_method_name
    144 
    145   def __str__(self):
    146     return "Method called is not a member of the object: %s" % \
    147       self._unknown_method_name
    148 
    149 
    150 class Mox(object):
    151   """Mox: a factory for creating mock objects."""
    152 
    153   # A list of types that should be stubbed out with MockObjects (as
    154   # opposed to MockAnythings).
    155   _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
    156                       types.ObjectType, types.TypeType]
    157 
    158   def __init__(self):
    159     """Initialize a new Mox."""
    160 
    161     self._mock_objects = []
    162     self.stubs = stubout.StubOutForTesting()
    163 
    164   def CreateMock(self, class_to_mock):
    165     """Create a new mock object.
    166 
    167     Args:
    168       # class_to_mock: the class to be mocked
    169       class_to_mock: class
    170 
    171     Returns:
    172       MockObject that can be used as the class_to_mock would be.
    173     """
    174 
    175     new_mock = MockObject(class_to_mock)
    176     self._mock_objects.append(new_mock)
    177     return new_mock
    178 
    179   def CreateMockAnything(self):
    180     """Create a mock that will accept any method calls.
    181 
    182     This does not enforce an interface.
    183     """
    184 
    185     new_mock = MockAnything()
    186     self._mock_objects.append(new_mock)
    187     return new_mock
    188 
    189   def ReplayAll(self):
    190     """Set all mock objects to replay mode."""
    191 
    192     for mock_obj in self._mock_objects:
    193       mock_obj._Replay()
    194 
    195 
    196   def VerifyAll(self):
    197     """Call verify on all mock objects created."""
    198 
    199     for mock_obj in self._mock_objects:
    200       mock_obj._Verify()
    201 
    202   def ResetAll(self):
    203     """Call reset on all mock objects.  This does not unset stubs."""
    204 
    205     for mock_obj in self._mock_objects:
    206       mock_obj._Reset()
    207 
    208   def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
    209     """Replace a method, attribute, etc. with a Mock.
    210 
    211     This will replace a class or module with a MockObject, and everything else
    212     (method, function, etc) with a MockAnything.  This can be overridden to
    213     always use a MockAnything by setting use_mock_anything to True.
    214 
    215     Args:
    216       obj: A Python object (class, module, instance, callable).
    217       attr_name: str.  The name of the attribute to replace with a mock.
    218       use_mock_anything: bool. True if a MockAnything should be used regardless
    219         of the type of attribute.
    220     """
    221 
    222     attr_to_replace = getattr(obj, attr_name)
    223     if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
    224       stub = self.CreateMock(attr_to_replace)
    225     else:
    226       stub = self.CreateMockAnything()
    227 
    228     self.stubs.Set(obj, attr_name, stub)
    229 
    230   def UnsetStubs(self):
    231     """Restore stubs to their original state."""
    232 
    233     self.stubs.UnsetAll()
    234 
    235 def Replay(*args):
    236   """Put mocks into Replay mode.
    237 
    238   Args:
    239     # args is any number of mocks to put into replay mode.
    240   """
    241 
    242   for mock in args:
    243     mock._Replay()
    244 
    245 
    246 def Verify(*args):
    247   """Verify mocks.
    248 
    249   Args:
    250     # args is any number of mocks to be verified.
    251   """
    252 
    253   for mock in args:
    254     mock._Verify()
    255 
    256 
    257 def Reset(*args):
    258   """Reset mocks.
    259 
    260   Args:
    261     # args is any number of mocks to be reset.
    262   """
    263 
    264   for mock in args:
    265     mock._Reset()
    266 
    267 
    268 class MockAnything:
    269   """A mock that can be used to mock anything.
    270 
    271   This is helpful for mocking classes that do not provide a public interface.
    272   """
    273 
    274   def __init__(self):
    275     """ """
    276     self._Reset()
    277 
    278   def __getattr__(self, method_name):
    279     """Intercept method calls on this object.
    280 
    281      A new MockMethod is returned that is aware of the MockAnything's
    282      state (record or replay).  The call will be recorded or replayed
    283      by the MockMethod's __call__.
    284 
    285     Args:
    286       # method name: the name of the method being called.
    287       method_name: str
    288 
    289     Returns:
    290       A new MockMethod aware of MockAnything's state (record or replay).
    291     """
    292 
    293     return self._CreateMockMethod(method_name)
    294 
    295   def _CreateMockMethod(self, method_name):
    296     """Create a new mock method call and return it.
    297 
    298     Args:
    299       # method name: the name of the method being called.
    300       method_name: str
    301 
    302     Returns:
    303       A new MockMethod aware of MockAnything's state (record or replay).
    304     """
    305 
    306     return MockMethod(method_name, self._expected_calls_queue,
    307                       self._replay_mode)
    308 
    309   def __nonzero__(self):
    310     """Return 1 for nonzero so the mock can be used as a conditional."""
    311 
    312     return 1
    313 
    314   def __eq__(self, rhs):
    315     """Provide custom logic to compare objects."""
    316 
    317     return (isinstance(rhs, MockAnything) and
    318             self._replay_mode == rhs._replay_mode and
    319             self._expected_calls_queue == rhs._expected_calls_queue)
    320 
    321   def __ne__(self, rhs):
    322     """Provide custom logic to compare objects."""
    323 
    324     return not self == rhs
    325 
    326   def _Replay(self):
    327     """Start replaying expected method calls."""
    328 
    329     self._replay_mode = True
    330 
    331   def _Verify(self):
    332     """Verify that all of the expected calls have been made.
    333 
    334     Raises:
    335       ExpectedMethodCallsError: if there are still more method calls in the
    336         expected queue.
    337     """
    338 
    339     # If the list of expected calls is not empty, raise an exception
    340     if self._expected_calls_queue:
    341       # The last MultipleTimesGroup is not popped from the queue.
    342       if (len(self._expected_calls_queue) == 1 and
    343           isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
    344           self._expected_calls_queue[0].IsSatisfied()):
    345         pass
    346       else:
    347         raise ExpectedMethodCallsError(self._expected_calls_queue)
    348 
    349   def _Reset(self):
    350     """Reset the state of this mock to record mode with an empty queue."""
    351 
    352     # Maintain a list of method calls we are expecting
    353     self._expected_calls_queue = deque()
    354 
    355     # Make sure we are in setup mode, not replay mode
    356     self._replay_mode = False
    357 
    358 
    359 class MockObject(MockAnything, object):
    360   """A mock object that simulates the public/protected interface of a class."""
    361 
    362   def __init__(self, class_to_mock):
    363     """Initialize a mock object.
    364 
    365     This determines the methods and properties of the class and stores them.
    366 
    367     Args:
    368       # class_to_mock: class to be mocked
    369       class_to_mock: class
    370     """
    371 
    372     # This is used to hack around the mixin/inheritance of MockAnything, which
    373     # is not a proper object (it can be anything. :-)
    374     MockAnything.__dict__['__init__'](self)
    375 
    376     # Get a list of all the public and special methods we should mock.
    377     self._known_methods = set()
    378     self._known_vars = set()
    379     self._class_to_mock = class_to_mock
    380     for method in dir(class_to_mock):
    381       if callable(getattr(class_to_mock, method)):
    382         self._known_methods.add(method)
    383       else:
    384         self._known_vars.add(method)
    385 
    386   def __getattr__(self, name):
    387     """Intercept attribute request on this object.
    388 
    389     If the attribute is a public class variable, it will be returned and not
    390     recorded as a call.
    391 
    392     If the attribute is not a variable, it is handled like a method
    393     call. The method name is checked against the set of mockable
    394     methods, and a new MockMethod is returned that is aware of the
    395     MockObject's state (record or replay).  The call will be recorded
    396     or replayed by the MockMethod's __call__.
    397 
    398     Args:
    399       # name: the name of the attribute being requested.
    400       name: str
    401 
    402     Returns:
    403       Either a class variable or a new MockMethod that is aware of the state
    404       of the mock (record or replay).
    405 
    406     Raises:
    407       UnknownMethodCallError if the MockObject does not mock the requested
    408           method.
    409     """
    410 
    411     if name in self._known_vars:
    412       return getattr(self._class_to_mock, name)
    413 
    414     if name in self._known_methods:
    415       return self._CreateMockMethod(name)
    416 
    417     raise UnknownMethodCallError(name)
    418 
    419   def __eq__(self, rhs):
    420     """Provide custom logic to compare objects."""
    421 
    422     return (isinstance(rhs, MockObject) and
    423             self._class_to_mock == rhs._class_to_mock and
    424             self._replay_mode == rhs._replay_mode and
    425             self._expected_calls_queue == rhs._expected_calls_queue)
    426 
    427   def __setitem__(self, key, value):
    428     """Provide custom logic for mocking classes that support item assignment.
    429 
    430     Args:
    431       key: Key to set the value for.
    432       value: Value to set.
    433 
    434     Returns:
    435       Expected return value in replay mode.  A MockMethod object for the
    436       __setitem__ method that has already been called if not in replay mode.
    437 
    438     Raises:
    439       TypeError if the underlying class does not support item assignment.
    440       UnexpectedMethodCallError if the object does not expect the call to
    441         __setitem__.
    442 
    443     """
    444     setitem = self._class_to_mock.__dict__.get('__setitem__', None)
    445 
    446     # Verify the class supports item assignment.
    447     if setitem is None:
    448       raise TypeError('object does not support item assignment')
    449 
    450     # If we are in replay mode then simply call the mock __setitem__ method.
    451     if self._replay_mode:
    452       return MockMethod('__setitem__', self._expected_calls_queue,
    453                         self._replay_mode)(key, value)
    454 
    455 
    456     # Otherwise, create a mock method __setitem__.
    457     return self._CreateMockMethod('__setitem__')(key, value)
    458 
    459   def __getitem__(self, key):
    460     """Provide custom logic for mocking classes that are subscriptable.
    461 
    462     Args:
    463       key: Key to return the value for.
    464 
    465     Returns:
    466       Expected return value in replay mode.  A MockMethod object for the
    467       __getitem__ method that has already been called if not in replay mode.
    468 
    469     Raises:
    470       TypeError if the underlying class is not subscriptable.
    471       UnexpectedMethodCallError if the object does not expect the call to
    472         __setitem__.
    473 
    474     """
    475     getitem = self._class_to_mock.__dict__.get('__getitem__', None)
    476 
    477     # Verify the class supports item assignment.
    478     if getitem is None:
    479       raise TypeError('unsubscriptable object')
    480 
    481     # If we are in replay mode then simply call the mock __getitem__ method.
    482     if self._replay_mode:
    483       return MockMethod('__getitem__', self._expected_calls_queue,
    484                         self._replay_mode)(key)
    485 
    486 
    487     # Otherwise, create a mock method __getitem__.
    488     return self._CreateMockMethod('__getitem__')(key)
    489 
    490   def __call__(self, *params, **named_params):
    491     """Provide custom logic for mocking classes that are callable."""
    492 
    493     # Verify the class we are mocking is callable
    494     callable = self._class_to_mock.__dict__.get('__call__', None)
    495     if callable is None:
    496       raise TypeError('Not callable')
    497 
    498     # Because the call is happening directly on this object instead of a method,
    499     # the call on the mock method is made right here
    500     mock_method = self._CreateMockMethod('__call__')
    501     return mock_method(*params, **named_params)
    502 
    503   @property
    504   def __class__(self):
    505     """Return the class that is being mocked."""
    506 
    507     return self._class_to_mock
    508 
    509 
    510 class MockMethod(object):
    511   """Callable mock method.
    512 
    513   A MockMethod should act exactly like the method it mocks, accepting parameters
    514   and returning a value, or throwing an exception (as specified).  When this
    515   method is called, it can optionally verify whether the called method (name and
    516   signature) matches the expected method.
    517   """
    518 
    519   def __init__(self, method_name, call_queue, replay_mode):
    520     """Construct a new mock method.
    521 
    522     Args:
    523       # method_name: the name of the method
    524       # call_queue: deque of calls, verify this call against the head, or add
    525       #     this call to the queue.
    526       # replay_mode: False if we are recording, True if we are verifying calls
    527       #     against the call queue.
    528       method_name: str
    529       call_queue: list or deque
    530       replay_mode: bool
    531     """
    532 
    533     self._name = method_name
    534     self._call_queue = call_queue
    535     if not isinstance(call_queue, deque):
    536       self._call_queue = deque(self._call_queue)
    537     self._replay_mode = replay_mode
    538 
    539     self._params = None
    540     self._named_params = None
    541     self._return_value = None
    542     self._exception = None
    543     self._side_effects = None
    544 
    545   def __call__(self, *params, **named_params):
    546     """Log parameters and return the specified return value.
    547 
    548     If the Mock(Anything/Object) associated with this call is in record mode,
    549     this MockMethod will be pushed onto the expected call queue.  If the mock
    550     is in replay mode, this will pop a MockMethod off the top of the queue and
    551     verify this call is equal to the expected call.
    552 
    553     Raises:
    554       UnexpectedMethodCall if this call is supposed to match an expected method
    555         call and it does not.
    556     """
    557 
    558     self._params = params
    559     self._named_params = named_params
    560 
    561     if not self._replay_mode:
    562       self._call_queue.append(self)
    563       return self
    564 
    565     expected_method = self._VerifyMethodCall()
    566 
    567     if expected_method._side_effects:
    568       expected_method._side_effects(*params, **named_params)
    569 
    570     if expected_method._exception:
    571       raise expected_method._exception
    572 
    573     return expected_method._return_value
    574 
    575   def __getattr__(self, name):
    576     """Raise an AttributeError with a helpful message."""
    577 
    578     raise AttributeError('MockMethod has no attribute "%s". '
    579         'Did you remember to put your mocks in replay mode?' % name)
    580 
    581   def _PopNextMethod(self):
    582     """Pop the next method from our call queue."""
    583     try:
    584       return self._call_queue.popleft()
    585     except IndexError:
    586       raise UnexpectedMethodCallError(self, None)
    587 
    588   def _VerifyMethodCall(self):
    589     """Verify the called method is expected.
    590 
    591     This can be an ordered method, or part of an unordered set.
    592 
    593     Returns:
    594       The expected mock method.
    595 
    596     Raises:
    597       UnexpectedMethodCall if the method called was not expected.
    598     """
    599 
    600     expected = self._PopNextMethod()
    601 
    602     # Loop here, because we might have a MethodGroup followed by another
    603     # group.
    604     while isinstance(expected, MethodGroup):
    605       expected, method = expected.MethodCalled(self)
    606       if method is not None:
    607         return method
    608 
    609     # This is a mock method, so just check equality.
    610     if expected != self:
    611       raise UnexpectedMethodCallError(self, expected)
    612 
    613     return expected
    614 
    615   def __str__(self):
    616     params = ', '.join(
    617         [repr(p) for p in self._params or []] +
    618         ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
    619     desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
    620     return desc
    621 
    622   def __eq__(self, rhs):
    623     """Test whether this MockMethod is equivalent to another MockMethod.
    624 
    625     Args:
    626       # rhs: the right hand side of the test
    627       rhs: MockMethod
    628     """
    629 
    630     return (isinstance(rhs, MockMethod) and
    631             self._name == rhs._name and
    632             self._params == rhs._params and
    633             self._named_params == rhs._named_params)
    634 
    635   def __ne__(self, rhs):
    636     """Test whether this MockMethod is not equivalent to another MockMethod.
    637 
    638     Args:
    639       # rhs: the right hand side of the test
    640       rhs: MockMethod
    641     """
    642 
    643     return not self == rhs
    644 
    645   def GetPossibleGroup(self):
    646     """Returns a possible group from the end of the call queue or None if no
    647     other methods are on the stack.
    648     """
    649 
    650     # Remove this method from the tail of the queue so we can add it to a group.
    651     this_method = self._call_queue.pop()
    652     assert this_method == self
    653 
    654     # Determine if the tail of the queue is a group, or just a regular ordered
    655     # mock method.
    656     group = None
    657     try:
    658       group = self._call_queue[-1]
    659     except IndexError:
    660       pass
    661 
    662     return group
    663 
    664   def _CheckAndCreateNewGroup(self, group_name, group_class):
    665     """Checks if the last method (a possible group) is an instance of our
    666     group_class. Adds the current method to this group or creates a new one.
    667 
    668     Args:
    669 
    670       group_name: the name of the group.
    671       group_class: the class used to create instance of this new group
    672     """
    673     group = self.GetPossibleGroup()
    674 
    675     # If this is a group, and it is the correct group, add the method.
    676     if isinstance(group, group_class) and group.group_name() == group_name:
    677       group.AddMethod(self)
    678       return self
    679 
    680     # Create a new group and add the method.
    681     new_group = group_class(group_name)
    682     new_group.AddMethod(self)
    683     self._call_queue.append(new_group)
    684     return self
    685 
    686   def InAnyOrder(self, group_name="default"):
    687     """Move this method into a group of unordered calls.
    688 
    689     A group of unordered calls must be defined together, and must be executed
    690     in full before the next expected method can be called.  There can be
    691     multiple groups that are expected serially, if they are given
    692     different group names.  The same group name can be reused if there is a
    693     standard method call, or a group with a different name, spliced between
    694     usages.
    695 
    696     Args:
    697       group_name: the name of the unordered group.
    698 
    699     Returns:
    700       self
    701     """
    702     return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
    703 
    704   def MultipleTimes(self, group_name="default"):
    705     """Move this method into group of calls which may be called multiple times.
    706 
    707     A group of repeating calls must be defined together, and must be executed in
    708     full before the next expected mehtod can be called.
    709 
    710     Args:
    711       group_name: the name of the unordered group.
    712 
    713     Returns:
    714       self
    715     """
    716     return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
    717 
    718   def AndReturn(self, return_value):
    719     """Set the value to return when this method is called.
    720 
    721     Args:
    722       # return_value can be anything.
    723     """
    724 
    725     self._return_value = return_value
    726     return return_value
    727 
    728   def AndRaise(self, exception):
    729     """Set the exception to raise when this method is called.
    730 
    731     Args:
    732       # exception: the exception to raise when this method is called.
    733       exception: Exception
    734     """
    735 
    736     self._exception = exception
    737 
    738   def WithSideEffects(self, side_effects):
    739     """Set the side effects that are simulated when this method is called.
    740 
    741     Args:
    742       side_effects: A callable which modifies the parameters or other relevant
    743         state which a given test case depends on.
    744 
    745     Returns:
    746       Self for chaining with AndReturn and AndRaise.
    747     """
    748     self._side_effects = side_effects
    749     return self
    750 
    751 class Comparator:
    752   """Base class for all Mox comparators.
    753 
    754   A Comparator can be used as a parameter to a mocked method when the exact
    755   value is not known.  For example, the code you are testing might build up a
    756   long SQL string that is passed to your mock DAO. You're only interested that
    757   the IN clause contains the proper primary keys, so you can set your mock
    758   up as follows:
    759 
    760   mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
    761 
    762   Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
    763 
    764   A Comparator may replace one or more parameters, for example:
    765   # return at most 10 rows
    766   mock_dao.RunQuery(StrContains('SELECT'), 10)
    767 
    768   or
    769 
    770   # Return some non-deterministic number of rows
    771   mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
    772   """
    773 
    774   def equals(self, rhs):
    775     """Special equals method that all comparators must implement.
    776 
    777     Args:
    778       rhs: any python object
    779     """
    780 
    781     raise NotImplementedError, 'method must be implemented by a subclass.'
    782 
    783   def __eq__(self, rhs):
    784     return self.equals(rhs)
    785 
    786   def __ne__(self, rhs):
    787     return not self.equals(rhs)
    788 
    789 
    790 class IsA(Comparator):
    791   """This class wraps a basic Python type or class.  It is used to verify
    792   that a parameter is of the given type or class.
    793 
    794   Example:
    795   mock_dao.Connect(IsA(DbConnectInfo))
    796   """
    797 
    798   def __init__(self, class_name):
    799     """Initialize IsA
    800 
    801     Args:
    802       class_name: basic python type or a class
    803     """
    804 
    805     self._class_name = class_name
    806 
    807   def equals(self, rhs):
    808     """Check to see if the RHS is an instance of class_name.
    809 
    810     Args:
    811       # rhs: the right hand side of the test
    812       rhs: object
    813 
    814     Returns:
    815       bool
    816     """
    817 
    818     try:
    819       return isinstance(rhs, self._class_name)
    820     except TypeError:
    821       # Check raw types if there was a type error.  This is helpful for
    822       # things like cStringIO.StringIO.
    823       return type(rhs) == type(self._class_name)
    824 
    825   def __repr__(self):
    826     return str(self._class_name)
    827 
    828 class IsAlmost(Comparator):
    829   """Comparison class used to check whether a parameter is nearly equal
    830   to a given value.  Generally useful for floating point numbers.
    831 
    832   Example mock_dao.SetTimeout((IsAlmost(3.9)))
    833   """
    834 
    835   def __init__(self, float_value, places=7):
    836     """Initialize IsAlmost.
    837 
    838     Args:
    839       float_value: The value for making the comparison.
    840       places: The number of decimal places to round to.
    841     """
    842 
    843     self._float_value = float_value
    844     self._places = places
    845 
    846   def equals(self, rhs):
    847     """Check to see if RHS is almost equal to float_value
    848 
    849     Args:
    850       rhs: the value to compare to float_value
    851 
    852     Returns:
    853       bool
    854     """
    855 
    856     try:
    857       return round(rhs-self._float_value, self._places) == 0
    858     except TypeError:
    859       # This is probably because either float_value or rhs is not a number.
    860       return False
    861 
    862   def __repr__(self):
    863     return str(self._float_value)
    864 
    865 class StrContains(Comparator):
    866   """Comparison class used to check whether a substring exists in a
    867   string parameter.  This can be useful in mocking a database with SQL
    868   passed in as a string parameter, for example.
    869 
    870   Example:
    871   mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
    872   """
    873 
    874   def __init__(self, search_string):
    875     """Initialize.
    876 
    877     Args:
    878       # search_string: the string you are searching for
    879       search_string: str
    880     """
    881 
    882     self._search_string = search_string
    883 
    884   def equals(self, rhs):
    885     """Check to see if the search_string is contained in the rhs string.
    886 
    887     Args:
    888       # rhs: the right hand side of the test
    889       rhs: object
    890 
    891     Returns:
    892       bool
    893     """
    894 
    895     try:
    896       return rhs.find(self._search_string) > -1
    897     except Exception:
    898       return False
    899 
    900   def __repr__(self):
    901     return '<str containing \'%s\'>' % self._search_string
    902 
    903 
    904 class Regex(Comparator):
    905   """Checks if a string matches a regular expression.
    906 
    907   This uses a given regular expression to determine equality.
    908   """
    909 
    910   def __init__(self, pattern, flags=0):
    911     """Initialize.
    912 
    913     Args:
    914       # pattern is the regular expression to search for
    915       pattern: str
    916       # flags passed to re.compile function as the second argument
    917       flags: int
    918     """
    919 
    920     self.regex = re.compile(pattern, flags=flags)
    921 
    922   def equals(self, rhs):
    923     """Check to see if rhs matches regular expression pattern.
    924 
    925     Returns:
    926       bool
    927     """
    928 
    929     return self.regex.search(rhs) is not None
    930 
    931   def __repr__(self):
    932     s = '<regular expression \'%s\'' % self.regex.pattern
    933     if self.regex.flags:
    934       s += ', flags=%d' % self.regex.flags
    935     s += '>'
    936     return s
    937 
    938 
    939 class In(Comparator):
    940   """Checks whether an item (or key) is in a list (or dict) parameter.
    941 
    942   Example:
    943   mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
    944   """
    945 
    946   def __init__(self, key):
    947     """Initialize.
    948 
    949     Args:
    950       # key is any thing that could be in a list or a key in a dict
    951     """
    952 
    953     self._key = key
    954 
    955   def equals(self, rhs):
    956     """Check to see whether key is in rhs.
    957 
    958     Args:
    959       rhs: dict
    960 
    961     Returns:
    962       bool
    963     """
    964 
    965     return self._key in rhs
    966 
    967   def __repr__(self):
    968     return '<sequence or map containing \'%s\'>' % self._key
    969 
    970 
    971 class ContainsKeyValue(Comparator):
    972   """Checks whether a key/value pair is in a dict parameter.
    973 
    974   Example:
    975   mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
    976   """
    977 
    978   def __init__(self, key, value):
    979     """Initialize.
    980 
    981     Args:
    982       # key: a key in a dict
    983       # value: the corresponding value
    984     """
    985 
    986     self._key = key
    987     self._value = value
    988 
    989   def equals(self, rhs):
    990     """Check whether the given key/value pair is in the rhs dict.
    991 
    992     Returns:
    993       bool
    994     """
    995 
    996     try:
    997       return rhs[self._key] == self._value
    998     except Exception:
    999       return False
   1000 
   1001   def __repr__(self):
   1002     return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
   1003 
   1004 
   1005 class SameElementsAs(Comparator):
   1006   """Checks whether iterables contain the same elements (ignoring order).
   1007 
   1008   Example:
   1009   mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
   1010   """
   1011 
   1012   def __init__(self, expected_seq):
   1013     """Initialize.
   1014 
   1015     Args:
   1016       expected_seq: a sequence
   1017     """
   1018 
   1019     self._expected_seq = expected_seq
   1020 
   1021   def equals(self, actual_seq):
   1022     """Check to see whether actual_seq has same elements as expected_seq.
   1023 
   1024     Args:
   1025       actual_seq: sequence
   1026 
   1027     Returns:
   1028       bool
   1029     """
   1030 
   1031     try:
   1032       expected = dict([(element, None) for element in self._expected_seq])
   1033       actual = dict([(element, None) for element in actual_seq])
   1034     except TypeError:
   1035       # Fall back to slower list-compare if any of the objects are unhashable.
   1036       expected = list(self._expected_seq)
   1037       actual = list(actual_seq)
   1038       expected.sort()
   1039       actual.sort()
   1040     return expected == actual
   1041 
   1042   def __repr__(self):
   1043     return '<sequence with same elements as \'%s\'>' % self._expected_seq
   1044 
   1045 
   1046 class And(Comparator):
   1047   """Evaluates one or more Comparators on RHS and returns an AND of the results.
   1048   """
   1049 
   1050   def __init__(self, *args):
   1051     """Initialize.
   1052 
   1053     Args:
   1054       *args: One or more Comparator
   1055     """
   1056 
   1057     self._comparators = args
   1058 
   1059   def equals(self, rhs):
   1060     """Checks whether all Comparators are equal to rhs.
   1061 
   1062     Args:
   1063       # rhs: can be anything
   1064 
   1065     Returns:
   1066       bool
   1067     """
   1068 
   1069     for comparator in self._comparators:
   1070       if not comparator.equals(rhs):
   1071         return False
   1072 
   1073     return True
   1074 
   1075   def __repr__(self):
   1076     return '<AND %s>' % str(self._comparators)
   1077 
   1078 
   1079 class Or(Comparator):
   1080   """Evaluates one or more Comparators on RHS and returns an OR of the results.
   1081   """
   1082 
   1083   def __init__(self, *args):
   1084     """Initialize.
   1085 
   1086     Args:
   1087       *args: One or more Mox comparators
   1088     """
   1089 
   1090     self._comparators = args
   1091 
   1092   def equals(self, rhs):
   1093     """Checks whether any Comparator is equal to rhs.
   1094 
   1095     Args:
   1096       # rhs: can be anything
   1097 
   1098     Returns:
   1099       bool
   1100     """
   1101 
   1102     for comparator in self._comparators:
   1103       if comparator.equals(rhs):
   1104         return True
   1105 
   1106     return False
   1107 
   1108   def __repr__(self):
   1109     return '<OR %s>' % str(self._comparators)
   1110 
   1111 
   1112 class Func(Comparator):
   1113   """Call a function that should verify the parameter passed in is correct.
   1114 
   1115   You may need the ability to perform more advanced operations on the parameter
   1116   in order to validate it.  You can use this to have a callable validate any
   1117   parameter. The callable should return either True or False.
   1118 
   1119 
   1120   Example:
   1121 
   1122   def myParamValidator(param):
   1123     # Advanced logic here
   1124     return True
   1125 
   1126   mock_dao.DoSomething(Func(myParamValidator), true)
   1127   """
   1128 
   1129   def __init__(self, func):
   1130     """Initialize.
   1131 
   1132     Args:
   1133       func: callable that takes one parameter and returns a bool
   1134     """
   1135 
   1136     self._func = func
   1137 
   1138   def equals(self, rhs):
   1139     """Test whether rhs passes the function test.
   1140 
   1141     rhs is passed into func.
   1142 
   1143     Args:
   1144       rhs: any python object
   1145 
   1146     Returns:
   1147       the result of func(rhs)
   1148     """
   1149 
   1150     return self._func(rhs)
   1151 
   1152   def __repr__(self):
   1153     return str(self._func)
   1154 
   1155 
   1156 class IgnoreArg(Comparator):
   1157   """Ignore an argument.
   1158 
   1159   This can be used when we don't care about an argument of a method call.
   1160 
   1161   Example:
   1162   # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
   1163   mymock.CastMagic(3, IgnoreArg(), 'disappear')
   1164   """
   1165 
   1166   def equals(self, unused_rhs):
   1167     """Ignores arguments and returns True.
   1168 
   1169     Args:
   1170       unused_rhs: any python object
   1171 
   1172     Returns:
   1173       always returns True
   1174     """
   1175 
   1176     return True
   1177 
   1178   def __repr__(self):
   1179     return '<IgnoreArg>'
   1180 
   1181 
   1182 class MethodGroup(object):
   1183   """Base class containing common behaviour for MethodGroups."""
   1184 
   1185   def __init__(self, group_name):
   1186     self._group_name = group_name
   1187 
   1188   def group_name(self):
   1189     return self._group_name
   1190 
   1191   def __str__(self):
   1192     return '<%s "%s">' % (self.__class__.__name__, self._group_name)
   1193 
   1194   def AddMethod(self, mock_method):
   1195     raise NotImplementedError
   1196 
   1197   def MethodCalled(self, mock_method):
   1198     raise NotImplementedError
   1199 
   1200   def IsSatisfied(self):
   1201     raise NotImplementedError
   1202 
   1203 class UnorderedGroup(MethodGroup):
   1204   """UnorderedGroup holds a set of method calls that may occur in any order.
   1205 
   1206   This construct is helpful for non-deterministic events, such as iterating
   1207   over the keys of a dict.
   1208   """
   1209 
   1210   def __init__(self, group_name):
   1211     super(UnorderedGroup, self).__init__(group_name)
   1212     self._methods = []
   1213 
   1214   def AddMethod(self, mock_method):
   1215     """Add a method to this group.
   1216 
   1217     Args:
   1218       mock_method: A mock method to be added to this group.
   1219     """
   1220 
   1221     self._methods.append(mock_method)
   1222 
   1223   def MethodCalled(self, mock_method):
   1224     """Remove a method call from the group.
   1225 
   1226     If the method is not in the set, an UnexpectedMethodCallError will be
   1227     raised.
   1228 
   1229     Args:
   1230       mock_method: a mock method that should be equal to a method in the group.
   1231 
   1232     Returns:
   1233       The mock method from the group
   1234 
   1235     Raises:
   1236       UnexpectedMethodCallError if the mock_method was not in the group.
   1237     """
   1238 
   1239     # Check to see if this method exists, and if so, remove it from the set
   1240     # and return it.
   1241     for method in self._methods:
   1242       if method == mock_method:
   1243         # Remove the called mock_method instead of the method in the group.
   1244         # The called method will match any comparators when equality is checked
   1245         # during removal.  The method in the group could pass a comparator to
   1246         # another comparator during the equality check.
   1247         self._methods.remove(mock_method)
   1248 
   1249         # If this group is not empty, put it back at the head of the queue.
   1250         if not self.IsSatisfied():
   1251           mock_method._call_queue.appendleft(self)
   1252 
   1253         return self, method
   1254 
   1255     raise UnexpectedMethodCallError(mock_method, self)
   1256 
   1257   def IsSatisfied(self):
   1258     """Return True if there are not any methods in this group."""
   1259 
   1260     return len(self._methods) == 0
   1261 
   1262 
   1263 class MultipleTimesGroup(MethodGroup):
   1264   """MultipleTimesGroup holds methods that may be called any number of times.
   1265 
   1266   Note: Each method must be called at least once.
   1267 
   1268   This is helpful, if you don't know or care how many times a method is called.
   1269   """
   1270 
   1271   def __init__(self, group_name):
   1272     super(MultipleTimesGroup, self).__init__(group_name)
   1273     self._methods = set()
   1274     self._methods_called = set()
   1275 
   1276   def AddMethod(self, mock_method):
   1277     """Add a method to this group.
   1278 
   1279     Args:
   1280       mock_method: A mock method to be added to this group.
   1281     """
   1282 
   1283     self._methods.add(mock_method)
   1284 
   1285   def MethodCalled(self, mock_method):
   1286     """Remove a method call from the group.
   1287 
   1288     If the method is not in the set, an UnexpectedMethodCallError will be
   1289     raised.
   1290 
   1291     Args:
   1292       mock_method: a mock method that should be equal to a method in the group.
   1293 
   1294     Returns:
   1295       The mock method from the group
   1296 
   1297     Raises:
   1298       UnexpectedMethodCallError if the mock_method was not in the group.
   1299     """
   1300 
   1301     # Check to see if this method exists, and if so add it to the set of
   1302     # called methods.
   1303 
   1304     for method in self._methods:
   1305       if method == mock_method:
   1306         self._methods_called.add(mock_method)
   1307         # Always put this group back on top of the queue, because we don't know
   1308         # when we are done.
   1309         mock_method._call_queue.appendleft(self)
   1310         return self, method
   1311 
   1312     if self.IsSatisfied():
   1313       next_method = mock_method._PopNextMethod();
   1314       return next_method, None
   1315     else:
   1316       raise UnexpectedMethodCallError(mock_method, self)
   1317 
   1318   def IsSatisfied(self):
   1319     """Return True if all methods in this group are called at least once."""
   1320     # NOTE(psycho): We can't use the simple set difference here because we want
   1321     # to match different parameters which are considered the same e.g. IsA(str)
   1322     # and some string. This solution is O(n^2) but n should be small.
   1323     tmp = self._methods.copy()
   1324     for called in self._methods_called:
   1325       for expected in tmp:
   1326         if called == expected:
   1327           tmp.remove(expected)
   1328           if not tmp:
   1329             return True
   1330           break
   1331     return False
   1332 
   1333 
   1334 class MoxMetaTestBase(type):
   1335   """Metaclass to add mox cleanup and verification to every test.
   1336 
   1337   As the mox unit testing class is being constructed (MoxTestBase or a
   1338   subclass), this metaclass will modify all test functions to call the
   1339   CleanUpMox method of the test class after they finish. This means that
   1340   unstubbing and verifying will happen for every test with no additional code,
   1341   and any failures will result in test failures as opposed to errors.
   1342   """
   1343 
   1344   def __init__(cls, name, bases, d):
   1345     type.__init__(cls, name, bases, d)
   1346 
   1347     # also get all the attributes from the base classes to account
   1348     # for a case when test class is not the immediate child of MoxTestBase
   1349     for base in bases:
   1350       for attr_name in dir(base):
   1351         d[attr_name] = getattr(base, attr_name)
   1352 
   1353     for func_name, func in d.items():
   1354       if func_name.startswith('test') and callable(func):
   1355         setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
   1356 
   1357   @staticmethod
   1358   def CleanUpTest(cls, func):
   1359     """Adds Mox cleanup code to any MoxTestBase method.
   1360 
   1361     Always unsets stubs after a test. Will verify all mocks for tests that
   1362     otherwise pass.
   1363 
   1364     Args:
   1365       cls: MoxTestBase or subclass; the class whose test method we are altering.
   1366       func: method; the method of the MoxTestBase test class we wish to alter.
   1367 
   1368     Returns:
   1369       The modified method.
   1370     """
   1371     def new_method(self, *args, **kwargs):
   1372       mox_obj = getattr(self, 'mox', None)
   1373       cleanup_mox = False
   1374       if mox_obj and isinstance(mox_obj, Mox):
   1375         cleanup_mox = True
   1376       try:
   1377         func(self, *args, **kwargs)
   1378       finally:
   1379         if cleanup_mox:
   1380           mox_obj.UnsetStubs()
   1381       if cleanup_mox:
   1382         mox_obj.VerifyAll()
   1383     new_method.__name__ = func.__name__
   1384     new_method.__doc__ = func.__doc__
   1385     new_method.__module__ = func.__module__
   1386     return new_method
   1387 
   1388 
   1389 class MoxTestBase(unittest.TestCase):
   1390   """Convenience test class to make stubbing easier.
   1391 
   1392   Sets up a "mox" attribute which is an instance of Mox - any mox tests will
   1393   want this. Also automatically unsets any stubs and verifies that all mock
   1394   methods have been called at the end of each test, eliminating boilerplate
   1395   code.
   1396   """
   1397 
   1398   __metaclass__ = MoxMetaTestBase
   1399 
   1400   def setUp(self):
   1401     self.mox = Mox()
   1402