Home | History | Annotate | Download | only in test_asyncio
      1 """Tests for proactor_events.py"""
      2 
      3 import socket
      4 import unittest
      5 from unittest import mock
      6 
      7 import asyncio
      8 from asyncio.proactor_events import BaseProactorEventLoop
      9 from asyncio.proactor_events import _ProactorSocketTransport
     10 from asyncio.proactor_events import _ProactorWritePipeTransport
     11 from asyncio.proactor_events import _ProactorDuplexPipeTransport
     12 from asyncio import test_utils
     13 
     14 
     15 def close_transport(transport):
     16     # Don't call transport.close() because the event loop and the IOCP proactor
     17     # are mocked
     18     if transport._sock is None:
     19         return
     20     transport._sock.close()
     21     transport._sock = None
     22 
     23 
     24 class ProactorSocketTransportTests(test_utils.TestCase):
     25 
     26     def setUp(self):
     27         super().setUp()
     28         self.loop = self.new_test_loop()
     29         self.addCleanup(self.loop.close)
     30         self.proactor = mock.Mock()
     31         self.loop._proactor = self.proactor
     32         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
     33         self.sock = mock.Mock(socket.socket)
     34 
     35     def socket_transport(self, waiter=None):
     36         transport = _ProactorSocketTransport(self.loop, self.sock,
     37                                              self.protocol, waiter=waiter)
     38         self.addCleanup(close_transport, transport)
     39         return transport
     40 
     41     def test_ctor(self):
     42         fut = asyncio.Future(loop=self.loop)
     43         tr = self.socket_transport(waiter=fut)
     44         test_utils.run_briefly(self.loop)
     45         self.assertIsNone(fut.result())
     46         self.protocol.connection_made(tr)
     47         self.proactor.recv.assert_called_with(self.sock, 4096)
     48 
     49     def test_loop_reading(self):
     50         tr = self.socket_transport()
     51         tr._loop_reading()
     52         self.loop._proactor.recv.assert_called_with(self.sock, 4096)
     53         self.assertFalse(self.protocol.data_received.called)
     54         self.assertFalse(self.protocol.eof_received.called)
     55 
     56     def test_loop_reading_data(self):
     57         res = asyncio.Future(loop=self.loop)
     58         res.set_result(b'data')
     59 
     60         tr = self.socket_transport()
     61         tr._read_fut = res
     62         tr._loop_reading(res)
     63         self.loop._proactor.recv.assert_called_with(self.sock, 4096)
     64         self.protocol.data_received.assert_called_with(b'data')
     65 
     66     def test_loop_reading_no_data(self):
     67         res = asyncio.Future(loop=self.loop)
     68         res.set_result(b'')
     69 
     70         tr = self.socket_transport()
     71         self.assertRaises(AssertionError, tr._loop_reading, res)
     72 
     73         tr.close = mock.Mock()
     74         tr._read_fut = res
     75         tr._loop_reading(res)
     76         self.assertFalse(self.loop._proactor.recv.called)
     77         self.assertTrue(self.protocol.eof_received.called)
     78         self.assertTrue(tr.close.called)
     79 
     80     def test_loop_reading_aborted(self):
     81         err = self.loop._proactor.recv.side_effect = ConnectionAbortedError()
     82 
     83         tr = self.socket_transport()
     84         tr._fatal_error = mock.Mock()
     85         tr._loop_reading()
     86         tr._fatal_error.assert_called_with(
     87                             err,
     88                             'Fatal read error on pipe transport')
     89 
     90     def test_loop_reading_aborted_closing(self):
     91         self.loop._proactor.recv.side_effect = ConnectionAbortedError()
     92 
     93         tr = self.socket_transport()
     94         tr._closing = True
     95         tr._fatal_error = mock.Mock()
     96         tr._loop_reading()
     97         self.assertFalse(tr._fatal_error.called)
     98 
     99     def test_loop_reading_aborted_is_fatal(self):
    100         self.loop._proactor.recv.side_effect = ConnectionAbortedError()
    101         tr = self.socket_transport()
    102         tr._closing = False
    103         tr._fatal_error = mock.Mock()
    104         tr._loop_reading()
    105         self.assertTrue(tr._fatal_error.called)
    106 
    107     def test_loop_reading_conn_reset_lost(self):
    108         err = self.loop._proactor.recv.side_effect = ConnectionResetError()
    109 
    110         tr = self.socket_transport()
    111         tr._closing = False
    112         tr._fatal_error = mock.Mock()
    113         tr._force_close = mock.Mock()
    114         tr._loop_reading()
    115         self.assertFalse(tr._fatal_error.called)
    116         tr._force_close.assert_called_with(err)
    117 
    118     def test_loop_reading_exception(self):
    119         err = self.loop._proactor.recv.side_effect = (OSError())
    120 
    121         tr = self.socket_transport()
    122         tr._fatal_error = mock.Mock()
    123         tr._loop_reading()
    124         tr._fatal_error.assert_called_with(
    125                             err,
    126                             'Fatal read error on pipe transport')
    127 
    128     def test_write(self):
    129         tr = self.socket_transport()
    130         tr._loop_writing = mock.Mock()
    131         tr.write(b'data')
    132         self.assertEqual(tr._buffer, None)
    133         tr._loop_writing.assert_called_with(data=b'data')
    134 
    135     def test_write_no_data(self):
    136         tr = self.socket_transport()
    137         tr.write(b'')
    138         self.assertFalse(tr._buffer)
    139 
    140     def test_write_more(self):
    141         tr = self.socket_transport()
    142         tr._write_fut = mock.Mock()
    143         tr._loop_writing = mock.Mock()
    144         tr.write(b'data')
    145         self.assertEqual(tr._buffer, b'data')
    146         self.assertFalse(tr._loop_writing.called)
    147 
    148     def test_loop_writing(self):
    149         tr = self.socket_transport()
    150         tr._buffer = bytearray(b'data')
    151         tr._loop_writing()
    152         self.loop._proactor.send.assert_called_with(self.sock, b'data')
    153         self.loop._proactor.send.return_value.add_done_callback.\
    154             assert_called_with(tr._loop_writing)
    155 
    156     @mock.patch('asyncio.proactor_events.logger')
    157     def test_loop_writing_err(self, m_log):
    158         err = self.loop._proactor.send.side_effect = OSError()
    159         tr = self.socket_transport()
    160         tr._fatal_error = mock.Mock()
    161         tr._buffer = [b'da', b'ta']
    162         tr._loop_writing()
    163         tr._fatal_error.assert_called_with(
    164                             err,
    165                             'Fatal write error on pipe transport')
    166         tr._conn_lost = 1
    167 
    168         tr.write(b'data')
    169         tr.write(b'data')
    170         tr.write(b'data')
    171         tr.write(b'data')
    172         tr.write(b'data')
    173         self.assertEqual(tr._buffer, None)
    174         m_log.warning.assert_called_with('socket.send() raised exception.')
    175 
    176     def test_loop_writing_stop(self):
    177         fut = asyncio.Future(loop=self.loop)
    178         fut.set_result(b'data')
    179 
    180         tr = self.socket_transport()
    181         tr._write_fut = fut
    182         tr._loop_writing(fut)
    183         self.assertIsNone(tr._write_fut)
    184 
    185     def test_loop_writing_closing(self):
    186         fut = asyncio.Future(loop=self.loop)
    187         fut.set_result(1)
    188 
    189         tr = self.socket_transport()
    190         tr._write_fut = fut
    191         tr.close()
    192         tr._loop_writing(fut)
    193         self.assertIsNone(tr._write_fut)
    194         test_utils.run_briefly(self.loop)
    195         self.protocol.connection_lost.assert_called_with(None)
    196 
    197     def test_abort(self):
    198         tr = self.socket_transport()
    199         tr._force_close = mock.Mock()
    200         tr.abort()
    201         tr._force_close.assert_called_with(None)
    202 
    203     def test_close(self):
    204         tr = self.socket_transport()
    205         tr.close()
    206         test_utils.run_briefly(self.loop)
    207         self.protocol.connection_lost.assert_called_with(None)
    208         self.assertTrue(tr.is_closing())
    209         self.assertEqual(tr._conn_lost, 1)
    210 
    211         self.protocol.connection_lost.reset_mock()
    212         tr.close()
    213         test_utils.run_briefly(self.loop)
    214         self.assertFalse(self.protocol.connection_lost.called)
    215 
    216     def test_close_write_fut(self):
    217         tr = self.socket_transport()
    218         tr._write_fut = mock.Mock()
    219         tr.close()
    220         test_utils.run_briefly(self.loop)
    221         self.assertFalse(self.protocol.connection_lost.called)
    222 
    223     def test_close_buffer(self):
    224         tr = self.socket_transport()
    225         tr._buffer = [b'data']
    226         tr.close()
    227         test_utils.run_briefly(self.loop)
    228         self.assertFalse(self.protocol.connection_lost.called)
    229 
    230     @mock.patch('asyncio.base_events.logger')
    231     def test_fatal_error(self, m_logging):
    232         tr = self.socket_transport()
    233         tr._force_close = mock.Mock()
    234         tr._fatal_error(None)
    235         self.assertTrue(tr._force_close.called)
    236         self.assertTrue(m_logging.error.called)
    237 
    238     def test_force_close(self):
    239         tr = self.socket_transport()
    240         tr._buffer = [b'data']
    241         read_fut = tr._read_fut = mock.Mock()
    242         write_fut = tr._write_fut = mock.Mock()
    243         tr._force_close(None)
    244 
    245         read_fut.cancel.assert_called_with()
    246         write_fut.cancel.assert_called_with()
    247         test_utils.run_briefly(self.loop)
    248         self.protocol.connection_lost.assert_called_with(None)
    249         self.assertEqual(None, tr._buffer)
    250         self.assertEqual(tr._conn_lost, 1)
    251 
    252     def test_force_close_idempotent(self):
    253         tr = self.socket_transport()
    254         tr._closing = True
    255         tr._force_close(None)
    256         test_utils.run_briefly(self.loop)
    257         self.assertFalse(self.protocol.connection_lost.called)
    258 
    259     def test_fatal_error_2(self):
    260         tr = self.socket_transport()
    261         tr._buffer = [b'data']
    262         tr._force_close(None)
    263 
    264         test_utils.run_briefly(self.loop)
    265         self.protocol.connection_lost.assert_called_with(None)
    266         self.assertEqual(None, tr._buffer)
    267 
    268     def test_call_connection_lost(self):
    269         tr = self.socket_transport()
    270         tr._call_connection_lost(None)
    271         self.assertTrue(self.protocol.connection_lost.called)
    272         self.assertTrue(self.sock.close.called)
    273 
    274     def test_write_eof(self):
    275         tr = self.socket_transport()
    276         self.assertTrue(tr.can_write_eof())
    277         tr.write_eof()
    278         self.sock.shutdown.assert_called_with(socket.SHUT_WR)
    279         tr.write_eof()
    280         self.assertEqual(self.sock.shutdown.call_count, 1)
    281         tr.close()
    282 
    283     def test_write_eof_buffer(self):
    284         tr = self.socket_transport()
    285         f = asyncio.Future(loop=self.loop)
    286         tr._loop._proactor.send.return_value = f
    287         tr.write(b'data')
    288         tr.write_eof()
    289         self.assertTrue(tr._eof_written)
    290         self.assertFalse(self.sock.shutdown.called)
    291         tr._loop._proactor.send.assert_called_with(self.sock, b'data')
    292         f.set_result(4)
    293         self.loop._run_once()
    294         self.sock.shutdown.assert_called_with(socket.SHUT_WR)
    295         tr.close()
    296 
    297     def test_write_eof_write_pipe(self):
    298         tr = _ProactorWritePipeTransport(
    299             self.loop, self.sock, self.protocol)
    300         self.assertTrue(tr.can_write_eof())
    301         tr.write_eof()
    302         self.assertTrue(tr.is_closing())
    303         self.loop._run_once()
    304         self.assertTrue(self.sock.close.called)
    305         tr.close()
    306 
    307     def test_write_eof_buffer_write_pipe(self):
    308         tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol)
    309         f = asyncio.Future(loop=self.loop)
    310         tr._loop._proactor.send.return_value = f
    311         tr.write(b'data')
    312         tr.write_eof()
    313         self.assertTrue(tr.is_closing())
    314         self.assertFalse(self.sock.shutdown.called)
    315         tr._loop._proactor.send.assert_called_with(self.sock, b'data')
    316         f.set_result(4)
    317         self.loop._run_once()
    318         self.loop._run_once()
    319         self.assertTrue(self.sock.close.called)
    320         tr.close()
    321 
    322     def test_write_eof_duplex_pipe(self):
    323         tr = _ProactorDuplexPipeTransport(
    324             self.loop, self.sock, self.protocol)
    325         self.assertFalse(tr.can_write_eof())
    326         with self.assertRaises(NotImplementedError):
    327             tr.write_eof()
    328         close_transport(tr)
    329 
    330     def test_pause_resume_reading(self):
    331         tr = self.socket_transport()
    332         futures = []
    333         for msg in [b'data1', b'data2', b'data3', b'data4', b'']:
    334             f = asyncio.Future(loop=self.loop)
    335             f.set_result(msg)
    336             futures.append(f)
    337         self.loop._proactor.recv.side_effect = futures
    338         self.loop._run_once()
    339         self.assertFalse(tr._paused)
    340         self.loop._run_once()
    341         self.protocol.data_received.assert_called_with(b'data1')
    342         self.loop._run_once()
    343         self.protocol.data_received.assert_called_with(b'data2')
    344         tr.pause_reading()
    345         self.assertTrue(tr._paused)
    346         for i in range(10):
    347             self.loop._run_once()
    348         self.protocol.data_received.assert_called_with(b'data2')
    349         tr.resume_reading()
    350         self.assertFalse(tr._paused)
    351         self.loop._run_once()
    352         self.protocol.data_received.assert_called_with(b'data3')
    353         self.loop._run_once()
    354         self.protocol.data_received.assert_called_with(b'data4')
    355         tr.close()
    356 
    357 
    358     def pause_writing_transport(self, high):
    359         tr = self.socket_transport()
    360         tr.set_write_buffer_limits(high=high)
    361 
    362         self.assertEqual(tr.get_write_buffer_size(), 0)
    363         self.assertFalse(self.protocol.pause_writing.called)
    364         self.assertFalse(self.protocol.resume_writing.called)
    365         return tr
    366 
    367     def test_pause_resume_writing(self):
    368         tr = self.pause_writing_transport(high=4)
    369 
    370         # write a large chunk, must pause writing
    371         fut = asyncio.Future(loop=self.loop)
    372         self.loop._proactor.send.return_value = fut
    373         tr.write(b'large data')
    374         self.loop._run_once()
    375         self.assertTrue(self.protocol.pause_writing.called)
    376 
    377         # flush the buffer
    378         fut.set_result(None)
    379         self.loop._run_once()
    380         self.assertEqual(tr.get_write_buffer_size(), 0)
    381         self.assertTrue(self.protocol.resume_writing.called)
    382 
    383     def test_pause_writing_2write(self):
    384         tr = self.pause_writing_transport(high=4)
    385 
    386         # first short write, the buffer is not full (3 <= 4)
    387         fut1 = asyncio.Future(loop=self.loop)
    388         self.loop._proactor.send.return_value = fut1
    389         tr.write(b'123')
    390         self.loop._run_once()
    391         self.assertEqual(tr.get_write_buffer_size(), 3)
    392         self.assertFalse(self.protocol.pause_writing.called)
    393 
    394         # fill the buffer, must pause writing (6 > 4)
    395         tr.write(b'abc')
    396         self.loop._run_once()
    397         self.assertEqual(tr.get_write_buffer_size(), 6)
    398         self.assertTrue(self.protocol.pause_writing.called)
    399 
    400     def test_pause_writing_3write(self):
    401         tr = self.pause_writing_transport(high=4)
    402 
    403         # first short write, the buffer is not full (1 <= 4)
    404         fut = asyncio.Future(loop=self.loop)
    405         self.loop._proactor.send.return_value = fut
    406         tr.write(b'1')
    407         self.loop._run_once()
    408         self.assertEqual(tr.get_write_buffer_size(), 1)
    409         self.assertFalse(self.protocol.pause_writing.called)
    410 
    411         # second short write, the buffer is not full (3 <= 4)
    412         tr.write(b'23')
    413         self.loop._run_once()
    414         self.assertEqual(tr.get_write_buffer_size(), 3)
    415         self.assertFalse(self.protocol.pause_writing.called)
    416 
    417         # fill the buffer, must pause writing (6 > 4)
    418         tr.write(b'abc')
    419         self.loop._run_once()
    420         self.assertEqual(tr.get_write_buffer_size(), 6)
    421         self.assertTrue(self.protocol.pause_writing.called)
    422 
    423     def test_dont_pause_writing(self):
    424         tr = self.pause_writing_transport(high=4)
    425 
    426         # write a large chunk which completes immedialty,
    427         # it should not pause writing
    428         fut = asyncio.Future(loop=self.loop)
    429         fut.set_result(None)
    430         self.loop._proactor.send.return_value = fut
    431         tr.write(b'very large data')
    432         self.loop._run_once()
    433         self.assertEqual(tr.get_write_buffer_size(), 0)
    434         self.assertFalse(self.protocol.pause_writing.called)
    435 
    436 
    437 class BaseProactorEventLoopTests(test_utils.TestCase):
    438 
    439     def setUp(self):
    440         super().setUp()
    441 
    442         self.sock = test_utils.mock_nonblocking_socket()
    443         self.proactor = mock.Mock()
    444 
    445         self.ssock, self.csock = mock.Mock(), mock.Mock()
    446 
    447         class EventLoop(BaseProactorEventLoop):
    448             def _socketpair(s):
    449                 return (self.ssock, self.csock)
    450 
    451         self.loop = EventLoop(self.proactor)
    452         self.set_event_loop(self.loop)
    453 
    454     @mock.patch.object(BaseProactorEventLoop, 'call_soon')
    455     @mock.patch.object(BaseProactorEventLoop, '_socketpair')
    456     def test_ctor(self, socketpair, call_soon):
    457         ssock, csock = socketpair.return_value = (
    458             mock.Mock(), mock.Mock())
    459         loop = BaseProactorEventLoop(self.proactor)
    460         self.assertIs(loop._ssock, ssock)
    461         self.assertIs(loop._csock, csock)
    462         self.assertEqual(loop._internal_fds, 1)
    463         call_soon.assert_called_with(loop._loop_self_reading)
    464         loop.close()
    465 
    466     def test_close_self_pipe(self):
    467         self.loop._close_self_pipe()
    468         self.assertEqual(self.loop._internal_fds, 0)
    469         self.assertTrue(self.ssock.close.called)
    470         self.assertTrue(self.csock.close.called)
    471         self.assertIsNone(self.loop._ssock)
    472         self.assertIsNone(self.loop._csock)
    473 
    474         # Don't call close(): _close_self_pipe() cannot be called twice
    475         self.loop._closed = True
    476 
    477     def test_close(self):
    478         self.loop._close_self_pipe = mock.Mock()
    479         self.loop.close()
    480         self.assertTrue(self.loop._close_self_pipe.called)
    481         self.assertTrue(self.proactor.close.called)
    482         self.assertIsNone(self.loop._proactor)
    483 
    484         self.loop._close_self_pipe.reset_mock()
    485         self.loop.close()
    486         self.assertFalse(self.loop._close_self_pipe.called)
    487 
    488     def test_sock_recv(self):
    489         self.loop.sock_recv(self.sock, 1024)
    490         self.proactor.recv.assert_called_with(self.sock, 1024)
    491 
    492     def test_sock_sendall(self):
    493         self.loop.sock_sendall(self.sock, b'data')
    494         self.proactor.send.assert_called_with(self.sock, b'data')
    495 
    496     def test_sock_connect(self):
    497         self.loop.sock_connect(self.sock, ('1.2.3.4', 123))
    498         self.proactor.connect.assert_called_with(self.sock, ('1.2.3.4', 123))
    499 
    500     def test_sock_accept(self):
    501         self.loop.sock_accept(self.sock)
    502         self.proactor.accept.assert_called_with(self.sock)
    503 
    504     def test_socketpair(self):
    505         class EventLoop(BaseProactorEventLoop):
    506             # override the destructor to not log a ResourceWarning
    507             def __del__(self):
    508                 pass
    509         self.assertRaises(
    510             NotImplementedError, EventLoop, self.proactor)
    511 
    512     def test_make_socket_transport(self):
    513         tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
    514         self.assertIsInstance(tr, _ProactorSocketTransport)
    515         close_transport(tr)
    516 
    517     def test_loop_self_reading(self):
    518         self.loop._loop_self_reading()
    519         self.proactor.recv.assert_called_with(self.ssock, 4096)
    520         self.proactor.recv.return_value.add_done_callback.assert_called_with(
    521             self.loop._loop_self_reading)
    522 
    523     def test_loop_self_reading_fut(self):
    524         fut = mock.Mock()
    525         self.loop._loop_self_reading(fut)
    526         self.assertTrue(fut.result.called)
    527         self.proactor.recv.assert_called_with(self.ssock, 4096)
    528         self.proactor.recv.return_value.add_done_callback.assert_called_with(
    529             self.loop._loop_self_reading)
    530 
    531     def test_loop_self_reading_exception(self):
    532         self.loop.close = mock.Mock()
    533         self.loop.call_exception_handler = mock.Mock()
    534         self.proactor.recv.side_effect = OSError()
    535         self.loop._loop_self_reading()
    536         self.assertTrue(self.loop.call_exception_handler.called)
    537 
    538     def test_write_to_self(self):
    539         self.loop._write_to_self()
    540         self.csock.send.assert_called_with(b'\0')
    541 
    542     def test_process_events(self):
    543         self.loop._process_events([])
    544 
    545     @mock.patch('asyncio.base_events.logger')
    546     def test_create_server(self, m_log):
    547         pf = mock.Mock()
    548         call_soon = self.loop.call_soon = mock.Mock()
    549 
    550         self.loop._start_serving(pf, self.sock)
    551         self.assertTrue(call_soon.called)
    552 
    553         # callback
    554         loop = call_soon.call_args[0][0]
    555         loop()
    556         self.proactor.accept.assert_called_with(self.sock)
    557 
    558         # conn
    559         fut = mock.Mock()
    560         fut.result.return_value = (mock.Mock(), mock.Mock())
    561 
    562         make_tr = self.loop._make_socket_transport = mock.Mock()
    563         loop(fut)
    564         self.assertTrue(fut.result.called)
    565         self.assertTrue(make_tr.called)
    566 
    567         # exception
    568         fut.result.side_effect = OSError()
    569         loop(fut)
    570         self.assertTrue(self.sock.close.called)
    571         self.assertTrue(m_log.error.called)
    572 
    573     def test_create_server_cancel(self):
    574         pf = mock.Mock()
    575         call_soon = self.loop.call_soon = mock.Mock()
    576 
    577         self.loop._start_serving(pf, self.sock)
    578         loop = call_soon.call_args[0][0]
    579 
    580         # cancelled
    581         fut = asyncio.Future(loop=self.loop)
    582         fut.cancel()
    583         loop(fut)
    584         self.assertTrue(self.sock.close.called)
    585 
    586     def test_stop_serving(self):
    587         sock = mock.Mock()
    588         self.loop._stop_serving(sock)
    589         self.assertTrue(sock.close.called)
    590         self.proactor._stop_serving.assert_called_with(sock)
    591 
    592 
    593 if __name__ == '__main__':
    594     unittest.main()
    595