Home | History | Annotate | Download | only in multiprocessing
      1 #
      2 # A higher level module for using sockets (or Windows named pipes)
      3 #
      4 # multiprocessing/connection.py
      5 #
      6 # Copyright (c) 2006-2008, R Oudkerk
      7 # All rights reserved.
      8 #
      9 # Redistribution and use in source and binary forms, with or without
     10 # modification, are permitted provided that the following conditions
     11 # are met:
     12 #
     13 # 1. Redistributions of source code must retain the above copyright
     14 #    notice, this list of conditions and the following disclaimer.
     15 # 2. Redistributions in binary form must reproduce the above copyright
     16 #    notice, this list of conditions and the following disclaimer in the
     17 #    documentation and/or other materials provided with the distribution.
     18 # 3. Neither the name of author nor the names of any contributors may be
     19 #    used to endorse or promote products derived from this software
     20 #    without specific prior written permission.
     21 #
     22 # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
     23 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     24 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     25 # ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
     26 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
     27 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
     28 # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
     29 # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     30 # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
     31 # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
     32 # SUCH DAMAGE.
     33 #
     34 
     35 __all__ = [ 'Client', 'Listener', 'Pipe' ]
     36 
     37 import os
     38 import sys
     39 import socket
     40 import errno
     41 import time
     42 import tempfile
     43 import itertools
     44 
     45 import _multiprocessing
     46 from multiprocessing import current_process, AuthenticationError
     47 from multiprocessing.util import get_temp_dir, Finalize, sub_debug, debug
     48 from multiprocessing.forking import duplicate, close
     49 
     50 
     51 #
     52 #
     53 #
     54 
     55 BUFSIZE = 8192
     56 # A very generous timeout when it comes to local connections...
     57 CONNECTION_TIMEOUT = 20.
     58 
     59 _mmap_counter = itertools.count()
     60 
     61 default_family = 'AF_INET'
     62 families = ['AF_INET']
     63 
     64 if hasattr(socket, 'AF_UNIX'):
     65     default_family = 'AF_UNIX'
     66     families += ['AF_UNIX']
     67 
     68 if sys.platform == 'win32':
     69     default_family = 'AF_PIPE'
     70     families += ['AF_PIPE']
     71 
     72 
     73 def _init_timeout(timeout=CONNECTION_TIMEOUT):
     74     return time.time() + timeout
     75 
     76 def _check_timeout(t):
     77     return time.time() > t
     78 
     79 #
     80 #
     81 #
     82 
     83 def arbitrary_address(family):
     84     '''
     85     Return an arbitrary free address for the given family
     86     '''
     87     if family == 'AF_INET':
     88         return ('localhost', 0)
     89     elif family == 'AF_UNIX':
     90         return tempfile.mktemp(prefix='listener-', dir=get_temp_dir())
     91     elif family == 'AF_PIPE':
     92         return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
     93                                (os.getpid(), _mmap_counter.next()), dir="")
     94     else:
     95         raise ValueError('unrecognized family')
     96 
     97 
     98 def address_type(address):
     99     '''
    100     Return the types of the address
    101 
    102     This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE'
    103     '''
    104     if type(address) == tuple:
    105         return 'AF_INET'
    106     elif type(address) is str and address.startswith('\\\\'):
    107         return 'AF_PIPE'
    108     elif type(address) is str:
    109         return 'AF_UNIX'
    110     else:
    111         raise ValueError('address type of %r unrecognized' % address)
    112 
    113 #
    114 # Public functions
    115 #
    116 
    117 class Listener(object):
    118     '''
    119     Returns a listener object.
    120 
    121     This is a wrapper for a bound socket which is 'listening' for
    122     connections, or for a Windows named pipe.
    123     '''
    124     def __init__(self, address=None, family=None, backlog=1, authkey=None):
    125         family = family or (address and address_type(address)) \
    126                  or default_family
    127         address = address or arbitrary_address(family)
    128 
    129         if family == 'AF_PIPE':
    130             self._listener = PipeListener(address, backlog)
    131         else:
    132             self._listener = SocketListener(address, family, backlog)
    133 
    134         if authkey is not None and not isinstance(authkey, bytes):
    135             raise TypeError, 'authkey should be a byte string'
    136 
    137         self._authkey = authkey
    138 
    139     def accept(self):
    140         '''
    141         Accept a connection on the bound socket or named pipe of `self`.
    142 
    143         Returns a `Connection` object.
    144         '''
    145         c = self._listener.accept()
    146         if self._authkey:
    147             deliver_challenge(c, self._authkey)
    148             answer_challenge(c, self._authkey)
    149         return c
    150 
    151     def close(self):
    152         '''
    153         Close the bound socket or named pipe of `self`.
    154         '''
    155         return self._listener.close()
    156 
    157     address = property(lambda self: self._listener._address)
    158     last_accepted = property(lambda self: self._listener._last_accepted)
    159 
    160 
    161 def Client(address, family=None, authkey=None):
    162     '''
    163     Returns a connection to the address of a `Listener`
    164     '''
    165     family = family or address_type(address)
    166     if family == 'AF_PIPE':
    167         c = PipeClient(address)
    168     else:
    169         c = SocketClient(address)
    170 
    171     if authkey is not None and not isinstance(authkey, bytes):
    172         raise TypeError, 'authkey should be a byte string'
    173 
    174     if authkey is not None:
    175         answer_challenge(c, authkey)
    176         deliver_challenge(c, authkey)
    177 
    178     return c
    179 
    180 
    181 if sys.platform != 'win32':
    182 
    183     def Pipe(duplex=True):
    184         '''
    185         Returns pair of connection objects at either end of a pipe
    186         '''
    187         if duplex:
    188             s1, s2 = socket.socketpair()
    189             s1.setblocking(True)
    190             s2.setblocking(True)
    191             c1 = _multiprocessing.Connection(os.dup(s1.fileno()))
    192             c2 = _multiprocessing.Connection(os.dup(s2.fileno()))
    193             s1.close()
    194             s2.close()
    195         else:
    196             fd1, fd2 = os.pipe()
    197             c1 = _multiprocessing.Connection(fd1, writable=False)
    198             c2 = _multiprocessing.Connection(fd2, readable=False)
    199 
    200         return c1, c2
    201 
    202 else:
    203     from _multiprocessing import win32
    204 
    205     def Pipe(duplex=True):
    206         '''
    207         Returns pair of connection objects at either end of a pipe
    208         '''
    209         address = arbitrary_address('AF_PIPE')
    210         if duplex:
    211             openmode = win32.PIPE_ACCESS_DUPLEX
    212             access = win32.GENERIC_READ | win32.GENERIC_WRITE
    213             obsize, ibsize = BUFSIZE, BUFSIZE
    214         else:
    215             openmode = win32.PIPE_ACCESS_INBOUND
    216             access = win32.GENERIC_WRITE
    217             obsize, ibsize = 0, BUFSIZE
    218 
    219         h1 = win32.CreateNamedPipe(
    220             address, openmode,
    221             win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
    222             win32.PIPE_WAIT,
    223             1, obsize, ibsize, win32.NMPWAIT_WAIT_FOREVER, win32.NULL
    224             )
    225         h2 = win32.CreateFile(
    226             address, access, 0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
    227             )
    228         win32.SetNamedPipeHandleState(
    229             h2, win32.PIPE_READMODE_MESSAGE, None, None
    230             )
    231 
    232         try:
    233             win32.ConnectNamedPipe(h1, win32.NULL)
    234         except WindowsError, e:
    235             if e.args[0] != win32.ERROR_PIPE_CONNECTED:
    236                 raise
    237 
    238         c1 = _multiprocessing.PipeConnection(h1, writable=duplex)
    239         c2 = _multiprocessing.PipeConnection(h2, readable=duplex)
    240 
    241         return c1, c2
    242 
    243 #
    244 # Definitions for connections based on sockets
    245 #
    246 
    247 class SocketListener(object):
    248     '''
    249     Representation of a socket which is bound to an address and listening
    250     '''
    251     def __init__(self, address, family, backlog=1):
    252         self._socket = socket.socket(getattr(socket, family))
    253         try:
    254             self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    255             self._socket.setblocking(True)
    256             self._socket.bind(address)
    257             self._socket.listen(backlog)
    258             self._address = self._socket.getsockname()
    259         except socket.error:
    260             self._socket.close()
    261             raise
    262         self._family = family
    263         self._last_accepted = None
    264 
    265         if family == 'AF_UNIX':
    266             self._unlink = Finalize(
    267                 self, os.unlink, args=(address,), exitpriority=0
    268                 )
    269         else:
    270             self._unlink = None
    271 
    272     def accept(self):
    273         while True:
    274             try:
    275                 s, self._last_accepted = self._socket.accept()
    276             except socket.error as e:
    277                 if e.args[0] != errno.EINTR:
    278                     raise
    279             else:
    280                 break
    281         s.setblocking(True)
    282         fd = duplicate(s.fileno())
    283         conn = _multiprocessing.Connection(fd)
    284         s.close()
    285         return conn
    286 
    287     def close(self):
    288         try:
    289             self._socket.close()
    290         finally:
    291             unlink = self._unlink
    292             if unlink is not None:
    293                 self._unlink = None
    294                 unlink()
    295 
    296 
    297 def SocketClient(address):
    298     '''
    299     Return a connection object connected to the socket given by `address`
    300     '''
    301     family = getattr(socket, address_type(address))
    302     t = _init_timeout()
    303 
    304     while 1:
    305         s = socket.socket(family)
    306         s.setblocking(True)
    307         try:
    308             s.connect(address)
    309         except socket.error, e:
    310             s.close()
    311             if e.args[0] != errno.ECONNREFUSED or _check_timeout(t):
    312                 debug('failed to connect to address %s', address)
    313                 raise
    314             time.sleep(0.01)
    315         else:
    316             break
    317     else:
    318         raise
    319 
    320     fd = duplicate(s.fileno())
    321     conn = _multiprocessing.Connection(fd)
    322     s.close()
    323     return conn
    324 
    325 #
    326 # Definitions for connections based on named pipes
    327 #
    328 
    329 if sys.platform == 'win32':
    330 
    331     class PipeListener(object):
    332         '''
    333         Representation of a named pipe
    334         '''
    335         def __init__(self, address, backlog=None):
    336             self._address = address
    337             handle = win32.CreateNamedPipe(
    338                 address, win32.PIPE_ACCESS_DUPLEX,
    339                 win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
    340                 win32.PIPE_WAIT,
    341                 win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
    342                 win32.NMPWAIT_WAIT_FOREVER, win32.NULL
    343                 )
    344             self._handle_queue = [handle]
    345             self._last_accepted = None
    346 
    347             sub_debug('listener created with address=%r', self._address)
    348 
    349             self.close = Finalize(
    350                 self, PipeListener._finalize_pipe_listener,
    351                 args=(self._handle_queue, self._address), exitpriority=0
    352                 )
    353 
    354         def accept(self):
    355             newhandle = win32.CreateNamedPipe(
    356                 self._address, win32.PIPE_ACCESS_DUPLEX,
    357                 win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
    358                 win32.PIPE_WAIT,
    359                 win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
    360                 win32.NMPWAIT_WAIT_FOREVER, win32.NULL
    361                 )
    362             self._handle_queue.append(newhandle)
    363             handle = self._handle_queue.pop(0)
    364             try:
    365                 win32.ConnectNamedPipe(handle, win32.NULL)
    366             except WindowsError, e:
    367                 # ERROR_NO_DATA can occur if a client has already connected,
    368                 # written data and then disconnected -- see Issue 14725.
    369                 if e.args[0] not in (win32.ERROR_PIPE_CONNECTED,
    370                                      win32.ERROR_NO_DATA):
    371                     raise
    372             return _multiprocessing.PipeConnection(handle)
    373 
    374         @staticmethod
    375         def _finalize_pipe_listener(queue, address):
    376             sub_debug('closing listener with address=%r', address)
    377             for handle in queue:
    378                 close(handle)
    379 
    380     def PipeClient(address):
    381         '''
    382         Return a connection object connected to the pipe given by `address`
    383         '''
    384         t = _init_timeout()
    385         while 1:
    386             try:
    387                 win32.WaitNamedPipe(address, 1000)
    388                 h = win32.CreateFile(
    389                     address, win32.GENERIC_READ | win32.GENERIC_WRITE,
    390                     0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
    391                     )
    392             except WindowsError, e:
    393                 if e.args[0] not in (win32.ERROR_SEM_TIMEOUT,
    394                                      win32.ERROR_PIPE_BUSY) or _check_timeout(t):
    395                     raise
    396             else:
    397                 break
    398         else:
    399             raise
    400 
    401         win32.SetNamedPipeHandleState(
    402             h, win32.PIPE_READMODE_MESSAGE, None, None
    403             )
    404         return _multiprocessing.PipeConnection(h)
    405 
    406 #
    407 # Authentication stuff
    408 #
    409 
    410 MESSAGE_LENGTH = 20
    411 
    412 CHALLENGE = b'#CHALLENGE#'
    413 WELCOME = b'#WELCOME#'
    414 FAILURE = b'#FAILURE#'
    415 
    416 def deliver_challenge(connection, authkey):
    417     import hmac
    418     assert isinstance(authkey, bytes)
    419     message = os.urandom(MESSAGE_LENGTH)
    420     connection.send_bytes(CHALLENGE + message)
    421     digest = hmac.new(authkey, message).digest()
    422     response = connection.recv_bytes(256)        # reject large message
    423     if response == digest:
    424         connection.send_bytes(WELCOME)
    425     else:
    426         connection.send_bytes(FAILURE)
    427         raise AuthenticationError('digest received was wrong')
    428 
    429 def answer_challenge(connection, authkey):
    430     import hmac
    431     assert isinstance(authkey, bytes)
    432     message = connection.recv_bytes(256)         # reject large message
    433     assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message
    434     message = message[len(CHALLENGE):]
    435     digest = hmac.new(authkey, message).digest()
    436     connection.send_bytes(digest)
    437     response = connection.recv_bytes(256)        # reject large message
    438     if response != WELCOME:
    439         raise AuthenticationError('digest sent was rejected')
    440 
    441 #
    442 # Support for using xmlrpclib for serialization
    443 #
    444 
    445 class ConnectionWrapper(object):
    446     def __init__(self, conn, dumps, loads):
    447         self._conn = conn
    448         self._dumps = dumps
    449         self._loads = loads
    450         for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'):
    451             obj = getattr(conn, attr)
    452             setattr(self, attr, obj)
    453     def send(self, obj):
    454         s = self._dumps(obj)
    455         self._conn.send_bytes(s)
    456     def recv(self):
    457         s = self._conn.recv_bytes()
    458         return self._loads(s)
    459 
    460 def _xml_dumps(obj):
    461     return xmlrpclib.dumps((obj,), None, None, None, 1)
    462 
    463 def _xml_loads(s):
    464     (obj,), method = xmlrpclib.loads(s)
    465     return obj
    466 
    467 class XmlListener(Listener):
    468     def accept(self):
    469         global xmlrpclib
    470         import xmlrpclib
    471         obj = Listener.accept(self)
    472         return ConnectionWrapper(obj, _xml_dumps, _xml_loads)
    473 
    474 def XmlClient(*args, **kwds):
    475     global xmlrpclib
    476     import xmlrpclib
    477     return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
    478