1 import errno 2 import os 3 import selectors 4 import signal 5 import socket 6 import struct 7 import sys 8 import threading 9 import warnings 10 11 from . import connection 12 from . import process 13 from .context import reduction 14 from . import semaphore_tracker 15 from . import spawn 16 from . import util 17 18 __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', 19 'set_forkserver_preload'] 20 21 # 22 # 23 # 24 25 MAXFDS_TO_SEND = 256 26 SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t 27 28 # 29 # Forkserver class 30 # 31 32 class ForkServer(object): 33 34 def __init__(self): 35 self._forkserver_address = None 36 self._forkserver_alive_fd = None 37 self._forkserver_pid = None 38 self._inherited_fds = None 39 self._lock = threading.Lock() 40 self._preload_modules = ['__main__'] 41 42 def set_forkserver_preload(self, modules_names): 43 '''Set list of module names to try to load in forkserver process.''' 44 if not all(type(mod) is str for mod in self._preload_modules): 45 raise TypeError('module_names must be a list of strings') 46 self._preload_modules = modules_names 47 48 def get_inherited_fds(self): 49 '''Return list of fds inherited from parent process. 50 51 This returns None if the current process was not started by fork 52 server. 53 ''' 54 return self._inherited_fds 55 56 def connect_to_new_process(self, fds): 57 '''Request forkserver to create a child process. 58 59 Returns a pair of fds (status_r, data_w). The calling process can read 60 the child process's pid and (eventually) its returncode from status_r. 61 The calling process should write to data_w the pickled preparation and 62 process data. 63 ''' 64 self.ensure_running() 65 if len(fds) + 4 >= MAXFDS_TO_SEND: 66 raise ValueError('too many fds') 67 with socket.socket(socket.AF_UNIX) as client: 68 client.connect(self._forkserver_address) 69 parent_r, child_w = os.pipe() 70 child_r, parent_w = os.pipe() 71 allfds = [child_r, child_w, self._forkserver_alive_fd, 72 semaphore_tracker.getfd()] 73 allfds += fds 74 try: 75 reduction.sendfds(client, allfds) 76 return parent_r, parent_w 77 except: 78 os.close(parent_r) 79 os.close(parent_w) 80 raise 81 finally: 82 os.close(child_r) 83 os.close(child_w) 84 85 def ensure_running(self): 86 '''Make sure that a fork server is running. 87 88 This can be called from any process. Note that usually a child 89 process will just reuse the forkserver started by its parent, so 90 ensure_running() will do nothing. 91 ''' 92 with self._lock: 93 semaphore_tracker.ensure_running() 94 if self._forkserver_pid is not None: 95 # forkserver was launched before, is it still running? 96 pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG) 97 if not pid: 98 # still alive 99 return 100 # dead, launch it again 101 os.close(self._forkserver_alive_fd) 102 self._forkserver_address = None 103 self._forkserver_alive_fd = None 104 self._forkserver_pid = None 105 106 cmd = ('from multiprocessing.forkserver import main; ' + 107 'main(%d, %d, %r, **%r)') 108 109 if self._preload_modules: 110 desired_keys = {'main_path', 'sys_path'} 111 data = spawn.get_preparation_data('ignore') 112 data = {x: y for x, y in data.items() if x in desired_keys} 113 else: 114 data = {} 115 116 with socket.socket(socket.AF_UNIX) as listener: 117 address = connection.arbitrary_address('AF_UNIX') 118 listener.bind(address) 119 os.chmod(address, 0o600) 120 listener.listen() 121 122 # all client processes own the write end of the "alive" pipe; 123 # when they all terminate the read end becomes ready. 124 alive_r, alive_w = os.pipe() 125 try: 126 fds_to_pass = [listener.fileno(), alive_r] 127 cmd %= (listener.fileno(), alive_r, self._preload_modules, 128 data) 129 exe = spawn.get_executable() 130 args = [exe] + util._args_from_interpreter_flags() 131 args += ['-c', cmd] 132 pid = util.spawnv_passfds(exe, args, fds_to_pass) 133 except: 134 os.close(alive_w) 135 raise 136 finally: 137 os.close(alive_r) 138 self._forkserver_address = address 139 self._forkserver_alive_fd = alive_w 140 self._forkserver_pid = pid 141 142 # 143 # 144 # 145 146 def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): 147 '''Run forkserver.''' 148 if preload: 149 if '__main__' in preload and main_path is not None: 150 process.current_process()._inheriting = True 151 try: 152 spawn.import_main_path(main_path) 153 finally: 154 del process.current_process()._inheriting 155 for modname in preload: 156 try: 157 __import__(modname) 158 except ImportError: 159 pass 160 161 util._close_stdin() 162 163 sig_r, sig_w = os.pipe() 164 os.set_blocking(sig_r, False) 165 os.set_blocking(sig_w, False) 166 167 def sigchld_handler(*_unused): 168 # Dummy signal handler, doesn't do anything 169 pass 170 171 handlers = { 172 # unblocking SIGCHLD allows the wakeup fd to notify our event loop 173 signal.SIGCHLD: sigchld_handler, 174 # protect the process from ^C 175 signal.SIGINT: signal.SIG_IGN, 176 } 177 old_handlers = {sig: signal.signal(sig, val) 178 for (sig, val) in handlers.items()} 179 180 # calling os.write() in the Python signal handler is racy 181 signal.set_wakeup_fd(sig_w) 182 183 # map child pids to client fds 184 pid_to_fd = {} 185 186 with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ 187 selectors.DefaultSelector() as selector: 188 _forkserver._forkserver_address = listener.getsockname() 189 190 selector.register(listener, selectors.EVENT_READ) 191 selector.register(alive_r, selectors.EVENT_READ) 192 selector.register(sig_r, selectors.EVENT_READ) 193 194 while True: 195 try: 196 while True: 197 rfds = [key.fileobj for (key, events) in selector.select()] 198 if rfds: 199 break 200 201 if alive_r in rfds: 202 # EOF because no more client processes left 203 assert os.read(alive_r, 1) == b'', "Not at EOF?" 204 raise SystemExit 205 206 if sig_r in rfds: 207 # Got SIGCHLD 208 os.read(sig_r, 65536) # exhaust 209 while True: 210 # Scan for child processes 211 try: 212 pid, sts = os.waitpid(-1, os.WNOHANG) 213 except ChildProcessError: 214 break 215 if pid == 0: 216 break 217 child_w = pid_to_fd.pop(pid, None) 218 if child_w is not None: 219 if os.WIFSIGNALED(sts): 220 returncode = -os.WTERMSIG(sts) 221 else: 222 if not os.WIFEXITED(sts): 223 raise AssertionError( 224 "Child {0:n} status is {1:n}".format( 225 pid,sts)) 226 returncode = os.WEXITSTATUS(sts) 227 # Send exit code to client process 228 try: 229 write_signed(child_w, returncode) 230 except BrokenPipeError: 231 # client vanished 232 pass 233 os.close(child_w) 234 else: 235 # This shouldn't happen really 236 warnings.warn('forkserver: waitpid returned ' 237 'unexpected pid %d' % pid) 238 239 if listener in rfds: 240 # Incoming fork request 241 with listener.accept()[0] as s: 242 # Receive fds from client 243 fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) 244 if len(fds) > MAXFDS_TO_SEND: 245 raise RuntimeError( 246 "Too many ({0:n}) fds to send".format( 247 len(fds))) 248 child_r, child_w, *fds = fds 249 s.close() 250 pid = os.fork() 251 if pid == 0: 252 # Child 253 code = 1 254 try: 255 listener.close() 256 selector.close() 257 unused_fds = [alive_r, child_w, sig_r, sig_w] 258 unused_fds.extend(pid_to_fd.values()) 259 code = _serve_one(child_r, fds, 260 unused_fds, 261 old_handlers) 262 except Exception: 263 sys.excepthook(*sys.exc_info()) 264 sys.stderr.flush() 265 finally: 266 os._exit(code) 267 else: 268 # Send pid to client process 269 try: 270 write_signed(child_w, pid) 271 except BrokenPipeError: 272 # client vanished 273 pass 274 pid_to_fd[pid] = child_w 275 os.close(child_r) 276 for fd in fds: 277 os.close(fd) 278 279 except OSError as e: 280 if e.errno != errno.ECONNABORTED: 281 raise 282 283 284 def _serve_one(child_r, fds, unused_fds, handlers): 285 # close unnecessary stuff and reset signal handlers 286 signal.set_wakeup_fd(-1) 287 for sig, val in handlers.items(): 288 signal.signal(sig, val) 289 for fd in unused_fds: 290 os.close(fd) 291 292 (_forkserver._forkserver_alive_fd, 293 semaphore_tracker._semaphore_tracker._fd, 294 *_forkserver._inherited_fds) = fds 295 296 # Run process object received over pipe 297 code = spawn._main(child_r) 298 299 return code 300 301 302 # 303 # Read and write signed numbers 304 # 305 306 def read_signed(fd): 307 data = b'' 308 length = SIGNED_STRUCT.size 309 while len(data) < length: 310 s = os.read(fd, length - len(data)) 311 if not s: 312 raise EOFError('unexpected EOF') 313 data += s 314 return SIGNED_STRUCT.unpack(data)[0] 315 316 def write_signed(fd, n): 317 msg = SIGNED_STRUCT.pack(n) 318 while msg: 319 nbytes = os.write(fd, msg) 320 if nbytes == 0: 321 raise RuntimeError('should not get here') 322 msg = msg[nbytes:] 323 324 # 325 # 326 # 327 328 _forkserver = ForkServer() 329 ensure_running = _forkserver.ensure_running 330 get_inherited_fds = _forkserver.get_inherited_fds 331 connect_to_new_process = _forkserver.connect_to_new_process 332 set_forkserver_preload = _forkserver.set_forkserver_preload 333