Home | History | Annotate | Download | only in asyncio
      1 """Stream-related things."""
      2 
      3 __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
      4            'open_connection', 'start_server',
      5            'IncompleteReadError',
      6            'LimitOverrunError',
      7            ]
      8 
      9 import socket
     10 
     11 if hasattr(socket, 'AF_UNIX'):
     12     __all__.extend(['open_unix_connection', 'start_unix_server'])
     13 
     14 from . import coroutines
     15 from . import compat
     16 from . import events
     17 from . import protocols
     18 from .coroutines import coroutine
     19 from .log import logger
     20 
     21 
     22 _DEFAULT_LIMIT = 2 ** 16
     23 
     24 
     25 class IncompleteReadError(EOFError):
     26     """
     27     Incomplete read error. Attributes:
     28 
     29     - partial: read bytes string before the end of stream was reached
     30     - expected: total number of expected bytes (or None if unknown)
     31     """
     32     def __init__(self, partial, expected):
     33         super().__init__("%d bytes read on a total of %r expected bytes"
     34                          % (len(partial), expected))
     35         self.partial = partial
     36         self.expected = expected
     37 
     38 
     39 class LimitOverrunError(Exception):
     40     """Reached the buffer limit while looking for a separator.
     41 
     42     Attributes:
     43     - consumed: total number of to be consumed bytes.
     44     """
     45     def __init__(self, message, consumed):
     46         super().__init__(message)
     47         self.consumed = consumed
     48 
     49 
     50 @coroutine
     51 def open_connection(host=None, port=None, *,
     52                     loop=None, limit=_DEFAULT_LIMIT, **kwds):
     53     """A wrapper for create_connection() returning a (reader, writer) pair.
     54 
     55     The reader returned is a StreamReader instance; the writer is a
     56     StreamWriter instance.
     57 
     58     The arguments are all the usual arguments to create_connection()
     59     except protocol_factory; most common are positional host and port,
     60     with various optional keyword arguments following.
     61 
     62     Additional optional keyword arguments are loop (to set the event loop
     63     instance to use) and limit (to set the buffer limit passed to the
     64     StreamReader).
     65 
     66     (If you want to customize the StreamReader and/or
     67     StreamReaderProtocol classes, just copy the code -- there's
     68     really nothing special here except some convenience.)
     69     """
     70     if loop is None:
     71         loop = events.get_event_loop()
     72     reader = StreamReader(limit=limit, loop=loop)
     73     protocol = StreamReaderProtocol(reader, loop=loop)
     74     transport, _ = yield from loop.create_connection(
     75         lambda: protocol, host, port, **kwds)
     76     writer = StreamWriter(transport, protocol, reader, loop)
     77     return reader, writer
     78 
     79 
     80 @coroutine
     81 def start_server(client_connected_cb, host=None, port=None, *,
     82                  loop=None, limit=_DEFAULT_LIMIT, **kwds):
     83     """Start a socket server, call back for each client connected.
     84 
     85     The first parameter, `client_connected_cb`, takes two parameters:
     86     client_reader, client_writer.  client_reader is a StreamReader
     87     object, while client_writer is a StreamWriter object.  This
     88     parameter can either be a plain callback function or a coroutine;
     89     if it is a coroutine, it will be automatically converted into a
     90     Task.
     91 
     92     The rest of the arguments are all the usual arguments to
     93     loop.create_server() except protocol_factory; most common are
     94     positional host and port, with various optional keyword arguments
     95     following.  The return value is the same as loop.create_server().
     96 
     97     Additional optional keyword arguments are loop (to set the event loop
     98     instance to use) and limit (to set the buffer limit passed to the
     99     StreamReader).
    100 
    101     The return value is the same as loop.create_server(), i.e. a
    102     Server object which can be used to stop the service.
    103     """
    104     if loop is None:
    105         loop = events.get_event_loop()
    106 
    107     def factory():
    108         reader = StreamReader(limit=limit, loop=loop)
    109         protocol = StreamReaderProtocol(reader, client_connected_cb,
    110                                         loop=loop)
    111         return protocol
    112 
    113     return (yield from loop.create_server(factory, host, port, **kwds))
    114 
    115 
    116 if hasattr(socket, 'AF_UNIX'):
    117     # UNIX Domain Sockets are supported on this platform
    118 
    119     @coroutine
    120     def open_unix_connection(path=None, *,
    121                              loop=None, limit=_DEFAULT_LIMIT, **kwds):
    122         """Similar to `open_connection` but works with UNIX Domain Sockets."""
    123         if loop is None:
    124             loop = events.get_event_loop()
    125         reader = StreamReader(limit=limit, loop=loop)
    126         protocol = StreamReaderProtocol(reader, loop=loop)
    127         transport, _ = yield from loop.create_unix_connection(
    128             lambda: protocol, path, **kwds)
    129         writer = StreamWriter(transport, protocol, reader, loop)
    130         return reader, writer
    131 
    132     @coroutine
    133     def start_unix_server(client_connected_cb, path=None, *,
    134                           loop=None, limit=_DEFAULT_LIMIT, **kwds):
    135         """Similar to `start_server` but works with UNIX Domain Sockets."""
    136         if loop is None:
    137             loop = events.get_event_loop()
    138 
    139         def factory():
    140             reader = StreamReader(limit=limit, loop=loop)
    141             protocol = StreamReaderProtocol(reader, client_connected_cb,
    142                                             loop=loop)
    143             return protocol
    144 
    145         return (yield from loop.create_unix_server(factory, path, **kwds))
    146 
    147 
    148 class FlowControlMixin(protocols.Protocol):
    149     """Reusable flow control logic for StreamWriter.drain().
    150 
    151     This implements the protocol methods pause_writing(),
    152     resume_reading() and connection_lost().  If the subclass overrides
    153     these it must call the super methods.
    154 
    155     StreamWriter.drain() must wait for _drain_helper() coroutine.
    156     """
    157 
    158     def __init__(self, loop=None):
    159         if loop is None:
    160             self._loop = events.get_event_loop()
    161         else:
    162             self._loop = loop
    163         self._paused = False
    164         self._drain_waiter = None
    165         self._connection_lost = False
    166 
    167     def pause_writing(self):
    168         assert not self._paused
    169         self._paused = True
    170         if self._loop.get_debug():
    171             logger.debug("%r pauses writing", self)
    172 
    173     def resume_writing(self):
    174         assert self._paused
    175         self._paused = False
    176         if self._loop.get_debug():
    177             logger.debug("%r resumes writing", self)
    178 
    179         waiter = self._drain_waiter
    180         if waiter is not None:
    181             self._drain_waiter = None
    182             if not waiter.done():
    183                 waiter.set_result(None)
    184 
    185     def connection_lost(self, exc):
    186         self._connection_lost = True
    187         # Wake up the writer if currently paused.
    188         if not self._paused:
    189             return
    190         waiter = self._drain_waiter
    191         if waiter is None:
    192             return
    193         self._drain_waiter = None
    194         if waiter.done():
    195             return
    196         if exc is None:
    197             waiter.set_result(None)
    198         else:
    199             waiter.set_exception(exc)
    200 
    201     @coroutine
    202     def _drain_helper(self):
    203         if self._connection_lost:
    204             raise ConnectionResetError('Connection lost')
    205         if not self._paused:
    206             return
    207         waiter = self._drain_waiter
    208         assert waiter is None or waiter.cancelled()
    209         waiter = self._loop.create_future()
    210         self._drain_waiter = waiter
    211         yield from waiter
    212 
    213 
    214 class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
    215     """Helper class to adapt between Protocol and StreamReader.
    216 
    217     (This is a helper class instead of making StreamReader itself a
    218     Protocol subclass, because the StreamReader has other potential
    219     uses, and to prevent the user of the StreamReader to accidentally
    220     call inappropriate methods of the protocol.)
    221     """
    222 
    223     def __init__(self, stream_reader, client_connected_cb=None, loop=None):
    224         super().__init__(loop=loop)
    225         self._stream_reader = stream_reader
    226         self._stream_writer = None
    227         self._client_connected_cb = client_connected_cb
    228         self._over_ssl = False
    229 
    230     def connection_made(self, transport):
    231         self._stream_reader.set_transport(transport)
    232         self._over_ssl = transport.get_extra_info('sslcontext') is not None
    233         if self._client_connected_cb is not None:
    234             self._stream_writer = StreamWriter(transport, self,
    235                                                self._stream_reader,
    236                                                self._loop)
    237             res = self._client_connected_cb(self._stream_reader,
    238                                             self._stream_writer)
    239             if coroutines.iscoroutine(res):
    240                 self._loop.create_task(res)
    241 
    242     def connection_lost(self, exc):
    243         if self._stream_reader is not None:
    244             if exc is None:
    245                 self._stream_reader.feed_eof()
    246             else:
    247                 self._stream_reader.set_exception(exc)
    248         super().connection_lost(exc)
    249         self._stream_reader = None
    250         self._stream_writer = None
    251 
    252     def data_received(self, data):
    253         self._stream_reader.feed_data(data)
    254 
    255     def eof_received(self):
    256         self._stream_reader.feed_eof()
    257         if self._over_ssl:
    258             # Prevent a warning in SSLProtocol.eof_received:
    259             # "returning true from eof_received()
    260             # has no effect when using ssl"
    261             return False
    262         return True
    263 
    264 
    265 class StreamWriter:
    266     """Wraps a Transport.
    267 
    268     This exposes write(), writelines(), [can_]write_eof(),
    269     get_extra_info() and close().  It adds drain() which returns an
    270     optional Future on which you can wait for flow control.  It also
    271     adds a transport property which references the Transport
    272     directly.
    273     """
    274 
    275     def __init__(self, transport, protocol, reader, loop):
    276         self._transport = transport
    277         self._protocol = protocol
    278         # drain() expects that the reader has an exception() method
    279         assert reader is None or isinstance(reader, StreamReader)
    280         self._reader = reader
    281         self._loop = loop
    282 
    283     def __repr__(self):
    284         info = [self.__class__.__name__, 'transport=%r' % self._transport]
    285         if self._reader is not None:
    286             info.append('reader=%r' % self._reader)
    287         return '<%s>' % ' '.join(info)
    288 
    289     @property
    290     def transport(self):
    291         return self._transport
    292 
    293     def write(self, data):
    294         self._transport.write(data)
    295 
    296     def writelines(self, data):
    297         self._transport.writelines(data)
    298 
    299     def write_eof(self):
    300         return self._transport.write_eof()
    301 
    302     def can_write_eof(self):
    303         return self._transport.can_write_eof()
    304 
    305     def close(self):
    306         return self._transport.close()
    307 
    308     def get_extra_info(self, name, default=None):
    309         return self._transport.get_extra_info(name, default)
    310 
    311     @coroutine
    312     def drain(self):
    313         """Flush the write buffer.
    314 
    315         The intended use is to write
    316 
    317           w.write(data)
    318           yield from w.drain()
    319         """
    320         if self._reader is not None:
    321             exc = self._reader.exception()
    322             if exc is not None:
    323                 raise exc
    324         if self._transport is not None:
    325             if self._transport.is_closing():
    326                 # Yield to the event loop so connection_lost() may be
    327                 # called.  Without this, _drain_helper() would return
    328                 # immediately, and code that calls
    329                 #     write(...); yield from drain()
    330                 # in a loop would never call connection_lost(), so it
    331                 # would not see an error when the socket is closed.
    332                 yield
    333         yield from self._protocol._drain_helper()
    334 
    335 
    336 class StreamReader:
    337 
    338     def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
    339         # The line length limit is  a security feature;
    340         # it also doubles as half the buffer limit.
    341 
    342         if limit <= 0:
    343             raise ValueError('Limit cannot be <= 0')
    344 
    345         self._limit = limit
    346         if loop is None:
    347             self._loop = events.get_event_loop()
    348         else:
    349             self._loop = loop
    350         self._buffer = bytearray()
    351         self._eof = False    # Whether we're done.
    352         self._waiter = None  # A future used by _wait_for_data()
    353         self._exception = None
    354         self._transport = None
    355         self._paused = False
    356 
    357     def __repr__(self):
    358         info = ['StreamReader']
    359         if self._buffer:
    360             info.append('%d bytes' % len(self._buffer))
    361         if self._eof:
    362             info.append('eof')
    363         if self._limit != _DEFAULT_LIMIT:
    364             info.append('l=%d' % self._limit)
    365         if self._waiter:
    366             info.append('w=%r' % self._waiter)
    367         if self._exception:
    368             info.append('e=%r' % self._exception)
    369         if self._transport:
    370             info.append('t=%r' % self._transport)
    371         if self._paused:
    372             info.append('paused')
    373         return '<%s>' % ' '.join(info)
    374 
    375     def exception(self):
    376         return self._exception
    377 
    378     def set_exception(self, exc):
    379         self._exception = exc
    380 
    381         waiter = self._waiter
    382         if waiter is not None:
    383             self._waiter = None
    384             if not waiter.cancelled():
    385                 waiter.set_exception(exc)
    386 
    387     def _wakeup_waiter(self):
    388         """Wakeup read*() functions waiting for data or EOF."""
    389         waiter = self._waiter
    390         if waiter is not None:
    391             self._waiter = None
    392             if not waiter.cancelled():
    393                 waiter.set_result(None)
    394 
    395     def set_transport(self, transport):
    396         assert self._transport is None, 'Transport already set'
    397         self._transport = transport
    398 
    399     def _maybe_resume_transport(self):
    400         if self._paused and len(self._buffer) <= self._limit:
    401             self._paused = False
    402             self._transport.resume_reading()
    403 
    404     def feed_eof(self):
    405         self._eof = True
    406         self._wakeup_waiter()
    407 
    408     def at_eof(self):
    409         """Return True if the buffer is empty and 'feed_eof' was called."""
    410         return self._eof and not self._buffer
    411 
    412     def feed_data(self, data):
    413         assert not self._eof, 'feed_data after feed_eof'
    414 
    415         if not data:
    416             return
    417 
    418         self._buffer.extend(data)
    419         self._wakeup_waiter()
    420 
    421         if (self._transport is not None and
    422                 not self._paused and
    423                 len(self._buffer) > 2 * self._limit):
    424             try:
    425                 self._transport.pause_reading()
    426             except NotImplementedError:
    427                 # The transport can't be paused.
    428                 # We'll just have to buffer all data.
    429                 # Forget the transport so we don't keep trying.
    430                 self._transport = None
    431             else:
    432                 self._paused = True
    433 
    434     @coroutine
    435     def _wait_for_data(self, func_name):
    436         """Wait until feed_data() or feed_eof() is called.
    437 
    438         If stream was paused, automatically resume it.
    439         """
    440         # StreamReader uses a future to link the protocol feed_data() method
    441         # to a read coroutine. Running two read coroutines at the same time
    442         # would have an unexpected behaviour. It would not possible to know
    443         # which coroutine would get the next data.
    444         if self._waiter is not None:
    445             raise RuntimeError('%s() called while another coroutine is '
    446                                'already waiting for incoming data' % func_name)
    447 
    448         assert not self._eof, '_wait_for_data after EOF'
    449 
    450         # Waiting for data while paused will make deadlock, so prevent it.
    451         # This is essential for readexactly(n) for case when n > self._limit.
    452         if self._paused:
    453             self._paused = False
    454             self._transport.resume_reading()
    455 
    456         self._waiter = self._loop.create_future()
    457         try:
    458             yield from self._waiter
    459         finally:
    460             self._waiter = None
    461 
    462     @coroutine
    463     def readline(self):
    464         """Read chunk of data from the stream until newline (b'\n') is found.
    465 
    466         On success, return chunk that ends with newline. If only partial
    467         line can be read due to EOF, return incomplete line without
    468         terminating newline. When EOF was reached while no bytes read, empty
    469         bytes object is returned.
    470 
    471         If limit is reached, ValueError will be raised. In that case, if
    472         newline was found, complete line including newline will be removed
    473         from internal buffer. Else, internal buffer will be cleared. Limit is
    474         compared against part of the line without newline.
    475 
    476         If stream was paused, this function will automatically resume it if
    477         needed.
    478         """
    479         sep = b'\n'
    480         seplen = len(sep)
    481         try:
    482             line = yield from self.readuntil(sep)
    483         except IncompleteReadError as e:
    484             return e.partial
    485         except LimitOverrunError as e:
    486             if self._buffer.startswith(sep, e.consumed):
    487                 del self._buffer[:e.consumed + seplen]
    488             else:
    489                 self._buffer.clear()
    490             self._maybe_resume_transport()
    491             raise ValueError(e.args[0])
    492         return line
    493 
    494     @coroutine
    495     def readuntil(self, separator=b'\n'):
    496         """Read data from the stream until ``separator`` is found.
    497 
    498         On success, the data and separator will be removed from the
    499         internal buffer (consumed). Returned data will include the
    500         separator at the end.
    501 
    502         Configured stream limit is used to check result. Limit sets the
    503         maximal length of data that can be returned, not counting the
    504         separator.
    505 
    506         If an EOF occurs and the complete separator is still not found,
    507         an IncompleteReadError exception will be raised, and the internal
    508         buffer will be reset.  The IncompleteReadError.partial attribute
    509         may contain the separator partially.
    510 
    511         If the data cannot be read because of over limit, a
    512         LimitOverrunError exception  will be raised, and the data
    513         will be left in the internal buffer, so it can be read again.
    514         """
    515         seplen = len(separator)
    516         if seplen == 0:
    517             raise ValueError('Separator should be at least one-byte string')
    518 
    519         if self._exception is not None:
    520             raise self._exception
    521 
    522         # Consume whole buffer except last bytes, which length is
    523         # one less than seplen. Let's check corner cases with
    524         # separator='SEPARATOR':
    525         # * we have received almost complete separator (without last
    526         #   byte). i.e buffer='some textSEPARATO'. In this case we
    527         #   can safely consume len(separator) - 1 bytes.
    528         # * last byte of buffer is first byte of separator, i.e.
    529         #   buffer='abcdefghijklmnopqrS'. We may safely consume
    530         #   everything except that last byte, but this require to
    531         #   analyze bytes of buffer that match partial separator.
    532         #   This is slow and/or require FSM. For this case our
    533         #   implementation is not optimal, since require rescanning
    534         #   of data that is known to not belong to separator. In
    535         #   real world, separator will not be so long to notice
    536         #   performance problems. Even when reading MIME-encoded
    537         #   messages :)
    538 
    539         # `offset` is the number of bytes from the beginning of the buffer
    540         # where there is no occurrence of `separator`.
    541         offset = 0
    542 
    543         # Loop until we find `separator` in the buffer, exceed the buffer size,
    544         # or an EOF has happened.
    545         while True:
    546             buflen = len(self._buffer)
    547 
    548             # Check if we now have enough data in the buffer for `separator` to
    549             # fit.
    550             if buflen - offset >= seplen:
    551                 isep = self._buffer.find(separator, offset)
    552 
    553                 if isep != -1:
    554                     # `separator` is in the buffer. `isep` will be used later
    555                     # to retrieve the data.
    556                     break
    557 
    558                 # see upper comment for explanation.
    559                 offset = buflen + 1 - seplen
    560                 if offset > self._limit:
    561                     raise LimitOverrunError(
    562                         'Separator is not found, and chunk exceed the limit',
    563                         offset)
    564 
    565             # Complete message (with full separator) may be present in buffer
    566             # even when EOF flag is set. This may happen when the last chunk
    567             # adds data which makes separator be found. That's why we check for
    568             # EOF *ater* inspecting the buffer.
    569             if self._eof:
    570                 chunk = bytes(self._buffer)
    571                 self._buffer.clear()
    572                 raise IncompleteReadError(chunk, None)
    573 
    574             # _wait_for_data() will resume reading if stream was paused.
    575             yield from self._wait_for_data('readuntil')
    576 
    577         if isep > self._limit:
    578             raise LimitOverrunError(
    579                 'Separator is found, but chunk is longer than limit', isep)
    580 
    581         chunk = self._buffer[:isep + seplen]
    582         del self._buffer[:isep + seplen]
    583         self._maybe_resume_transport()
    584         return bytes(chunk)
    585 
    586     @coroutine
    587     def read(self, n=-1):
    588         """Read up to `n` bytes from the stream.
    589 
    590         If n is not provided, or set to -1, read until EOF and return all read
    591         bytes. If the EOF was received and the internal buffer is empty, return
    592         an empty bytes object.
    593 
    594         If n is zero, return empty bytes object immediately.
    595 
    596         If n is positive, this function try to read `n` bytes, and may return
    597         less or equal bytes than requested, but at least one byte. If EOF was
    598         received before any byte is read, this function returns empty byte
    599         object.
    600 
    601         Returned value is not limited with limit, configured at stream
    602         creation.
    603 
    604         If stream was paused, this function will automatically resume it if
    605         needed.
    606         """
    607 
    608         if self._exception is not None:
    609             raise self._exception
    610 
    611         if n == 0:
    612             return b''
    613 
    614         if n < 0:
    615             # This used to just loop creating a new waiter hoping to
    616             # collect everything in self._buffer, but that would
    617             # deadlock if the subprocess sends more than self.limit
    618             # bytes.  So just call self.read(self._limit) until EOF.
    619             blocks = []
    620             while True:
    621                 block = yield from self.read(self._limit)
    622                 if not block:
    623                     break
    624                 blocks.append(block)
    625             return b''.join(blocks)
    626 
    627         if not self._buffer and not self._eof:
    628             yield from self._wait_for_data('read')
    629 
    630         # This will work right even if buffer is less than n bytes
    631         data = bytes(self._buffer[:n])
    632         del self._buffer[:n]
    633 
    634         self._maybe_resume_transport()
    635         return data
    636 
    637     @coroutine
    638     def readexactly(self, n):
    639         """Read exactly `n` bytes.
    640 
    641         Raise an IncompleteReadError if EOF is reached before `n` bytes can be
    642         read. The IncompleteReadError.partial attribute of the exception will
    643         contain the partial read bytes.
    644 
    645         if n is zero, return empty bytes object.
    646 
    647         Returned value is not limited with limit, configured at stream
    648         creation.
    649 
    650         If stream was paused, this function will automatically resume it if
    651         needed.
    652         """
    653         if n < 0:
    654             raise ValueError('readexactly size can not be less than zero')
    655 
    656         if self._exception is not None:
    657             raise self._exception
    658 
    659         if n == 0:
    660             return b''
    661 
    662         while len(self._buffer) < n:
    663             if self._eof:
    664                 incomplete = bytes(self._buffer)
    665                 self._buffer.clear()
    666                 raise IncompleteReadError(incomplete, n)
    667 
    668             yield from self._wait_for_data('readexactly')
    669 
    670         if len(self._buffer) == n:
    671             data = bytes(self._buffer)
    672             self._buffer.clear()
    673         else:
    674             data = bytes(self._buffer[:n])
    675             del self._buffer[:n]
    676         self._maybe_resume_transport()
    677         return data
    678 
    679     if compat.PY35:
    680         @coroutine
    681         def __aiter__(self):
    682             return self
    683 
    684         @coroutine
    685         def __anext__(self):
    686             val = yield from self.readline()
    687             if val == b'':
    688                 raise StopAsyncIteration
    689             return val
    690 
    691     if compat.PY352:
    692         # In Python 3.5.2 and greater, __aiter__ should return
    693         # the asynchronous iterator directly.
    694         def __aiter__(self):
    695             return self
    696