Home | History | Annotate | Download | only in test
      1 import errno
      2 import os
      3 import random
      4 import selectors
      5 import signal
      6 import socket
      7 import sys
      8 from test import support
      9 from time import sleep
     10 import unittest
     11 import unittest.mock
     12 import tempfile
     13 from time import monotonic as time
     14 try:
     15     import resource
     16 except ImportError:
     17     resource = None
     18 
     19 
     20 if hasattr(socket, 'socketpair'):
     21     socketpair = socket.socketpair
     22 else:
     23     def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
     24         with socket.socket(family, type, proto) as l:
     25             l.bind((support.HOST, 0))
     26             l.listen()
     27             c = socket.socket(family, type, proto)
     28             try:
     29                 c.connect(l.getsockname())
     30                 caddr = c.getsockname()
     31                 while True:
     32                     a, addr = l.accept()
     33                     # check that we've got the correct client
     34                     if addr == caddr:
     35                         return c, a
     36                     a.close()
     37             except OSError:
     38                 c.close()
     39                 raise
     40 
     41 
     42 def find_ready_matching(ready, flag):
     43     match = []
     44     for key, events in ready:
     45         if events & flag:
     46             match.append(key.fileobj)
     47     return match
     48 
     49 
     50 class BaseSelectorTestCase(unittest.TestCase):
     51 
     52     def make_socketpair(self):
     53         rd, wr = socketpair()
     54         self.addCleanup(rd.close)
     55         self.addCleanup(wr.close)
     56         return rd, wr
     57 
     58     def test_register(self):
     59         s = self.SELECTOR()
     60         self.addCleanup(s.close)
     61 
     62         rd, wr = self.make_socketpair()
     63 
     64         key = s.register(rd, selectors.EVENT_READ, "data")
     65         self.assertIsInstance(key, selectors.SelectorKey)
     66         self.assertEqual(key.fileobj, rd)
     67         self.assertEqual(key.fd, rd.fileno())
     68         self.assertEqual(key.events, selectors.EVENT_READ)
     69         self.assertEqual(key.data, "data")
     70 
     71         # register an unknown event
     72         self.assertRaises(ValueError, s.register, 0, 999999)
     73 
     74         # register an invalid FD
     75         self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ)
     76 
     77         # register twice
     78         self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ)
     79 
     80         # register the same FD, but with a different object
     81         self.assertRaises(KeyError, s.register, rd.fileno(),
     82                           selectors.EVENT_READ)
     83 
     84     def test_unregister(self):
     85         s = self.SELECTOR()
     86         self.addCleanup(s.close)
     87 
     88         rd, wr = self.make_socketpair()
     89 
     90         s.register(rd, selectors.EVENT_READ)
     91         s.unregister(rd)
     92 
     93         # unregister an unknown file obj
     94         self.assertRaises(KeyError, s.unregister, 999999)
     95 
     96         # unregister twice
     97         self.assertRaises(KeyError, s.unregister, rd)
     98 
     99     def test_unregister_after_fd_close(self):
    100         s = self.SELECTOR()
    101         self.addCleanup(s.close)
    102         rd, wr = self.make_socketpair()
    103         r, w = rd.fileno(), wr.fileno()
    104         s.register(r, selectors.EVENT_READ)
    105         s.register(w, selectors.EVENT_WRITE)
    106         rd.close()
    107         wr.close()
    108         s.unregister(r)
    109         s.unregister(w)
    110 
    111     @unittest.skipUnless(os.name == 'posix', "requires posix")
    112     def test_unregister_after_fd_close_and_reuse(self):
    113         s = self.SELECTOR()
    114         self.addCleanup(s.close)
    115         rd, wr = self.make_socketpair()
    116         r, w = rd.fileno(), wr.fileno()
    117         s.register(r, selectors.EVENT_READ)
    118         s.register(w, selectors.EVENT_WRITE)
    119         rd2, wr2 = self.make_socketpair()
    120         rd.close()
    121         wr.close()
    122         os.dup2(rd2.fileno(), r)
    123         os.dup2(wr2.fileno(), w)
    124         self.addCleanup(os.close, r)
    125         self.addCleanup(os.close, w)
    126         s.unregister(r)
    127         s.unregister(w)
    128 
    129     def test_unregister_after_socket_close(self):
    130         s = self.SELECTOR()
    131         self.addCleanup(s.close)
    132         rd, wr = self.make_socketpair()
    133         s.register(rd, selectors.EVENT_READ)
    134         s.register(wr, selectors.EVENT_WRITE)
    135         rd.close()
    136         wr.close()
    137         s.unregister(rd)
    138         s.unregister(wr)
    139 
    140     def test_modify(self):
    141         s = self.SELECTOR()
    142         self.addCleanup(s.close)
    143 
    144         rd, wr = self.make_socketpair()
    145 
    146         key = s.register(rd, selectors.EVENT_READ)
    147 
    148         # modify events
    149         key2 = s.modify(rd, selectors.EVENT_WRITE)
    150         self.assertNotEqual(key.events, key2.events)
    151         self.assertEqual(key2, s.get_key(rd))
    152 
    153         s.unregister(rd)
    154 
    155         # modify data
    156         d1 = object()
    157         d2 = object()
    158 
    159         key = s.register(rd, selectors.EVENT_READ, d1)
    160         key2 = s.modify(rd, selectors.EVENT_READ, d2)
    161         self.assertEqual(key.events, key2.events)
    162         self.assertNotEqual(key.data, key2.data)
    163         self.assertEqual(key2, s.get_key(rd))
    164         self.assertEqual(key2.data, d2)
    165 
    166         # modify unknown file obj
    167         self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ)
    168 
    169         # modify use a shortcut
    170         d3 = object()
    171         s.register = unittest.mock.Mock()
    172         s.unregister = unittest.mock.Mock()
    173 
    174         s.modify(rd, selectors.EVENT_READ, d3)
    175         self.assertFalse(s.register.called)
    176         self.assertFalse(s.unregister.called)
    177 
    178     def test_close(self):
    179         s = self.SELECTOR()
    180         self.addCleanup(s.close)
    181 
    182         mapping = s.get_map()
    183         rd, wr = self.make_socketpair()
    184 
    185         s.register(rd, selectors.EVENT_READ)
    186         s.register(wr, selectors.EVENT_WRITE)
    187 
    188         s.close()
    189         self.assertRaises(RuntimeError, s.get_key, rd)
    190         self.assertRaises(RuntimeError, s.get_key, wr)
    191         self.assertRaises(KeyError, mapping.__getitem__, rd)
    192         self.assertRaises(KeyError, mapping.__getitem__, wr)
    193 
    194     def test_get_key(self):
    195         s = self.SELECTOR()
    196         self.addCleanup(s.close)
    197 
    198         rd, wr = self.make_socketpair()
    199 
    200         key = s.register(rd, selectors.EVENT_READ, "data")
    201         self.assertEqual(key, s.get_key(rd))
    202 
    203         # unknown file obj
    204         self.assertRaises(KeyError, s.get_key, 999999)
    205 
    206     def test_get_map(self):
    207         s = self.SELECTOR()
    208         self.addCleanup(s.close)
    209 
    210         rd, wr = self.make_socketpair()
    211 
    212         keys = s.get_map()
    213         self.assertFalse(keys)
    214         self.assertEqual(len(keys), 0)
    215         self.assertEqual(list(keys), [])
    216         key = s.register(rd, selectors.EVENT_READ, "data")
    217         self.assertIn(rd, keys)
    218         self.assertEqual(key, keys[rd])
    219         self.assertEqual(len(keys), 1)
    220         self.assertEqual(list(keys), [rd.fileno()])
    221         self.assertEqual(list(keys.values()), [key])
    222 
    223         # unknown file obj
    224         with self.assertRaises(KeyError):
    225             keys[999999]
    226 
    227         # Read-only mapping
    228         with self.assertRaises(TypeError):
    229             del keys[rd]
    230 
    231     def test_select(self):
    232         s = self.SELECTOR()
    233         self.addCleanup(s.close)
    234 
    235         rd, wr = self.make_socketpair()
    236 
    237         s.register(rd, selectors.EVENT_READ)
    238         wr_key = s.register(wr, selectors.EVENT_WRITE)
    239 
    240         result = s.select()
    241         for key, events in result:
    242             self.assertTrue(isinstance(key, selectors.SelectorKey))
    243             self.assertTrue(events)
    244             self.assertFalse(events & ~(selectors.EVENT_READ |
    245                                         selectors.EVENT_WRITE))
    246 
    247         self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result)
    248 
    249     def test_context_manager(self):
    250         s = self.SELECTOR()
    251         self.addCleanup(s.close)
    252 
    253         rd, wr = self.make_socketpair()
    254 
    255         with s as sel:
    256             sel.register(rd, selectors.EVENT_READ)
    257             sel.register(wr, selectors.EVENT_WRITE)
    258 
    259         self.assertRaises(RuntimeError, s.get_key, rd)
    260         self.assertRaises(RuntimeError, s.get_key, wr)
    261 
    262     def test_fileno(self):
    263         s = self.SELECTOR()
    264         self.addCleanup(s.close)
    265 
    266         if hasattr(s, 'fileno'):
    267             fd = s.fileno()
    268             self.assertTrue(isinstance(fd, int))
    269             self.assertGreaterEqual(fd, 0)
    270 
    271     def test_selector(self):
    272         s = self.SELECTOR()
    273         self.addCleanup(s.close)
    274 
    275         NUM_SOCKETS = 12
    276         MSG = b" This is a test."
    277         MSG_LEN = len(MSG)
    278         readers = []
    279         writers = []
    280         r2w = {}
    281         w2r = {}
    282 
    283         for i in range(NUM_SOCKETS):
    284             rd, wr = self.make_socketpair()
    285             s.register(rd, selectors.EVENT_READ)
    286             s.register(wr, selectors.EVENT_WRITE)
    287             readers.append(rd)
    288             writers.append(wr)
    289             r2w[rd] = wr
    290             w2r[wr] = rd
    291 
    292         bufs = []
    293 
    294         while writers:
    295             ready = s.select()
    296             ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE)
    297             if not ready_writers:
    298                 self.fail("no sockets ready for writing")
    299             wr = random.choice(ready_writers)
    300             wr.send(MSG)
    301 
    302             for i in range(10):
    303                 ready = s.select()
    304                 ready_readers = find_ready_matching(ready,
    305                                                     selectors.EVENT_READ)
    306                 if ready_readers:
    307                     break
    308                 # there might be a delay between the write to the write end and
    309                 # the read end is reported ready
    310                 sleep(0.1)
    311             else:
    312                 self.fail("no sockets ready for reading")
    313             self.assertEqual([w2r[wr]], ready_readers)
    314             rd = ready_readers[0]
    315             buf = rd.recv(MSG_LEN)
    316             self.assertEqual(len(buf), MSG_LEN)
    317             bufs.append(buf)
    318             s.unregister(r2w[rd])
    319             s.unregister(rd)
    320             writers.remove(r2w[rd])
    321 
    322         self.assertEqual(bufs, [MSG] * NUM_SOCKETS)
    323 
    324     @unittest.skipIf(sys.platform == 'win32',
    325                      'select.select() cannot be used with empty fd sets')
    326     def test_empty_select(self):
    327         # Issue #23009: Make sure EpollSelector.select() works when no FD is
    328         # registered.
    329         s = self.SELECTOR()
    330         self.addCleanup(s.close)
    331         self.assertEqual(s.select(timeout=0), [])
    332 
    333     def test_timeout(self):
    334         s = self.SELECTOR()
    335         self.addCleanup(s.close)
    336 
    337         rd, wr = self.make_socketpair()
    338 
    339         s.register(wr, selectors.EVENT_WRITE)
    340         t = time()
    341         self.assertEqual(1, len(s.select(0)))
    342         self.assertEqual(1, len(s.select(-1)))
    343         self.assertLess(time() - t, 0.5)
    344 
    345         s.unregister(wr)
    346         s.register(rd, selectors.EVENT_READ)
    347         t = time()
    348         self.assertFalse(s.select(0))
    349         self.assertFalse(s.select(-1))
    350         self.assertLess(time() - t, 0.5)
    351 
    352         t0 = time()
    353         self.assertFalse(s.select(1))
    354         t1 = time()
    355         dt = t1 - t0
    356         # Tolerate 2.0 seconds for very slow buildbots
    357         self.assertTrue(0.8 <= dt <= 2.0, dt)
    358 
    359     @unittest.skipUnless(hasattr(signal, "alarm"),
    360                          "signal.alarm() required for this test")
    361     def test_select_interrupt_exc(self):
    362         s = self.SELECTOR()
    363         self.addCleanup(s.close)
    364 
    365         rd, wr = self.make_socketpair()
    366 
    367         class InterruptSelect(Exception):
    368             pass
    369 
    370         def handler(*args):
    371             raise InterruptSelect
    372 
    373         orig_alrm_handler = signal.signal(signal.SIGALRM, handler)
    374         self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
    375         self.addCleanup(signal.alarm, 0)
    376 
    377         signal.alarm(1)
    378 
    379         s.register(rd, selectors.EVENT_READ)
    380         t = time()
    381         # select() is interrupted by a signal which raises an exception
    382         with self.assertRaises(InterruptSelect):
    383             s.select(30)
    384         # select() was interrupted before the timeout of 30 seconds
    385         self.assertLess(time() - t, 5.0)
    386 
    387     @unittest.skipUnless(hasattr(signal, "alarm"),
    388                          "signal.alarm() required for this test")
    389     def test_select_interrupt_noraise(self):
    390         s = self.SELECTOR()
    391         self.addCleanup(s.close)
    392 
    393         rd, wr = self.make_socketpair()
    394 
    395         orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None)
    396         self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
    397         self.addCleanup(signal.alarm, 0)
    398 
    399         signal.alarm(1)
    400 
    401         s.register(rd, selectors.EVENT_READ)
    402         t = time()
    403         # select() is interrupted by a signal, but the signal handler doesn't
    404         # raise an exception, so select() should by retries with a recomputed
    405         # timeout
    406         self.assertFalse(s.select(1.5))
    407         self.assertGreaterEqual(time() - t, 1.0)
    408 
    409 
    410 class ScalableSelectorMixIn:
    411 
    412     # see issue #18963 for why it's skipped on older OS X versions
    413     @support.requires_mac_ver(10, 5)
    414     @unittest.skipUnless(resource, "Test needs resource module")
    415     def test_above_fd_setsize(self):
    416         # A scalable implementation should have no problem with more than
    417         # FD_SETSIZE file descriptors. Since we don't know the value, we just
    418         # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling.
    419         soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
    420         try:
    421             resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
    422             self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE,
    423                             (soft, hard))
    424             NUM_FDS = min(hard, 2**16)
    425         except (OSError, ValueError):
    426             NUM_FDS = soft
    427 
    428         # guard for already allocated FDs (stdin, stdout...)
    429         NUM_FDS -= 32
    430 
    431         s = self.SELECTOR()
    432         self.addCleanup(s.close)
    433 
    434         for i in range(NUM_FDS // 2):
    435             try:
    436                 rd, wr = self.make_socketpair()
    437             except OSError:
    438                 # too many FDs, skip - note that we should only catch EMFILE
    439                 # here, but apparently *BSD and Solaris can fail upon connect()
    440                 # or bind() with EADDRNOTAVAIL, so let's be safe
    441                 self.skipTest("FD limit reached")
    442 
    443             try:
    444                 s.register(rd, selectors.EVENT_READ)
    445                 s.register(wr, selectors.EVENT_WRITE)
    446             except OSError as e:
    447                 if e.errno == errno.ENOSPC:
    448                     # this can be raised by epoll if we go over
    449                     # fs.epoll.max_user_watches sysctl
    450                     self.skipTest("FD limit reached")
    451                 raise
    452 
    453         self.assertEqual(NUM_FDS // 2, len(s.select()))
    454 
    455 
    456 class DefaultSelectorTestCase(BaseSelectorTestCase):
    457 
    458     SELECTOR = selectors.DefaultSelector
    459 
    460 
    461 class SelectSelectorTestCase(BaseSelectorTestCase):
    462 
    463     SELECTOR = selectors.SelectSelector
    464 
    465 
    466 @unittest.skipUnless(hasattr(selectors, 'PollSelector'),
    467                      "Test needs selectors.PollSelector")
    468 class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
    469 
    470     SELECTOR = getattr(selectors, 'PollSelector', None)
    471 
    472 
    473 @unittest.skipUnless(hasattr(selectors, 'EpollSelector'),
    474                      "Test needs selectors.EpollSelector")
    475 class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
    476 
    477     SELECTOR = getattr(selectors, 'EpollSelector', None)
    478 
    479     def test_register_file(self):
    480         # epoll(7) returns EPERM when given a file to watch
    481         s = self.SELECTOR()
    482         with tempfile.NamedTemporaryFile() as f:
    483             with self.assertRaises(IOError):
    484                 s.register(f, selectors.EVENT_READ)
    485             # the SelectorKey has been removed
    486             with self.assertRaises(KeyError):
    487                 s.get_key(f)
    488 
    489 
    490 @unittest.skipUnless(hasattr(selectors, 'KqueueSelector'),
    491                      "Test needs selectors.KqueueSelector)")
    492 class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
    493 
    494     SELECTOR = getattr(selectors, 'KqueueSelector', None)
    495 
    496     def test_register_bad_fd(self):
    497         # a file descriptor that's been closed should raise an OSError
    498         # with EBADF
    499         s = self.SELECTOR()
    500         bad_f = support.make_bad_fd()
    501         with self.assertRaises(OSError) as cm:
    502             s.register(bad_f, selectors.EVENT_READ)
    503         self.assertEqual(cm.exception.errno, errno.EBADF)
    504         # the SelectorKey has been removed
    505         with self.assertRaises(KeyError):
    506             s.get_key(bad_f)
    507 
    508 
    509 @unittest.skipUnless(hasattr(selectors, 'DevpollSelector'),
    510                      "Test needs selectors.DevpollSelector")
    511 class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
    512 
    513     SELECTOR = getattr(selectors, 'DevpollSelector', None)
    514 
    515 
    516 
    517 def test_main():
    518     tests = [DefaultSelectorTestCase, SelectSelectorTestCase,
    519              PollSelectorTestCase, EpollSelectorTestCase,
    520              KqueueSelectorTestCase, DevpollSelectorTestCase]
    521     support.run_unittest(*tests)
    522     support.reap_children()
    523 
    524 
    525 if __name__ == "__main__":
    526     test_main()
    527