Home | History | Annotate | Download | only in asyncio
      1 __all__ = (
      2     'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
      3     'open_connection', 'start_server',
      4     'IncompleteReadError', 'LimitOverrunError',
      5 )
      6 
      7 import socket
      8 
      9 if hasattr(socket, 'AF_UNIX'):
     10     __all__ += ('open_unix_connection', 'start_unix_server')
     11 
     12 from . import coroutines
     13 from . import events
     14 from . import protocols
     15 from .log import logger
     16 from .tasks import sleep
     17 
     18 
     19 _DEFAULT_LIMIT = 2 ** 16  # 64 KiB
     20 
     21 
     22 class IncompleteReadError(EOFError):
     23     """
     24     Incomplete read error. Attributes:
     25 
     26     - partial: read bytes string before the end of stream was reached
     27     - expected: total number of expected bytes (or None if unknown)
     28     """
     29     def __init__(self, partial, expected):
     30         super().__init__(f'{len(partial)} bytes read on a total of '
     31                          f'{expected!r} expected bytes')
     32         self.partial = partial
     33         self.expected = expected
     34 
     35     def __reduce__(self):
     36         return type(self), (self.partial, self.expected)
     37 
     38 
     39 class LimitOverrunError(Exception):
     40     """Reached the buffer limit while looking for a separator.
     41 
     42     Attributes:
     43     - consumed: total number of to be consumed bytes.
     44     """
     45     def __init__(self, message, consumed):
     46         super().__init__(message)
     47         self.consumed = consumed
     48 
     49     def __reduce__(self):
     50         return type(self), (self.args[0], self.consumed)
     51 
     52 
     53 async def open_connection(host=None, port=None, *,
     54                           loop=None, limit=_DEFAULT_LIMIT, **kwds):
     55     """A wrapper for create_connection() returning a (reader, writer) pair.
     56 
     57     The reader returned is a StreamReader instance; the writer is a
     58     StreamWriter instance.
     59 
     60     The arguments are all the usual arguments to create_connection()
     61     except protocol_factory; most common are positional host and port,
     62     with various optional keyword arguments following.
     63 
     64     Additional optional keyword arguments are loop (to set the event loop
     65     instance to use) and limit (to set the buffer limit passed to the
     66     StreamReader).
     67 
     68     (If you want to customize the StreamReader and/or
     69     StreamReaderProtocol classes, just copy the code -- there's
     70     really nothing special here except some convenience.)
     71     """
     72     if loop is None:
     73         loop = events.get_event_loop()
     74     reader = StreamReader(limit=limit, loop=loop)
     75     protocol = StreamReaderProtocol(reader, loop=loop)
     76     transport, _ = await loop.create_connection(
     77         lambda: protocol, host, port, **kwds)
     78     writer = StreamWriter(transport, protocol, reader, loop)
     79     return reader, writer
     80 
     81 
     82 async def start_server(client_connected_cb, host=None, port=None, *,
     83                        loop=None, limit=_DEFAULT_LIMIT, **kwds):
     84     """Start a socket server, call back for each client connected.
     85 
     86     The first parameter, `client_connected_cb`, takes two parameters:
     87     client_reader, client_writer.  client_reader is a StreamReader
     88     object, while client_writer is a StreamWriter object.  This
     89     parameter can either be a plain callback function or a coroutine;
     90     if it is a coroutine, it will be automatically converted into a
     91     Task.
     92 
     93     The rest of the arguments are all the usual arguments to
     94     loop.create_server() except protocol_factory; most common are
     95     positional host and port, with various optional keyword arguments
     96     following.  The return value is the same as loop.create_server().
     97 
     98     Additional optional keyword arguments are loop (to set the event loop
     99     instance to use) and limit (to set the buffer limit passed to the
    100     StreamReader).
    101 
    102     The return value is the same as loop.create_server(), i.e. a
    103     Server object which can be used to stop the service.
    104     """
    105     if loop is None:
    106         loop = events.get_event_loop()
    107 
    108     def factory():
    109         reader = StreamReader(limit=limit, loop=loop)
    110         protocol = StreamReaderProtocol(reader, client_connected_cb,
    111                                         loop=loop)
    112         return protocol
    113 
    114     return await loop.create_server(factory, host, port, **kwds)
    115 
    116 
    117 if hasattr(socket, 'AF_UNIX'):
    118     # UNIX Domain Sockets are supported on this platform
    119 
    120     async def open_unix_connection(path=None, *,
    121                                    loop=None, limit=_DEFAULT_LIMIT, **kwds):
    122         """Similar to `open_connection` but works with UNIX Domain Sockets."""
    123         if loop is None:
    124             loop = events.get_event_loop()
    125         reader = StreamReader(limit=limit, loop=loop)
    126         protocol = StreamReaderProtocol(reader, loop=loop)
    127         transport, _ = await loop.create_unix_connection(
    128             lambda: protocol, path, **kwds)
    129         writer = StreamWriter(transport, protocol, reader, loop)
    130         return reader, writer
    131 
    132     async def start_unix_server(client_connected_cb, path=None, *,
    133                                 loop=None, limit=_DEFAULT_LIMIT, **kwds):
    134         """Similar to `start_server` but works with UNIX Domain Sockets."""
    135         if loop is None:
    136             loop = events.get_event_loop()
    137 
    138         def factory():
    139             reader = StreamReader(limit=limit, loop=loop)
    140             protocol = StreamReaderProtocol(reader, client_connected_cb,
    141                                             loop=loop)
    142             return protocol
    143 
    144         return await loop.create_unix_server(factory, path, **kwds)
    145 
    146 
    147 class FlowControlMixin(protocols.Protocol):
    148     """Reusable flow control logic for StreamWriter.drain().
    149 
    150     This implements the protocol methods pause_writing(),
    151     resume_writing() and connection_lost().  If the subclass overrides
    152     these it must call the super methods.
    153 
    154     StreamWriter.drain() must wait for _drain_helper() coroutine.
    155     """
    156 
    157     def __init__(self, loop=None):
    158         if loop is None:
    159             self._loop = events.get_event_loop()
    160         else:
    161             self._loop = loop
    162         self._paused = False
    163         self._drain_waiter = None
    164         self._connection_lost = False
    165 
    166     def pause_writing(self):
    167         assert not self._paused
    168         self._paused = True
    169         if self._loop.get_debug():
    170             logger.debug("%r pauses writing", self)
    171 
    172     def resume_writing(self):
    173         assert self._paused
    174         self._paused = False
    175         if self._loop.get_debug():
    176             logger.debug("%r resumes writing", self)
    177 
    178         waiter = self._drain_waiter
    179         if waiter is not None:
    180             self._drain_waiter = None
    181             if not waiter.done():
    182                 waiter.set_result(None)
    183 
    184     def connection_lost(self, exc):
    185         self._connection_lost = True
    186         # Wake up the writer if currently paused.
    187         if not self._paused:
    188             return
    189         waiter = self._drain_waiter
    190         if waiter is None:
    191             return
    192         self._drain_waiter = None
    193         if waiter.done():
    194             return
    195         if exc is None:
    196             waiter.set_result(None)
    197         else:
    198             waiter.set_exception(exc)
    199 
    200     async def _drain_helper(self):
    201         if self._connection_lost:
    202             raise ConnectionResetError('Connection lost')
    203         if not self._paused:
    204             return
    205         waiter = self._drain_waiter
    206         assert waiter is None or waiter.cancelled()
    207         waiter = self._loop.create_future()
    208         self._drain_waiter = waiter
    209         await waiter
    210 
    211 
    212 class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
    213     """Helper class to adapt between Protocol and StreamReader.
    214 
    215     (This is a helper class instead of making StreamReader itself a
    216     Protocol subclass, because the StreamReader has other potential
    217     uses, and to prevent the user of the StreamReader to accidentally
    218     call inappropriate methods of the protocol.)
    219     """
    220 
    221     def __init__(self, stream_reader, client_connected_cb=None, loop=None):
    222         super().__init__(loop=loop)
    223         self._stream_reader = stream_reader
    224         self._stream_writer = None
    225         self._client_connected_cb = client_connected_cb
    226         self._over_ssl = False
    227         self._closed = self._loop.create_future()
    228 
    229     def connection_made(self, transport):
    230         self._stream_reader.set_transport(transport)
    231         self._over_ssl = transport.get_extra_info('sslcontext') is not None
    232         if self._client_connected_cb is not None:
    233             self._stream_writer = StreamWriter(transport, self,
    234                                                self._stream_reader,
    235                                                self._loop)
    236             res = self._client_connected_cb(self._stream_reader,
    237                                             self._stream_writer)
    238             if coroutines.iscoroutine(res):
    239                 self._loop.create_task(res)
    240 
    241     def connection_lost(self, exc):
    242         if self._stream_reader is not None:
    243             if exc is None:
    244                 self._stream_reader.feed_eof()
    245             else:
    246                 self._stream_reader.set_exception(exc)
    247         if not self._closed.done():
    248             if exc is None:
    249                 self._closed.set_result(None)
    250             else:
    251                 self._closed.set_exception(exc)
    252         super().connection_lost(exc)
    253         self._stream_reader = None
    254         self._stream_writer = None
    255 
    256     def data_received(self, data):
    257         self._stream_reader.feed_data(data)
    258 
    259     def eof_received(self):
    260         self._stream_reader.feed_eof()
    261         if self._over_ssl:
    262             # Prevent a warning in SSLProtocol.eof_received:
    263             # "returning true from eof_received()
    264             # has no effect when using ssl"
    265             return False
    266         return True
    267 
    268     def __del__(self):
    269         # Prevent reports about unhandled exceptions.
    270         # Better than self._closed._log_traceback = False hack
    271         closed = self._closed
    272         if closed.done() and not closed.cancelled():
    273             closed.exception()
    274 
    275 
    276 class StreamWriter:
    277     """Wraps a Transport.
    278 
    279     This exposes write(), writelines(), [can_]write_eof(),
    280     get_extra_info() and close().  It adds drain() which returns an
    281     optional Future on which you can wait for flow control.  It also
    282     adds a transport property which references the Transport
    283     directly.
    284     """
    285 
    286     def __init__(self, transport, protocol, reader, loop):
    287         self._transport = transport
    288         self._protocol = protocol
    289         # drain() expects that the reader has an exception() method
    290         assert reader is None or isinstance(reader, StreamReader)
    291         self._reader = reader
    292         self._loop = loop
    293 
    294     def __repr__(self):
    295         info = [self.__class__.__name__, f'transport={self._transport!r}']
    296         if self._reader is not None:
    297             info.append(f'reader={self._reader!r}')
    298         return '<{}>'.format(' '.join(info))
    299 
    300     @property
    301     def transport(self):
    302         return self._transport
    303 
    304     def write(self, data):
    305         self._transport.write(data)
    306 
    307     def writelines(self, data):
    308         self._transport.writelines(data)
    309 
    310     def write_eof(self):
    311         return self._transport.write_eof()
    312 
    313     def can_write_eof(self):
    314         return self._transport.can_write_eof()
    315 
    316     def close(self):
    317         return self._transport.close()
    318 
    319     def is_closing(self):
    320         return self._transport.is_closing()
    321 
    322     async def wait_closed(self):
    323         await self._protocol._closed
    324 
    325     def get_extra_info(self, name, default=None):
    326         return self._transport.get_extra_info(name, default)
    327 
    328     async def drain(self):
    329         """Flush the write buffer.
    330 
    331         The intended use is to write
    332 
    333           w.write(data)
    334           await w.drain()
    335         """
    336         if self._reader is not None:
    337             exc = self._reader.exception()
    338             if exc is not None:
    339                 raise exc
    340         if self._transport.is_closing():
    341             # Yield to the event loop so connection_lost() may be
    342             # called.  Without this, _drain_helper() would return
    343             # immediately, and code that calls
    344             #     write(...); await drain()
    345             # in a loop would never call connection_lost(), so it
    346             # would not see an error when the socket is closed.
    347             await sleep(0, loop=self._loop)
    348         await self._protocol._drain_helper()
    349 
    350 
    351 class StreamReader:
    352 
    353     def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
    354         # The line length limit is  a security feature;
    355         # it also doubles as half the buffer limit.
    356 
    357         if limit <= 0:
    358             raise ValueError('Limit cannot be <= 0')
    359 
    360         self._limit = limit
    361         if loop is None:
    362             self._loop = events.get_event_loop()
    363         else:
    364             self._loop = loop
    365         self._buffer = bytearray()
    366         self._eof = False    # Whether we're done.
    367         self._waiter = None  # A future used by _wait_for_data()
    368         self._exception = None
    369         self._transport = None
    370         self._paused = False
    371 
    372     def __repr__(self):
    373         info = ['StreamReader']
    374         if self._buffer:
    375             info.append(f'{len(self._buffer)} bytes')
    376         if self._eof:
    377             info.append('eof')
    378         if self._limit != _DEFAULT_LIMIT:
    379             info.append(f'limit={self._limit}')
    380         if self._waiter:
    381             info.append(f'waiter={self._waiter!r}')
    382         if self._exception:
    383             info.append(f'exception={self._exception!r}')
    384         if self._transport:
    385             info.append(f'transport={self._transport!r}')
    386         if self._paused:
    387             info.append('paused')
    388         return '<{}>'.format(' '.join(info))
    389 
    390     def exception(self):
    391         return self._exception
    392 
    393     def set_exception(self, exc):
    394         self._exception = exc
    395 
    396         waiter = self._waiter
    397         if waiter is not None:
    398             self._waiter = None
    399             if not waiter.cancelled():
    400                 waiter.set_exception(exc)
    401 
    402     def _wakeup_waiter(self):
    403         """Wakeup read*() functions waiting for data or EOF."""
    404         waiter = self._waiter
    405         if waiter is not None:
    406             self._waiter = None
    407             if not waiter.cancelled():
    408                 waiter.set_result(None)
    409 
    410     def set_transport(self, transport):
    411         assert self._transport is None, 'Transport already set'
    412         self._transport = transport
    413 
    414     def _maybe_resume_transport(self):
    415         if self._paused and len(self._buffer) <= self._limit:
    416             self._paused = False
    417             self._transport.resume_reading()
    418 
    419     def feed_eof(self):
    420         self._eof = True
    421         self._wakeup_waiter()
    422 
    423     def at_eof(self):
    424         """Return True if the buffer is empty and 'feed_eof' was called."""
    425         return self._eof and not self._buffer
    426 
    427     def feed_data(self, data):
    428         assert not self._eof, 'feed_data after feed_eof'
    429 
    430         if not data:
    431             return
    432 
    433         self._buffer.extend(data)
    434         self._wakeup_waiter()
    435 
    436         if (self._transport is not None and
    437                 not self._paused and
    438                 len(self._buffer) > 2 * self._limit):
    439             try:
    440                 self._transport.pause_reading()
    441             except NotImplementedError:
    442                 # The transport can't be paused.
    443                 # We'll just have to buffer all data.
    444                 # Forget the transport so we don't keep trying.
    445                 self._transport = None
    446             else:
    447                 self._paused = True
    448 
    449     async def _wait_for_data(self, func_name):
    450         """Wait until feed_data() or feed_eof() is called.
    451 
    452         If stream was paused, automatically resume it.
    453         """
    454         # StreamReader uses a future to link the protocol feed_data() method
    455         # to a read coroutine. Running two read coroutines at the same time
    456         # would have an unexpected behaviour. It would not possible to know
    457         # which coroutine would get the next data.
    458         if self._waiter is not None:
    459             raise RuntimeError(
    460                 f'{func_name}() called while another coroutine is '
    461                 f'already waiting for incoming data')
    462 
    463         assert not self._eof, '_wait_for_data after EOF'
    464 
    465         # Waiting for data while paused will make deadlock, so prevent it.
    466         # This is essential for readexactly(n) for case when n > self._limit.
    467         if self._paused:
    468             self._paused = False
    469             self._transport.resume_reading()
    470 
    471         self._waiter = self._loop.create_future()
    472         try:
    473             await self._waiter
    474         finally:
    475             self._waiter = None
    476 
    477     async def readline(self):
    478         """Read chunk of data from the stream until newline (b'\n') is found.
    479 
    480         On success, return chunk that ends with newline. If only partial
    481         line can be read due to EOF, return incomplete line without
    482         terminating newline. When EOF was reached while no bytes read, empty
    483         bytes object is returned.
    484 
    485         If limit is reached, ValueError will be raised. In that case, if
    486         newline was found, complete line including newline will be removed
    487         from internal buffer. Else, internal buffer will be cleared. Limit is
    488         compared against part of the line without newline.
    489 
    490         If stream was paused, this function will automatically resume it if
    491         needed.
    492         """
    493         sep = b'\n'
    494         seplen = len(sep)
    495         try:
    496             line = await self.readuntil(sep)
    497         except IncompleteReadError as e:
    498             return e.partial
    499         except LimitOverrunError as e:
    500             if self._buffer.startswith(sep, e.consumed):
    501                 del self._buffer[:e.consumed + seplen]
    502             else:
    503                 self._buffer.clear()
    504             self._maybe_resume_transport()
    505             raise ValueError(e.args[0])
    506         return line
    507 
    508     async def readuntil(self, separator=b'\n'):
    509         """Read data from the stream until ``separator`` is found.
    510 
    511         On success, the data and separator will be removed from the
    512         internal buffer (consumed). Returned data will include the
    513         separator at the end.
    514 
    515         Configured stream limit is used to check result. Limit sets the
    516         maximal length of data that can be returned, not counting the
    517         separator.
    518 
    519         If an EOF occurs and the complete separator is still not found,
    520         an IncompleteReadError exception will be raised, and the internal
    521         buffer will be reset.  The IncompleteReadError.partial attribute
    522         may contain the separator partially.
    523 
    524         If the data cannot be read because of over limit, a
    525         LimitOverrunError exception  will be raised, and the data
    526         will be left in the internal buffer, so it can be read again.
    527         """
    528         seplen = len(separator)
    529         if seplen == 0:
    530             raise ValueError('Separator should be at least one-byte string')
    531 
    532         if self._exception is not None:
    533             raise self._exception
    534 
    535         # Consume whole buffer except last bytes, which length is
    536         # one less than seplen. Let's check corner cases with
    537         # separator='SEPARATOR':
    538         # * we have received almost complete separator (without last
    539         #   byte). i.e buffer='some textSEPARATO'. In this case we
    540         #   can safely consume len(separator) - 1 bytes.
    541         # * last byte of buffer is first byte of separator, i.e.
    542         #   buffer='abcdefghijklmnopqrS'. We may safely consume
    543         #   everything except that last byte, but this require to
    544         #   analyze bytes of buffer that match partial separator.
    545         #   This is slow and/or require FSM. For this case our
    546         #   implementation is not optimal, since require rescanning
    547         #   of data that is known to not belong to separator. In
    548         #   real world, separator will not be so long to notice
    549         #   performance problems. Even when reading MIME-encoded
    550         #   messages :)
    551 
    552         # `offset` is the number of bytes from the beginning of the buffer
    553         # where there is no occurrence of `separator`.
    554         offset = 0
    555 
    556         # Loop until we find `separator` in the buffer, exceed the buffer size,
    557         # or an EOF has happened.
    558         while True:
    559             buflen = len(self._buffer)
    560 
    561             # Check if we now have enough data in the buffer for `separator` to
    562             # fit.
    563             if buflen - offset >= seplen:
    564                 isep = self._buffer.find(separator, offset)
    565 
    566                 if isep != -1:
    567                     # `separator` is in the buffer. `isep` will be used later
    568                     # to retrieve the data.
    569                     break
    570 
    571                 # see upper comment for explanation.
    572                 offset = buflen + 1 - seplen
    573                 if offset > self._limit:
    574                     raise LimitOverrunError(
    575                         'Separator is not found, and chunk exceed the limit',
    576                         offset)
    577 
    578             # Complete message (with full separator) may be present in buffer
    579             # even when EOF flag is set. This may happen when the last chunk
    580             # adds data which makes separator be found. That's why we check for
    581             # EOF *ater* inspecting the buffer.
    582             if self._eof:
    583                 chunk = bytes(self._buffer)
    584                 self._buffer.clear()
    585                 raise IncompleteReadError(chunk, None)
    586 
    587             # _wait_for_data() will resume reading if stream was paused.
    588             await self._wait_for_data('readuntil')
    589 
    590         if isep > self._limit:
    591             raise LimitOverrunError(
    592                 'Separator is found, but chunk is longer than limit', isep)
    593 
    594         chunk = self._buffer[:isep + seplen]
    595         del self._buffer[:isep + seplen]
    596         self._maybe_resume_transport()
    597         return bytes(chunk)
    598 
    599     async def read(self, n=-1):
    600         """Read up to `n` bytes from the stream.
    601 
    602         If n is not provided, or set to -1, read until EOF and return all read
    603         bytes. If the EOF was received and the internal buffer is empty, return
    604         an empty bytes object.
    605 
    606         If n is zero, return empty bytes object immediately.
    607 
    608         If n is positive, this function try to read `n` bytes, and may return
    609         less or equal bytes than requested, but at least one byte. If EOF was
    610         received before any byte is read, this function returns empty byte
    611         object.
    612 
    613         Returned value is not limited with limit, configured at stream
    614         creation.
    615 
    616         If stream was paused, this function will automatically resume it if
    617         needed.
    618         """
    619 
    620         if self._exception is not None:
    621             raise self._exception
    622 
    623         if n == 0:
    624             return b''
    625 
    626         if n < 0:
    627             # This used to just loop creating a new waiter hoping to
    628             # collect everything in self._buffer, but that would
    629             # deadlock if the subprocess sends more than self.limit
    630             # bytes.  So just call self.read(self._limit) until EOF.
    631             blocks = []
    632             while True:
    633                 block = await self.read(self._limit)
    634                 if not block:
    635                     break
    636                 blocks.append(block)
    637             return b''.join(blocks)
    638 
    639         if not self._buffer and not self._eof:
    640             await self._wait_for_data('read')
    641 
    642         # This will work right even if buffer is less than n bytes
    643         data = bytes(self._buffer[:n])
    644         del self._buffer[:n]
    645 
    646         self._maybe_resume_transport()
    647         return data
    648 
    649     async def readexactly(self, n):
    650         """Read exactly `n` bytes.
    651 
    652         Raise an IncompleteReadError if EOF is reached before `n` bytes can be
    653         read. The IncompleteReadError.partial attribute of the exception will
    654         contain the partial read bytes.
    655 
    656         if n is zero, return empty bytes object.
    657 
    658         Returned value is not limited with limit, configured at stream
    659         creation.
    660 
    661         If stream was paused, this function will automatically resume it if
    662         needed.
    663         """
    664         if n < 0:
    665             raise ValueError('readexactly size can not be less than zero')
    666 
    667         if self._exception is not None:
    668             raise self._exception
    669 
    670         if n == 0:
    671             return b''
    672 
    673         while len(self._buffer) < n:
    674             if self._eof:
    675                 incomplete = bytes(self._buffer)
    676                 self._buffer.clear()
    677                 raise IncompleteReadError(incomplete, n)
    678 
    679             await self._wait_for_data('readexactly')
    680 
    681         if len(self._buffer) == n:
    682             data = bytes(self._buffer)
    683             self._buffer.clear()
    684         else:
    685             data = bytes(self._buffer[:n])
    686             del self._buffer[:n]
    687         self._maybe_resume_transport()
    688         return data
    689 
    690     def __aiter__(self):
    691         return self
    692 
    693     async def __anext__(self):
    694         val = await self.readline()
    695         if val == b'':
    696             raise StopAsyncIteration
    697         return val
    698