Home | History | Annotate | Download | only in test_utils
      1 __author__ = "raphtee (at] google.com (Travis Miller)"
      2 
      3 
      4 import re, collections, StringIO, sys, unittest
      5 
      6 
      7 class StubNotFoundError(Exception):
      8     'Raised when god is asked to unstub an attribute that was not stubbed'
      9     pass
     10 
     11 
     12 class CheckPlaybackError(Exception):
     13     'Raised when mock playback does not match recorded calls.'
     14     pass
     15 
     16 
     17 class SaveDataAfterCloseStringIO(StringIO.StringIO):
     18     """Saves the contents in a final_data property when close() is called.
     19 
     20     Useful as a mock output file object to test both that the file was
     21     closed and what was written.
     22 
     23     Properties:
     24       final_data: Set to the StringIO's getvalue() data when close() is
     25           called.  None if close() has not been called.
     26     """
     27     final_data = None
     28 
     29     def close(self):
     30         self.final_data = self.getvalue()
     31         StringIO.StringIO.close(self)
     32 
     33 
     34 
     35 class argument_comparator(object):
     36     def is_satisfied_by(self, parameter):
     37         raise NotImplementedError
     38 
     39 
     40 class equality_comparator(argument_comparator):
     41     def __init__(self, value):
     42         self.value = value
     43 
     44 
     45     @staticmethod
     46     def _types_match(arg1, arg2):
     47         if isinstance(arg1, basestring) and isinstance(arg2, basestring):
     48             return True
     49         return type(arg1) == type(arg2)
     50 
     51 
     52     @classmethod
     53     def _compare(cls, actual_arg, expected_arg):
     54         if isinstance(expected_arg, argument_comparator):
     55             return expected_arg.is_satisfied_by(actual_arg)
     56         if not cls._types_match(expected_arg, actual_arg):
     57             return False
     58 
     59         if isinstance(expected_arg, list) or isinstance(expected_arg, tuple):
     60             # recurse on lists/tuples
     61             if len(actual_arg) != len(expected_arg):
     62                 return False
     63             for actual_item, expected_item in zip(actual_arg, expected_arg):
     64                 if not cls._compare(actual_item, expected_item):
     65                     return False
     66         elif isinstance(expected_arg, dict):
     67             # recurse on dicts
     68             if not cls._compare(sorted(actual_arg.keys()),
     69                                 sorted(expected_arg.keys())):
     70                 return False
     71             for key, value in actual_arg.iteritems():
     72                 if not cls._compare(value, expected_arg[key]):
     73                     return False
     74         elif actual_arg != expected_arg:
     75             return False
     76 
     77         return True
     78 
     79 
     80     def is_satisfied_by(self, parameter):
     81         return self._compare(parameter, self.value)
     82 
     83 
     84     def __str__(self):
     85         if isinstance(self.value, argument_comparator):
     86             return str(self.value)
     87         return repr(self.value)
     88 
     89 
     90 class regex_comparator(argument_comparator):
     91     def __init__(self, pattern, flags=0):
     92         self.regex = re.compile(pattern, flags)
     93 
     94 
     95     def is_satisfied_by(self, parameter):
     96         return self.regex.search(parameter) is not None
     97 
     98 
     99     def __str__(self):
    100         return self.regex.pattern
    101 
    102 
    103 class is_string_comparator(argument_comparator):
    104     def is_satisfied_by(self, parameter):
    105         return isinstance(parameter, basestring)
    106 
    107 
    108     def __str__(self):
    109         return "a string"
    110 
    111 
    112 class is_instance_comparator(argument_comparator):
    113     def __init__(self, cls):
    114         self.cls = cls
    115 
    116 
    117     def is_satisfied_by(self, parameter):
    118         return isinstance(parameter, self.cls)
    119 
    120 
    121     def __str__(self):
    122         return "is a %s" % self.cls
    123 
    124 
    125 class anything_comparator(argument_comparator):
    126     def is_satisfied_by(self, parameter):
    127         return True
    128 
    129 
    130     def __str__(self):
    131         return 'anything'
    132 
    133 
    134 class base_mapping(object):
    135     def __init__(self, symbol, return_obj, *args, **dargs):
    136         self.return_obj = return_obj
    137         self.symbol = symbol
    138         self.args = [equality_comparator(arg) for arg in args]
    139         self.dargs = dict((key, equality_comparator(value))
    140                           for key, value in dargs.iteritems())
    141         self.error = None
    142 
    143 
    144     def match(self, *args, **dargs):
    145         if len(args) != len(self.args) or len(dargs) != len(self.dargs):
    146             return False
    147 
    148         for i, expected_arg in enumerate(self.args):
    149             if not expected_arg.is_satisfied_by(args[i]):
    150                 return False
    151 
    152         # check for incorrect dargs
    153         for key, value in dargs.iteritems():
    154             if key not in self.dargs:
    155                 return False
    156             if not self.dargs[key].is_satisfied_by(value):
    157                 return False
    158 
    159         # check for missing dargs
    160         for key in self.dargs.iterkeys():
    161             if key not in dargs:
    162                 return False
    163 
    164         return True
    165 
    166 
    167     def __str__(self):
    168         return _dump_function_call(self.symbol, self.args, self.dargs)
    169 
    170 
    171 class function_mapping(base_mapping):
    172     def __init__(self, symbol, return_val, *args, **dargs):
    173         super(function_mapping, self).__init__(symbol, return_val, *args,
    174                                                **dargs)
    175 
    176 
    177     def and_return(self, return_obj):
    178         self.return_obj = return_obj
    179 
    180 
    181     def and_raises(self, error):
    182         self.error = error
    183 
    184 
    185 class function_any_args_mapping(function_mapping):
    186     """A mock function mapping that doesn't verify its arguments."""
    187     def match(self, *args, **dargs):
    188         return True
    189 
    190 
    191 class mock_function(object):
    192     def __init__(self, symbol, default_return_val=None,
    193                  record=None, playback=None):
    194         self.default_return_val = default_return_val
    195         self.num_calls = 0
    196         self.args = []
    197         self.dargs = []
    198         self.symbol = symbol
    199         self.record = record
    200         self.playback = playback
    201         self.__name__ = symbol
    202 
    203 
    204     def __call__(self, *args, **dargs):
    205         self.num_calls += 1
    206         self.args.append(args)
    207         self.dargs.append(dargs)
    208         if self.playback:
    209             return self.playback(self.symbol, *args, **dargs)
    210         else:
    211             return self.default_return_val
    212 
    213 
    214     def expect_call(self, *args, **dargs):
    215         mapping = function_mapping(self.symbol, None, *args, **dargs)
    216         if self.record:
    217             self.record(mapping)
    218 
    219         return mapping
    220 
    221 
    222     def expect_any_call(self):
    223         """Like expect_call but don't give a hoot what arguments are passed."""
    224         mapping = function_any_args_mapping(self.symbol, None)
    225         if self.record:
    226             self.record(mapping)
    227 
    228         return mapping
    229 
    230 
    231 class mask_function(mock_function):
    232     def __init__(self, symbol, original_function, default_return_val=None,
    233                  record=None, playback=None):
    234         super(mask_function, self).__init__(symbol,
    235                                             default_return_val,
    236                                             record, playback)
    237         self.original_function = original_function
    238 
    239 
    240     def run_original_function(self, *args, **dargs):
    241         return self.original_function(*args, **dargs)
    242 
    243 
    244 class mock_class(object):
    245     def __init__(self, cls, name, default_ret_val=None,
    246                  record=None, playback=None):
    247         self.__name = name
    248         self.__record = record
    249         self.__playback = playback
    250 
    251         for symbol in dir(cls):
    252             if symbol.startswith("_"):
    253                 continue
    254 
    255             orig_symbol = getattr(cls, symbol)
    256             if callable(orig_symbol):
    257                 f_name = "%s.%s" % (self.__name, symbol)
    258                 func = mock_function(f_name, default_ret_val,
    259                                      self.__record, self.__playback)
    260                 setattr(self, symbol, func)
    261             else:
    262                 setattr(self, symbol, orig_symbol)
    263 
    264 
    265     def __repr__(self):
    266         return '<mock_class: %s>' % self.__name
    267 
    268 
    269 class mock_god(object):
    270     NONEXISTENT_ATTRIBUTE = object()
    271 
    272     def __init__(self, debug=False, fail_fast=True, ut=None):
    273         """
    274         With debug=True, all recorded method calls will be printed as
    275         they happen.
    276         With fail_fast=True, unexpected calls will immediately cause an
    277         exception to be raised.  With False, they will be silently recorded and
    278         only reported when check_playback() is called.
    279         """
    280         self.recording = collections.deque()
    281         self.errors = []
    282         self._stubs = []
    283         self._debug = debug
    284         self._fail_fast = fail_fast
    285         self._ut = ut
    286 
    287 
    288     def set_fail_fast(self, fail_fast):
    289         self._fail_fast = fail_fast
    290 
    291 
    292     def create_mock_class_obj(self, cls, name, default_ret_val=None):
    293         record = self.__record_call
    294         playback = self.__method_playback
    295         errors = self.errors
    296 
    297         class cls_sub(cls):
    298             cls_count = 0
    299 
    300             # overwrite the initializer
    301             def __init__(self, *args, **dargs):
    302                 pass
    303 
    304 
    305             @classmethod
    306             def expect_new(typ, *args, **dargs):
    307                 obj = typ.make_new(*args, **dargs)
    308                 mapping = base_mapping(name, obj, *args, **dargs)
    309                 record(mapping)
    310                 return obj
    311 
    312 
    313             def __new__(typ, *args, **dargs):
    314                 return playback(name, *args, **dargs)
    315 
    316 
    317             @classmethod
    318             def make_new(typ, *args, **dargs):
    319                 obj = super(cls_sub, typ).__new__(typ, *args,
    320                                                   **dargs)
    321 
    322                 typ.cls_count += 1
    323                 obj_name = "%s_%s" % (name, typ.cls_count)
    324                 for symbol in dir(obj):
    325                     if (symbol.startswith("__") and
    326                         symbol.endswith("__")):
    327                         continue
    328 
    329                     if isinstance(getattr(typ, symbol, None), property):
    330                         continue
    331 
    332                     orig_symbol = getattr(obj, symbol)
    333                     if callable(orig_symbol):
    334                         f_name = ("%s.%s" %
    335                                   (obj_name, symbol))
    336                         func = mock_function(f_name,
    337                                         default_ret_val,
    338                                         record,
    339                                         playback)
    340                         setattr(obj, symbol, func)
    341                     else:
    342                         setattr(obj, symbol,
    343                                 orig_symbol)
    344 
    345                 return obj
    346 
    347         return cls_sub
    348 
    349 
    350     def create_mock_class(self, cls, name, default_ret_val=None):
    351         """
    352         Given something that defines a namespace cls (class, object,
    353         module), and a (hopefully unique) name, will create a
    354         mock_class object with that name and that possessess all
    355         the public attributes of cls.  default_ret_val sets the
    356         default_ret_val on all methods of the cls mock.
    357         """
    358         return mock_class(cls, name, default_ret_val,
    359                           self.__record_call, self.__method_playback)
    360 
    361 
    362     def create_mock_function(self, symbol, default_return_val=None):
    363         """
    364         create a mock_function with name symbol and default return
    365         value of default_ret_val.
    366         """
    367         return mock_function(symbol, default_return_val,
    368                              self.__record_call, self.__method_playback)
    369 
    370 
    371     def mock_up(self, obj, name, default_ret_val=None):
    372         """
    373         Given an object (class instance or module) and a registration
    374         name, then replace all its methods with mock function objects
    375         (passing the orignal functions to the mock functions).
    376         """
    377         for symbol in dir(obj):
    378             if symbol.startswith("__"):
    379                 continue
    380 
    381             orig_symbol = getattr(obj, symbol)
    382             if callable(orig_symbol):
    383                 f_name = "%s.%s" % (name, symbol)
    384                 func = mask_function(f_name, orig_symbol,
    385                                      default_ret_val,
    386                                      self.__record_call,
    387                                      self.__method_playback)
    388                 setattr(obj, symbol, func)
    389 
    390 
    391     def stub_with(self, namespace, symbol, new_attribute):
    392         original_attribute = getattr(namespace, symbol,
    393                                      self.NONEXISTENT_ATTRIBUTE)
    394 
    395         # You only want to save the original attribute in cases where it is
    396         # directly associated with the object in question. In cases where
    397         # the attribute is actually inherited via some sort of hierarchy
    398         # you want to delete the stub (restoring the original structure)
    399         attribute_is_inherited = (hasattr(namespace, '__dict__') and
    400                                   symbol not in namespace.__dict__)
    401         if attribute_is_inherited:
    402             original_attribute = self.NONEXISTENT_ATTRIBUTE
    403 
    404         newstub = (namespace, symbol, original_attribute, new_attribute)
    405         self._stubs.append(newstub)
    406         setattr(namespace, symbol, new_attribute)
    407 
    408 
    409     def stub_function(self, namespace, symbol):
    410         mock_attribute = self.create_mock_function(symbol)
    411         self.stub_with(namespace, symbol, mock_attribute)
    412 
    413 
    414     def stub_class_method(self, cls, symbol):
    415         mock_attribute = self.create_mock_function(symbol)
    416         self.stub_with(cls, symbol, staticmethod(mock_attribute))
    417 
    418 
    419     def stub_class(self, namespace, symbol):
    420         attr = getattr(namespace, symbol)
    421         mock_class = self.create_mock_class_obj(attr, symbol)
    422         self.stub_with(namespace, symbol, mock_class)
    423 
    424 
    425     def stub_function_to_return(self, namespace, symbol, object_to_return):
    426         """Stub out a function with one that always returns a fixed value.
    427 
    428         @param namespace The namespace containing the function to stub out.
    429         @param symbol The attribute within the namespace to stub out.
    430         @param object_to_return The value that the stub should return whenever
    431             it is called.
    432         """
    433         self.stub_with(namespace, symbol,
    434                        lambda *args, **dargs: object_to_return)
    435 
    436 
    437     def _perform_unstub(self, stub):
    438         namespace, symbol, orig_attr, new_attr = stub
    439         if orig_attr == self.NONEXISTENT_ATTRIBUTE:
    440             delattr(namespace, symbol)
    441         else:
    442             setattr(namespace, symbol, orig_attr)
    443 
    444 
    445     def unstub(self, namespace, symbol):
    446         for stub in reversed(self._stubs):
    447             if (namespace, symbol) == (stub[0], stub[1]):
    448                 self._perform_unstub(stub)
    449                 self._stubs.remove(stub)
    450                 return
    451 
    452         raise StubNotFoundError()
    453 
    454 
    455     def unstub_all(self):
    456         self._stubs.reverse()
    457         for stub in self._stubs:
    458             self._perform_unstub(stub)
    459         self._stubs = []
    460 
    461 
    462     def __method_playback(self, symbol, *args, **dargs):
    463         if self._debug:
    464             print >> sys.__stdout__, (' * Mock call: ' +
    465                                       _dump_function_call(symbol, args, dargs))
    466 
    467         if len(self.recording) != 0:
    468             func_call = self.recording[0]
    469             if func_call.symbol != symbol:
    470                 msg = ("Unexpected call: %s\nExpected: %s"
    471                     % (_dump_function_call(symbol, args, dargs),
    472                        func_call))
    473                 self._append_error(msg)
    474                 return None
    475 
    476             if not func_call.match(*args, **dargs):
    477                 msg = ("Incorrect call: %s\nExpected: %s"
    478                     % (_dump_function_call(symbol, args, dargs),
    479                       func_call))
    480                 self._append_error(msg)
    481                 return None
    482 
    483             # this is the expected call so pop it and return
    484             self.recording.popleft()
    485             if func_call.error:
    486                 raise func_call.error
    487             else:
    488                 return func_call.return_obj
    489         else:
    490             msg = ("unexpected call: %s"
    491                    % (_dump_function_call(symbol, args, dargs)))
    492             self._append_error(msg)
    493             return None
    494 
    495 
    496     def __record_call(self, mapping):
    497         self.recording.append(mapping)
    498 
    499 
    500     def _append_error(self, error):
    501         if self._debug:
    502             print >> sys.__stdout__, ' *** ' + error
    503         if self._fail_fast:
    504             raise CheckPlaybackError(error)
    505         self.errors.append(error)
    506 
    507 
    508     def check_playback(self):
    509         """
    510         Report any errors that were encounterd during calls
    511         to __method_playback().
    512         """
    513         if len(self.errors) > 0:
    514             if self._debug:
    515                 print '\nPlayback errors:'
    516             for error in self.errors:
    517                 print >> sys.__stdout__, error
    518 
    519             if self._ut:
    520                 self._ut.fail('\n'.join(self.errors))
    521 
    522             raise CheckPlaybackError
    523         elif len(self.recording) != 0:
    524             errors = []
    525             for func_call in self.recording:
    526                 error = "%s not called" % (func_call,)
    527                 errors.append(error)
    528                 print >> sys.__stdout__, error
    529 
    530             if self._ut:
    531                 self._ut.fail('\n'.join(errors))
    532 
    533             raise CheckPlaybackError
    534         self.recording.clear()
    535 
    536 
    537     def mock_io(self):
    538         """Mocks and saves the stdout & stderr output"""
    539         self.orig_stdout = sys.stdout
    540         self.orig_stderr = sys.stderr
    541 
    542         self.mock_streams_stdout = StringIO.StringIO('')
    543         self.mock_streams_stderr = StringIO.StringIO('')
    544 
    545         sys.stdout = self.mock_streams_stdout
    546         sys.stderr = self.mock_streams_stderr
    547 
    548 
    549     def unmock_io(self):
    550         """Restores the stdout & stderr, and returns both
    551         output strings"""
    552         sys.stdout = self.orig_stdout
    553         sys.stderr = self.orig_stderr
    554         values = (self.mock_streams_stdout.getvalue(),
    555                   self.mock_streams_stderr.getvalue())
    556 
    557         self.mock_streams_stdout.close()
    558         self.mock_streams_stderr.close()
    559         return values
    560 
    561 
    562 def _arg_to_str(arg):
    563     if isinstance(arg, argument_comparator):
    564         return str(arg)
    565     return repr(arg)
    566 
    567 
    568 def _dump_function_call(symbol, args, dargs):
    569     arg_vec = []
    570     for arg in args:
    571         arg_vec.append(_arg_to_str(arg))
    572     for key, val in dargs.iteritems():
    573         arg_vec.append("%s=%s" % (key, _arg_to_str(val)))
    574     return "%s(%s)" % (symbol, ', '.join(arg_vec))
    575