1 """Tests for proactor_events.py""" 2 3 import socket 4 import unittest 5 from unittest import mock 6 7 import asyncio 8 from asyncio.proactor_events import BaseProactorEventLoop 9 from asyncio.proactor_events import _ProactorSocketTransport 10 from asyncio.proactor_events import _ProactorWritePipeTransport 11 from asyncio.proactor_events import _ProactorDuplexPipeTransport 12 from asyncio import test_utils 13 14 15 def close_transport(transport): 16 # Don't call transport.close() because the event loop and the IOCP proactor 17 # are mocked 18 if transport._sock is None: 19 return 20 transport._sock.close() 21 transport._sock = None 22 23 24 class ProactorSocketTransportTests(test_utils.TestCase): 25 26 def setUp(self): 27 super().setUp() 28 self.loop = self.new_test_loop() 29 self.addCleanup(self.loop.close) 30 self.proactor = mock.Mock() 31 self.loop._proactor = self.proactor 32 self.protocol = test_utils.make_test_protocol(asyncio.Protocol) 33 self.sock = mock.Mock(socket.socket) 34 35 def socket_transport(self, waiter=None): 36 transport = _ProactorSocketTransport(self.loop, self.sock, 37 self.protocol, waiter=waiter) 38 self.addCleanup(close_transport, transport) 39 return transport 40 41 def test_ctor(self): 42 fut = asyncio.Future(loop=self.loop) 43 tr = self.socket_transport(waiter=fut) 44 test_utils.run_briefly(self.loop) 45 self.assertIsNone(fut.result()) 46 self.protocol.connection_made(tr) 47 self.proactor.recv.assert_called_with(self.sock, 4096) 48 49 def test_loop_reading(self): 50 tr = self.socket_transport() 51 tr._loop_reading() 52 self.loop._proactor.recv.assert_called_with(self.sock, 4096) 53 self.assertFalse(self.protocol.data_received.called) 54 self.assertFalse(self.protocol.eof_received.called) 55 56 def test_loop_reading_data(self): 57 res = asyncio.Future(loop=self.loop) 58 res.set_result(b'data') 59 60 tr = self.socket_transport() 61 tr._read_fut = res 62 tr._loop_reading(res) 63 self.loop._proactor.recv.assert_called_with(self.sock, 4096) 64 self.protocol.data_received.assert_called_with(b'data') 65 66 def test_loop_reading_no_data(self): 67 res = asyncio.Future(loop=self.loop) 68 res.set_result(b'') 69 70 tr = self.socket_transport() 71 self.assertRaises(AssertionError, tr._loop_reading, res) 72 73 tr.close = mock.Mock() 74 tr._read_fut = res 75 tr._loop_reading(res) 76 self.assertFalse(self.loop._proactor.recv.called) 77 self.assertTrue(self.protocol.eof_received.called) 78 self.assertTrue(tr.close.called) 79 80 def test_loop_reading_aborted(self): 81 err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() 82 83 tr = self.socket_transport() 84 tr._fatal_error = mock.Mock() 85 tr._loop_reading() 86 tr._fatal_error.assert_called_with( 87 err, 88 'Fatal read error on pipe transport') 89 90 def test_loop_reading_aborted_closing(self): 91 self.loop._proactor.recv.side_effect = ConnectionAbortedError() 92 93 tr = self.socket_transport() 94 tr._closing = True 95 tr._fatal_error = mock.Mock() 96 tr._loop_reading() 97 self.assertFalse(tr._fatal_error.called) 98 99 def test_loop_reading_aborted_is_fatal(self): 100 self.loop._proactor.recv.side_effect = ConnectionAbortedError() 101 tr = self.socket_transport() 102 tr._closing = False 103 tr._fatal_error = mock.Mock() 104 tr._loop_reading() 105 self.assertTrue(tr._fatal_error.called) 106 107 def test_loop_reading_conn_reset_lost(self): 108 err = self.loop._proactor.recv.side_effect = ConnectionResetError() 109 110 tr = self.socket_transport() 111 tr._closing = False 112 tr._fatal_error = mock.Mock() 113 tr._force_close = mock.Mock() 114 tr._loop_reading() 115 self.assertFalse(tr._fatal_error.called) 116 tr._force_close.assert_called_with(err) 117 118 def test_loop_reading_exception(self): 119 err = self.loop._proactor.recv.side_effect = (OSError()) 120 121 tr = self.socket_transport() 122 tr._fatal_error = mock.Mock() 123 tr._loop_reading() 124 tr._fatal_error.assert_called_with( 125 err, 126 'Fatal read error on pipe transport') 127 128 def test_write(self): 129 tr = self.socket_transport() 130 tr._loop_writing = mock.Mock() 131 tr.write(b'data') 132 self.assertEqual(tr._buffer, None) 133 tr._loop_writing.assert_called_with(data=b'data') 134 135 def test_write_no_data(self): 136 tr = self.socket_transport() 137 tr.write(b'') 138 self.assertFalse(tr._buffer) 139 140 def test_write_more(self): 141 tr = self.socket_transport() 142 tr._write_fut = mock.Mock() 143 tr._loop_writing = mock.Mock() 144 tr.write(b'data') 145 self.assertEqual(tr._buffer, b'data') 146 self.assertFalse(tr._loop_writing.called) 147 148 def test_loop_writing(self): 149 tr = self.socket_transport() 150 tr._buffer = bytearray(b'data') 151 tr._loop_writing() 152 self.loop._proactor.send.assert_called_with(self.sock, b'data') 153 self.loop._proactor.send.return_value.add_done_callback.\ 154 assert_called_with(tr._loop_writing) 155 156 @mock.patch('asyncio.proactor_events.logger') 157 def test_loop_writing_err(self, m_log): 158 err = self.loop._proactor.send.side_effect = OSError() 159 tr = self.socket_transport() 160 tr._fatal_error = mock.Mock() 161 tr._buffer = [b'da', b'ta'] 162 tr._loop_writing() 163 tr._fatal_error.assert_called_with( 164 err, 165 'Fatal write error on pipe transport') 166 tr._conn_lost = 1 167 168 tr.write(b'data') 169 tr.write(b'data') 170 tr.write(b'data') 171 tr.write(b'data') 172 tr.write(b'data') 173 self.assertEqual(tr._buffer, None) 174 m_log.warning.assert_called_with('socket.send() raised exception.') 175 176 def test_loop_writing_stop(self): 177 fut = asyncio.Future(loop=self.loop) 178 fut.set_result(b'data') 179 180 tr = self.socket_transport() 181 tr._write_fut = fut 182 tr._loop_writing(fut) 183 self.assertIsNone(tr._write_fut) 184 185 def test_loop_writing_closing(self): 186 fut = asyncio.Future(loop=self.loop) 187 fut.set_result(1) 188 189 tr = self.socket_transport() 190 tr._write_fut = fut 191 tr.close() 192 tr._loop_writing(fut) 193 self.assertIsNone(tr._write_fut) 194 test_utils.run_briefly(self.loop) 195 self.protocol.connection_lost.assert_called_with(None) 196 197 def test_abort(self): 198 tr = self.socket_transport() 199 tr._force_close = mock.Mock() 200 tr.abort() 201 tr._force_close.assert_called_with(None) 202 203 def test_close(self): 204 tr = self.socket_transport() 205 tr.close() 206 test_utils.run_briefly(self.loop) 207 self.protocol.connection_lost.assert_called_with(None) 208 self.assertTrue(tr.is_closing()) 209 self.assertEqual(tr._conn_lost, 1) 210 211 self.protocol.connection_lost.reset_mock() 212 tr.close() 213 test_utils.run_briefly(self.loop) 214 self.assertFalse(self.protocol.connection_lost.called) 215 216 def test_close_write_fut(self): 217 tr = self.socket_transport() 218 tr._write_fut = mock.Mock() 219 tr.close() 220 test_utils.run_briefly(self.loop) 221 self.assertFalse(self.protocol.connection_lost.called) 222 223 def test_close_buffer(self): 224 tr = self.socket_transport() 225 tr._buffer = [b'data'] 226 tr.close() 227 test_utils.run_briefly(self.loop) 228 self.assertFalse(self.protocol.connection_lost.called) 229 230 @mock.patch('asyncio.base_events.logger') 231 def test_fatal_error(self, m_logging): 232 tr = self.socket_transport() 233 tr._force_close = mock.Mock() 234 tr._fatal_error(None) 235 self.assertTrue(tr._force_close.called) 236 self.assertTrue(m_logging.error.called) 237 238 def test_force_close(self): 239 tr = self.socket_transport() 240 tr._buffer = [b'data'] 241 read_fut = tr._read_fut = mock.Mock() 242 write_fut = tr._write_fut = mock.Mock() 243 tr._force_close(None) 244 245 read_fut.cancel.assert_called_with() 246 write_fut.cancel.assert_called_with() 247 test_utils.run_briefly(self.loop) 248 self.protocol.connection_lost.assert_called_with(None) 249 self.assertEqual(None, tr._buffer) 250 self.assertEqual(tr._conn_lost, 1) 251 252 def test_force_close_idempotent(self): 253 tr = self.socket_transport() 254 tr._closing = True 255 tr._force_close(None) 256 test_utils.run_briefly(self.loop) 257 self.assertFalse(self.protocol.connection_lost.called) 258 259 def test_fatal_error_2(self): 260 tr = self.socket_transport() 261 tr._buffer = [b'data'] 262 tr._force_close(None) 263 264 test_utils.run_briefly(self.loop) 265 self.protocol.connection_lost.assert_called_with(None) 266 self.assertEqual(None, tr._buffer) 267 268 def test_call_connection_lost(self): 269 tr = self.socket_transport() 270 tr._call_connection_lost(None) 271 self.assertTrue(self.protocol.connection_lost.called) 272 self.assertTrue(self.sock.close.called) 273 274 def test_write_eof(self): 275 tr = self.socket_transport() 276 self.assertTrue(tr.can_write_eof()) 277 tr.write_eof() 278 self.sock.shutdown.assert_called_with(socket.SHUT_WR) 279 tr.write_eof() 280 self.assertEqual(self.sock.shutdown.call_count, 1) 281 tr.close() 282 283 def test_write_eof_buffer(self): 284 tr = self.socket_transport() 285 f = asyncio.Future(loop=self.loop) 286 tr._loop._proactor.send.return_value = f 287 tr.write(b'data') 288 tr.write_eof() 289 self.assertTrue(tr._eof_written) 290 self.assertFalse(self.sock.shutdown.called) 291 tr._loop._proactor.send.assert_called_with(self.sock, b'data') 292 f.set_result(4) 293 self.loop._run_once() 294 self.sock.shutdown.assert_called_with(socket.SHUT_WR) 295 tr.close() 296 297 def test_write_eof_write_pipe(self): 298 tr = _ProactorWritePipeTransport( 299 self.loop, self.sock, self.protocol) 300 self.assertTrue(tr.can_write_eof()) 301 tr.write_eof() 302 self.assertTrue(tr.is_closing()) 303 self.loop._run_once() 304 self.assertTrue(self.sock.close.called) 305 tr.close() 306 307 def test_write_eof_buffer_write_pipe(self): 308 tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) 309 f = asyncio.Future(loop=self.loop) 310 tr._loop._proactor.send.return_value = f 311 tr.write(b'data') 312 tr.write_eof() 313 self.assertTrue(tr.is_closing()) 314 self.assertFalse(self.sock.shutdown.called) 315 tr._loop._proactor.send.assert_called_with(self.sock, b'data') 316 f.set_result(4) 317 self.loop._run_once() 318 self.loop._run_once() 319 self.assertTrue(self.sock.close.called) 320 tr.close() 321 322 def test_write_eof_duplex_pipe(self): 323 tr = _ProactorDuplexPipeTransport( 324 self.loop, self.sock, self.protocol) 325 self.assertFalse(tr.can_write_eof()) 326 with self.assertRaises(NotImplementedError): 327 tr.write_eof() 328 close_transport(tr) 329 330 def test_pause_resume_reading(self): 331 tr = self.socket_transport() 332 futures = [] 333 for msg in [b'data1', b'data2', b'data3', b'data4', b'']: 334 f = asyncio.Future(loop=self.loop) 335 f.set_result(msg) 336 futures.append(f) 337 self.loop._proactor.recv.side_effect = futures 338 self.loop._run_once() 339 self.assertFalse(tr._paused) 340 self.loop._run_once() 341 self.protocol.data_received.assert_called_with(b'data1') 342 self.loop._run_once() 343 self.protocol.data_received.assert_called_with(b'data2') 344 tr.pause_reading() 345 self.assertTrue(tr._paused) 346 for i in range(10): 347 self.loop._run_once() 348 self.protocol.data_received.assert_called_with(b'data2') 349 tr.resume_reading() 350 self.assertFalse(tr._paused) 351 self.loop._run_once() 352 self.protocol.data_received.assert_called_with(b'data3') 353 self.loop._run_once() 354 self.protocol.data_received.assert_called_with(b'data4') 355 tr.close() 356 357 358 def pause_writing_transport(self, high): 359 tr = self.socket_transport() 360 tr.set_write_buffer_limits(high=high) 361 362 self.assertEqual(tr.get_write_buffer_size(), 0) 363 self.assertFalse(self.protocol.pause_writing.called) 364 self.assertFalse(self.protocol.resume_writing.called) 365 return tr 366 367 def test_pause_resume_writing(self): 368 tr = self.pause_writing_transport(high=4) 369 370 # write a large chunk, must pause writing 371 fut = asyncio.Future(loop=self.loop) 372 self.loop._proactor.send.return_value = fut 373 tr.write(b'large data') 374 self.loop._run_once() 375 self.assertTrue(self.protocol.pause_writing.called) 376 377 # flush the buffer 378 fut.set_result(None) 379 self.loop._run_once() 380 self.assertEqual(tr.get_write_buffer_size(), 0) 381 self.assertTrue(self.protocol.resume_writing.called) 382 383 def test_pause_writing_2write(self): 384 tr = self.pause_writing_transport(high=4) 385 386 # first short write, the buffer is not full (3 <= 4) 387 fut1 = asyncio.Future(loop=self.loop) 388 self.loop._proactor.send.return_value = fut1 389 tr.write(b'123') 390 self.loop._run_once() 391 self.assertEqual(tr.get_write_buffer_size(), 3) 392 self.assertFalse(self.protocol.pause_writing.called) 393 394 # fill the buffer, must pause writing (6 > 4) 395 tr.write(b'abc') 396 self.loop._run_once() 397 self.assertEqual(tr.get_write_buffer_size(), 6) 398 self.assertTrue(self.protocol.pause_writing.called) 399 400 def test_pause_writing_3write(self): 401 tr = self.pause_writing_transport(high=4) 402 403 # first short write, the buffer is not full (1 <= 4) 404 fut = asyncio.Future(loop=self.loop) 405 self.loop._proactor.send.return_value = fut 406 tr.write(b'1') 407 self.loop._run_once() 408 self.assertEqual(tr.get_write_buffer_size(), 1) 409 self.assertFalse(self.protocol.pause_writing.called) 410 411 # second short write, the buffer is not full (3 <= 4) 412 tr.write(b'23') 413 self.loop._run_once() 414 self.assertEqual(tr.get_write_buffer_size(), 3) 415 self.assertFalse(self.protocol.pause_writing.called) 416 417 # fill the buffer, must pause writing (6 > 4) 418 tr.write(b'abc') 419 self.loop._run_once() 420 self.assertEqual(tr.get_write_buffer_size(), 6) 421 self.assertTrue(self.protocol.pause_writing.called) 422 423 def test_dont_pause_writing(self): 424 tr = self.pause_writing_transport(high=4) 425 426 # write a large chunk which completes immedialty, 427 # it should not pause writing 428 fut = asyncio.Future(loop=self.loop) 429 fut.set_result(None) 430 self.loop._proactor.send.return_value = fut 431 tr.write(b'very large data') 432 self.loop._run_once() 433 self.assertEqual(tr.get_write_buffer_size(), 0) 434 self.assertFalse(self.protocol.pause_writing.called) 435 436 437 class BaseProactorEventLoopTests(test_utils.TestCase): 438 439 def setUp(self): 440 super().setUp() 441 442 self.sock = test_utils.mock_nonblocking_socket() 443 self.proactor = mock.Mock() 444 445 self.ssock, self.csock = mock.Mock(), mock.Mock() 446 447 class EventLoop(BaseProactorEventLoop): 448 def _socketpair(s): 449 return (self.ssock, self.csock) 450 451 self.loop = EventLoop(self.proactor) 452 self.set_event_loop(self.loop) 453 454 @mock.patch.object(BaseProactorEventLoop, 'call_soon') 455 @mock.patch.object(BaseProactorEventLoop, '_socketpair') 456 def test_ctor(self, socketpair, call_soon): 457 ssock, csock = socketpair.return_value = ( 458 mock.Mock(), mock.Mock()) 459 loop = BaseProactorEventLoop(self.proactor) 460 self.assertIs(loop._ssock, ssock) 461 self.assertIs(loop._csock, csock) 462 self.assertEqual(loop._internal_fds, 1) 463 call_soon.assert_called_with(loop._loop_self_reading) 464 loop.close() 465 466 def test_close_self_pipe(self): 467 self.loop._close_self_pipe() 468 self.assertEqual(self.loop._internal_fds, 0) 469 self.assertTrue(self.ssock.close.called) 470 self.assertTrue(self.csock.close.called) 471 self.assertIsNone(self.loop._ssock) 472 self.assertIsNone(self.loop._csock) 473 474 # Don't call close(): _close_self_pipe() cannot be called twice 475 self.loop._closed = True 476 477 def test_close(self): 478 self.loop._close_self_pipe = mock.Mock() 479 self.loop.close() 480 self.assertTrue(self.loop._close_self_pipe.called) 481 self.assertTrue(self.proactor.close.called) 482 self.assertIsNone(self.loop._proactor) 483 484 self.loop._close_self_pipe.reset_mock() 485 self.loop.close() 486 self.assertFalse(self.loop._close_self_pipe.called) 487 488 def test_sock_recv(self): 489 self.loop.sock_recv(self.sock, 1024) 490 self.proactor.recv.assert_called_with(self.sock, 1024) 491 492 def test_sock_sendall(self): 493 self.loop.sock_sendall(self.sock, b'data') 494 self.proactor.send.assert_called_with(self.sock, b'data') 495 496 def test_sock_connect(self): 497 self.loop.sock_connect(self.sock, ('1.2.3.4', 123)) 498 self.proactor.connect.assert_called_with(self.sock, ('1.2.3.4', 123)) 499 500 def test_sock_accept(self): 501 self.loop.sock_accept(self.sock) 502 self.proactor.accept.assert_called_with(self.sock) 503 504 def test_socketpair(self): 505 class EventLoop(BaseProactorEventLoop): 506 # override the destructor to not log a ResourceWarning 507 def __del__(self): 508 pass 509 self.assertRaises( 510 NotImplementedError, EventLoop, self.proactor) 511 512 def test_make_socket_transport(self): 513 tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) 514 self.assertIsInstance(tr, _ProactorSocketTransport) 515 close_transport(tr) 516 517 def test_loop_self_reading(self): 518 self.loop._loop_self_reading() 519 self.proactor.recv.assert_called_with(self.ssock, 4096) 520 self.proactor.recv.return_value.add_done_callback.assert_called_with( 521 self.loop._loop_self_reading) 522 523 def test_loop_self_reading_fut(self): 524 fut = mock.Mock() 525 self.loop._loop_self_reading(fut) 526 self.assertTrue(fut.result.called) 527 self.proactor.recv.assert_called_with(self.ssock, 4096) 528 self.proactor.recv.return_value.add_done_callback.assert_called_with( 529 self.loop._loop_self_reading) 530 531 def test_loop_self_reading_exception(self): 532 self.loop.close = mock.Mock() 533 self.loop.call_exception_handler = mock.Mock() 534 self.proactor.recv.side_effect = OSError() 535 self.loop._loop_self_reading() 536 self.assertTrue(self.loop.call_exception_handler.called) 537 538 def test_write_to_self(self): 539 self.loop._write_to_self() 540 self.csock.send.assert_called_with(b'\0') 541 542 def test_process_events(self): 543 self.loop._process_events([]) 544 545 @mock.patch('asyncio.base_events.logger') 546 def test_create_server(self, m_log): 547 pf = mock.Mock() 548 call_soon = self.loop.call_soon = mock.Mock() 549 550 self.loop._start_serving(pf, self.sock) 551 self.assertTrue(call_soon.called) 552 553 # callback 554 loop = call_soon.call_args[0][0] 555 loop() 556 self.proactor.accept.assert_called_with(self.sock) 557 558 # conn 559 fut = mock.Mock() 560 fut.result.return_value = (mock.Mock(), mock.Mock()) 561 562 make_tr = self.loop._make_socket_transport = mock.Mock() 563 loop(fut) 564 self.assertTrue(fut.result.called) 565 self.assertTrue(make_tr.called) 566 567 # exception 568 fut.result.side_effect = OSError() 569 loop(fut) 570 self.assertTrue(self.sock.close.called) 571 self.assertTrue(m_log.error.called) 572 573 def test_create_server_cancel(self): 574 pf = mock.Mock() 575 call_soon = self.loop.call_soon = mock.Mock() 576 577 self.loop._start_serving(pf, self.sock) 578 loop = call_soon.call_args[0][0] 579 580 # cancelled 581 fut = asyncio.Future(loop=self.loop) 582 fut.cancel() 583 loop(fut) 584 self.assertTrue(self.sock.close.called) 585 586 def test_stop_serving(self): 587 sock = mock.Mock() 588 self.loop._stop_serving(sock) 589 self.assertTrue(sock.close.called) 590 self.proactor._stop_serving.assert_called_with(sock) 591 592 593 if __name__ == '__main__': 594 unittest.main() 595