1 """ 2 Test suite for SocketServer.py. 3 """ 4 5 import contextlib 6 import imp 7 import os 8 import select 9 import signal 10 import socket 11 import select 12 import errno 13 import tempfile 14 import unittest 15 import SocketServer 16 17 import test.test_support 18 from test.test_support import reap_children, reap_threads, verbose 19 try: 20 import threading 21 except ImportError: 22 threading = None 23 24 test.test_support.requires("network") 25 26 TEST_STR = "hello world\n" 27 HOST = test.test_support.HOST 28 29 HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") 30 HAVE_FORKING = hasattr(os, "fork") and os.name != "os2" 31 32 def signal_alarm(n): 33 """Call signal.alarm when it exists (i.e. not on Windows).""" 34 if hasattr(signal, 'alarm'): 35 signal.alarm(n) 36 37 # Remember real select() to avoid interferences with mocking 38 _real_select = select.select 39 40 def receive(sock, n, timeout=20): 41 r, w, x = _real_select([sock], [], [], timeout) 42 if sock in r: 43 return sock.recv(n) 44 else: 45 raise RuntimeError, "timed out on %r" % (sock,) 46 47 if HAVE_UNIX_SOCKETS: 48 class ForkingUnixStreamServer(SocketServer.ForkingMixIn, 49 SocketServer.UnixStreamServer): 50 pass 51 52 class ForkingUnixDatagramServer(SocketServer.ForkingMixIn, 53 SocketServer.UnixDatagramServer): 54 pass 55 56 57 @contextlib.contextmanager 58 def simple_subprocess(testcase): 59 pid = os.fork() 60 if pid == 0: 61 # Don't raise an exception; it would be caught by the test harness. 62 os._exit(72) 63 yield None 64 pid2, status = os.waitpid(pid, 0) 65 testcase.assertEqual(pid2, pid) 66 testcase.assertEqual(72 << 8, status) 67 68 69 @unittest.skipUnless(threading, 'Threading required for this test.') 70 class SocketServerTest(unittest.TestCase): 71 """Test all socket servers.""" 72 73 def setUp(self): 74 signal_alarm(60) # Kill deadlocks after 60 seconds. 75 self.port_seed = 0 76 self.test_files = [] 77 78 def tearDown(self): 79 signal_alarm(0) # Didn't deadlock. 80 reap_children() 81 82 for fn in self.test_files: 83 try: 84 os.remove(fn) 85 except os.error: 86 pass 87 self.test_files[:] = [] 88 89 def pickaddr(self, proto): 90 if proto == socket.AF_INET: 91 return (HOST, 0) 92 else: 93 # XXX: We need a way to tell AF_UNIX to pick its own name 94 # like AF_INET provides port==0. 95 dir = None 96 if os.name == 'os2': 97 dir = '\socket' 98 fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) 99 if os.name == 'os2': 100 # AF_UNIX socket names on OS/2 require a specific prefix 101 # which can't include a drive letter and must also use 102 # backslashes as directory separators 103 if fn[1] == ':': 104 fn = fn[2:] 105 if fn[0] in (os.sep, os.altsep): 106 fn = fn[1:] 107 if os.sep == '/': 108 fn = fn.replace(os.sep, os.altsep) 109 else: 110 fn = fn.replace(os.altsep, os.sep) 111 self.test_files.append(fn) 112 return fn 113 114 def make_server(self, addr, svrcls, hdlrbase): 115 class MyServer(svrcls): 116 def handle_error(self, request, client_address): 117 self.close_request(request) 118 self.server_close() 119 raise 120 121 class MyHandler(hdlrbase): 122 def handle(self): 123 line = self.rfile.readline() 124 self.wfile.write(line) 125 126 if verbose: print "creating server" 127 server = MyServer(addr, MyHandler) 128 self.assertEqual(server.server_address, server.socket.getsockname()) 129 return server 130 131 @reap_threads 132 def run_server(self, svrcls, hdlrbase, testfunc): 133 server = self.make_server(self.pickaddr(svrcls.address_family), 134 svrcls, hdlrbase) 135 # We had the OS pick a port, so pull the real address out of 136 # the server. 137 addr = server.server_address 138 if verbose: 139 print "server created" 140 print "ADDR =", addr 141 print "CLASS =", svrcls 142 t = threading.Thread( 143 name='%s serving' % svrcls, 144 target=server.serve_forever, 145 # Short poll interval to make the test finish quickly. 146 # Time between requests is short enough that we won't wake 147 # up spuriously too many times. 148 kwargs={'poll_interval':0.01}) 149 t.daemon = True # In case this function raises. 150 t.start() 151 if verbose: print "server running" 152 for i in range(3): 153 if verbose: print "test client", i 154 testfunc(svrcls.address_family, addr) 155 if verbose: print "waiting for server" 156 server.shutdown() 157 t.join() 158 if verbose: print "done" 159 160 def stream_examine(self, proto, addr): 161 s = socket.socket(proto, socket.SOCK_STREAM) 162 s.connect(addr) 163 s.sendall(TEST_STR) 164 buf = data = receive(s, 100) 165 while data and '\n' not in buf: 166 data = receive(s, 100) 167 buf += data 168 self.assertEqual(buf, TEST_STR) 169 s.close() 170 171 def dgram_examine(self, proto, addr): 172 s = socket.socket(proto, socket.SOCK_DGRAM) 173 s.sendto(TEST_STR, addr) 174 buf = data = receive(s, 100) 175 while data and '\n' not in buf: 176 data = receive(s, 100) 177 buf += data 178 self.assertEqual(buf, TEST_STR) 179 s.close() 180 181 def test_TCPServer(self): 182 self.run_server(SocketServer.TCPServer, 183 SocketServer.StreamRequestHandler, 184 self.stream_examine) 185 186 def test_ThreadingTCPServer(self): 187 self.run_server(SocketServer.ThreadingTCPServer, 188 SocketServer.StreamRequestHandler, 189 self.stream_examine) 190 191 if HAVE_FORKING: 192 def test_ForkingTCPServer(self): 193 with simple_subprocess(self): 194 self.run_server(SocketServer.ForkingTCPServer, 195 SocketServer.StreamRequestHandler, 196 self.stream_examine) 197 198 if HAVE_UNIX_SOCKETS: 199 def test_UnixStreamServer(self): 200 self.run_server(SocketServer.UnixStreamServer, 201 SocketServer.StreamRequestHandler, 202 self.stream_examine) 203 204 def test_ThreadingUnixStreamServer(self): 205 self.run_server(SocketServer.ThreadingUnixStreamServer, 206 SocketServer.StreamRequestHandler, 207 self.stream_examine) 208 209 if HAVE_FORKING: 210 def test_ForkingUnixStreamServer(self): 211 with simple_subprocess(self): 212 self.run_server(ForkingUnixStreamServer, 213 SocketServer.StreamRequestHandler, 214 self.stream_examine) 215 216 def test_UDPServer(self): 217 self.run_server(SocketServer.UDPServer, 218 SocketServer.DatagramRequestHandler, 219 self.dgram_examine) 220 221 def test_ThreadingUDPServer(self): 222 self.run_server(SocketServer.ThreadingUDPServer, 223 SocketServer.DatagramRequestHandler, 224 self.dgram_examine) 225 226 if HAVE_FORKING: 227 def test_ForkingUDPServer(self): 228 with simple_subprocess(self): 229 self.run_server(SocketServer.ForkingUDPServer, 230 SocketServer.DatagramRequestHandler, 231 self.dgram_examine) 232 233 @contextlib.contextmanager 234 def mocked_select_module(self): 235 """Mocks the select.select() call to raise EINTR for first call""" 236 old_select = select.select 237 238 class MockSelect: 239 def __init__(self): 240 self.called = 0 241 242 def __call__(self, *args): 243 self.called += 1 244 if self.called == 1: 245 # raise the exception on first call 246 raise select.error(errno.EINTR, os.strerror(errno.EINTR)) 247 else: 248 # Return real select value for consecutive calls 249 return old_select(*args) 250 251 select.select = MockSelect() 252 try: 253 yield select.select 254 finally: 255 select.select = old_select 256 257 def test_InterruptServerSelectCall(self): 258 with self.mocked_select_module() as mock_select: 259 pid = self.run_server(SocketServer.TCPServer, 260 SocketServer.StreamRequestHandler, 261 self.stream_examine) 262 # Make sure select was called again: 263 self.assertGreater(mock_select.called, 1) 264 265 # Alas, on Linux (at least) recvfrom() doesn't return a meaningful 266 # client address so this cannot work: 267 268 # if HAVE_UNIX_SOCKETS: 269 # def test_UnixDatagramServer(self): 270 # self.run_server(SocketServer.UnixDatagramServer, 271 # SocketServer.DatagramRequestHandler, 272 # self.dgram_examine) 273 # 274 # def test_ThreadingUnixDatagramServer(self): 275 # self.run_server(SocketServer.ThreadingUnixDatagramServer, 276 # SocketServer.DatagramRequestHandler, 277 # self.dgram_examine) 278 # 279 # if HAVE_FORKING: 280 # def test_ForkingUnixDatagramServer(self): 281 # self.run_server(SocketServer.ForkingUnixDatagramServer, 282 # SocketServer.DatagramRequestHandler, 283 # self.dgram_examine) 284 285 @reap_threads 286 def test_shutdown(self): 287 # Issue #2302: shutdown() should always succeed in making an 288 # other thread leave serve_forever(). 289 class MyServer(SocketServer.TCPServer): 290 pass 291 292 class MyHandler(SocketServer.StreamRequestHandler): 293 pass 294 295 threads = [] 296 for i in range(20): 297 s = MyServer((HOST, 0), MyHandler) 298 t = threading.Thread( 299 name='MyServer serving', 300 target=s.serve_forever, 301 kwargs={'poll_interval':0.01}) 302 t.daemon = True # In case this function raises. 303 threads.append((t, s)) 304 for t, s in threads: 305 t.start() 306 s.shutdown() 307 for t, s in threads: 308 t.join() 309 310 311 def test_main(): 312 if imp.lock_held(): 313 # If the import lock is held, the threads will hang 314 raise unittest.SkipTest("can't run when import lock is held") 315 316 test.test_support.run_unittest(SocketServerTest) 317 318 if __name__ == "__main__": 319 test_main() 320