1 # Copyright (c) 2009, Google Inc. All rights reserved. 2 # 3 # Redistribution and use in source and binary forms, with or without 4 # modification, are permitted provided that the following conditions are 5 # met: 6 # 7 # * Redistributions of source code must retain the above copyright 8 # notice, this list of conditions and the following disclaimer. 9 # * Redistributions in binary form must reproduce the above 10 # copyright notice, this list of conditions and the following disclaimer 11 # in the documentation and/or other materials provided with the 12 # distribution. 13 # * Neither the name of Google Inc. nor the names of its 14 # contributors may be used to endorse or promote products derived from 15 # this software without specific prior written permission. 16 # 17 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 # 29 # Class for unittest support. Used for capturing stderr/stdout. 30 31 import logging 32 import sys 33 import unittest 34 35 from StringIO import StringIO 36 37 38 class OutputCapture(object): 39 # By default we capture the output to a stream. Other modules may override 40 # this function in order to do things like pass through the output. See 41 # webkitpy.test.main for an example. 42 @staticmethod 43 def stream_wrapper(stream): 44 return StringIO() 45 46 def __init__(self): 47 self.saved_outputs = dict() 48 self._log_level = logging.INFO 49 50 def set_log_level(self, log_level): 51 self._log_level = log_level 52 if hasattr(self, '_logs_handler'): 53 self._logs_handler.setLevel(self._log_level) 54 55 def _capture_output_with_name(self, output_name): 56 stream = getattr(sys, output_name) 57 captured_output = self.stream_wrapper(stream) 58 self.saved_outputs[output_name] = stream 59 setattr(sys, output_name, captured_output) 60 return captured_output 61 62 def _restore_output_with_name(self, output_name): 63 captured_output = getattr(sys, output_name).getvalue() 64 setattr(sys, output_name, self.saved_outputs[output_name]) 65 del self.saved_outputs[output_name] 66 return captured_output 67 68 def capture_output(self): 69 self._logs = StringIO() 70 self._logs_handler = logging.StreamHandler(self._logs) 71 self._logs_handler.setLevel(self._log_level) 72 self._logger = logging.getLogger() 73 self._orig_log_level = self._logger.level 74 self._logger.addHandler(self._logs_handler) 75 self._logger.setLevel(min(self._log_level, self._orig_log_level)) 76 return (self._capture_output_with_name("stdout"), self._capture_output_with_name("stderr")) 77 78 def restore_output(self): 79 self._logger.removeHandler(self._logs_handler) 80 self._logger.setLevel(self._orig_log_level) 81 self._logs_handler.flush() 82 self._logs.flush() 83 logs_string = self._logs.getvalue() 84 delattr(self, '_logs_handler') 85 delattr(self, '_logs') 86 return (self._restore_output_with_name("stdout"), self._restore_output_with_name("stderr"), logs_string) 87 88 def assert_outputs(self, testcase, function, args=[], kwargs={}, expected_stdout="", expected_stderr="", expected_exception=None, expected_logs=None): 89 self.capture_output() 90 try: 91 if expected_exception: 92 return_value = testcase.assertRaises(expected_exception, function, *args, **kwargs) 93 else: 94 return_value = function(*args, **kwargs) 95 finally: 96 (stdout_string, stderr_string, logs_string) = self.restore_output() 97 98 if hasattr(testcase, 'assertMultiLineEqual'): 99 testassert = testcase.assertMultiLineEqual 100 else: 101 testassert = testcase.assertEqual 102 103 testassert(stdout_string, expected_stdout) 104 testassert(stderr_string, expected_stderr) 105 if expected_logs is not None: 106 testassert(logs_string, expected_logs) 107 # This is a little strange, but I don't know where else to return this information. 108 return return_value 109 110 111 class OutputCaptureTestCaseBase(unittest.TestCase): 112 maxDiff = None 113 114 def setUp(self): 115 unittest.TestCase.setUp(self) 116 self.output_capture = OutputCapture() 117 (self.__captured_stdout, self.__captured_stderr) = self.output_capture.capture_output() 118 119 def tearDown(self): 120 del self.__captured_stdout 121 del self.__captured_stderr 122 self.output_capture.restore_output() 123 unittest.TestCase.tearDown(self) 124 125 def assertStdout(self, expected_stdout): 126 self.assertEqual(expected_stdout, self.__captured_stdout.getvalue()) 127 128 def assertStderr(self, expected_stderr): 129 self.assertEqual(expected_stderr, self.__captured_stderr.getvalue()) 130