Home | History | Annotate | Download | only in test
      1 """
      2 Test suite for SocketServer.py.
      3 """
      4 
      5 import contextlib
      6 import imp
      7 import os
      8 import select
      9 import signal
     10 import socket
     11 import select
     12 import errno
     13 import tempfile
     14 import unittest
     15 import SocketServer
     16 
     17 import test.test_support
     18 from test.test_support import reap_children, reap_threads, verbose
     19 try:
     20     import threading
     21 except ImportError:
     22     threading = None
     23 
     24 test.test_support.requires("network")
     25 
     26 TEST_STR = "hello world\n"
     27 HOST = test.test_support.HOST
     28 
     29 HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
     30 HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
     31 
     32 def signal_alarm(n):
     33     """Call signal.alarm when it exists (i.e. not on Windows)."""
     34     if hasattr(signal, 'alarm'):
     35         signal.alarm(n)
     36 
     37 # Remember real select() to avoid interferences with mocking
     38 _real_select = select.select
     39 
     40 def receive(sock, n, timeout=20):
     41     r, w, x = _real_select([sock], [], [], timeout)
     42     if sock in r:
     43         return sock.recv(n)
     44     else:
     45         raise RuntimeError, "timed out on %r" % (sock,)
     46 
     47 if HAVE_UNIX_SOCKETS:
     48     class ForkingUnixStreamServer(SocketServer.ForkingMixIn,
     49                                   SocketServer.UnixStreamServer):
     50         pass
     51 
     52     class ForkingUnixDatagramServer(SocketServer.ForkingMixIn,
     53                                     SocketServer.UnixDatagramServer):
     54         pass
     55 
     56 
     57 @contextlib.contextmanager
     58 def simple_subprocess(testcase):
     59     pid = os.fork()
     60     if pid == 0:
     61         # Don't raise an exception; it would be caught by the test harness.
     62         os._exit(72)
     63     yield None
     64     pid2, status = os.waitpid(pid, 0)
     65     testcase.assertEqual(pid2, pid)
     66     testcase.assertEqual(72 << 8, status)
     67 
     68 
     69 @unittest.skipUnless(threading, 'Threading required for this test.')
     70 class SocketServerTest(unittest.TestCase):
     71     """Test all socket servers."""
     72 
     73     def setUp(self):
     74         signal_alarm(60)  # Kill deadlocks after 60 seconds.
     75         self.port_seed = 0
     76         self.test_files = []
     77 
     78     def tearDown(self):
     79         signal_alarm(0)  # Didn't deadlock.
     80         reap_children()
     81 
     82         for fn in self.test_files:
     83             try:
     84                 os.remove(fn)
     85             except os.error:
     86                 pass
     87         self.test_files[:] = []
     88 
     89     def pickaddr(self, proto):
     90         if proto == socket.AF_INET:
     91             return (HOST, 0)
     92         else:
     93             # XXX: We need a way to tell AF_UNIX to pick its own name
     94             # like AF_INET provides port==0.
     95             dir = None
     96             if os.name == 'os2':
     97                 dir = '\socket'
     98             fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
     99             if os.name == 'os2':
    100                 # AF_UNIX socket names on OS/2 require a specific prefix
    101                 # which can't include a drive letter and must also use
    102                 # backslashes as directory separators
    103                 if fn[1] == ':':
    104                     fn = fn[2:]
    105                 if fn[0] in (os.sep, os.altsep):
    106                     fn = fn[1:]
    107                 if os.sep == '/':
    108                     fn = fn.replace(os.sep, os.altsep)
    109                 else:
    110                     fn = fn.replace(os.altsep, os.sep)
    111             self.test_files.append(fn)
    112             return fn
    113 
    114     def make_server(self, addr, svrcls, hdlrbase):
    115         class MyServer(svrcls):
    116             def handle_error(self, request, client_address):
    117                 self.close_request(request)
    118                 self.server_close()
    119                 raise
    120 
    121         class MyHandler(hdlrbase):
    122             def handle(self):
    123                 line = self.rfile.readline()
    124                 self.wfile.write(line)
    125 
    126         if verbose: print "creating server"
    127         server = MyServer(addr, MyHandler)
    128         self.assertEqual(server.server_address, server.socket.getsockname())
    129         return server
    130 
    131     @reap_threads
    132     def run_server(self, svrcls, hdlrbase, testfunc):
    133         server = self.make_server(self.pickaddr(svrcls.address_family),
    134                                   svrcls, hdlrbase)
    135         # We had the OS pick a port, so pull the real address out of
    136         # the server.
    137         addr = server.server_address
    138         if verbose:
    139             print "server created"
    140             print "ADDR =", addr
    141             print "CLASS =", svrcls
    142         t = threading.Thread(
    143             name='%s serving' % svrcls,
    144             target=server.serve_forever,
    145             # Short poll interval to make the test finish quickly.
    146             # Time between requests is short enough that we won't wake
    147             # up spuriously too many times.
    148             kwargs={'poll_interval':0.01})
    149         t.daemon = True  # In case this function raises.
    150         t.start()
    151         if verbose: print "server running"
    152         for i in range(3):
    153             if verbose: print "test client", i
    154             testfunc(svrcls.address_family, addr)
    155         if verbose: print "waiting for server"
    156         server.shutdown()
    157         t.join()
    158         if verbose: print "done"
    159 
    160     def stream_examine(self, proto, addr):
    161         s = socket.socket(proto, socket.SOCK_STREAM)
    162         s.connect(addr)
    163         s.sendall(TEST_STR)
    164         buf = data = receive(s, 100)
    165         while data and '\n' not in buf:
    166             data = receive(s, 100)
    167             buf += data
    168         self.assertEqual(buf, TEST_STR)
    169         s.close()
    170 
    171     def dgram_examine(self, proto, addr):
    172         s = socket.socket(proto, socket.SOCK_DGRAM)
    173         s.sendto(TEST_STR, addr)
    174         buf = data = receive(s, 100)
    175         while data and '\n' not in buf:
    176             data = receive(s, 100)
    177             buf += data
    178         self.assertEqual(buf, TEST_STR)
    179         s.close()
    180 
    181     def test_TCPServer(self):
    182         self.run_server(SocketServer.TCPServer,
    183                         SocketServer.StreamRequestHandler,
    184                         self.stream_examine)
    185 
    186     def test_ThreadingTCPServer(self):
    187         self.run_server(SocketServer.ThreadingTCPServer,
    188                         SocketServer.StreamRequestHandler,
    189                         self.stream_examine)
    190 
    191     if HAVE_FORKING:
    192         def test_ForkingTCPServer(self):
    193             with simple_subprocess(self):
    194                 self.run_server(SocketServer.ForkingTCPServer,
    195                                 SocketServer.StreamRequestHandler,
    196                                 self.stream_examine)
    197 
    198     if HAVE_UNIX_SOCKETS:
    199         def test_UnixStreamServer(self):
    200             self.run_server(SocketServer.UnixStreamServer,
    201                             SocketServer.StreamRequestHandler,
    202                             self.stream_examine)
    203 
    204         def test_ThreadingUnixStreamServer(self):
    205             self.run_server(SocketServer.ThreadingUnixStreamServer,
    206                             SocketServer.StreamRequestHandler,
    207                             self.stream_examine)
    208 
    209         if HAVE_FORKING:
    210             def test_ForkingUnixStreamServer(self):
    211                 with simple_subprocess(self):
    212                     self.run_server(ForkingUnixStreamServer,
    213                                     SocketServer.StreamRequestHandler,
    214                                     self.stream_examine)
    215 
    216     def test_UDPServer(self):
    217         self.run_server(SocketServer.UDPServer,
    218                         SocketServer.DatagramRequestHandler,
    219                         self.dgram_examine)
    220 
    221     def test_ThreadingUDPServer(self):
    222         self.run_server(SocketServer.ThreadingUDPServer,
    223                         SocketServer.DatagramRequestHandler,
    224                         self.dgram_examine)
    225 
    226     if HAVE_FORKING:
    227         def test_ForkingUDPServer(self):
    228             with simple_subprocess(self):
    229                 self.run_server(SocketServer.ForkingUDPServer,
    230                                 SocketServer.DatagramRequestHandler,
    231                                 self.dgram_examine)
    232 
    233     @contextlib.contextmanager
    234     def mocked_select_module(self):
    235         """Mocks the select.select() call to raise EINTR for first call"""
    236         old_select = select.select
    237 
    238         class MockSelect:
    239             def __init__(self):
    240                 self.called = 0
    241 
    242             def __call__(self, *args):
    243                 self.called += 1
    244                 if self.called == 1:
    245                     # raise the exception on first call
    246                     raise select.error(errno.EINTR, os.strerror(errno.EINTR))
    247                 else:
    248                     # Return real select value for consecutive calls
    249                     return old_select(*args)
    250 
    251         select.select = MockSelect()
    252         try:
    253             yield select.select
    254         finally:
    255             select.select = old_select
    256 
    257     def test_InterruptServerSelectCall(self):
    258         with self.mocked_select_module() as mock_select:
    259             pid = self.run_server(SocketServer.TCPServer,
    260                                   SocketServer.StreamRequestHandler,
    261                                   self.stream_examine)
    262             # Make sure select was called again:
    263             self.assertGreater(mock_select.called, 1)
    264 
    265     # Alas, on Linux (at least) recvfrom() doesn't return a meaningful
    266     # client address so this cannot work:
    267 
    268     # if HAVE_UNIX_SOCKETS:
    269     #     def test_UnixDatagramServer(self):
    270     #         self.run_server(SocketServer.UnixDatagramServer,
    271     #                         SocketServer.DatagramRequestHandler,
    272     #                         self.dgram_examine)
    273     #
    274     #     def test_ThreadingUnixDatagramServer(self):
    275     #         self.run_server(SocketServer.ThreadingUnixDatagramServer,
    276     #                         SocketServer.DatagramRequestHandler,
    277     #                         self.dgram_examine)
    278     #
    279     #     if HAVE_FORKING:
    280     #         def test_ForkingUnixDatagramServer(self):
    281     #             self.run_server(SocketServer.ForkingUnixDatagramServer,
    282     #                             SocketServer.DatagramRequestHandler,
    283     #                             self.dgram_examine)
    284 
    285     @reap_threads
    286     def test_shutdown(self):
    287         # Issue #2302: shutdown() should always succeed in making an
    288         # other thread leave serve_forever().
    289         class MyServer(SocketServer.TCPServer):
    290             pass
    291 
    292         class MyHandler(SocketServer.StreamRequestHandler):
    293             pass
    294 
    295         threads = []
    296         for i in range(20):
    297             s = MyServer((HOST, 0), MyHandler)
    298             t = threading.Thread(
    299                 name='MyServer serving',
    300                 target=s.serve_forever,
    301                 kwargs={'poll_interval':0.01})
    302             t.daemon = True  # In case this function raises.
    303             threads.append((t, s))
    304         for t, s in threads:
    305             t.start()
    306             s.shutdown()
    307         for t, s in threads:
    308             t.join()
    309 
    310 
    311 def test_main():
    312     if imp.lock_held():
    313         # If the import lock is held, the threads will hang
    314         raise unittest.SkipTest("can't run when import lock is held")
    315 
    316     test.test_support.run_unittest(SocketServerTest)
    317 
    318 if __name__ == "__main__":
    319     test_main()
    320