Home | History | Annotate | Download | only in idlelib
      1 """RPC Implementation, originally written for the Python Idle IDE
      2 
      3 For security reasons, GvR requested that Idle's Python execution server process
      4 connect to the Idle process, which listens for the connection.  Since Idle has
      5 only one client per server, this was not a limitation.
      6 
      7    +---------------------------------+ +-------------+
      8    | SocketServer.BaseRequestHandler | | SocketIO    |
      9    +---------------------------------+ +-------------+
     10                    ^                   | register()  |
     11                    |                   | unregister()|
     12                    |                   +-------------+
     13                    |                      ^  ^
     14                    |                      |  |
     15                    | + -------------------+  |
     16                    | |                       |
     17    +-------------------------+        +-----------------+
     18    | RPCHandler              |        | RPCClient       |
     19    | [attribute of RPCServer]|        |                 |
     20    +-------------------------+        +-----------------+
     21 
     22 The RPCServer handler class is expected to provide register/unregister methods.
     23 RPCHandler inherits the mix-in class SocketIO, which provides these methods.
     24 
     25 See the Idle run.main() docstring for further information on how this was
     26 accomplished in Idle.
     27 
     28 """
     29 
     30 import sys
     31 import os
     32 import socket
     33 import select
     34 import SocketServer
     35 import struct
     36 import cPickle as pickle
     37 import threading
     38 import Queue
     39 import traceback
     40 import copy_reg
     41 import types
     42 import marshal
     43 
     44 
     45 def unpickle_code(ms):
     46     co = marshal.loads(ms)
     47     assert isinstance(co, types.CodeType)
     48     return co
     49 
     50 def pickle_code(co):
     51     assert isinstance(co, types.CodeType)
     52     ms = marshal.dumps(co)
     53     return unpickle_code, (ms,)
     54 
     55 # XXX KBK 24Aug02 function pickling capability not used in Idle
     56 #  def unpickle_function(ms):
     57 #      return ms
     58 
     59 #  def pickle_function(fn):
     60 #      assert isinstance(fn, type.FunctionType)
     61 #      return repr(fn)
     62 
     63 copy_reg.pickle(types.CodeType, pickle_code, unpickle_code)
     64 # copy_reg.pickle(types.FunctionType, pickle_function, unpickle_function)
     65 
     66 BUFSIZE = 8*1024
     67 LOCALHOST = '127.0.0.1'
     68 
     69 class RPCServer(SocketServer.TCPServer):
     70 
     71     def __init__(self, addr, handlerclass=None):
     72         if handlerclass is None:
     73             handlerclass = RPCHandler
     74         SocketServer.TCPServer.__init__(self, addr, handlerclass)
     75 
     76     def server_bind(self):
     77         "Override TCPServer method, no bind() phase for connecting entity"
     78         pass
     79 
     80     def server_activate(self):
     81         """Override TCPServer method, connect() instead of listen()
     82 
     83         Due to the reversed connection, self.server_address is actually the
     84         address of the Idle Client to which we are connecting.
     85 
     86         """
     87         self.socket.connect(self.server_address)
     88 
     89     def get_request(self):
     90         "Override TCPServer method, return already connected socket"
     91         return self.socket, self.server_address
     92 
     93     def handle_error(self, request, client_address):
     94         """Override TCPServer method
     95 
     96         Error message goes to __stderr__.  No error message if exiting
     97         normally or socket raised EOF.  Other exceptions not handled in
     98         server code will cause os._exit.
     99 
    100         """
    101         try:
    102             raise
    103         except SystemExit:
    104             raise
    105         except:
    106             erf = sys.__stderr__
    107             print>>erf, '\n' + '-'*40
    108             print>>erf, 'Unhandled server exception!'
    109             print>>erf, 'Thread: %s' % threading.currentThread().getName()
    110             print>>erf, 'Client Address: ', client_address
    111             print>>erf, 'Request: ', repr(request)
    112             traceback.print_exc(file=erf)
    113             print>>erf, '\n*** Unrecoverable, server exiting!'
    114             print>>erf, '-'*40
    115             os._exit(0)
    116 
    117 #----------------- end class RPCServer --------------------
    118 
    119 objecttable = {}
    120 request_queue = Queue.Queue(0)
    121 response_queue = Queue.Queue(0)
    122 
    123 
    124 class SocketIO(object):
    125 
    126     nextseq = 0
    127 
    128     def __init__(self, sock, objtable=None, debugging=None):
    129         self.sockthread = threading.currentThread()
    130         if debugging is not None:
    131             self.debugging = debugging
    132         self.sock = sock
    133         if objtable is None:
    134             objtable = objecttable
    135         self.objtable = objtable
    136         self.responses = {}
    137         self.cvars = {}
    138 
    139     def close(self):
    140         sock = self.sock
    141         self.sock = None
    142         if sock is not None:
    143             sock.close()
    144 
    145     def exithook(self):
    146         "override for specific exit action"
    147         os._exit(0)
    148 
    149     def debug(self, *args):
    150         if not self.debugging:
    151             return
    152         s = self.location + " " + str(threading.currentThread().getName())
    153         for a in args:
    154             s = s + " " + str(a)
    155         print>>sys.__stderr__, s
    156 
    157     def register(self, oid, object):
    158         self.objtable[oid] = object
    159 
    160     def unregister(self, oid):
    161         try:
    162             del self.objtable[oid]
    163         except KeyError:
    164             pass
    165 
    166     def localcall(self, seq, request):
    167         self.debug("localcall:", request)
    168         try:
    169             how, (oid, methodname, args, kwargs) = request
    170         except TypeError:
    171             return ("ERROR", "Bad request format")
    172         if oid not in self.objtable:
    173             return ("ERROR", "Unknown object id: %r" % (oid,))
    174         obj = self.objtable[oid]
    175         if methodname == "__methods__":
    176             methods = {}
    177             _getmethods(obj, methods)
    178             return ("OK", methods)
    179         if methodname == "__attributes__":
    180             attributes = {}
    181             _getattributes(obj, attributes)
    182             return ("OK", attributes)
    183         if not hasattr(obj, methodname):
    184             return ("ERROR", "Unsupported method name: %r" % (methodname,))
    185         method = getattr(obj, methodname)
    186         try:
    187             if how == 'CALL':
    188                 ret = method(*args, **kwargs)
    189                 if isinstance(ret, RemoteObject):
    190                     ret = remoteref(ret)
    191                 return ("OK", ret)
    192             elif how == 'QUEUE':
    193                 request_queue.put((seq, (method, args, kwargs)))
    194                 return("QUEUED", None)
    195             else:
    196                 return ("ERROR", "Unsupported message type: %s" % how)
    197         except SystemExit:
    198             raise
    199         except socket.error:
    200             raise
    201         except:
    202             msg = "*** Internal Error: rpc.py:SocketIO.localcall()\n\n"\
    203                   " Object: %s \n Method: %s \n Args: %s\n"
    204             print>>sys.__stderr__, msg % (oid, method, args)
    205             traceback.print_exc(file=sys.__stderr__)
    206             return ("EXCEPTION", None)
    207 
    208     def remotecall(self, oid, methodname, args, kwargs):
    209         self.debug("remotecall:asynccall: ", oid, methodname)
    210         seq = self.asynccall(oid, methodname, args, kwargs)
    211         return self.asyncreturn(seq)
    212 
    213     def remotequeue(self, oid, methodname, args, kwargs):
    214         self.debug("remotequeue:asyncqueue: ", oid, methodname)
    215         seq = self.asyncqueue(oid, methodname, args, kwargs)
    216         return self.asyncreturn(seq)
    217 
    218     def asynccall(self, oid, methodname, args, kwargs):
    219         request = ("CALL", (oid, methodname, args, kwargs))
    220         seq = self.newseq()
    221         if threading.currentThread() != self.sockthread:
    222             cvar = threading.Condition()
    223             self.cvars[seq] = cvar
    224         self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
    225         self.putmessage((seq, request))
    226         return seq
    227 
    228     def asyncqueue(self, oid, methodname, args, kwargs):
    229         request = ("QUEUE", (oid, methodname, args, kwargs))
    230         seq = self.newseq()
    231         if threading.currentThread() != self.sockthread:
    232             cvar = threading.Condition()
    233             self.cvars[seq] = cvar
    234         self.debug(("asyncqueue:%d:" % seq), oid, methodname, args, kwargs)
    235         self.putmessage((seq, request))
    236         return seq
    237 
    238     def asyncreturn(self, seq):
    239         self.debug("asyncreturn:%d:call getresponse(): " % seq)
    240         response = self.getresponse(seq, wait=0.05)
    241         self.debug(("asyncreturn:%d:response: " % seq), response)
    242         return self.decoderesponse(response)
    243 
    244     def decoderesponse(self, response):
    245         how, what = response
    246         if how == "OK":
    247             return what
    248         if how == "QUEUED":
    249             return None
    250         if how == "EXCEPTION":
    251             self.debug("decoderesponse: EXCEPTION")
    252             return None
    253         if how == "EOF":
    254             self.debug("decoderesponse: EOF")
    255             self.decode_interrupthook()
    256             return None
    257         if how == "ERROR":
    258             self.debug("decoderesponse: Internal ERROR:", what)
    259             raise RuntimeError, what
    260         raise SystemError, (how, what)
    261 
    262     def decode_interrupthook(self):
    263         ""
    264         raise EOFError
    265 
    266     def mainloop(self):
    267         """Listen on socket until I/O not ready or EOF
    268 
    269         pollresponse() will loop looking for seq number None, which
    270         never comes, and exit on EOFError.
    271 
    272         """
    273         try:
    274             self.getresponse(myseq=None, wait=0.05)
    275         except EOFError:
    276             self.debug("mainloop:return")
    277             return
    278 
    279     def getresponse(self, myseq, wait):
    280         response = self._getresponse(myseq, wait)
    281         if response is not None:
    282             how, what = response
    283             if how == "OK":
    284                 response = how, self._proxify(what)
    285         return response
    286 
    287     def _proxify(self, obj):
    288         if isinstance(obj, RemoteProxy):
    289             return RPCProxy(self, obj.oid)
    290         if isinstance(obj, types.ListType):
    291             return map(self._proxify, obj)
    292         # XXX Check for other types -- not currently needed
    293         return obj
    294 
    295     def _getresponse(self, myseq, wait):
    296         self.debug("_getresponse:myseq:", myseq)
    297         if threading.currentThread() is self.sockthread:
    298             # this thread does all reading of requests or responses
    299             while 1:
    300                 response = self.pollresponse(myseq, wait)
    301                 if response is not None:
    302                     return response
    303         else:
    304             # wait for notification from socket handling thread
    305             cvar = self.cvars[myseq]
    306             cvar.acquire()
    307             while myseq not in self.responses:
    308                 cvar.wait()
    309             response = self.responses[myseq]
    310             self.debug("_getresponse:%s: thread woke up: response: %s" %
    311                        (myseq, response))
    312             del self.responses[myseq]
    313             del self.cvars[myseq]
    314             cvar.release()
    315             return response
    316 
    317     def newseq(self):
    318         self.nextseq = seq = self.nextseq + 2
    319         return seq
    320 
    321     def putmessage(self, message):
    322         self.debug("putmessage:%d:" % message[0])
    323         try:
    324             s = pickle.dumps(message)
    325         except pickle.PicklingError:
    326             print >>sys.__stderr__, "Cannot pickle:", repr(message)
    327             raise
    328         s = struct.pack("<i", len(s)) + s
    329         while len(s) > 0:
    330             try:
    331                 r, w, x = select.select([], [self.sock], [])
    332                 n = self.sock.send(s[:BUFSIZE])
    333             except (AttributeError, TypeError):
    334                 raise IOError, "socket no longer exists"
    335             s = s[n:]
    336 
    337     buffer = ""
    338     bufneed = 4
    339     bufstate = 0 # meaning: 0 => reading count; 1 => reading data
    340 
    341     def pollpacket(self, wait):
    342         self._stage0()
    343         if len(self.buffer) < self.bufneed:
    344             r, w, x = select.select([self.sock.fileno()], [], [], wait)
    345             if len(r) == 0:
    346                 return None
    347             try:
    348                 s = self.sock.recv(BUFSIZE)
    349             except socket.error:
    350                 raise EOFError
    351             if len(s) == 0:
    352                 raise EOFError
    353             self.buffer += s
    354             self._stage0()
    355         return self._stage1()
    356 
    357     def _stage0(self):
    358         if self.bufstate == 0 and len(self.buffer) >= 4:
    359             s = self.buffer[:4]
    360             self.buffer = self.buffer[4:]
    361             self.bufneed = struct.unpack("<i", s)[0]
    362             self.bufstate = 1
    363 
    364     def _stage1(self):
    365         if self.bufstate == 1 and len(self.buffer) >= self.bufneed:
    366             packet = self.buffer[:self.bufneed]
    367             self.buffer = self.buffer[self.bufneed:]
    368             self.bufneed = 4
    369             self.bufstate = 0
    370             return packet
    371 
    372     def pollmessage(self, wait):
    373         packet = self.pollpacket(wait)
    374         if packet is None:
    375             return None
    376         try:
    377             message = pickle.loads(packet)
    378         except pickle.UnpicklingError:
    379             print >>sys.__stderr__, "-----------------------"
    380             print >>sys.__stderr__, "cannot unpickle packet:", repr(packet)
    381             traceback.print_stack(file=sys.__stderr__)
    382             print >>sys.__stderr__, "-----------------------"
    383             raise
    384         return message
    385 
    386     def pollresponse(self, myseq, wait):
    387         """Handle messages received on the socket.
    388 
    389         Some messages received may be asynchronous 'call' or 'queue' requests,
    390         and some may be responses for other threads.
    391 
    392         'call' requests are passed to self.localcall() with the expectation of
    393         immediate execution, during which time the socket is not serviced.
    394 
    395         'queue' requests are used for tasks (which may block or hang) to be
    396         processed in a different thread.  These requests are fed into
    397         request_queue by self.localcall().  Responses to queued requests are
    398         taken from response_queue and sent across the link with the associated
    399         sequence numbers.  Messages in the queues are (sequence_number,
    400         request/response) tuples and code using this module removing messages
    401         from the request_queue is responsible for returning the correct
    402         sequence number in the response_queue.
    403 
    404         pollresponse() will loop until a response message with the myseq
    405         sequence number is received, and will save other responses in
    406         self.responses and notify the owning thread.
    407 
    408         """
    409         while 1:
    410             # send queued response if there is one available
    411             try:
    412                 qmsg = response_queue.get(0)
    413             except Queue.Empty:
    414                 pass
    415             else:
    416                 seq, response = qmsg
    417                 message = (seq, ('OK', response))
    418                 self.putmessage(message)
    419             # poll for message on link
    420             try:
    421                 message = self.pollmessage(wait)
    422                 if message is None:  # socket not ready
    423                     return None
    424             except EOFError:
    425                 self.handle_EOF()
    426                 return None
    427             except AttributeError:
    428                 return None
    429             seq, resq = message
    430             how = resq[0]
    431             self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
    432             # process or queue a request
    433             if how in ("CALL", "QUEUE"):
    434                 self.debug("pollresponse:%d:localcall:call:" % seq)
    435                 response = self.localcall(seq, resq)
    436                 self.debug("pollresponse:%d:localcall:response:%s"
    437                            % (seq, response))
    438                 if how == "CALL":
    439                     self.putmessage((seq, response))
    440                 elif how == "QUEUE":
    441                     # don't acknowledge the 'queue' request!
    442                     pass
    443                 continue
    444             # return if completed message transaction
    445             elif seq == myseq:
    446                 return resq
    447             # must be a response for a different thread:
    448             else:
    449                 cv = self.cvars.get(seq, None)
    450                 # response involving unknown sequence number is discarded,
    451                 # probably intended for prior incarnation of server
    452                 if cv is not None:
    453                     cv.acquire()
    454                     self.responses[seq] = resq
    455                     cv.notify()
    456                     cv.release()
    457                 continue
    458 
    459     def handle_EOF(self):
    460         "action taken upon link being closed by peer"
    461         self.EOFhook()
    462         self.debug("handle_EOF")
    463         for key in self.cvars:
    464             cv = self.cvars[key]
    465             cv.acquire()
    466             self.responses[key] = ('EOF', None)
    467             cv.notify()
    468             cv.release()
    469         # call our (possibly overridden) exit function
    470         self.exithook()
    471 
    472     def EOFhook(self):
    473         "Classes using rpc client/server can override to augment EOF action"
    474         pass
    475 
    476 #----------------- end class SocketIO --------------------
    477 
    478 class RemoteObject(object):
    479     # Token mix-in class
    480     pass
    481 
    482 def remoteref(obj):
    483     oid = id(obj)
    484     objecttable[oid] = obj
    485     return RemoteProxy(oid)
    486 
    487 class RemoteProxy(object):
    488 
    489     def __init__(self, oid):
    490         self.oid = oid
    491 
    492 class RPCHandler(SocketServer.BaseRequestHandler, SocketIO):
    493 
    494     debugging = False
    495     location = "#S"  # Server
    496 
    497     def __init__(self, sock, addr, svr):
    498         svr.current_handler = self ## cgt xxx
    499         SocketIO.__init__(self, sock)
    500         SocketServer.BaseRequestHandler.__init__(self, sock, addr, svr)
    501 
    502     def handle(self):
    503         "handle() method required by SocketServer"
    504         self.mainloop()
    505 
    506     def get_remote_proxy(self, oid):
    507         return RPCProxy(self, oid)
    508 
    509 class RPCClient(SocketIO):
    510 
    511     debugging = False
    512     location = "#C"  # Client
    513 
    514     nextseq = 1 # Requests coming from the client are odd numbered
    515 
    516     def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
    517         self.listening_sock = socket.socket(family, type)
    518         self.listening_sock.bind(address)
    519         self.listening_sock.listen(1)
    520 
    521     def accept(self):
    522         working_sock, address = self.listening_sock.accept()
    523         if self.debugging:
    524             print>>sys.__stderr__, "****** Connection request from ", address
    525         if address[0] == LOCALHOST:
    526             SocketIO.__init__(self, working_sock)
    527         else:
    528             print>>sys.__stderr__, "** Invalid host: ", address
    529             raise socket.error
    530 
    531     def get_remote_proxy(self, oid):
    532         return RPCProxy(self, oid)
    533 
    534 class RPCProxy(object):
    535 
    536     __methods = None
    537     __attributes = None
    538 
    539     def __init__(self, sockio, oid):
    540         self.sockio = sockio
    541         self.oid = oid
    542 
    543     def __getattr__(self, name):
    544         if self.__methods is None:
    545             self.__getmethods()
    546         if self.__methods.get(name):
    547             return MethodProxy(self.sockio, self.oid, name)
    548         if self.__attributes is None:
    549             self.__getattributes()
    550         if name in self.__attributes:
    551             value = self.sockio.remotecall(self.oid, '__getattribute__',
    552                                            (name,), {})
    553             return value
    554         else:
    555             raise AttributeError, name
    556 
    557     def __getattributes(self):
    558         self.__attributes = self.sockio.remotecall(self.oid,
    559                                                 "__attributes__", (), {})
    560 
    561     def __getmethods(self):
    562         self.__methods = self.sockio.remotecall(self.oid,
    563                                                 "__methods__", (), {})
    564 
    565 def _getmethods(obj, methods):
    566     # Helper to get a list of methods from an object
    567     # Adds names to dictionary argument 'methods'
    568     for name in dir(obj):
    569         attr = getattr(obj, name)
    570         if hasattr(attr, '__call__'):
    571             methods[name] = 1
    572     if type(obj) == types.InstanceType:
    573         _getmethods(obj.__class__, methods)
    574     if type(obj) == types.ClassType:
    575         for super in obj.__bases__:
    576             _getmethods(super, methods)
    577 
    578 def _getattributes(obj, attributes):
    579     for name in dir(obj):
    580         attr = getattr(obj, name)
    581         if not hasattr(attr, '__call__'):
    582             attributes[name] = 1
    583 
    584 class MethodProxy(object):
    585 
    586     def __init__(self, sockio, oid, name):
    587         self.sockio = sockio
    588         self.oid = oid
    589         self.name = name
    590 
    591     def __call__(self, *args, **kwargs):
    592         value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
    593         return value
    594 
    595 
    596 # XXX KBK 09Sep03  We need a proper unit test for this module.  Previously
    597 #                  existing test code was removed at Rev 1.27 (r34098).
    598