Home | History | Annotate | Download | only in testmock
      1 import unittest
      2 from warnings import catch_warnings
      3 
      4 from unittest.test.testmock.support import is_instance
      5 from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call
      6 
      7 
      8 
      9 something  = sentinel.Something
     10 something_else  = sentinel.SomethingElse
     11 
     12 
     13 
     14 class WithTest(unittest.TestCase):
     15 
     16     def test_with_statement(self):
     17         with patch('%s.something' % __name__, sentinel.Something2):
     18             self.assertEqual(something, sentinel.Something2, "unpatched")
     19         self.assertEqual(something, sentinel.Something)
     20 
     21 
     22     def test_with_statement_exception(self):
     23         try:
     24             with patch('%s.something' % __name__, sentinel.Something2):
     25                 self.assertEqual(something, sentinel.Something2, "unpatched")
     26                 raise Exception('pow')
     27         except Exception:
     28             pass
     29         else:
     30             self.fail("patch swallowed exception")
     31         self.assertEqual(something, sentinel.Something)
     32 
     33 
     34     def test_with_statement_as(self):
     35         with patch('%s.something' % __name__) as mock_something:
     36             self.assertEqual(something, mock_something, "unpatched")
     37             self.assertTrue(is_instance(mock_something, MagicMock),
     38                             "patching wrong type")
     39         self.assertEqual(something, sentinel.Something)
     40 
     41 
     42     def test_patch_object_with_statement(self):
     43         class Foo(object):
     44             something = 'foo'
     45         original = Foo.something
     46         with patch.object(Foo, 'something'):
     47             self.assertNotEqual(Foo.something, original, "unpatched")
     48         self.assertEqual(Foo.something, original)
     49 
     50 
     51     def test_with_statement_nested(self):
     52         with catch_warnings(record=True):
     53             with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else:
     54                 self.assertEqual(something, mock_something, "unpatched")
     55                 self.assertEqual(something_else, mock_something_else,
     56                                  "unpatched")
     57 
     58         self.assertEqual(something, sentinel.Something)
     59         self.assertEqual(something_else, sentinel.SomethingElse)
     60 
     61 
     62     def test_with_statement_specified(self):
     63         with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
     64             self.assertEqual(something, mock_something, "unpatched")
     65             self.assertEqual(mock_something, sentinel.Patched, "wrong patch")
     66         self.assertEqual(something, sentinel.Something)
     67 
     68 
     69     def testContextManagerMocking(self):
     70         mock = Mock()
     71         mock.__enter__ = Mock()
     72         mock.__exit__ = Mock()
     73         mock.__exit__.return_value = False
     74 
     75         with mock as m:
     76             self.assertEqual(m, mock.__enter__.return_value)
     77         mock.__enter__.assert_called_with()
     78         mock.__exit__.assert_called_with(None, None, None)
     79 
     80 
     81     def test_context_manager_with_magic_mock(self):
     82         mock = MagicMock()
     83 
     84         with self.assertRaises(TypeError):
     85             with mock:
     86                 'foo' + 3
     87         mock.__enter__.assert_called_with()
     88         self.assertTrue(mock.__exit__.called)
     89 
     90 
     91     def test_with_statement_same_attribute(self):
     92         with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
     93             self.assertEqual(something, mock_something, "unpatched")
     94 
     95             with patch('%s.something' % __name__) as mock_again:
     96                 self.assertEqual(something, mock_again, "unpatched")
     97 
     98             self.assertEqual(something, mock_something,
     99                              "restored with wrong instance")
    100 
    101         self.assertEqual(something, sentinel.Something, "not restored")
    102 
    103 
    104     def test_with_statement_imbricated(self):
    105         with patch('%s.something' % __name__) as mock_something:
    106             self.assertEqual(something, mock_something, "unpatched")
    107 
    108             with patch('%s.something_else' % __name__) as mock_something_else:
    109                 self.assertEqual(something_else, mock_something_else,
    110                                  "unpatched")
    111 
    112         self.assertEqual(something, sentinel.Something)
    113         self.assertEqual(something_else, sentinel.SomethingElse)
    114 
    115 
    116     def test_dict_context_manager(self):
    117         foo = {}
    118         with patch.dict(foo, {'a': 'b'}):
    119             self.assertEqual(foo, {'a': 'b'})
    120         self.assertEqual(foo, {})
    121 
    122         with self.assertRaises(NameError):
    123             with patch.dict(foo, {'a': 'b'}):
    124                 self.assertEqual(foo, {'a': 'b'})
    125                 raise NameError('Konrad')
    126 
    127         self.assertEqual(foo, {})
    128 
    129 
    130 
    131 class TestMockOpen(unittest.TestCase):
    132 
    133     def test_mock_open(self):
    134         mock = mock_open()
    135         with patch('%s.open' % __name__, mock, create=True) as patched:
    136             self.assertIs(patched, mock)
    137             open('foo')
    138 
    139         mock.assert_called_once_with('foo')
    140 
    141 
    142     def test_mock_open_context_manager(self):
    143         mock = mock_open()
    144         handle = mock.return_value
    145         with patch('%s.open' % __name__, mock, create=True):
    146             with open('foo') as f:
    147                 f.read()
    148 
    149         expected_calls = [call('foo'), call().__enter__(), call().read(),
    150                           call().__exit__(None, None, None)]
    151         self.assertEqual(mock.mock_calls, expected_calls)
    152         self.assertIs(f, handle)
    153 
    154     def test_mock_open_context_manager_multiple_times(self):
    155         mock = mock_open()
    156         with patch('%s.open' % __name__, mock, create=True):
    157             with open('foo') as f:
    158                 f.read()
    159             with open('bar') as f:
    160                 f.read()
    161 
    162         expected_calls = [
    163             call('foo'), call().__enter__(), call().read(),
    164             call().__exit__(None, None, None),
    165             call('bar'), call().__enter__(), call().read(),
    166             call().__exit__(None, None, None)]
    167         self.assertEqual(mock.mock_calls, expected_calls)
    168 
    169     def test_explicit_mock(self):
    170         mock = MagicMock()
    171         mock_open(mock)
    172 
    173         with patch('%s.open' % __name__, mock, create=True) as patched:
    174             self.assertIs(patched, mock)
    175             open('foo')
    176 
    177         mock.assert_called_once_with('foo')
    178 
    179 
    180     def test_read_data(self):
    181         mock = mock_open(read_data='foo')
    182         with patch('%s.open' % __name__, mock, create=True):
    183             h = open('bar')
    184             result = h.read()
    185 
    186         self.assertEqual(result, 'foo')
    187 
    188 
    189     def test_readline_data(self):
    190         # Check that readline will return all the lines from the fake file
    191         mock = mock_open(read_data='foo\nbar\nbaz\n')
    192         with patch('%s.open' % __name__, mock, create=True):
    193             h = open('bar')
    194             line1 = h.readline()
    195             line2 = h.readline()
    196             line3 = h.readline()
    197         self.assertEqual(line1, 'foo\n')
    198         self.assertEqual(line2, 'bar\n')
    199         self.assertEqual(line3, 'baz\n')
    200 
    201         # Check that we properly emulate a file that doesn't end in a newline
    202         mock = mock_open(read_data='foo')
    203         with patch('%s.open' % __name__, mock, create=True):
    204             h = open('bar')
    205             result = h.readline()
    206         self.assertEqual(result, 'foo')
    207 
    208 
    209     def test_readlines_data(self):
    210         # Test that emulating a file that ends in a newline character works
    211         mock = mock_open(read_data='foo\nbar\nbaz\n')
    212         with patch('%s.open' % __name__, mock, create=True):
    213             h = open('bar')
    214             result = h.readlines()
    215         self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n'])
    216 
    217         # Test that files without a final newline will also be correctly
    218         # emulated
    219         mock = mock_open(read_data='foo\nbar\nbaz')
    220         with patch('%s.open' % __name__, mock, create=True):
    221             h = open('bar')
    222             result = h.readlines()
    223 
    224         self.assertEqual(result, ['foo\n', 'bar\n', 'baz'])
    225 
    226 
    227     def test_read_bytes(self):
    228         mock = mock_open(read_data=b'\xc6')
    229         with patch('%s.open' % __name__, mock, create=True):
    230             with open('abc', 'rb') as f:
    231                 result = f.read()
    232         self.assertEqual(result, b'\xc6')
    233 
    234 
    235     def test_readline_bytes(self):
    236         m = mock_open(read_data=b'abc\ndef\nghi\n')
    237         with patch('%s.open' % __name__, m, create=True):
    238             with open('abc', 'rb') as f:
    239                 line1 = f.readline()
    240                 line2 = f.readline()
    241                 line3 = f.readline()
    242         self.assertEqual(line1, b'abc\n')
    243         self.assertEqual(line2, b'def\n')
    244         self.assertEqual(line3, b'ghi\n')
    245 
    246 
    247     def test_readlines_bytes(self):
    248         m = mock_open(read_data=b'abc\ndef\nghi\n')
    249         with patch('%s.open' % __name__, m, create=True):
    250             with open('abc', 'rb') as f:
    251                 result = f.readlines()
    252         self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n'])
    253 
    254 
    255     def test_mock_open_read_with_argument(self):
    256         # At one point calling read with an argument was broken
    257         # for mocks returned by mock_open
    258         some_data = 'foo\nbar\nbaz'
    259         mock = mock_open(read_data=some_data)
    260         self.assertEqual(mock().read(10), some_data)
    261 
    262 
    263     def test_interleaved_reads(self):
    264         # Test that calling read, readline, and readlines pulls data
    265         # sequentially from the data we preload with
    266         mock = mock_open(read_data='foo\nbar\nbaz\n')
    267         with patch('%s.open' % __name__, mock, create=True):
    268             h = open('bar')
    269             line1 = h.readline()
    270             rest = h.readlines()
    271         self.assertEqual(line1, 'foo\n')
    272         self.assertEqual(rest, ['bar\n', 'baz\n'])
    273 
    274         mock = mock_open(read_data='foo\nbar\nbaz\n')
    275         with patch('%s.open' % __name__, mock, create=True):
    276             h = open('bar')
    277             line1 = h.readline()
    278             rest = h.read()
    279         self.assertEqual(line1, 'foo\n')
    280         self.assertEqual(rest, 'bar\nbaz\n')
    281 
    282 
    283     def test_overriding_return_values(self):
    284         mock = mock_open(read_data='foo')
    285         handle = mock()
    286 
    287         handle.read.return_value = 'bar'
    288         handle.readline.return_value = 'bar'
    289         handle.readlines.return_value = ['bar']
    290 
    291         self.assertEqual(handle.read(), 'bar')
    292         self.assertEqual(handle.readline(), 'bar')
    293         self.assertEqual(handle.readlines(), ['bar'])
    294 
    295         # call repeatedly to check that a StopIteration is not propagated
    296         self.assertEqual(handle.readline(), 'bar')
    297         self.assertEqual(handle.readline(), 'bar')
    298 
    299 
    300 if __name__ == '__main__':
    301     unittest.main()
    302