1 import errno 2 import os 3 import selectors 4 import signal 5 import socket 6 import struct 7 import sys 8 import threading 9 10 from . import connection 11 from . import process 12 from .context import reduction 13 from . import semaphore_tracker 14 from . import spawn 15 from . import util 16 17 __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', 18 'set_forkserver_preload'] 19 20 # 21 # 22 # 23 24 MAXFDS_TO_SEND = 256 25 UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t 26 27 # 28 # Forkserver class 29 # 30 31 class ForkServer(object): 32 33 def __init__(self): 34 self._forkserver_address = None 35 self._forkserver_alive_fd = None 36 self._inherited_fds = None 37 self._lock = threading.Lock() 38 self._preload_modules = ['__main__'] 39 40 def set_forkserver_preload(self, modules_names): 41 '''Set list of module names to try to load in forkserver process.''' 42 if not all(type(mod) is str for mod in self._preload_modules): 43 raise TypeError('module_names must be a list of strings') 44 self._preload_modules = modules_names 45 46 def get_inherited_fds(self): 47 '''Return list of fds inherited from parent process. 48 49 This returns None if the current process was not started by fork 50 server. 51 ''' 52 return self._inherited_fds 53 54 def connect_to_new_process(self, fds): 55 '''Request forkserver to create a child process. 56 57 Returns a pair of fds (status_r, data_w). The calling process can read 58 the child process's pid and (eventually) its returncode from status_r. 59 The calling process should write to data_w the pickled preparation and 60 process data. 61 ''' 62 self.ensure_running() 63 if len(fds) + 4 >= MAXFDS_TO_SEND: 64 raise ValueError('too many fds') 65 with socket.socket(socket.AF_UNIX) as client: 66 client.connect(self._forkserver_address) 67 parent_r, child_w = os.pipe() 68 child_r, parent_w = os.pipe() 69 allfds = [child_r, child_w, self._forkserver_alive_fd, 70 semaphore_tracker.getfd()] 71 allfds += fds 72 try: 73 reduction.sendfds(client, allfds) 74 return parent_r, parent_w 75 except: 76 os.close(parent_r) 77 os.close(parent_w) 78 raise 79 finally: 80 os.close(child_r) 81 os.close(child_w) 82 83 def ensure_running(self): 84 '''Make sure that a fork server is running. 85 86 This can be called from any process. Note that usually a child 87 process will just reuse the forkserver started by its parent, so 88 ensure_running() will do nothing. 89 ''' 90 with self._lock: 91 semaphore_tracker.ensure_running() 92 if self._forkserver_alive_fd is not None: 93 return 94 95 cmd = ('from multiprocessing.forkserver import main; ' + 96 'main(%d, %d, %r, **%r)') 97 98 if self._preload_modules: 99 desired_keys = {'main_path', 'sys_path'} 100 data = spawn.get_preparation_data('ignore') 101 data = dict((x,y) for (x,y) in data.items() 102 if x in desired_keys) 103 else: 104 data = {} 105 106 with socket.socket(socket.AF_UNIX) as listener: 107 address = connection.arbitrary_address('AF_UNIX') 108 listener.bind(address) 109 os.chmod(address, 0o600) 110 listener.listen() 111 112 # all client processes own the write end of the "alive" pipe; 113 # when they all terminate the read end becomes ready. 114 alive_r, alive_w = os.pipe() 115 try: 116 fds_to_pass = [listener.fileno(), alive_r] 117 cmd %= (listener.fileno(), alive_r, self._preload_modules, 118 data) 119 exe = spawn.get_executable() 120 args = [exe] + util._args_from_interpreter_flags() 121 args += ['-c', cmd] 122 pid = util.spawnv_passfds(exe, args, fds_to_pass) 123 except: 124 os.close(alive_w) 125 raise 126 finally: 127 os.close(alive_r) 128 self._forkserver_address = address 129 self._forkserver_alive_fd = alive_w 130 131 # 132 # 133 # 134 135 def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): 136 '''Run forkserver.''' 137 if preload: 138 if '__main__' in preload and main_path is not None: 139 process.current_process()._inheriting = True 140 try: 141 spawn.import_main_path(main_path) 142 finally: 143 del process.current_process()._inheriting 144 for modname in preload: 145 try: 146 __import__(modname) 147 except ImportError: 148 pass 149 150 util._close_stdin() 151 152 # ignoring SIGCHLD means no need to reap zombie processes 153 handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN) 154 with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ 155 selectors.DefaultSelector() as selector: 156 _forkserver._forkserver_address = listener.getsockname() 157 158 selector.register(listener, selectors.EVENT_READ) 159 selector.register(alive_r, selectors.EVENT_READ) 160 161 while True: 162 try: 163 while True: 164 rfds = [key.fileobj for (key, events) in selector.select()] 165 if rfds: 166 break 167 168 if alive_r in rfds: 169 # EOF because no more client processes left 170 assert os.read(alive_r, 1) == b'' 171 raise SystemExit 172 173 assert listener in rfds 174 with listener.accept()[0] as s: 175 code = 1 176 if os.fork() == 0: 177 try: 178 _serve_one(s, listener, alive_r, handler) 179 except Exception: 180 sys.excepthook(*sys.exc_info()) 181 sys.stderr.flush() 182 finally: 183 os._exit(code) 184 185 except OSError as e: 186 if e.errno != errno.ECONNABORTED: 187 raise 188 189 def _serve_one(s, listener, alive_r, handler): 190 # close unnecessary stuff and reset SIGCHLD handler 191 listener.close() 192 os.close(alive_r) 193 signal.signal(signal.SIGCHLD, handler) 194 195 # receive fds from parent process 196 fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) 197 s.close() 198 assert len(fds) <= MAXFDS_TO_SEND 199 (child_r, child_w, _forkserver._forkserver_alive_fd, 200 stfd, *_forkserver._inherited_fds) = fds 201 semaphore_tracker._semaphore_tracker._fd = stfd 202 203 # send pid to client processes 204 write_unsigned(child_w, os.getpid()) 205 206 # reseed random number generator 207 if 'random' in sys.modules: 208 import random 209 random.seed() 210 211 # run process object received over pipe 212 code = spawn._main(child_r) 213 214 # write the exit code to the pipe 215 write_unsigned(child_w, code) 216 217 # 218 # Read and write unsigned numbers 219 # 220 221 def read_unsigned(fd): 222 data = b'' 223 length = UNSIGNED_STRUCT.size 224 while len(data) < length: 225 s = os.read(fd, length - len(data)) 226 if not s: 227 raise EOFError('unexpected EOF') 228 data += s 229 return UNSIGNED_STRUCT.unpack(data)[0] 230 231 def write_unsigned(fd, n): 232 msg = UNSIGNED_STRUCT.pack(n) 233 while msg: 234 nbytes = os.write(fd, msg) 235 if nbytes == 0: 236 raise RuntimeError('should not get here') 237 msg = msg[nbytes:] 238 239 # 240 # 241 # 242 243 _forkserver = ForkServer() 244 ensure_running = _forkserver.ensure_running 245 get_inherited_fds = _forkserver.get_inherited_fds 246 connect_to_new_process = _forkserver.connect_to_new_process 247 set_forkserver_preload = _forkserver.set_forkserver_preload 248