Home | History | Annotate | Download | only in test_asyncio
      1 """Utilities shared by tests."""
      2 
      3 import collections
      4 import contextlib
      5 import io
      6 import logging
      7 import os
      8 import re
      9 import selectors
     10 import socket
     11 import socketserver
     12 import sys
     13 import tempfile
     14 import threading
     15 import time
     16 import unittest
     17 import weakref
     18 
     19 from unittest import mock
     20 
     21 from http.server import HTTPServer
     22 from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
     23 
     24 try:
     25     import ssl
     26 except ImportError:  # pragma: no cover
     27     ssl = None
     28 
     29 from asyncio import base_events
     30 from asyncio import events
     31 from asyncio import format_helpers
     32 from asyncio import futures
     33 from asyncio import tasks
     34 from asyncio.log import logger
     35 from test import support
     36 
     37 
     38 def data_file(filename):
     39     if hasattr(support, 'TEST_HOME_DIR'):
     40         fullname = os.path.join(support.TEST_HOME_DIR, filename)
     41         if os.path.isfile(fullname):
     42             return fullname
     43     fullname = os.path.join(os.path.dirname(__file__), '..', filename)
     44     if os.path.isfile(fullname):
     45         return fullname
     46     raise FileNotFoundError(filename)
     47 
     48 
     49 ONLYCERT = data_file('ssl_cert.pem')
     50 ONLYKEY = data_file('ssl_key.pem')
     51 SIGNED_CERTFILE = data_file('keycert3.pem')
     52 SIGNING_CA = data_file('pycacert.pem')
     53 PEERCERT = {
     54     'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
     55     'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
     56     'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
     57     'issuer': ((('countryName', 'XY'),),
     58             (('organizationName', 'Python Software Foundation CA'),),
     59             (('commonName', 'our-ca-server'),)),
     60     'notAfter': 'Jul  7 14:23:16 2028 GMT',
     61     'notBefore': 'Aug 29 14:23:16 2018 GMT',
     62     'serialNumber': 'CB2D80995A69525C',
     63     'subject': ((('countryName', 'XY'),),
     64              (('localityName', 'Castle Anthrax'),),
     65              (('organizationName', 'Python Software Foundation'),),
     66              (('commonName', 'localhost'),)),
     67     'subjectAltName': (('DNS', 'localhost'),),
     68     'version': 3
     69 }
     70 
     71 
     72 def simple_server_sslcontext():
     73     server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
     74     server_context.load_cert_chain(ONLYCERT, ONLYKEY)
     75     server_context.check_hostname = False
     76     server_context.verify_mode = ssl.CERT_NONE
     77     return server_context
     78 
     79 
     80 def simple_client_sslcontext(*, disable_verify=True):
     81     client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
     82     client_context.check_hostname = False
     83     if disable_verify:
     84         client_context.verify_mode = ssl.CERT_NONE
     85     return client_context
     86 
     87 
     88 def dummy_ssl_context():
     89     if ssl is None:
     90         return None
     91     else:
     92         return ssl.SSLContext(ssl.PROTOCOL_TLS)
     93 
     94 
     95 def run_briefly(loop):
     96     async def once():
     97         pass
     98     gen = once()
     99     t = loop.create_task(gen)
    100     # Don't log a warning if the task is not done after run_until_complete().
    101     # It occurs if the loop is stopped or if a task raises a BaseException.
    102     t._log_destroy_pending = False
    103     try:
    104         loop.run_until_complete(t)
    105     finally:
    106         gen.close()
    107 
    108 
    109 def run_until(loop, pred, timeout=30):
    110     deadline = time.monotonic() + timeout
    111     while not pred():
    112         if timeout is not None:
    113             timeout = deadline - time.monotonic()
    114             if timeout <= 0:
    115                 raise futures.TimeoutError()
    116         loop.run_until_complete(tasks.sleep(0.001, loop=loop))
    117 
    118 
    119 def run_once(loop):
    120     """Legacy API to run once through the event loop.
    121 
    122     This is the recommended pattern for test code.  It will poll the
    123     selector once and run all callbacks scheduled in response to I/O
    124     events.
    125     """
    126     loop.call_soon(loop.stop)
    127     loop.run_forever()
    128 
    129 
    130 class SilentWSGIRequestHandler(WSGIRequestHandler):
    131 
    132     def get_stderr(self):
    133         return io.StringIO()
    134 
    135     def log_message(self, format, *args):
    136         pass
    137 
    138 
    139 class SilentWSGIServer(WSGIServer):
    140 
    141     request_timeout = 2
    142 
    143     def get_request(self):
    144         request, client_addr = super().get_request()
    145         request.settimeout(self.request_timeout)
    146         return request, client_addr
    147 
    148     def handle_error(self, request, client_address):
    149         pass
    150 
    151 
    152 class SSLWSGIServerMixin:
    153 
    154     def finish_request(self, request, client_address):
    155         # The relative location of our test directory (which
    156         # contains the ssl key and certificate files) differs
    157         # between the stdlib and stand-alone asyncio.
    158         # Prefer our own if we can find it.
    159         context = ssl.SSLContext()
    160         context.load_cert_chain(ONLYCERT, ONLYKEY)
    161 
    162         ssock = context.wrap_socket(request, server_side=True)
    163         try:
    164             self.RequestHandlerClass(ssock, client_address, self)
    165             ssock.close()
    166         except OSError:
    167             # maybe socket has been closed by peer
    168             pass
    169 
    170 
    171 class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
    172     pass
    173 
    174 
    175 def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
    176 
    177     def app(environ, start_response):
    178         status = '200 OK'
    179         headers = [('Content-type', 'text/plain')]
    180         start_response(status, headers)
    181         return [b'Test message']
    182 
    183     # Run the test WSGI server in a separate thread in order not to
    184     # interfere with event handling in the main thread
    185     server_class = server_ssl_cls if use_ssl else server_cls
    186     httpd = server_class(address, SilentWSGIRequestHandler)
    187     httpd.set_app(app)
    188     httpd.address = httpd.server_address
    189     server_thread = threading.Thread(
    190         target=lambda: httpd.serve_forever(poll_interval=0.05))
    191     server_thread.start()
    192     try:
    193         yield httpd
    194     finally:
    195         httpd.shutdown()
    196         httpd.server_close()
    197         server_thread.join()
    198 
    199 
    200 if hasattr(socket, 'AF_UNIX'):
    201 
    202     class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
    203 
    204         def server_bind(self):
    205             socketserver.UnixStreamServer.server_bind(self)
    206             self.server_name = '127.0.0.1'
    207             self.server_port = 80
    208 
    209 
    210     class UnixWSGIServer(UnixHTTPServer, WSGIServer):
    211 
    212         request_timeout = 2
    213 
    214         def server_bind(self):
    215             UnixHTTPServer.server_bind(self)
    216             self.setup_environ()
    217 
    218         def get_request(self):
    219             request, client_addr = super().get_request()
    220             request.settimeout(self.request_timeout)
    221             # Code in the stdlib expects that get_request
    222             # will return a socket and a tuple (host, port).
    223             # However, this isn't true for UNIX sockets,
    224             # as the second return value will be a path;
    225             # hence we return some fake data sufficient
    226             # to get the tests going
    227             return request, ('127.0.0.1', '')
    228 
    229 
    230     class SilentUnixWSGIServer(UnixWSGIServer):
    231 
    232         def handle_error(self, request, client_address):
    233             pass
    234 
    235 
    236     class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
    237         pass
    238 
    239 
    240     def gen_unix_socket_path():
    241         with tempfile.NamedTemporaryFile() as file:
    242             return file.name
    243 
    244 
    245     @contextlib.contextmanager
    246     def unix_socket_path():
    247         path = gen_unix_socket_path()
    248         try:
    249             yield path
    250         finally:
    251             try:
    252                 os.unlink(path)
    253             except OSError:
    254                 pass
    255 
    256 
    257     @contextlib.contextmanager
    258     def run_test_unix_server(*, use_ssl=False):
    259         with unix_socket_path() as path:
    260             yield from _run_test_server(address=path, use_ssl=use_ssl,
    261                                         server_cls=SilentUnixWSGIServer,
    262                                         server_ssl_cls=UnixSSLWSGIServer)
    263 
    264 
    265 @contextlib.contextmanager
    266 def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
    267     yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
    268                                 server_cls=SilentWSGIServer,
    269                                 server_ssl_cls=SSLWSGIServer)
    270 
    271 
    272 def make_test_protocol(base):
    273     dct = {}
    274     for name in dir(base):
    275         if name.startswith('__') and name.endswith('__'):
    276             # skip magic names
    277             continue
    278         dct[name] = MockCallback(return_value=None)
    279     return type('TestProtocol', (base,) + base.__bases__, dct)()
    280 
    281 
    282 class TestSelector(selectors.BaseSelector):
    283 
    284     def __init__(self):
    285         self.keys = {}
    286 
    287     def register(self, fileobj, events, data=None):
    288         key = selectors.SelectorKey(fileobj, 0, events, data)
    289         self.keys[fileobj] = key
    290         return key
    291 
    292     def unregister(self, fileobj):
    293         return self.keys.pop(fileobj)
    294 
    295     def select(self, timeout):
    296         return []
    297 
    298     def get_map(self):
    299         return self.keys
    300 
    301 
    302 class TestLoop(base_events.BaseEventLoop):
    303     """Loop for unittests.
    304 
    305     It manages self time directly.
    306     If something scheduled to be executed later then
    307     on next loop iteration after all ready handlers done
    308     generator passed to __init__ is calling.
    309 
    310     Generator should be like this:
    311 
    312         def gen():
    313             ...
    314             when = yield ...
    315             ... = yield time_advance
    316 
    317     Value returned by yield is absolute time of next scheduled handler.
    318     Value passed to yield is time advance to move loop's time forward.
    319     """
    320 
    321     def __init__(self, gen=None):
    322         super().__init__()
    323 
    324         if gen is None:
    325             def gen():
    326                 yield
    327             self._check_on_close = False
    328         else:
    329             self._check_on_close = True
    330 
    331         self._gen = gen()
    332         next(self._gen)
    333         self._time = 0
    334         self._clock_resolution = 1e-9
    335         self._timers = []
    336         self._selector = TestSelector()
    337 
    338         self.readers = {}
    339         self.writers = {}
    340         self.reset_counters()
    341 
    342         self._transports = weakref.WeakValueDictionary()
    343 
    344     def time(self):
    345         return self._time
    346 
    347     def advance_time(self, advance):
    348         """Move test time forward."""
    349         if advance:
    350             self._time += advance
    351 
    352     def close(self):
    353         super().close()
    354         if self._check_on_close:
    355             try:
    356                 self._gen.send(0)
    357             except StopIteration:
    358                 pass
    359             else:  # pragma: no cover
    360                 raise AssertionError("Time generator is not finished")
    361 
    362     def _add_reader(self, fd, callback, *args):
    363         self.readers[fd] = events.Handle(callback, args, self, None)
    364 
    365     def _remove_reader(self, fd):
    366         self.remove_reader_count[fd] += 1
    367         if fd in self.readers:
    368             del self.readers[fd]
    369             return True
    370         else:
    371             return False
    372 
    373     def assert_reader(self, fd, callback, *args):
    374         if fd not in self.readers:
    375             raise AssertionError(f'fd {fd} is not registered')
    376         handle = self.readers[fd]
    377         if handle._callback != callback:
    378             raise AssertionError(
    379                 f'unexpected callback: {handle._callback} != {callback}')
    380         if handle._args != args:
    381             raise AssertionError(
    382                 f'unexpected callback args: {handle._args} != {args}')
    383 
    384     def assert_no_reader(self, fd):
    385         if fd in self.readers:
    386             raise AssertionError(f'fd {fd} is registered')
    387 
    388     def _add_writer(self, fd, callback, *args):
    389         self.writers[fd] = events.Handle(callback, args, self, None)
    390 
    391     def _remove_writer(self, fd):
    392         self.remove_writer_count[fd] += 1
    393         if fd in self.writers:
    394             del self.writers[fd]
    395             return True
    396         else:
    397             return False
    398 
    399     def assert_writer(self, fd, callback, *args):
    400         assert fd in self.writers, 'fd {} is not registered'.format(fd)
    401         handle = self.writers[fd]
    402         assert handle._callback == callback, '{!r} != {!r}'.format(
    403             handle._callback, callback)
    404         assert handle._args == args, '{!r} != {!r}'.format(
    405             handle._args, args)
    406 
    407     def _ensure_fd_no_transport(self, fd):
    408         if not isinstance(fd, int):
    409             try:
    410                 fd = int(fd.fileno())
    411             except (AttributeError, TypeError, ValueError):
    412                 # This code matches selectors._fileobj_to_fd function.
    413                 raise ValueError("Invalid file object: "
    414                                  "{!r}".format(fd)) from None
    415         try:
    416             transport = self._transports[fd]
    417         except KeyError:
    418             pass
    419         else:
    420             raise RuntimeError(
    421                 'File descriptor {!r} is used by transport {!r}'.format(
    422                     fd, transport))
    423 
    424     def add_reader(self, fd, callback, *args):
    425         """Add a reader callback."""
    426         self._ensure_fd_no_transport(fd)
    427         return self._add_reader(fd, callback, *args)
    428 
    429     def remove_reader(self, fd):
    430         """Remove a reader callback."""
    431         self._ensure_fd_no_transport(fd)
    432         return self._remove_reader(fd)
    433 
    434     def add_writer(self, fd, callback, *args):
    435         """Add a writer callback.."""
    436         self._ensure_fd_no_transport(fd)
    437         return self._add_writer(fd, callback, *args)
    438 
    439     def remove_writer(self, fd):
    440         """Remove a writer callback."""
    441         self._ensure_fd_no_transport(fd)
    442         return self._remove_writer(fd)
    443 
    444     def reset_counters(self):
    445         self.remove_reader_count = collections.defaultdict(int)
    446         self.remove_writer_count = collections.defaultdict(int)
    447 
    448     def _run_once(self):
    449         super()._run_once()
    450         for when in self._timers:
    451             advance = self._gen.send(when)
    452             self.advance_time(advance)
    453         self._timers = []
    454 
    455     def call_at(self, when, callback, *args, context=None):
    456         self._timers.append(when)
    457         return super().call_at(when, callback, *args, context=context)
    458 
    459     def _process_events(self, event_list):
    460         return
    461 
    462     def _write_to_self(self):
    463         pass
    464 
    465 
    466 def MockCallback(**kwargs):
    467     return mock.Mock(spec=['__call__'], **kwargs)
    468 
    469 
    470 class MockPattern(str):
    471     """A regex based str with a fuzzy __eq__.
    472 
    473     Use this helper with 'mock.assert_called_with', or anywhere
    474     where a regex comparison between strings is needed.
    475 
    476     For instance:
    477        mock_call.assert_called_with(MockPattern('spam.*ham'))
    478     """
    479     def __eq__(self, other):
    480         return bool(re.search(str(self), other, re.S))
    481 
    482 
    483 class MockInstanceOf:
    484     def __init__(self, type):
    485         self._type = type
    486 
    487     def __eq__(self, other):
    488         return isinstance(other, self._type)
    489 
    490 
    491 def get_function_source(func):
    492     source = format_helpers._get_function_source(func)
    493     if source is None:
    494         raise ValueError("unable to get the source of %r" % (func,))
    495     return source
    496 
    497 
    498 class TestCase(unittest.TestCase):
    499     @staticmethod
    500     def close_loop(loop):
    501         executor = loop._default_executor
    502         if executor is not None:
    503             executor.shutdown(wait=True)
    504         loop.close()
    505 
    506     def set_event_loop(self, loop, *, cleanup=True):
    507         assert loop is not None
    508         # ensure that the event loop is passed explicitly in asyncio
    509         events.set_event_loop(None)
    510         if cleanup:
    511             self.addCleanup(self.close_loop, loop)
    512 
    513     def new_test_loop(self, gen=None):
    514         loop = TestLoop(gen)
    515         self.set_event_loop(loop)
    516         return loop
    517 
    518     def unpatch_get_running_loop(self):
    519         events._get_running_loop = self._get_running_loop
    520 
    521     def setUp(self):
    522         self._get_running_loop = events._get_running_loop
    523         events._get_running_loop = lambda: None
    524         self._thread_cleanup = support.threading_setup()
    525 
    526     def tearDown(self):
    527         self.unpatch_get_running_loop()
    528 
    529         events.set_event_loop(None)
    530 
    531         # Detect CPython bug #23353: ensure that yield/yield-from is not used
    532         # in an except block of a generator
    533         self.assertEqual(sys.exc_info(), (None, None, None))
    534 
    535         self.doCleanups()
    536         support.threading_cleanup(*self._thread_cleanup)
    537         support.reap_children()
    538 
    539 
    540 @contextlib.contextmanager
    541 def disable_logger():
    542     """Context manager to disable asyncio logger.
    543 
    544     For example, it can be used to ignore warnings in debug mode.
    545     """
    546     old_level = logger.level
    547     try:
    548         logger.setLevel(logging.CRITICAL+1)
    549         yield
    550     finally:
    551         logger.setLevel(old_level)
    552 
    553 
    554 def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
    555                             family=socket.AF_INET):
    556     """Create a mock of a non-blocking socket."""
    557     sock = mock.MagicMock(socket.socket)
    558     sock.proto = proto
    559     sock.type = type
    560     sock.family = family
    561     sock.gettimeout.return_value = 0.0
    562     return sock
    563