Home | History | Annotate | Download | only in multiprocessing
      1 import errno
      2 import os
      3 import selectors
      4 import signal
      5 import socket
      6 import struct
      7 import sys
      8 import threading
      9 import warnings
     10 
     11 from . import connection
     12 from . import process
     13 from .context import reduction
     14 from . import semaphore_tracker
     15 from . import spawn
     16 from . import util
     17 
     18 __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
     19            'set_forkserver_preload']
     20 
     21 #
     22 #
     23 #
     24 
     25 MAXFDS_TO_SEND = 256
     26 SIGNED_STRUCT = struct.Struct('q')     # large enough for pid_t
     27 
     28 #
     29 # Forkserver class
     30 #
     31 
     32 class ForkServer(object):
     33 
     34     def __init__(self):
     35         self._forkserver_address = None
     36         self._forkserver_alive_fd = None
     37         self._forkserver_pid = None
     38         self._inherited_fds = None
     39         self._lock = threading.Lock()
     40         self._preload_modules = ['__main__']
     41 
     42     def set_forkserver_preload(self, modules_names):
     43         '''Set list of module names to try to load in forkserver process.'''
     44         if not all(type(mod) is str for mod in self._preload_modules):
     45             raise TypeError('module_names must be a list of strings')
     46         self._preload_modules = modules_names
     47 
     48     def get_inherited_fds(self):
     49         '''Return list of fds inherited from parent process.
     50 
     51         This returns None if the current process was not started by fork
     52         server.
     53         '''
     54         return self._inherited_fds
     55 
     56     def connect_to_new_process(self, fds):
     57         '''Request forkserver to create a child process.
     58 
     59         Returns a pair of fds (status_r, data_w).  The calling process can read
     60         the child process's pid and (eventually) its returncode from status_r.
     61         The calling process should write to data_w the pickled preparation and
     62         process data.
     63         '''
     64         self.ensure_running()
     65         if len(fds) + 4 >= MAXFDS_TO_SEND:
     66             raise ValueError('too many fds')
     67         with socket.socket(socket.AF_UNIX) as client:
     68             client.connect(self._forkserver_address)
     69             parent_r, child_w = os.pipe()
     70             child_r, parent_w = os.pipe()
     71             allfds = [child_r, child_w, self._forkserver_alive_fd,
     72                       semaphore_tracker.getfd()]
     73             allfds += fds
     74             try:
     75                 reduction.sendfds(client, allfds)
     76                 return parent_r, parent_w
     77             except:
     78                 os.close(parent_r)
     79                 os.close(parent_w)
     80                 raise
     81             finally:
     82                 os.close(child_r)
     83                 os.close(child_w)
     84 
     85     def ensure_running(self):
     86         '''Make sure that a fork server is running.
     87 
     88         This can be called from any process.  Note that usually a child
     89         process will just reuse the forkserver started by its parent, so
     90         ensure_running() will do nothing.
     91         '''
     92         with self._lock:
     93             semaphore_tracker.ensure_running()
     94             if self._forkserver_pid is not None:
     95                 # forkserver was launched before, is it still running?
     96                 pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
     97                 if not pid:
     98                     # still alive
     99                     return
    100                 # dead, launch it again
    101                 os.close(self._forkserver_alive_fd)
    102                 self._forkserver_address = None
    103                 self._forkserver_alive_fd = None
    104                 self._forkserver_pid = None
    105 
    106             cmd = ('from multiprocessing.forkserver import main; ' +
    107                    'main(%d, %d, %r, **%r)')
    108 
    109             if self._preload_modules:
    110                 desired_keys = {'main_path', 'sys_path'}
    111                 data = spawn.get_preparation_data('ignore')
    112                 data = {x: y for x, y in data.items() if x in desired_keys}
    113             else:
    114                 data = {}
    115 
    116             with socket.socket(socket.AF_UNIX) as listener:
    117                 address = connection.arbitrary_address('AF_UNIX')
    118                 listener.bind(address)
    119                 os.chmod(address, 0o600)
    120                 listener.listen()
    121 
    122                 # all client processes own the write end of the "alive" pipe;
    123                 # when they all terminate the read end becomes ready.
    124                 alive_r, alive_w = os.pipe()
    125                 try:
    126                     fds_to_pass = [listener.fileno(), alive_r]
    127                     cmd %= (listener.fileno(), alive_r, self._preload_modules,
    128                             data)
    129                     exe = spawn.get_executable()
    130                     args = [exe] + util._args_from_interpreter_flags()
    131                     args += ['-c', cmd]
    132                     pid = util.spawnv_passfds(exe, args, fds_to_pass)
    133                 except:
    134                     os.close(alive_w)
    135                     raise
    136                 finally:
    137                     os.close(alive_r)
    138                 self._forkserver_address = address
    139                 self._forkserver_alive_fd = alive_w
    140                 self._forkserver_pid = pid
    141 
    142 #
    143 #
    144 #
    145 
    146 def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
    147     '''Run forkserver.'''
    148     if preload:
    149         if '__main__' in preload and main_path is not None:
    150             process.current_process()._inheriting = True
    151             try:
    152                 spawn.import_main_path(main_path)
    153             finally:
    154                 del process.current_process()._inheriting
    155         for modname in preload:
    156             try:
    157                 __import__(modname)
    158             except ImportError:
    159                 pass
    160 
    161     util._close_stdin()
    162 
    163     sig_r, sig_w = os.pipe()
    164     os.set_blocking(sig_r, False)
    165     os.set_blocking(sig_w, False)
    166 
    167     def sigchld_handler(*_unused):
    168         # Dummy signal handler, doesn't do anything
    169         pass
    170 
    171     handlers = {
    172         # unblocking SIGCHLD allows the wakeup fd to notify our event loop
    173         signal.SIGCHLD: sigchld_handler,
    174         # protect the process from ^C
    175         signal.SIGINT: signal.SIG_IGN,
    176         }
    177     old_handlers = {sig: signal.signal(sig, val)
    178                     for (sig, val) in handlers.items()}
    179 
    180     # calling os.write() in the Python signal handler is racy
    181     signal.set_wakeup_fd(sig_w)
    182 
    183     # map child pids to client fds
    184     pid_to_fd = {}
    185 
    186     with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
    187          selectors.DefaultSelector() as selector:
    188         _forkserver._forkserver_address = listener.getsockname()
    189 
    190         selector.register(listener, selectors.EVENT_READ)
    191         selector.register(alive_r, selectors.EVENT_READ)
    192         selector.register(sig_r, selectors.EVENT_READ)
    193 
    194         while True:
    195             try:
    196                 while True:
    197                     rfds = [key.fileobj for (key, events) in selector.select()]
    198                     if rfds:
    199                         break
    200 
    201                 if alive_r in rfds:
    202                     # EOF because no more client processes left
    203                     assert os.read(alive_r, 1) == b'', "Not at EOF?"
    204                     raise SystemExit
    205 
    206                 if sig_r in rfds:
    207                     # Got SIGCHLD
    208                     os.read(sig_r, 65536)  # exhaust
    209                     while True:
    210                         # Scan for child processes
    211                         try:
    212                             pid, sts = os.waitpid(-1, os.WNOHANG)
    213                         except ChildProcessError:
    214                             break
    215                         if pid == 0:
    216                             break
    217                         child_w = pid_to_fd.pop(pid, None)
    218                         if child_w is not None:
    219                             if os.WIFSIGNALED(sts):
    220                                 returncode = -os.WTERMSIG(sts)
    221                             else:
    222                                 if not os.WIFEXITED(sts):
    223                                     raise AssertionError(
    224                                         "Child {0:n} status is {1:n}".format(
    225                                             pid,sts))
    226                                 returncode = os.WEXITSTATUS(sts)
    227                             # Send exit code to client process
    228                             try:
    229                                 write_signed(child_w, returncode)
    230                             except BrokenPipeError:
    231                                 # client vanished
    232                                 pass
    233                             os.close(child_w)
    234                         else:
    235                             # This shouldn't happen really
    236                             warnings.warn('forkserver: waitpid returned '
    237                                           'unexpected pid %d' % pid)
    238 
    239                 if listener in rfds:
    240                     # Incoming fork request
    241                     with listener.accept()[0] as s:
    242                         # Receive fds from client
    243                         fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
    244                         if len(fds) > MAXFDS_TO_SEND:
    245                             raise RuntimeError(
    246                                 "Too many ({0:n}) fds to send".format(
    247                                     len(fds)))
    248                         child_r, child_w, *fds = fds
    249                         s.close()
    250                         pid = os.fork()
    251                         if pid == 0:
    252                             # Child
    253                             code = 1
    254                             try:
    255                                 listener.close()
    256                                 selector.close()
    257                                 unused_fds = [alive_r, child_w, sig_r, sig_w]
    258                                 unused_fds.extend(pid_to_fd.values())
    259                                 code = _serve_one(child_r, fds,
    260                                                   unused_fds,
    261                                                   old_handlers)
    262                             except Exception:
    263                                 sys.excepthook(*sys.exc_info())
    264                                 sys.stderr.flush()
    265                             finally:
    266                                 os._exit(code)
    267                         else:
    268                             # Send pid to client process
    269                             try:
    270                                 write_signed(child_w, pid)
    271                             except BrokenPipeError:
    272                                 # client vanished
    273                                 pass
    274                             pid_to_fd[pid] = child_w
    275                             os.close(child_r)
    276                             for fd in fds:
    277                                 os.close(fd)
    278 
    279             except OSError as e:
    280                 if e.errno != errno.ECONNABORTED:
    281                     raise
    282 
    283 
    284 def _serve_one(child_r, fds, unused_fds, handlers):
    285     # close unnecessary stuff and reset signal handlers
    286     signal.set_wakeup_fd(-1)
    287     for sig, val in handlers.items():
    288         signal.signal(sig, val)
    289     for fd in unused_fds:
    290         os.close(fd)
    291 
    292     (_forkserver._forkserver_alive_fd,
    293      semaphore_tracker._semaphore_tracker._fd,
    294      *_forkserver._inherited_fds) = fds
    295 
    296     # Run process object received over pipe
    297     code = spawn._main(child_r)
    298 
    299     return code
    300 
    301 
    302 #
    303 # Read and write signed numbers
    304 #
    305 
    306 def read_signed(fd):
    307     data = b''
    308     length = SIGNED_STRUCT.size
    309     while len(data) < length:
    310         s = os.read(fd, length - len(data))
    311         if not s:
    312             raise EOFError('unexpected EOF')
    313         data += s
    314     return SIGNED_STRUCT.unpack(data)[0]
    315 
    316 def write_signed(fd, n):
    317     msg = SIGNED_STRUCT.pack(n)
    318     while msg:
    319         nbytes = os.write(fd, msg)
    320         if nbytes == 0:
    321             raise RuntimeError('should not get here')
    322         msg = msg[nbytes:]
    323 
    324 #
    325 #
    326 #
    327 
    328 _forkserver = ForkServer()
    329 ensure_running = _forkserver.ensure_running
    330 get_inherited_fds = _forkserver.get_inherited_fds
    331 connect_to_new_process = _forkserver.connect_to_new_process
    332 set_forkserver_preload = _forkserver.set_forkserver_preload
    333