Home | History | Annotate | Download | only in test
      1 """
      2 Test suite for SocketServer.py.
      3 """
      4 
      5 import contextlib
      6 import imp
      7 import os
      8 import select
      9 import signal
     10 import socket
     11 import select
     12 import errno
     13 import tempfile
     14 import unittest
     15 import SocketServer
     16 
     17 import test.test_support
     18 from test.test_support import reap_children, reap_threads, verbose
     19 try:
     20     import threading
     21 except ImportError:
     22     threading = None
     23 
     24 test.test_support.requires("network")
     25 
     26 TEST_STR = "hello world\n"
     27 HOST = test.test_support.HOST
     28 
     29 HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
     30 requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
     31                                             'requires Unix sockets')
     32 HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
     33 requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
     34 
     35 def signal_alarm(n):
     36     """Call signal.alarm when it exists (i.e. not on Windows)."""
     37     if hasattr(signal, 'alarm'):
     38         signal.alarm(n)
     39 
     40 # Remember real select() to avoid interferences with mocking
     41 _real_select = select.select
     42 
     43 def receive(sock, n, timeout=20):
     44     r, w, x = _real_select([sock], [], [], timeout)
     45     if sock in r:
     46         return sock.recv(n)
     47     else:
     48         raise RuntimeError, "timed out on %r" % (sock,)
     49 
     50 if HAVE_UNIX_SOCKETS:
     51     class ForkingUnixStreamServer(SocketServer.ForkingMixIn,
     52                                   SocketServer.UnixStreamServer):
     53         pass
     54 
     55     class ForkingUnixDatagramServer(SocketServer.ForkingMixIn,
     56                                     SocketServer.UnixDatagramServer):
     57         pass
     58 
     59 
     60 @contextlib.contextmanager
     61 def simple_subprocess(testcase):
     62     pid = os.fork()
     63     if pid == 0:
     64         # Don't raise an exception; it would be caught by the test harness.
     65         os._exit(72)
     66     yield None
     67     pid2, status = os.waitpid(pid, 0)
     68     testcase.assertEqual(pid2, pid)
     69     testcase.assertEqual(72 << 8, status)
     70 
     71 
     72 @unittest.skipUnless(threading, 'Threading required for this test.')
     73 class SocketServerTest(unittest.TestCase):
     74     """Test all socket servers."""
     75 
     76     def setUp(self):
     77         signal_alarm(60)  # Kill deadlocks after 60 seconds.
     78         self.port_seed = 0
     79         self.test_files = []
     80 
     81     def tearDown(self):
     82         signal_alarm(0)  # Didn't deadlock.
     83         reap_children()
     84 
     85         for fn in self.test_files:
     86             try:
     87                 os.remove(fn)
     88             except os.error:
     89                 pass
     90         self.test_files[:] = []
     91 
     92     def pickaddr(self, proto):
     93         if proto == socket.AF_INET:
     94             return (HOST, 0)
     95         else:
     96             # XXX: We need a way to tell AF_UNIX to pick its own name
     97             # like AF_INET provides port==0.
     98             dir = None
     99             if os.name == 'os2':
    100                 dir = '\socket'
    101             fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
    102             if os.name == 'os2':
    103                 # AF_UNIX socket names on OS/2 require a specific prefix
    104                 # which can't include a drive letter and must also use
    105                 # backslashes as directory separators
    106                 if fn[1] == ':':
    107                     fn = fn[2:]
    108                 if fn[0] in (os.sep, os.altsep):
    109                     fn = fn[1:]
    110                 if os.sep == '/':
    111                     fn = fn.replace(os.sep, os.altsep)
    112                 else:
    113                     fn = fn.replace(os.altsep, os.sep)
    114             self.test_files.append(fn)
    115             return fn
    116 
    117     def make_server(self, addr, svrcls, hdlrbase):
    118         class MyServer(svrcls):
    119             def handle_error(self, request, client_address):
    120                 self.close_request(request)
    121                 self.server_close()
    122                 raise
    123 
    124         class MyHandler(hdlrbase):
    125             def handle(self):
    126                 line = self.rfile.readline()
    127                 self.wfile.write(line)
    128 
    129         if verbose: print "creating server"
    130         server = MyServer(addr, MyHandler)
    131         self.assertEqual(server.server_address, server.socket.getsockname())
    132         return server
    133 
    134     @reap_threads
    135     def run_server(self, svrcls, hdlrbase, testfunc):
    136         server = self.make_server(self.pickaddr(svrcls.address_family),
    137                                   svrcls, hdlrbase)
    138         # We had the OS pick a port, so pull the real address out of
    139         # the server.
    140         addr = server.server_address
    141         if verbose:
    142             print "server created"
    143             print "ADDR =", addr
    144             print "CLASS =", svrcls
    145         t = threading.Thread(
    146             name='%s serving' % svrcls,
    147             target=server.serve_forever,
    148             # Short poll interval to make the test finish quickly.
    149             # Time between requests is short enough that we won't wake
    150             # up spuriously too many times.
    151             kwargs={'poll_interval':0.01})
    152         t.daemon = True  # In case this function raises.
    153         t.start()
    154         if verbose: print "server running"
    155         for i in range(3):
    156             if verbose: print "test client", i
    157             testfunc(svrcls.address_family, addr)
    158         if verbose: print "waiting for server"
    159         server.shutdown()
    160         t.join()
    161         server.server_close()
    162         self.assertRaises(socket.error, server.socket.fileno)
    163         if verbose: print "done"
    164 
    165     def stream_examine(self, proto, addr):
    166         s = socket.socket(proto, socket.SOCK_STREAM)
    167         s.connect(addr)
    168         s.sendall(TEST_STR)
    169         buf = data = receive(s, 100)
    170         while data and '\n' not in buf:
    171             data = receive(s, 100)
    172             buf += data
    173         self.assertEqual(buf, TEST_STR)
    174         s.close()
    175 
    176     def dgram_examine(self, proto, addr):
    177         s = socket.socket(proto, socket.SOCK_DGRAM)
    178         if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
    179             s.bind(self.pickaddr(proto))
    180         s.sendto(TEST_STR, addr)
    181         buf = data = receive(s, 100)
    182         while data and '\n' not in buf:
    183             data = receive(s, 100)
    184             buf += data
    185         self.assertEqual(buf, TEST_STR)
    186         s.close()
    187 
    188     def test_TCPServer(self):
    189         self.run_server(SocketServer.TCPServer,
    190                         SocketServer.StreamRequestHandler,
    191                         self.stream_examine)
    192 
    193     def test_ThreadingTCPServer(self):
    194         self.run_server(SocketServer.ThreadingTCPServer,
    195                         SocketServer.StreamRequestHandler,
    196                         self.stream_examine)
    197 
    198     @requires_forking
    199     def test_ForkingTCPServer(self):
    200         with simple_subprocess(self):
    201             self.run_server(SocketServer.ForkingTCPServer,
    202                             SocketServer.StreamRequestHandler,
    203                             self.stream_examine)
    204 
    205     @requires_unix_sockets
    206     def test_UnixStreamServer(self):
    207         self.run_server(SocketServer.UnixStreamServer,
    208                         SocketServer.StreamRequestHandler,
    209                         self.stream_examine)
    210 
    211     @requires_unix_sockets
    212     def test_ThreadingUnixStreamServer(self):
    213         self.run_server(SocketServer.ThreadingUnixStreamServer,
    214                         SocketServer.StreamRequestHandler,
    215                         self.stream_examine)
    216 
    217     @requires_unix_sockets
    218     @requires_forking
    219     def test_ForkingUnixStreamServer(self):
    220         with simple_subprocess(self):
    221             self.run_server(ForkingUnixStreamServer,
    222                             SocketServer.StreamRequestHandler,
    223                             self.stream_examine)
    224 
    225     def test_UDPServer(self):
    226         self.run_server(SocketServer.UDPServer,
    227                         SocketServer.DatagramRequestHandler,
    228                         self.dgram_examine)
    229 
    230     def test_ThreadingUDPServer(self):
    231         self.run_server(SocketServer.ThreadingUDPServer,
    232                         SocketServer.DatagramRequestHandler,
    233                         self.dgram_examine)
    234 
    235     @requires_forking
    236     def test_ForkingUDPServer(self):
    237         with simple_subprocess(self):
    238             self.run_server(SocketServer.ForkingUDPServer,
    239                             SocketServer.DatagramRequestHandler,
    240                             self.dgram_examine)
    241 
    242     @contextlib.contextmanager
    243     def mocked_select_module(self):
    244         """Mocks the select.select() call to raise EINTR for first call"""
    245         old_select = select.select
    246 
    247         class MockSelect:
    248             def __init__(self):
    249                 self.called = 0
    250 
    251             def __call__(self, *args):
    252                 self.called += 1
    253                 if self.called == 1:
    254                     # raise the exception on first call
    255                     raise select.error(errno.EINTR, os.strerror(errno.EINTR))
    256                 else:
    257                     # Return real select value for consecutive calls
    258                     return old_select(*args)
    259 
    260         select.select = MockSelect()
    261         try:
    262             yield select.select
    263         finally:
    264             select.select = old_select
    265 
    266     def test_InterruptServerSelectCall(self):
    267         with self.mocked_select_module() as mock_select:
    268             pid = self.run_server(SocketServer.TCPServer,
    269                                   SocketServer.StreamRequestHandler,
    270                                   self.stream_examine)
    271             # Make sure select was called again:
    272             self.assertGreater(mock_select.called, 1)
    273 
    274     @requires_unix_sockets
    275     def test_UnixDatagramServer(self):
    276         self.run_server(SocketServer.UnixDatagramServer,
    277                         SocketServer.DatagramRequestHandler,
    278                         self.dgram_examine)
    279 
    280     @requires_unix_sockets
    281     def test_ThreadingUnixDatagramServer(self):
    282         self.run_server(SocketServer.ThreadingUnixDatagramServer,
    283                         SocketServer.DatagramRequestHandler,
    284                         self.dgram_examine)
    285 
    286     @requires_unix_sockets
    287     @requires_forking
    288     def test_ForkingUnixDatagramServer(self):
    289         self.run_server(ForkingUnixDatagramServer,
    290                         SocketServer.DatagramRequestHandler,
    291                         self.dgram_examine)
    292 
    293     @reap_threads
    294     def test_shutdown(self):
    295         # Issue #2302: shutdown() should always succeed in making an
    296         # other thread leave serve_forever().
    297         class MyServer(SocketServer.TCPServer):
    298             pass
    299 
    300         class MyHandler(SocketServer.StreamRequestHandler):
    301             pass
    302 
    303         threads = []
    304         for i in range(20):
    305             s = MyServer((HOST, 0), MyHandler)
    306             t = threading.Thread(
    307                 name='MyServer serving',
    308                 target=s.serve_forever,
    309                 kwargs={'poll_interval':0.01})
    310             t.daemon = True  # In case this function raises.
    311             threads.append((t, s))
    312         for t, s in threads:
    313             t.start()
    314             s.shutdown()
    315         for t, s in threads:
    316             t.join()
    317 
    318     def test_tcpserver_bind_leak(self):
    319         # Issue #22435: the server socket wouldn't be closed if bind()/listen()
    320         # failed.
    321         # Create many servers for which bind() will fail, to see if this result
    322         # in FD exhaustion.
    323         for i in range(1024):
    324             with self.assertRaises(OverflowError):
    325                 SocketServer.TCPServer((HOST, -1),
    326                                        SocketServer.StreamRequestHandler)
    327 
    328 
    329 class MiscTestCase(unittest.TestCase):
    330 
    331     def test_shutdown_request_called_if_verify_request_false(self):
    332         # Issue #26309: BaseServer should call shutdown_request even if
    333         # verify_request is False
    334 
    335         class MyServer(SocketServer.TCPServer):
    336             def verify_request(self, request, client_address):
    337                 return False
    338 
    339             shutdown_called = 0
    340             def shutdown_request(self, request):
    341                 self.shutdown_called += 1
    342                 SocketServer.TCPServer.shutdown_request(self, request)
    343 
    344         server = MyServer((HOST, 0), SocketServer.StreamRequestHandler)
    345         s = socket.socket(server.address_family, socket.SOCK_STREAM)
    346         s.connect(server.server_address)
    347         s.close()
    348         server.handle_request()
    349         self.assertEqual(server.shutdown_called, 1)
    350         server.server_close()
    351 
    352 
    353 def test_main():
    354     if imp.lock_held():
    355         # If the import lock is held, the threads will hang
    356         raise unittest.SkipTest("can't run when import lock is held")
    357 
    358     test.test_support.run_unittest(SocketServerTest)
    359 
    360 if __name__ == "__main__":
    361     test_main()
    362