Home | History | Annotate | Download | only in test_asyncio
      1 """Tests for streams.py."""
      2 
      3 import gc
      4 import os
      5 import queue
      6 import socket
      7 import sys
      8 import threading
      9 import unittest
     10 from unittest import mock
     11 try:
     12     import ssl
     13 except ImportError:
     14     ssl = None
     15 
     16 import asyncio
     17 from asyncio import test_utils
     18 
     19 
     20 class StreamReaderTests(test_utils.TestCase):
     21 
     22     DATA = b'line1\nline2\nline3\n'
     23 
     24     def setUp(self):
     25         super().setUp()
     26         self.loop = asyncio.new_event_loop()
     27         self.set_event_loop(self.loop)
     28 
     29     def tearDown(self):
     30         # just in case if we have transport close callbacks
     31         test_utils.run_briefly(self.loop)
     32 
     33         self.loop.close()
     34         gc.collect()
     35         super().tearDown()
     36 
     37     @mock.patch('asyncio.streams.events')
     38     def test_ctor_global_loop(self, m_events):
     39         stream = asyncio.StreamReader()
     40         self.assertIs(stream._loop, m_events.get_event_loop.return_value)
     41 
     42     def _basetest_open_connection(self, open_connection_fut):
     43         reader, writer = self.loop.run_until_complete(open_connection_fut)
     44         writer.write(b'GET / HTTP/1.0\r\n\r\n')
     45         f = reader.readline()
     46         data = self.loop.run_until_complete(f)
     47         self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
     48         f = reader.read()
     49         data = self.loop.run_until_complete(f)
     50         self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
     51         writer.close()
     52 
     53     def test_open_connection(self):
     54         with test_utils.run_test_server() as httpd:
     55             conn_fut = asyncio.open_connection(*httpd.address,
     56                                                loop=self.loop)
     57             self._basetest_open_connection(conn_fut)
     58 
     59     @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
     60     def test_open_unix_connection(self):
     61         with test_utils.run_test_unix_server() as httpd:
     62             conn_fut = asyncio.open_unix_connection(httpd.address,
     63                                                     loop=self.loop)
     64             self._basetest_open_connection(conn_fut)
     65 
     66     def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
     67         try:
     68             reader, writer = self.loop.run_until_complete(open_connection_fut)
     69         finally:
     70             asyncio.set_event_loop(None)
     71         writer.write(b'GET / HTTP/1.0\r\n\r\n')
     72         f = reader.read()
     73         data = self.loop.run_until_complete(f)
     74         self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
     75 
     76         writer.close()
     77 
     78     @unittest.skipIf(ssl is None, 'No ssl module')
     79     def test_open_connection_no_loop_ssl(self):
     80         with test_utils.run_test_server(use_ssl=True) as httpd:
     81             conn_fut = asyncio.open_connection(
     82                 *httpd.address,
     83                 ssl=test_utils.dummy_ssl_context(),
     84                 loop=self.loop)
     85 
     86             self._basetest_open_connection_no_loop_ssl(conn_fut)
     87 
     88     @unittest.skipIf(ssl is None, 'No ssl module')
     89     @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
     90     def test_open_unix_connection_no_loop_ssl(self):
     91         with test_utils.run_test_unix_server(use_ssl=True) as httpd:
     92             conn_fut = asyncio.open_unix_connection(
     93                 httpd.address,
     94                 ssl=test_utils.dummy_ssl_context(),
     95                 server_hostname='',
     96                 loop=self.loop)
     97 
     98             self._basetest_open_connection_no_loop_ssl(conn_fut)
     99 
    100     def _basetest_open_connection_error(self, open_connection_fut):
    101         reader, writer = self.loop.run_until_complete(open_connection_fut)
    102         writer._protocol.connection_lost(ZeroDivisionError())
    103         f = reader.read()
    104         with self.assertRaises(ZeroDivisionError):
    105             self.loop.run_until_complete(f)
    106         writer.close()
    107         test_utils.run_briefly(self.loop)
    108 
    109     def test_open_connection_error(self):
    110         with test_utils.run_test_server() as httpd:
    111             conn_fut = asyncio.open_connection(*httpd.address,
    112                                                loop=self.loop)
    113             self._basetest_open_connection_error(conn_fut)
    114 
    115     @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
    116     def test_open_unix_connection_error(self):
    117         with test_utils.run_test_unix_server() as httpd:
    118             conn_fut = asyncio.open_unix_connection(httpd.address,
    119                                                     loop=self.loop)
    120             self._basetest_open_connection_error(conn_fut)
    121 
    122     def test_feed_empty_data(self):
    123         stream = asyncio.StreamReader(loop=self.loop)
    124 
    125         stream.feed_data(b'')
    126         self.assertEqual(b'', stream._buffer)
    127 
    128     def test_feed_nonempty_data(self):
    129         stream = asyncio.StreamReader(loop=self.loop)
    130 
    131         stream.feed_data(self.DATA)
    132         self.assertEqual(self.DATA, stream._buffer)
    133 
    134     def test_read_zero(self):
    135         # Read zero bytes.
    136         stream = asyncio.StreamReader(loop=self.loop)
    137         stream.feed_data(self.DATA)
    138 
    139         data = self.loop.run_until_complete(stream.read(0))
    140         self.assertEqual(b'', data)
    141         self.assertEqual(self.DATA, stream._buffer)
    142 
    143     def test_read(self):
    144         # Read bytes.
    145         stream = asyncio.StreamReader(loop=self.loop)
    146         read_task = asyncio.Task(stream.read(30), loop=self.loop)
    147 
    148         def cb():
    149             stream.feed_data(self.DATA)
    150         self.loop.call_soon(cb)
    151 
    152         data = self.loop.run_until_complete(read_task)
    153         self.assertEqual(self.DATA, data)
    154         self.assertEqual(b'', stream._buffer)
    155 
    156     def test_read_line_breaks(self):
    157         # Read bytes without line breaks.
    158         stream = asyncio.StreamReader(loop=self.loop)
    159         stream.feed_data(b'line1')
    160         stream.feed_data(b'line2')
    161 
    162         data = self.loop.run_until_complete(stream.read(5))
    163 
    164         self.assertEqual(b'line1', data)
    165         self.assertEqual(b'line2', stream._buffer)
    166 
    167     def test_read_eof(self):
    168         # Read bytes, stop at eof.
    169         stream = asyncio.StreamReader(loop=self.loop)
    170         read_task = asyncio.Task(stream.read(1024), loop=self.loop)
    171 
    172         def cb():
    173             stream.feed_eof()
    174         self.loop.call_soon(cb)
    175 
    176         data = self.loop.run_until_complete(read_task)
    177         self.assertEqual(b'', data)
    178         self.assertEqual(b'', stream._buffer)
    179 
    180     def test_read_until_eof(self):
    181         # Read all bytes until eof.
    182         stream = asyncio.StreamReader(loop=self.loop)
    183         read_task = asyncio.Task(stream.read(-1), loop=self.loop)
    184 
    185         def cb():
    186             stream.feed_data(b'chunk1\n')
    187             stream.feed_data(b'chunk2')
    188             stream.feed_eof()
    189         self.loop.call_soon(cb)
    190 
    191         data = self.loop.run_until_complete(read_task)
    192 
    193         self.assertEqual(b'chunk1\nchunk2', data)
    194         self.assertEqual(b'', stream._buffer)
    195 
    196     def test_read_exception(self):
    197         stream = asyncio.StreamReader(loop=self.loop)
    198         stream.feed_data(b'line\n')
    199 
    200         data = self.loop.run_until_complete(stream.read(2))
    201         self.assertEqual(b'li', data)
    202 
    203         stream.set_exception(ValueError())
    204         self.assertRaises(
    205             ValueError, self.loop.run_until_complete, stream.read(2))
    206 
    207     def test_invalid_limit(self):
    208         with self.assertRaisesRegex(ValueError, 'imit'):
    209             asyncio.StreamReader(limit=0, loop=self.loop)
    210 
    211         with self.assertRaisesRegex(ValueError, 'imit'):
    212             asyncio.StreamReader(limit=-1, loop=self.loop)
    213 
    214     def test_read_limit(self):
    215         stream = asyncio.StreamReader(limit=3, loop=self.loop)
    216         stream.feed_data(b'chunk')
    217         data = self.loop.run_until_complete(stream.read(5))
    218         self.assertEqual(b'chunk', data)
    219         self.assertEqual(b'', stream._buffer)
    220 
    221     def test_readline(self):
    222         # Read one line. 'readline' will need to wait for the data
    223         # to come from 'cb'
    224         stream = asyncio.StreamReader(loop=self.loop)
    225         stream.feed_data(b'chunk1 ')
    226         read_task = asyncio.Task(stream.readline(), loop=self.loop)
    227 
    228         def cb():
    229             stream.feed_data(b'chunk2 ')
    230             stream.feed_data(b'chunk3 ')
    231             stream.feed_data(b'\n chunk4')
    232         self.loop.call_soon(cb)
    233 
    234         line = self.loop.run_until_complete(read_task)
    235         self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
    236         self.assertEqual(b' chunk4', stream._buffer)
    237 
    238     def test_readline_limit_with_existing_data(self):
    239         # Read one line. The data is in StreamReader's buffer
    240         # before the event loop is run.
    241 
    242         stream = asyncio.StreamReader(limit=3, loop=self.loop)
    243         stream.feed_data(b'li')
    244         stream.feed_data(b'ne1\nline2\n')
    245 
    246         self.assertRaises(
    247             ValueError, self.loop.run_until_complete, stream.readline())
    248         # The buffer should contain the remaining data after exception
    249         self.assertEqual(b'line2\n', stream._buffer)
    250 
    251         stream = asyncio.StreamReader(limit=3, loop=self.loop)
    252         stream.feed_data(b'li')
    253         stream.feed_data(b'ne1')
    254         stream.feed_data(b'li')
    255 
    256         self.assertRaises(
    257             ValueError, self.loop.run_until_complete, stream.readline())
    258         # No b'\n' at the end. The 'limit' is set to 3. So before
    259         # waiting for the new data in buffer, 'readline' will consume
    260         # the entire buffer, and since the length of the consumed data
    261         # is more than 3, it will raise a ValueError. The buffer is
    262         # expected to be empty now.
    263         self.assertEqual(b'', stream._buffer)
    264 
    265     def test_at_eof(self):
    266         stream = asyncio.StreamReader(loop=self.loop)
    267         self.assertFalse(stream.at_eof())
    268 
    269         stream.feed_data(b'some data\n')
    270         self.assertFalse(stream.at_eof())
    271 
    272         self.loop.run_until_complete(stream.readline())
    273         self.assertFalse(stream.at_eof())
    274 
    275         stream.feed_data(b'some data\n')
    276         stream.feed_eof()
    277         self.loop.run_until_complete(stream.readline())
    278         self.assertTrue(stream.at_eof())
    279 
    280     def test_readline_limit(self):
    281         # Read one line. StreamReaders are fed with data after
    282         # their 'readline' methods are called.
    283 
    284         stream = asyncio.StreamReader(limit=7, loop=self.loop)
    285         def cb():
    286             stream.feed_data(b'chunk1')
    287             stream.feed_data(b'chunk2')
    288             stream.feed_data(b'chunk3\n')
    289             stream.feed_eof()
    290         self.loop.call_soon(cb)
    291 
    292         self.assertRaises(
    293             ValueError, self.loop.run_until_complete, stream.readline())
    294         # The buffer had just one line of data, and after raising
    295         # a ValueError it should be empty.
    296         self.assertEqual(b'', stream._buffer)
    297 
    298         stream = asyncio.StreamReader(limit=7, loop=self.loop)
    299         def cb():
    300             stream.feed_data(b'chunk1')
    301             stream.feed_data(b'chunk2\n')
    302             stream.feed_data(b'chunk3\n')
    303             stream.feed_eof()
    304         self.loop.call_soon(cb)
    305 
    306         self.assertRaises(
    307             ValueError, self.loop.run_until_complete, stream.readline())
    308         self.assertEqual(b'chunk3\n', stream._buffer)
    309 
    310         # check strictness of the limit
    311         stream = asyncio.StreamReader(limit=7, loop=self.loop)
    312         stream.feed_data(b'1234567\n')
    313         line = self.loop.run_until_complete(stream.readline())
    314         self.assertEqual(b'1234567\n', line)
    315         self.assertEqual(b'', stream._buffer)
    316 
    317         stream.feed_data(b'12345678\n')
    318         with self.assertRaises(ValueError) as cm:
    319             self.loop.run_until_complete(stream.readline())
    320         self.assertEqual(b'', stream._buffer)
    321 
    322         stream.feed_data(b'12345678')
    323         with self.assertRaises(ValueError) as cm:
    324             self.loop.run_until_complete(stream.readline())
    325         self.assertEqual(b'', stream._buffer)
    326 
    327     def test_readline_nolimit_nowait(self):
    328         # All needed data for the first 'readline' call will be
    329         # in the buffer.
    330         stream = asyncio.StreamReader(loop=self.loop)
    331         stream.feed_data(self.DATA[:6])
    332         stream.feed_data(self.DATA[6:])
    333 
    334         line = self.loop.run_until_complete(stream.readline())
    335 
    336         self.assertEqual(b'line1\n', line)
    337         self.assertEqual(b'line2\nline3\n', stream._buffer)
    338 
    339     def test_readline_eof(self):
    340         stream = asyncio.StreamReader(loop=self.loop)
    341         stream.feed_data(b'some data')
    342         stream.feed_eof()
    343 
    344         line = self.loop.run_until_complete(stream.readline())
    345         self.assertEqual(b'some data', line)
    346 
    347     def test_readline_empty_eof(self):
    348         stream = asyncio.StreamReader(loop=self.loop)
    349         stream.feed_eof()
    350 
    351         line = self.loop.run_until_complete(stream.readline())
    352         self.assertEqual(b'', line)
    353 
    354     def test_readline_read_byte_count(self):
    355         stream = asyncio.StreamReader(loop=self.loop)
    356         stream.feed_data(self.DATA)
    357 
    358         self.loop.run_until_complete(stream.readline())
    359 
    360         data = self.loop.run_until_complete(stream.read(7))
    361 
    362         self.assertEqual(b'line2\nl', data)
    363         self.assertEqual(b'ine3\n', stream._buffer)
    364 
    365     def test_readline_exception(self):
    366         stream = asyncio.StreamReader(loop=self.loop)
    367         stream.feed_data(b'line\n')
    368 
    369         data = self.loop.run_until_complete(stream.readline())
    370         self.assertEqual(b'line\n', data)
    371 
    372         stream.set_exception(ValueError())
    373         self.assertRaises(
    374             ValueError, self.loop.run_until_complete, stream.readline())
    375         self.assertEqual(b'', stream._buffer)
    376 
    377     def test_readuntil_separator(self):
    378         stream = asyncio.StreamReader(loop=self.loop)
    379         with self.assertRaisesRegex(ValueError, 'Separator should be'):
    380             self.loop.run_until_complete(stream.readuntil(separator=b''))
    381 
    382     def test_readuntil_multi_chunks(self):
    383         stream = asyncio.StreamReader(loop=self.loop)
    384 
    385         stream.feed_data(b'lineAAA')
    386         data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
    387         self.assertEqual(b'lineAAA', data)
    388         self.assertEqual(b'', stream._buffer)
    389 
    390         stream.feed_data(b'lineAAA')
    391         data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
    392         self.assertEqual(b'lineAAA', data)
    393         self.assertEqual(b'', stream._buffer)
    394 
    395         stream.feed_data(b'lineAAAxxx')
    396         data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
    397         self.assertEqual(b'lineAAA', data)
    398         self.assertEqual(b'xxx', stream._buffer)
    399 
    400     def test_readuntil_multi_chunks_1(self):
    401         stream = asyncio.StreamReader(loop=self.loop)
    402 
    403         stream.feed_data(b'QWEaa')
    404         stream.feed_data(b'XYaa')
    405         stream.feed_data(b'a')
    406         data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
    407         self.assertEqual(b'QWEaaXYaaa', data)
    408         self.assertEqual(b'', stream._buffer)
    409 
    410         stream.feed_data(b'QWEaa')
    411         stream.feed_data(b'XYa')
    412         stream.feed_data(b'aa')
    413         data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
    414         self.assertEqual(b'QWEaaXYaaa', data)
    415         self.assertEqual(b'', stream._buffer)
    416 
    417         stream.feed_data(b'aaa')
    418         data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
    419         self.assertEqual(b'aaa', data)
    420         self.assertEqual(b'', stream._buffer)
    421 
    422         stream.feed_data(b'Xaaa')
    423         data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
    424         self.assertEqual(b'Xaaa', data)
    425         self.assertEqual(b'', stream._buffer)
    426 
    427         stream.feed_data(b'XXX')
    428         stream.feed_data(b'a')
    429         stream.feed_data(b'a')
    430         stream.feed_data(b'a')
    431         data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
    432         self.assertEqual(b'XXXaaa', data)
    433         self.assertEqual(b'', stream._buffer)
    434 
    435     def test_readuntil_eof(self):
    436         stream = asyncio.StreamReader(loop=self.loop)
    437         stream.feed_data(b'some dataAA')
    438         stream.feed_eof()
    439 
    440         with self.assertRaises(asyncio.IncompleteReadError) as cm:
    441             self.loop.run_until_complete(stream.readuntil(b'AAA'))
    442         self.assertEqual(cm.exception.partial, b'some dataAA')
    443         self.assertIsNone(cm.exception.expected)
    444         self.assertEqual(b'', stream._buffer)
    445 
    446     def test_readuntil_limit_found_sep(self):
    447         stream = asyncio.StreamReader(loop=self.loop, limit=3)
    448         stream.feed_data(b'some dataAA')
    449 
    450         with self.assertRaisesRegex(asyncio.LimitOverrunError,
    451                                     'not found') as cm:
    452             self.loop.run_until_complete(stream.readuntil(b'AAA'))
    453 
    454         self.assertEqual(b'some dataAA', stream._buffer)
    455 
    456         stream.feed_data(b'A')
    457         with self.assertRaisesRegex(asyncio.LimitOverrunError,
    458                                     'is found') as cm:
    459             self.loop.run_until_complete(stream.readuntil(b'AAA'))
    460 
    461         self.assertEqual(b'some dataAAA', stream._buffer)
    462 
    463     def test_readexactly_zero_or_less(self):
    464         # Read exact number of bytes (zero or less).
    465         stream = asyncio.StreamReader(loop=self.loop)
    466         stream.feed_data(self.DATA)
    467 
    468         data = self.loop.run_until_complete(stream.readexactly(0))
    469         self.assertEqual(b'', data)
    470         self.assertEqual(self.DATA, stream._buffer)
    471 
    472         with self.assertRaisesRegex(ValueError, 'less than zero'):
    473             self.loop.run_until_complete(stream.readexactly(-1))
    474         self.assertEqual(self.DATA, stream._buffer)
    475 
    476     def test_readexactly(self):
    477         # Read exact number of bytes.
    478         stream = asyncio.StreamReader(loop=self.loop)
    479 
    480         n = 2 * len(self.DATA)
    481         read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
    482 
    483         def cb():
    484             stream.feed_data(self.DATA)
    485             stream.feed_data(self.DATA)
    486             stream.feed_data(self.DATA)
    487         self.loop.call_soon(cb)
    488 
    489         data = self.loop.run_until_complete(read_task)
    490         self.assertEqual(self.DATA + self.DATA, data)
    491         self.assertEqual(self.DATA, stream._buffer)
    492 
    493     def test_readexactly_limit(self):
    494         stream = asyncio.StreamReader(limit=3, loop=self.loop)
    495         stream.feed_data(b'chunk')
    496         data = self.loop.run_until_complete(stream.readexactly(5))
    497         self.assertEqual(b'chunk', data)
    498         self.assertEqual(b'', stream._buffer)
    499 
    500     def test_readexactly_eof(self):
    501         # Read exact number of bytes (eof).
    502         stream = asyncio.StreamReader(loop=self.loop)
    503         n = 2 * len(self.DATA)
    504         read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
    505 
    506         def cb():
    507             stream.feed_data(self.DATA)
    508             stream.feed_eof()
    509         self.loop.call_soon(cb)
    510 
    511         with self.assertRaises(asyncio.IncompleteReadError) as cm:
    512             self.loop.run_until_complete(read_task)
    513         self.assertEqual(cm.exception.partial, self.DATA)
    514         self.assertEqual(cm.exception.expected, n)
    515         self.assertEqual(str(cm.exception),
    516                          '18 bytes read on a total of 36 expected bytes')
    517         self.assertEqual(b'', stream._buffer)
    518 
    519     def test_readexactly_exception(self):
    520         stream = asyncio.StreamReader(loop=self.loop)
    521         stream.feed_data(b'line\n')
    522 
    523         data = self.loop.run_until_complete(stream.readexactly(2))
    524         self.assertEqual(b'li', data)
    525 
    526         stream.set_exception(ValueError())
    527         self.assertRaises(
    528             ValueError, self.loop.run_until_complete, stream.readexactly(2))
    529 
    530     def test_exception(self):
    531         stream = asyncio.StreamReader(loop=self.loop)
    532         self.assertIsNone(stream.exception())
    533 
    534         exc = ValueError()
    535         stream.set_exception(exc)
    536         self.assertIs(stream.exception(), exc)
    537 
    538     def test_exception_waiter(self):
    539         stream = asyncio.StreamReader(loop=self.loop)
    540 
    541         @asyncio.coroutine
    542         def set_err():
    543             stream.set_exception(ValueError())
    544 
    545         t1 = asyncio.Task(stream.readline(), loop=self.loop)
    546         t2 = asyncio.Task(set_err(), loop=self.loop)
    547 
    548         self.loop.run_until_complete(asyncio.wait([t1, t2], loop=self.loop))
    549 
    550         self.assertRaises(ValueError, t1.result)
    551 
    552     def test_exception_cancel(self):
    553         stream = asyncio.StreamReader(loop=self.loop)
    554 
    555         t = asyncio.Task(stream.readline(), loop=self.loop)
    556         test_utils.run_briefly(self.loop)
    557         t.cancel()
    558         test_utils.run_briefly(self.loop)
    559         # The following line fails if set_exception() isn't careful.
    560         stream.set_exception(RuntimeError('message'))
    561         test_utils.run_briefly(self.loop)
    562         self.assertIs(stream._waiter, None)
    563 
    564     def test_start_server(self):
    565 
    566         class MyServer:
    567 
    568             def __init__(self, loop):
    569                 self.server = None
    570                 self.loop = loop
    571 
    572             @asyncio.coroutine
    573             def handle_client(self, client_reader, client_writer):
    574                 data = yield from client_reader.readline()
    575                 client_writer.write(data)
    576                 yield from client_writer.drain()
    577                 client_writer.close()
    578 
    579             def start(self):
    580                 sock = socket.socket()
    581                 sock.bind(('127.0.0.1', 0))
    582                 self.server = self.loop.run_until_complete(
    583                     asyncio.start_server(self.handle_client,
    584                                          sock=sock,
    585                                          loop=self.loop))
    586                 return sock.getsockname()
    587 
    588             def handle_client_callback(self, client_reader, client_writer):
    589                 self.loop.create_task(self.handle_client(client_reader,
    590                                                          client_writer))
    591 
    592             def start_callback(self):
    593                 sock = socket.socket()
    594                 sock.bind(('127.0.0.1', 0))
    595                 addr = sock.getsockname()
    596                 sock.close()
    597                 self.server = self.loop.run_until_complete(
    598                     asyncio.start_server(self.handle_client_callback,
    599                                          host=addr[0], port=addr[1],
    600                                          loop=self.loop))
    601                 return addr
    602 
    603             def stop(self):
    604                 if self.server is not None:
    605                     self.server.close()
    606                     self.loop.run_until_complete(self.server.wait_closed())
    607                     self.server = None
    608 
    609         @asyncio.coroutine
    610         def client(addr):
    611             reader, writer = yield from asyncio.open_connection(
    612                 *addr, loop=self.loop)
    613             # send a line
    614             writer.write(b"hello world!\n")
    615             # read it back
    616             msgback = yield from reader.readline()
    617             writer.close()
    618             return msgback
    619 
    620         # test the server variant with a coroutine as client handler
    621         server = MyServer(self.loop)
    622         addr = server.start()
    623         msg = self.loop.run_until_complete(asyncio.Task(client(addr),
    624                                                         loop=self.loop))
    625         server.stop()
    626         self.assertEqual(msg, b"hello world!\n")
    627 
    628         # test the server variant with a callback as client handler
    629         server = MyServer(self.loop)
    630         addr = server.start_callback()
    631         msg = self.loop.run_until_complete(asyncio.Task(client(addr),
    632                                                         loop=self.loop))
    633         server.stop()
    634         self.assertEqual(msg, b"hello world!\n")
    635 
    636     @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
    637     def test_start_unix_server(self):
    638 
    639         class MyServer:
    640 
    641             def __init__(self, loop, path):
    642                 self.server = None
    643                 self.loop = loop
    644                 self.path = path
    645 
    646             @asyncio.coroutine
    647             def handle_client(self, client_reader, client_writer):
    648                 data = yield from client_reader.readline()
    649                 client_writer.write(data)
    650                 yield from client_writer.drain()
    651                 client_writer.close()
    652 
    653             def start(self):
    654                 self.server = self.loop.run_until_complete(
    655                     asyncio.start_unix_server(self.handle_client,
    656                                               path=self.path,
    657                                               loop=self.loop))
    658 
    659             def handle_client_callback(self, client_reader, client_writer):
    660                 self.loop.create_task(self.handle_client(client_reader,
    661                                                          client_writer))
    662 
    663             def start_callback(self):
    664                 start = asyncio.start_unix_server(self.handle_client_callback,
    665                                                   path=self.path,
    666                                                   loop=self.loop)
    667                 self.server = self.loop.run_until_complete(start)
    668 
    669             def stop(self):
    670                 if self.server is not None:
    671                     self.server.close()
    672                     self.loop.run_until_complete(self.server.wait_closed())
    673                     self.server = None
    674 
    675         @asyncio.coroutine
    676         def client(path):
    677             reader, writer = yield from asyncio.open_unix_connection(
    678                 path, loop=self.loop)
    679             # send a line
    680             writer.write(b"hello world!\n")
    681             # read it back
    682             msgback = yield from reader.readline()
    683             writer.close()
    684             return msgback
    685 
    686         # test the server variant with a coroutine as client handler
    687         with test_utils.unix_socket_path() as path:
    688             server = MyServer(self.loop, path)
    689             server.start()
    690             msg = self.loop.run_until_complete(asyncio.Task(client(path),
    691                                                             loop=self.loop))
    692             server.stop()
    693             self.assertEqual(msg, b"hello world!\n")
    694 
    695         # test the server variant with a callback as client handler
    696         with test_utils.unix_socket_path() as path:
    697             server = MyServer(self.loop, path)
    698             server.start_callback()
    699             msg = self.loop.run_until_complete(asyncio.Task(client(path),
    700                                                             loop=self.loop))
    701             server.stop()
    702             self.assertEqual(msg, b"hello world!\n")
    703 
    704     @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
    705     def test_read_all_from_pipe_reader(self):
    706         # See asyncio issue 168.  This test is derived from the example
    707         # subprocess_attach_read_pipe.py, but we configure the
    708         # StreamReader's limit so that twice it is less than the size
    709         # of the data writter.  Also we must explicitly attach a child
    710         # watcher to the event loop.
    711 
    712         code = """\
    713 import os, sys
    714 fd = int(sys.argv[1])
    715 os.write(fd, b'data')
    716 os.close(fd)
    717 """
    718         rfd, wfd = os.pipe()
    719         args = [sys.executable, '-c', code, str(wfd)]
    720 
    721         pipe = open(rfd, 'rb', 0)
    722         reader = asyncio.StreamReader(loop=self.loop, limit=1)
    723         protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
    724         transport, _ = self.loop.run_until_complete(
    725             self.loop.connect_read_pipe(lambda: protocol, pipe))
    726 
    727         watcher = asyncio.SafeChildWatcher()
    728         watcher.attach_loop(self.loop)
    729         try:
    730             asyncio.set_child_watcher(watcher)
    731             create = asyncio.create_subprocess_exec(*args,
    732                                                     pass_fds={wfd},
    733                                                     loop=self.loop)
    734             proc = self.loop.run_until_complete(create)
    735             self.loop.run_until_complete(proc.wait())
    736         finally:
    737             asyncio.set_child_watcher(None)
    738 
    739         os.close(wfd)
    740         data = self.loop.run_until_complete(reader.read(-1))
    741         self.assertEqual(data, b'data')
    742 
    743     def test_streamreader_constructor(self):
    744         self.addCleanup(asyncio.set_event_loop, None)
    745         asyncio.set_event_loop(self.loop)
    746 
    747         # asyncio issue #184: Ensure that StreamReaderProtocol constructor
    748         # retrieves the current loop if the loop parameter is not set
    749         reader = asyncio.StreamReader()
    750         self.assertIs(reader._loop, self.loop)
    751 
    752     def test_streamreaderprotocol_constructor(self):
    753         self.addCleanup(asyncio.set_event_loop, None)
    754         asyncio.set_event_loop(self.loop)
    755 
    756         # asyncio issue #184: Ensure that StreamReaderProtocol constructor
    757         # retrieves the current loop if the loop parameter is not set
    758         reader = mock.Mock()
    759         protocol = asyncio.StreamReaderProtocol(reader)
    760         self.assertIs(protocol._loop, self.loop)
    761 
    762     def test_drain_raises(self):
    763         # See http://bugs.python.org/issue25441
    764 
    765         # This test should not use asyncio for the mock server; the
    766         # whole point of the test is to test for a bug in drain()
    767         # where it never gives up the event loop but the socket is
    768         # closed on the  server side.
    769 
    770         q = queue.Queue()
    771 
    772         def server():
    773             # Runs in a separate thread.
    774             sock = socket.socket()
    775             with sock:
    776                 sock.bind(('localhost', 0))
    777                 sock.listen(1)
    778                 addr = sock.getsockname()
    779                 q.put(addr)
    780                 clt, _ = sock.accept()
    781                 clt.close()
    782 
    783         @asyncio.coroutine
    784         def client(host, port):
    785             reader, writer = yield from asyncio.open_connection(
    786                 host, port, loop=self.loop)
    787 
    788             while True:
    789                 writer.write(b"foo\n")
    790                 yield from writer.drain()
    791 
    792         # Start the server thread and wait for it to be listening.
    793         thread = threading.Thread(target=server)
    794         thread.setDaemon(True)
    795         thread.start()
    796         addr = q.get()
    797 
    798         # Should not be stuck in an infinite loop.
    799         with self.assertRaises((ConnectionResetError, BrokenPipeError)):
    800             self.loop.run_until_complete(client(*addr))
    801 
    802         # Clean up the thread.  (Only on success; on failure, it may
    803         # be stuck in accept().)
    804         thread.join()
    805 
    806     def test___repr__(self):
    807         stream = asyncio.StreamReader(loop=self.loop)
    808         self.assertEqual("<StreamReader>", repr(stream))
    809 
    810     def test___repr__nondefault_limit(self):
    811         stream = asyncio.StreamReader(loop=self.loop, limit=123)
    812         self.assertEqual("<StreamReader l=123>", repr(stream))
    813 
    814     def test___repr__eof(self):
    815         stream = asyncio.StreamReader(loop=self.loop)
    816         stream.feed_eof()
    817         self.assertEqual("<StreamReader eof>", repr(stream))
    818 
    819     def test___repr__data(self):
    820         stream = asyncio.StreamReader(loop=self.loop)
    821         stream.feed_data(b'data')
    822         self.assertEqual("<StreamReader 4 bytes>", repr(stream))
    823 
    824     def test___repr__exception(self):
    825         stream = asyncio.StreamReader(loop=self.loop)
    826         exc = RuntimeError()
    827         stream.set_exception(exc)
    828         self.assertEqual("<StreamReader e=RuntimeError()>", repr(stream))
    829 
    830     def test___repr__waiter(self):
    831         stream = asyncio.StreamReader(loop=self.loop)
    832         stream._waiter = asyncio.Future(loop=self.loop)
    833         self.assertRegex(
    834             repr(stream),
    835             r"<StreamReader w=<Future pending[\S ]*>>")
    836         stream._waiter.set_result(None)
    837         self.loop.run_until_complete(stream._waiter)
    838         stream._waiter = None
    839         self.assertEqual("<StreamReader>", repr(stream))
    840 
    841     def test___repr__transport(self):
    842         stream = asyncio.StreamReader(loop=self.loop)
    843         stream._transport = mock.Mock()
    844         stream._transport.__repr__ = mock.Mock()
    845         stream._transport.__repr__.return_value = "<Transport>"
    846         self.assertEqual("<StreamReader t=<Transport>>", repr(stream))
    847 
    848 
    849 if __name__ == '__main__':
    850     unittest.main()
    851