Home | History | Annotate | Download | only in multiprocessing
      1 import errno
      2 import os
      3 import selectors
      4 import signal
      5 import socket
      6 import struct
      7 import sys
      8 import threading
      9 
     10 from . import connection
     11 from . import process
     12 from .context import reduction
     13 from . import semaphore_tracker
     14 from . import spawn
     15 from . import util
     16 
     17 __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
     18            'set_forkserver_preload']
     19 
     20 #
     21 #
     22 #
     23 
     24 MAXFDS_TO_SEND = 256
     25 UNSIGNED_STRUCT = struct.Struct('Q')     # large enough for pid_t
     26 
     27 #
     28 # Forkserver class
     29 #
     30 
     31 class ForkServer(object):
     32 
     33     def __init__(self):
     34         self._forkserver_address = None
     35         self._forkserver_alive_fd = None
     36         self._inherited_fds = None
     37         self._lock = threading.Lock()
     38         self._preload_modules = ['__main__']
     39 
     40     def set_forkserver_preload(self, modules_names):
     41         '''Set list of module names to try to load in forkserver process.'''
     42         if not all(type(mod) is str for mod in self._preload_modules):
     43             raise TypeError('module_names must be a list of strings')
     44         self._preload_modules = modules_names
     45 
     46     def get_inherited_fds(self):
     47         '''Return list of fds inherited from parent process.
     48 
     49         This returns None if the current process was not started by fork
     50         server.
     51         '''
     52         return self._inherited_fds
     53 
     54     def connect_to_new_process(self, fds):
     55         '''Request forkserver to create a child process.
     56 
     57         Returns a pair of fds (status_r, data_w).  The calling process can read
     58         the child process's pid and (eventually) its returncode from status_r.
     59         The calling process should write to data_w the pickled preparation and
     60         process data.
     61         '''
     62         self.ensure_running()
     63         if len(fds) + 4 >= MAXFDS_TO_SEND:
     64             raise ValueError('too many fds')
     65         with socket.socket(socket.AF_UNIX) as client:
     66             client.connect(self._forkserver_address)
     67             parent_r, child_w = os.pipe()
     68             child_r, parent_w = os.pipe()
     69             allfds = [child_r, child_w, self._forkserver_alive_fd,
     70                       semaphore_tracker.getfd()]
     71             allfds += fds
     72             try:
     73                 reduction.sendfds(client, allfds)
     74                 return parent_r, parent_w
     75             except:
     76                 os.close(parent_r)
     77                 os.close(parent_w)
     78                 raise
     79             finally:
     80                 os.close(child_r)
     81                 os.close(child_w)
     82 
     83     def ensure_running(self):
     84         '''Make sure that a fork server is running.
     85 
     86         This can be called from any process.  Note that usually a child
     87         process will just reuse the forkserver started by its parent, so
     88         ensure_running() will do nothing.
     89         '''
     90         with self._lock:
     91             semaphore_tracker.ensure_running()
     92             if self._forkserver_alive_fd is not None:
     93                 return
     94 
     95             cmd = ('from multiprocessing.forkserver import main; ' +
     96                    'main(%d, %d, %r, **%r)')
     97 
     98             if self._preload_modules:
     99                 desired_keys = {'main_path', 'sys_path'}
    100                 data = spawn.get_preparation_data('ignore')
    101                 data = dict((x,y) for (x,y) in data.items()
    102                             if x in desired_keys)
    103             else:
    104                 data = {}
    105 
    106             with socket.socket(socket.AF_UNIX) as listener:
    107                 address = connection.arbitrary_address('AF_UNIX')
    108                 listener.bind(address)
    109                 os.chmod(address, 0o600)
    110                 listener.listen()
    111 
    112                 # all client processes own the write end of the "alive" pipe;
    113                 # when they all terminate the read end becomes ready.
    114                 alive_r, alive_w = os.pipe()
    115                 try:
    116                     fds_to_pass = [listener.fileno(), alive_r]
    117                     cmd %= (listener.fileno(), alive_r, self._preload_modules,
    118                             data)
    119                     exe = spawn.get_executable()
    120                     args = [exe] + util._args_from_interpreter_flags()
    121                     args += ['-c', cmd]
    122                     pid = util.spawnv_passfds(exe, args, fds_to_pass)
    123                 except:
    124                     os.close(alive_w)
    125                     raise
    126                 finally:
    127                     os.close(alive_r)
    128                 self._forkserver_address = address
    129                 self._forkserver_alive_fd = alive_w
    130 
    131 #
    132 #
    133 #
    134 
    135 def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
    136     '''Run forkserver.'''
    137     if preload:
    138         if '__main__' in preload and main_path is not None:
    139             process.current_process()._inheriting = True
    140             try:
    141                 spawn.import_main_path(main_path)
    142             finally:
    143                 del process.current_process()._inheriting
    144         for modname in preload:
    145             try:
    146                 __import__(modname)
    147             except ImportError:
    148                 pass
    149 
    150     util._close_stdin()
    151 
    152     # ignoring SIGCHLD means no need to reap zombie processes
    153     handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
    154     with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
    155          selectors.DefaultSelector() as selector:
    156         _forkserver._forkserver_address = listener.getsockname()
    157 
    158         selector.register(listener, selectors.EVENT_READ)
    159         selector.register(alive_r, selectors.EVENT_READ)
    160 
    161         while True:
    162             try:
    163                 while True:
    164                     rfds = [key.fileobj for (key, events) in selector.select()]
    165                     if rfds:
    166                         break
    167 
    168                 if alive_r in rfds:
    169                     # EOF because no more client processes left
    170                     assert os.read(alive_r, 1) == b''
    171                     raise SystemExit
    172 
    173                 assert listener in rfds
    174                 with listener.accept()[0] as s:
    175                     code = 1
    176                     if os.fork() == 0:
    177                         try:
    178                             _serve_one(s, listener, alive_r, handler)
    179                         except Exception:
    180                             sys.excepthook(*sys.exc_info())
    181                             sys.stderr.flush()
    182                         finally:
    183                             os._exit(code)
    184 
    185             except OSError as e:
    186                 if e.errno != errno.ECONNABORTED:
    187                     raise
    188 
    189 def _serve_one(s, listener, alive_r, handler):
    190     # close unnecessary stuff and reset SIGCHLD handler
    191     listener.close()
    192     os.close(alive_r)
    193     signal.signal(signal.SIGCHLD, handler)
    194 
    195     # receive fds from parent process
    196     fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
    197     s.close()
    198     assert len(fds) <= MAXFDS_TO_SEND
    199     (child_r, child_w, _forkserver._forkserver_alive_fd,
    200      stfd, *_forkserver._inherited_fds) = fds
    201     semaphore_tracker._semaphore_tracker._fd = stfd
    202 
    203     # send pid to client processes
    204     write_unsigned(child_w, os.getpid())
    205 
    206     # reseed random number generator
    207     if 'random' in sys.modules:
    208         import random
    209         random.seed()
    210 
    211     # run process object received over pipe
    212     code = spawn._main(child_r)
    213 
    214     # write the exit code to the pipe
    215     write_unsigned(child_w, code)
    216 
    217 #
    218 # Read and write unsigned numbers
    219 #
    220 
    221 def read_unsigned(fd):
    222     data = b''
    223     length = UNSIGNED_STRUCT.size
    224     while len(data) < length:
    225         s = os.read(fd, length - len(data))
    226         if not s:
    227             raise EOFError('unexpected EOF')
    228         data += s
    229     return UNSIGNED_STRUCT.unpack(data)[0]
    230 
    231 def write_unsigned(fd, n):
    232     msg = UNSIGNED_STRUCT.pack(n)
    233     while msg:
    234         nbytes = os.write(fd, msg)
    235         if nbytes == 0:
    236             raise RuntimeError('should not get here')
    237         msg = msg[nbytes:]
    238 
    239 #
    240 #
    241 #
    242 
    243 _forkserver = ForkServer()
    244 ensure_running = _forkserver.ensure_running
    245 get_inherited_fds = _forkserver.get_inherited_fds
    246 connect_to_new_process = _forkserver.connect_to_new_process
    247 set_forkserver_preload = _forkserver.set_forkserver_preload
    248