1 __author__ = "raphtee (at] google.com (Travis Miller)" 2 3 4 import re, collections, StringIO, sys, unittest 5 6 7 class StubNotFoundError(Exception): 8 'Raised when god is asked to unstub an attribute that was not stubbed' 9 pass 10 11 12 class CheckPlaybackError(Exception): 13 'Raised when mock playback does not match recorded calls.' 14 pass 15 16 17 class SaveDataAfterCloseStringIO(StringIO.StringIO): 18 """Saves the contents in a final_data property when close() is called. 19 20 Useful as a mock output file object to test both that the file was 21 closed and what was written. 22 23 Properties: 24 final_data: Set to the StringIO's getvalue() data when close() is 25 called. None if close() has not been called. 26 """ 27 final_data = None 28 29 def close(self): 30 self.final_data = self.getvalue() 31 StringIO.StringIO.close(self) 32 33 34 35 class argument_comparator(object): 36 def is_satisfied_by(self, parameter): 37 raise NotImplementedError 38 39 40 class equality_comparator(argument_comparator): 41 def __init__(self, value): 42 self.value = value 43 44 45 @staticmethod 46 def _types_match(arg1, arg2): 47 if isinstance(arg1, basestring) and isinstance(arg2, basestring): 48 return True 49 return type(arg1) == type(arg2) 50 51 52 @classmethod 53 def _compare(cls, actual_arg, expected_arg): 54 if isinstance(expected_arg, argument_comparator): 55 return expected_arg.is_satisfied_by(actual_arg) 56 if not cls._types_match(expected_arg, actual_arg): 57 return False 58 59 if isinstance(expected_arg, list) or isinstance(expected_arg, tuple): 60 # recurse on lists/tuples 61 if len(actual_arg) != len(expected_arg): 62 return False 63 for actual_item, expected_item in zip(actual_arg, expected_arg): 64 if not cls._compare(actual_item, expected_item): 65 return False 66 elif isinstance(expected_arg, dict): 67 # recurse on dicts 68 if not cls._compare(sorted(actual_arg.keys()), 69 sorted(expected_arg.keys())): 70 return False 71 for key, value in actual_arg.iteritems(): 72 if not cls._compare(value, expected_arg[key]): 73 return False 74 elif actual_arg != expected_arg: 75 return False 76 77 return True 78 79 80 def is_satisfied_by(self, parameter): 81 return self._compare(parameter, self.value) 82 83 84 def __str__(self): 85 if isinstance(self.value, argument_comparator): 86 return str(self.value) 87 return repr(self.value) 88 89 90 class regex_comparator(argument_comparator): 91 def __init__(self, pattern, flags=0): 92 self.regex = re.compile(pattern, flags) 93 94 95 def is_satisfied_by(self, parameter): 96 return self.regex.search(parameter) is not None 97 98 99 def __str__(self): 100 return self.regex.pattern 101 102 103 class is_string_comparator(argument_comparator): 104 def is_satisfied_by(self, parameter): 105 return isinstance(parameter, basestring) 106 107 108 def __str__(self): 109 return "a string" 110 111 112 class is_instance_comparator(argument_comparator): 113 def __init__(self, cls): 114 self.cls = cls 115 116 117 def is_satisfied_by(self, parameter): 118 return isinstance(parameter, self.cls) 119 120 121 def __str__(self): 122 return "is a %s" % self.cls 123 124 125 class anything_comparator(argument_comparator): 126 def is_satisfied_by(self, parameter): 127 return True 128 129 130 def __str__(self): 131 return 'anything' 132 133 134 class base_mapping(object): 135 def __init__(self, symbol, return_obj, *args, **dargs): 136 self.return_obj = return_obj 137 self.symbol = symbol 138 self.args = [equality_comparator(arg) for arg in args] 139 self.dargs = dict((key, equality_comparator(value)) 140 for key, value in dargs.iteritems()) 141 self.error = None 142 143 144 def match(self, *args, **dargs): 145 if len(args) != len(self.args) or len(dargs) != len(self.dargs): 146 return False 147 148 for i, expected_arg in enumerate(self.args): 149 if not expected_arg.is_satisfied_by(args[i]): 150 return False 151 152 # check for incorrect dargs 153 for key, value in dargs.iteritems(): 154 if key not in self.dargs: 155 return False 156 if not self.dargs[key].is_satisfied_by(value): 157 return False 158 159 # check for missing dargs 160 for key in self.dargs.iterkeys(): 161 if key not in dargs: 162 return False 163 164 return True 165 166 167 def __str__(self): 168 return _dump_function_call(self.symbol, self.args, self.dargs) 169 170 171 class function_mapping(base_mapping): 172 def __init__(self, symbol, return_val, *args, **dargs): 173 super(function_mapping, self).__init__(symbol, return_val, *args, 174 **dargs) 175 176 177 def and_return(self, return_obj): 178 self.return_obj = return_obj 179 180 181 def and_raises(self, error): 182 self.error = error 183 184 185 class function_any_args_mapping(function_mapping): 186 """A mock function mapping that doesn't verify its arguments.""" 187 def match(self, *args, **dargs): 188 return True 189 190 191 class mock_function(object): 192 def __init__(self, symbol, default_return_val=None, 193 record=None, playback=None): 194 self.default_return_val = default_return_val 195 self.num_calls = 0 196 self.args = [] 197 self.dargs = [] 198 self.symbol = symbol 199 self.record = record 200 self.playback = playback 201 self.__name__ = symbol 202 203 204 def __call__(self, *args, **dargs): 205 self.num_calls += 1 206 self.args.append(args) 207 self.dargs.append(dargs) 208 if self.playback: 209 return self.playback(self.symbol, *args, **dargs) 210 else: 211 return self.default_return_val 212 213 214 def expect_call(self, *args, **dargs): 215 mapping = function_mapping(self.symbol, None, *args, **dargs) 216 if self.record: 217 self.record(mapping) 218 219 return mapping 220 221 222 def expect_any_call(self): 223 """Like expect_call but don't give a hoot what arguments are passed.""" 224 mapping = function_any_args_mapping(self.symbol, None) 225 if self.record: 226 self.record(mapping) 227 228 return mapping 229 230 231 class mask_function(mock_function): 232 def __init__(self, symbol, original_function, default_return_val=None, 233 record=None, playback=None): 234 super(mask_function, self).__init__(symbol, 235 default_return_val, 236 record, playback) 237 self.original_function = original_function 238 239 240 def run_original_function(self, *args, **dargs): 241 return self.original_function(*args, **dargs) 242 243 244 class mock_class(object): 245 def __init__(self, cls, name, default_ret_val=None, 246 record=None, playback=None): 247 self.__name = name 248 self.__record = record 249 self.__playback = playback 250 251 for symbol in dir(cls): 252 if symbol.startswith("_"): 253 continue 254 255 orig_symbol = getattr(cls, symbol) 256 if callable(orig_symbol): 257 f_name = "%s.%s" % (self.__name, symbol) 258 func = mock_function(f_name, default_ret_val, 259 self.__record, self.__playback) 260 setattr(self, symbol, func) 261 else: 262 setattr(self, symbol, orig_symbol) 263 264 265 def __repr__(self): 266 return '<mock_class: %s>' % self.__name 267 268 269 class mock_god(object): 270 NONEXISTENT_ATTRIBUTE = object() 271 272 def __init__(self, debug=False, fail_fast=True, ut=None): 273 """ 274 With debug=True, all recorded method calls will be printed as 275 they happen. 276 With fail_fast=True, unexpected calls will immediately cause an 277 exception to be raised. With False, they will be silently recorded and 278 only reported when check_playback() is called. 279 """ 280 self.recording = collections.deque() 281 self.errors = [] 282 self._stubs = [] 283 self._debug = debug 284 self._fail_fast = fail_fast 285 self._ut = ut 286 287 288 def set_fail_fast(self, fail_fast): 289 self._fail_fast = fail_fast 290 291 292 def create_mock_class_obj(self, cls, name, default_ret_val=None): 293 record = self.__record_call 294 playback = self.__method_playback 295 errors = self.errors 296 297 class cls_sub(cls): 298 cls_count = 0 299 300 # overwrite the initializer 301 def __init__(self, *args, **dargs): 302 pass 303 304 305 @classmethod 306 def expect_new(typ, *args, **dargs): 307 obj = typ.make_new(*args, **dargs) 308 mapping = base_mapping(name, obj, *args, **dargs) 309 record(mapping) 310 return obj 311 312 313 def __new__(typ, *args, **dargs): 314 return playback(name, *args, **dargs) 315 316 317 @classmethod 318 def make_new(typ, *args, **dargs): 319 obj = super(cls_sub, typ).__new__(typ, *args, 320 **dargs) 321 322 typ.cls_count += 1 323 obj_name = "%s_%s" % (name, typ.cls_count) 324 for symbol in dir(obj): 325 if (symbol.startswith("__") and 326 symbol.endswith("__")): 327 continue 328 329 if isinstance(getattr(typ, symbol, None), property): 330 continue 331 332 orig_symbol = getattr(obj, symbol) 333 if callable(orig_symbol): 334 f_name = ("%s.%s" % 335 (obj_name, symbol)) 336 func = mock_function(f_name, 337 default_ret_val, 338 record, 339 playback) 340 setattr(obj, symbol, func) 341 else: 342 setattr(obj, symbol, 343 orig_symbol) 344 345 return obj 346 347 return cls_sub 348 349 350 def create_mock_class(self, cls, name, default_ret_val=None): 351 """ 352 Given something that defines a namespace cls (class, object, 353 module), and a (hopefully unique) name, will create a 354 mock_class object with that name and that possessess all 355 the public attributes of cls. default_ret_val sets the 356 default_ret_val on all methods of the cls mock. 357 """ 358 return mock_class(cls, name, default_ret_val, 359 self.__record_call, self.__method_playback) 360 361 362 def create_mock_function(self, symbol, default_return_val=None): 363 """ 364 create a mock_function with name symbol and default return 365 value of default_ret_val. 366 """ 367 return mock_function(symbol, default_return_val, 368 self.__record_call, self.__method_playback) 369 370 371 def mock_up(self, obj, name, default_ret_val=None): 372 """ 373 Given an object (class instance or module) and a registration 374 name, then replace all its methods with mock function objects 375 (passing the orignal functions to the mock functions). 376 """ 377 for symbol in dir(obj): 378 if symbol.startswith("__"): 379 continue 380 381 orig_symbol = getattr(obj, symbol) 382 if callable(orig_symbol): 383 f_name = "%s.%s" % (name, symbol) 384 func = mask_function(f_name, orig_symbol, 385 default_ret_val, 386 self.__record_call, 387 self.__method_playback) 388 setattr(obj, symbol, func) 389 390 391 def stub_with(self, namespace, symbol, new_attribute): 392 original_attribute = getattr(namespace, symbol, 393 self.NONEXISTENT_ATTRIBUTE) 394 395 # You only want to save the original attribute in cases where it is 396 # directly associated with the object in question. In cases where 397 # the attribute is actually inherited via some sort of hierarchy 398 # you want to delete the stub (restoring the original structure) 399 attribute_is_inherited = (hasattr(namespace, '__dict__') and 400 symbol not in namespace.__dict__) 401 if attribute_is_inherited: 402 original_attribute = self.NONEXISTENT_ATTRIBUTE 403 404 newstub = (namespace, symbol, original_attribute, new_attribute) 405 self._stubs.append(newstub) 406 setattr(namespace, symbol, new_attribute) 407 408 409 def stub_function(self, namespace, symbol): 410 mock_attribute = self.create_mock_function(symbol) 411 self.stub_with(namespace, symbol, mock_attribute) 412 413 414 def stub_class_method(self, cls, symbol): 415 mock_attribute = self.create_mock_function(symbol) 416 self.stub_with(cls, symbol, staticmethod(mock_attribute)) 417 418 419 def stub_class(self, namespace, symbol): 420 attr = getattr(namespace, symbol) 421 mock_class = self.create_mock_class_obj(attr, symbol) 422 self.stub_with(namespace, symbol, mock_class) 423 424 425 def stub_function_to_return(self, namespace, symbol, object_to_return): 426 """Stub out a function with one that always returns a fixed value. 427 428 @param namespace The namespace containing the function to stub out. 429 @param symbol The attribute within the namespace to stub out. 430 @param object_to_return The value that the stub should return whenever 431 it is called. 432 """ 433 self.stub_with(namespace, symbol, 434 lambda *args, **dargs: object_to_return) 435 436 437 def _perform_unstub(self, stub): 438 namespace, symbol, orig_attr, new_attr = stub 439 if orig_attr == self.NONEXISTENT_ATTRIBUTE: 440 delattr(namespace, symbol) 441 else: 442 setattr(namespace, symbol, orig_attr) 443 444 445 def unstub(self, namespace, symbol): 446 for stub in reversed(self._stubs): 447 if (namespace, symbol) == (stub[0], stub[1]): 448 self._perform_unstub(stub) 449 self._stubs.remove(stub) 450 return 451 452 raise StubNotFoundError() 453 454 455 def unstub_all(self): 456 self._stubs.reverse() 457 for stub in self._stubs: 458 self._perform_unstub(stub) 459 self._stubs = [] 460 461 462 def __method_playback(self, symbol, *args, **dargs): 463 if self._debug: 464 print >> sys.__stdout__, (' * Mock call: ' + 465 _dump_function_call(symbol, args, dargs)) 466 467 if len(self.recording) != 0: 468 func_call = self.recording[0] 469 if func_call.symbol != symbol: 470 msg = ("Unexpected call: %s\nExpected: %s" 471 % (_dump_function_call(symbol, args, dargs), 472 func_call)) 473 self._append_error(msg) 474 return None 475 476 if not func_call.match(*args, **dargs): 477 msg = ("Incorrect call: %s\nExpected: %s" 478 % (_dump_function_call(symbol, args, dargs), 479 func_call)) 480 self._append_error(msg) 481 return None 482 483 # this is the expected call so pop it and return 484 self.recording.popleft() 485 if func_call.error: 486 raise func_call.error 487 else: 488 return func_call.return_obj 489 else: 490 msg = ("unexpected call: %s" 491 % (_dump_function_call(symbol, args, dargs))) 492 self._append_error(msg) 493 return None 494 495 496 def __record_call(self, mapping): 497 self.recording.append(mapping) 498 499 500 def _append_error(self, error): 501 if self._debug: 502 print >> sys.__stdout__, ' *** ' + error 503 if self._fail_fast: 504 raise CheckPlaybackError(error) 505 self.errors.append(error) 506 507 508 def check_playback(self): 509 """ 510 Report any errors that were encounterd during calls 511 to __method_playback(). 512 """ 513 if len(self.errors) > 0: 514 if self._debug: 515 print '\nPlayback errors:' 516 for error in self.errors: 517 print >> sys.__stdout__, error 518 519 if self._ut: 520 self._ut.fail('\n'.join(self.errors)) 521 522 raise CheckPlaybackError 523 elif len(self.recording) != 0: 524 errors = [] 525 for func_call in self.recording: 526 error = "%s not called" % (func_call,) 527 errors.append(error) 528 print >> sys.__stdout__, error 529 530 if self._ut: 531 self._ut.fail('\n'.join(errors)) 532 533 raise CheckPlaybackError 534 self.recording.clear() 535 536 537 def mock_io(self): 538 """Mocks and saves the stdout & stderr output""" 539 self.orig_stdout = sys.stdout 540 self.orig_stderr = sys.stderr 541 542 self.mock_streams_stdout = StringIO.StringIO('') 543 self.mock_streams_stderr = StringIO.StringIO('') 544 545 sys.stdout = self.mock_streams_stdout 546 sys.stderr = self.mock_streams_stderr 547 548 549 def unmock_io(self): 550 """Restores the stdout & stderr, and returns both 551 output strings""" 552 sys.stdout = self.orig_stdout 553 sys.stderr = self.orig_stderr 554 values = (self.mock_streams_stdout.getvalue(), 555 self.mock_streams_stderr.getvalue()) 556 557 self.mock_streams_stdout.close() 558 self.mock_streams_stderr.close() 559 return values 560 561 562 def _arg_to_str(arg): 563 if isinstance(arg, argument_comparator): 564 return str(arg) 565 return repr(arg) 566 567 568 def _dump_function_call(symbol, args, dargs): 569 arg_vec = [] 570 for arg in args: 571 arg_vec.append(_arg_to_str(arg)) 572 for key, val in dargs.iteritems(): 573 arg_vec.append("%s=%s" % (key, _arg_to_str(val))) 574 return "%s(%s)" % (symbol, ', '.join(arg_vec)) 575