Home | History | Annotate | Download | only in idlelib
      1 """RPC Implemention, originally written for the Python Idle IDE
      2 
      3 For security reasons, GvR requested that Idle's Python execution server process
      4 connect to the Idle process, which listens for the connection.  Since Idle has
      5 only one client per server, this was not a limitation.
      6 
      7    +---------------------------------+ +-------------+
      8    | SocketServer.BaseRequestHandler | | SocketIO    |
      9    +---------------------------------+ +-------------+
     10                    ^                   | register()  |
     11                    |                   | unregister()|
     12                    |                   +-------------+
     13                    |                      ^  ^
     14                    |                      |  |
     15                    | + -------------------+  |
     16                    | |                       |
     17    +-------------------------+        +-----------------+
     18    | RPCHandler              |        | RPCClient       |
     19    | [attribute of RPCServer]|        |                 |
     20    +-------------------------+        +-----------------+
     21 
     22 The RPCServer handler class is expected to provide register/unregister methods.
     23 RPCHandler inherits the mix-in class SocketIO, which provides these methods.
     24 
     25 See the Idle run.main() docstring for further information on how this was
     26 accomplished in Idle.
     27 
     28 """
     29 
     30 import sys
     31 import os
     32 import socket
     33 import select
     34 import SocketServer
     35 import struct
     36 import cPickle as pickle
     37 import threading
     38 import Queue
     39 import traceback
     40 import copy_reg
     41 import types
     42 import marshal
     43 
     44 
     45 def unpickle_code(ms):
     46     co = marshal.loads(ms)
     47     assert isinstance(co, types.CodeType)
     48     return co
     49 
     50 def pickle_code(co):
     51     assert isinstance(co, types.CodeType)
     52     ms = marshal.dumps(co)
     53     return unpickle_code, (ms,)
     54 
     55 # XXX KBK 24Aug02 function pickling capability not used in Idle
     56 #  def unpickle_function(ms):
     57 #      return ms
     58 
     59 #  def pickle_function(fn):
     60 #      assert isinstance(fn, type.FunctionType)
     61 #      return repr(fn)
     62 
     63 copy_reg.pickle(types.CodeType, pickle_code, unpickle_code)
     64 # copy_reg.pickle(types.FunctionType, pickle_function, unpickle_function)
     65 
     66 BUFSIZE = 8*1024
     67 LOCALHOST = '127.0.0.1'
     68 
     69 class RPCServer(SocketServer.TCPServer):
     70 
     71     def __init__(self, addr, handlerclass=None):
     72         if handlerclass is None:
     73             handlerclass = RPCHandler
     74         SocketServer.TCPServer.__init__(self, addr, handlerclass)
     75 
     76     def server_bind(self):
     77         "Override TCPServer method, no bind() phase for connecting entity"
     78         pass
     79 
     80     def server_activate(self):
     81         """Override TCPServer method, connect() instead of listen()
     82 
     83         Due to the reversed connection, self.server_address is actually the
     84         address of the Idle Client to which we are connecting.
     85 
     86         """
     87         self.socket.connect(self.server_address)
     88 
     89     def get_request(self):
     90         "Override TCPServer method, return already connected socket"
     91         return self.socket, self.server_address
     92 
     93     def handle_error(self, request, client_address):
     94         """Override TCPServer method
     95 
     96         Error message goes to __stderr__.  No error message if exiting
     97         normally or socket raised EOF.  Other exceptions not handled in
     98         server code will cause os._exit.
     99 
    100         """
    101         try:
    102             raise
    103         except SystemExit:
    104             raise
    105         except:
    106             erf = sys.__stderr__
    107             print>>erf, '\n' + '-'*40
    108             print>>erf, 'Unhandled server exception!'
    109             print>>erf, 'Thread: %s' % threading.currentThread().getName()
    110             print>>erf, 'Client Address: ', client_address
    111             print>>erf, 'Request: ', repr(request)
    112             traceback.print_exc(file=erf)
    113             print>>erf, '\n*** Unrecoverable, server exiting!'
    114             print>>erf, '-'*40
    115             os._exit(0)
    116 
    117 #----------------- end class RPCServer --------------------
    118 
    119 objecttable = {}
    120 request_queue = Queue.Queue(0)
    121 response_queue = Queue.Queue(0)
    122 
    123 
    124 class SocketIO(object):
    125 
    126     nextseq = 0
    127 
    128     def __init__(self, sock, objtable=None, debugging=None):
    129         self.sockthread = threading.currentThread()
    130         if debugging is not None:
    131             self.debugging = debugging
    132         self.sock = sock
    133         if objtable is None:
    134             objtable = objecttable
    135         self.objtable = objtable
    136         self.responses = {}
    137         self.cvars = {}
    138 
    139     def close(self):
    140         sock = self.sock
    141         self.sock = None
    142         if sock is not None:
    143             sock.close()
    144 
    145     def exithook(self):
    146         "override for specific exit action"
    147         os._exit(0)
    148 
    149     def debug(self, *args):
    150         if not self.debugging:
    151             return
    152         s = self.location + " " + str(threading.currentThread().getName())
    153         for a in args:
    154             s = s + " " + str(a)
    155         print>>sys.__stderr__, s
    156 
    157     def register(self, oid, object):
    158         self.objtable[oid] = object
    159 
    160     def unregister(self, oid):
    161         try:
    162             del self.objtable[oid]
    163         except KeyError:
    164             pass
    165 
    166     def localcall(self, seq, request):
    167         self.debug("localcall:", request)
    168         try:
    169             how, (oid, methodname, args, kwargs) = request
    170         except TypeError:
    171             return ("ERROR", "Bad request format")
    172         if oid not in self.objtable:
    173             return ("ERROR", "Unknown object id: %r" % (oid,))
    174         obj = self.objtable[oid]
    175         if methodname == "__methods__":
    176             methods = {}
    177             _getmethods(obj, methods)
    178             return ("OK", methods)
    179         if methodname == "__attributes__":
    180             attributes = {}
    181             _getattributes(obj, attributes)
    182             return ("OK", attributes)
    183         if not hasattr(obj, methodname):
    184             return ("ERROR", "Unsupported method name: %r" % (methodname,))
    185         method = getattr(obj, methodname)
    186         try:
    187             if how == 'CALL':
    188                 ret = method(*args, **kwargs)
    189                 if isinstance(ret, RemoteObject):
    190                     ret = remoteref(ret)
    191                 return ("OK", ret)
    192             elif how == 'QUEUE':
    193                 request_queue.put((seq, (method, args, kwargs)))
    194                 return("QUEUED", None)
    195             else:
    196                 return ("ERROR", "Unsupported message type: %s" % how)
    197         except SystemExit:
    198             raise
    199         except socket.error:
    200             raise
    201         except:
    202             msg = "*** Internal Error: rpc.py:SocketIO.localcall()\n\n"\
    203                   " Object: %s \n Method: %s \n Args: %s\n"
    204             print>>sys.__stderr__, msg % (oid, method, args)
    205             traceback.print_exc(file=sys.__stderr__)
    206             return ("EXCEPTION", None)
    207 
    208     def remotecall(self, oid, methodname, args, kwargs):
    209         self.debug("remotecall:asynccall: ", oid, methodname)
    210         seq = self.asynccall(oid, methodname, args, kwargs)
    211         return self.asyncreturn(seq)
    212 
    213     def remotequeue(self, oid, methodname, args, kwargs):
    214         self.debug("remotequeue:asyncqueue: ", oid, methodname)
    215         seq = self.asyncqueue(oid, methodname, args, kwargs)
    216         return self.asyncreturn(seq)
    217 
    218     def asynccall(self, oid, methodname, args, kwargs):
    219         request = ("CALL", (oid, methodname, args, kwargs))
    220         seq = self.newseq()
    221         if threading.currentThread() != self.sockthread:
    222             cvar = threading.Condition()
    223             self.cvars[seq] = cvar
    224         self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
    225         self.putmessage((seq, request))
    226         return seq
    227 
    228     def asyncqueue(self, oid, methodname, args, kwargs):
    229         request = ("QUEUE", (oid, methodname, args, kwargs))
    230         seq = self.newseq()
    231         if threading.currentThread() != self.sockthread:
    232             cvar = threading.Condition()
    233             self.cvars[seq] = cvar
    234         self.debug(("asyncqueue:%d:" % seq), oid, methodname, args, kwargs)
    235         self.putmessage((seq, request))
    236         return seq
    237 
    238     def asyncreturn(self, seq):
    239         self.debug("asyncreturn:%d:call getresponse(): " % seq)
    240         response = self.getresponse(seq, wait=0.05)
    241         self.debug(("asyncreturn:%d:response: " % seq), response)
    242         return self.decoderesponse(response)
    243 
    244     def decoderesponse(self, response):
    245         how, what = response
    246         if how == "OK":
    247             return what
    248         if how == "QUEUED":
    249             return None
    250         if how == "EXCEPTION":
    251             self.debug("decoderesponse: EXCEPTION")
    252             return None
    253         if how == "EOF":
    254             self.debug("decoderesponse: EOF")
    255             self.decode_interrupthook()
    256             return None
    257         if how == "ERROR":
    258             self.debug("decoderesponse: Internal ERROR:", what)
    259             raise RuntimeError, what
    260         raise SystemError, (how, what)
    261 
    262     def decode_interrupthook(self):
    263         ""
    264         raise EOFError
    265 
    266     def mainloop(self):
    267         """Listen on socket until I/O not ready or EOF
    268 
    269         pollresponse() will loop looking for seq number None, which
    270         never comes, and exit on EOFError.
    271 
    272         """
    273         try:
    274             self.getresponse(myseq=None, wait=0.05)
    275         except EOFError:
    276             self.debug("mainloop:return")
    277             return
    278 
    279     def getresponse(self, myseq, wait):
    280         response = self._getresponse(myseq, wait)
    281         if response is not None:
    282             how, what = response
    283             if how == "OK":
    284                 response = how, self._proxify(what)
    285         return response
    286 
    287     def _proxify(self, obj):
    288         if isinstance(obj, RemoteProxy):
    289             return RPCProxy(self, obj.oid)
    290         if isinstance(obj, types.ListType):
    291             return map(self._proxify, obj)
    292         # XXX Check for other types -- not currently needed
    293         return obj
    294 
    295     def _getresponse(self, myseq, wait):
    296         self.debug("_getresponse:myseq:", myseq)
    297         if threading.currentThread() is self.sockthread:
    298             # this thread does all reading of requests or responses
    299             while 1:
    300                 response = self.pollresponse(myseq, wait)
    301                 if response is not None:
    302                     return response
    303         else:
    304             # wait for notification from socket handling thread
    305             cvar = self.cvars[myseq]
    306             cvar.acquire()
    307             while myseq not in self.responses:
    308                 cvar.wait()
    309             response = self.responses[myseq]
    310             self.debug("_getresponse:%s: thread woke up: response: %s" %
    311                        (myseq, response))
    312             del self.responses[myseq]
    313             del self.cvars[myseq]
    314             cvar.release()
    315             return response
    316 
    317     def newseq(self):
    318         self.nextseq = seq = self.nextseq + 2
    319         return seq
    320 
    321     def putmessage(self, message):
    322         self.debug("putmessage:%d:" % message[0])
    323         try:
    324             s = pickle.dumps(message)
    325         except pickle.PicklingError:
    326             print >>sys.__stderr__, "Cannot pickle:", repr(message)
    327             raise
    328         s = struct.pack("<i", len(s)) + s
    329         while len(s) > 0:
    330             try:
    331                 r, w, x = select.select([], [self.sock], [])
    332                 n = self.sock.send(s[:BUFSIZE])
    333             except (AttributeError, TypeError):
    334                 raise IOError, "socket no longer exists"
    335             except socket.error:
    336                 raise
    337             else:
    338                 s = s[n:]
    339 
    340     buffer = ""
    341     bufneed = 4
    342     bufstate = 0 # meaning: 0 => reading count; 1 => reading data
    343 
    344     def pollpacket(self, wait):
    345         self._stage0()
    346         if len(self.buffer) < self.bufneed:
    347             r, w, x = select.select([self.sock.fileno()], [], [], wait)
    348             if len(r) == 0:
    349                 return None
    350             try:
    351                 s = self.sock.recv(BUFSIZE)
    352             except socket.error:
    353                 raise EOFError
    354             if len(s) == 0:
    355                 raise EOFError
    356             self.buffer += s
    357             self._stage0()
    358         return self._stage1()
    359 
    360     def _stage0(self):
    361         if self.bufstate == 0 and len(self.buffer) >= 4:
    362             s = self.buffer[:4]
    363             self.buffer = self.buffer[4:]
    364             self.bufneed = struct.unpack("<i", s)[0]
    365             self.bufstate = 1
    366 
    367     def _stage1(self):
    368         if self.bufstate == 1 and len(self.buffer) >= self.bufneed:
    369             packet = self.buffer[:self.bufneed]
    370             self.buffer = self.buffer[self.bufneed:]
    371             self.bufneed = 4
    372             self.bufstate = 0
    373             return packet
    374 
    375     def pollmessage(self, wait):
    376         packet = self.pollpacket(wait)
    377         if packet is None:
    378             return None
    379         try:
    380             message = pickle.loads(packet)
    381         except pickle.UnpicklingError:
    382             print >>sys.__stderr__, "-----------------------"
    383             print >>sys.__stderr__, "cannot unpickle packet:", repr(packet)
    384             traceback.print_stack(file=sys.__stderr__)
    385             print >>sys.__stderr__, "-----------------------"
    386             raise
    387         return message
    388 
    389     def pollresponse(self, myseq, wait):
    390         """Handle messages received on the socket.
    391 
    392         Some messages received may be asynchronous 'call' or 'queue' requests,
    393         and some may be responses for other threads.
    394 
    395         'call' requests are passed to self.localcall() with the expectation of
    396         immediate execution, during which time the socket is not serviced.
    397 
    398         'queue' requests are used for tasks (which may block or hang) to be
    399         processed in a different thread.  These requests are fed into
    400         request_queue by self.localcall().  Responses to queued requests are
    401         taken from response_queue and sent across the link with the associated
    402         sequence numbers.  Messages in the queues are (sequence_number,
    403         request/response) tuples and code using this module removing messages
    404         from the request_queue is responsible for returning the correct
    405         sequence number in the response_queue.
    406 
    407         pollresponse() will loop until a response message with the myseq
    408         sequence number is received, and will save other responses in
    409         self.responses and notify the owning thread.
    410 
    411         """
    412         while 1:
    413             # send queued response if there is one available
    414             try:
    415                 qmsg = response_queue.get(0)
    416             except Queue.Empty:
    417                 pass
    418             else:
    419                 seq, response = qmsg
    420                 message = (seq, ('OK', response))
    421                 self.putmessage(message)
    422             # poll for message on link
    423             try:
    424                 message = self.pollmessage(wait)
    425                 if message is None:  # socket not ready
    426                     return None
    427             except EOFError:
    428                 self.handle_EOF()
    429                 return None
    430             except AttributeError:
    431                 return None
    432             seq, resq = message
    433             how = resq[0]
    434             self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
    435             # process or queue a request
    436             if how in ("CALL", "QUEUE"):
    437                 self.debug("pollresponse:%d:localcall:call:" % seq)
    438                 response = self.localcall(seq, resq)
    439                 self.debug("pollresponse:%d:localcall:response:%s"
    440                            % (seq, response))
    441                 if how == "CALL":
    442                     self.putmessage((seq, response))
    443                 elif how == "QUEUE":
    444                     # don't acknowledge the 'queue' request!
    445                     pass
    446                 continue
    447             # return if completed message transaction
    448             elif seq == myseq:
    449                 return resq
    450             # must be a response for a different thread:
    451             else:
    452                 cv = self.cvars.get(seq, None)
    453                 # response involving unknown sequence number is discarded,
    454                 # probably intended for prior incarnation of server
    455                 if cv is not None:
    456                     cv.acquire()
    457                     self.responses[seq] = resq
    458                     cv.notify()
    459                     cv.release()
    460                 continue
    461 
    462     def handle_EOF(self):
    463         "action taken upon link being closed by peer"
    464         self.EOFhook()
    465         self.debug("handle_EOF")
    466         for key in self.cvars:
    467             cv = self.cvars[key]
    468             cv.acquire()
    469             self.responses[key] = ('EOF', None)
    470             cv.notify()
    471             cv.release()
    472         # call our (possibly overridden) exit function
    473         self.exithook()
    474 
    475     def EOFhook(self):
    476         "Classes using rpc client/server can override to augment EOF action"
    477         pass
    478 
    479 #----------------- end class SocketIO --------------------
    480 
    481 class RemoteObject(object):
    482     # Token mix-in class
    483     pass
    484 
    485 def remoteref(obj):
    486     oid = id(obj)
    487     objecttable[oid] = obj
    488     return RemoteProxy(oid)
    489 
    490 class RemoteProxy(object):
    491 
    492     def __init__(self, oid):
    493         self.oid = oid
    494 
    495 class RPCHandler(SocketServer.BaseRequestHandler, SocketIO):
    496 
    497     debugging = False
    498     location = "#S"  # Server
    499 
    500     def __init__(self, sock, addr, svr):
    501         svr.current_handler = self ## cgt xxx
    502         SocketIO.__init__(self, sock)
    503         SocketServer.BaseRequestHandler.__init__(self, sock, addr, svr)
    504 
    505     def handle(self):
    506         "handle() method required by SocketServer"
    507         self.mainloop()
    508 
    509     def get_remote_proxy(self, oid):
    510         return RPCProxy(self, oid)
    511 
    512 class RPCClient(SocketIO):
    513 
    514     debugging = False
    515     location = "#C"  # Client
    516 
    517     nextseq = 1 # Requests coming from the client are odd numbered
    518 
    519     def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
    520         self.listening_sock = socket.socket(family, type)
    521         self.listening_sock.bind(address)
    522         self.listening_sock.listen(1)
    523 
    524     def accept(self):
    525         working_sock, address = self.listening_sock.accept()
    526         if self.debugging:
    527             print>>sys.__stderr__, "****** Connection request from ", address
    528         if address[0] == LOCALHOST:
    529             SocketIO.__init__(self, working_sock)
    530         else:
    531             print>>sys.__stderr__, "** Invalid host: ", address
    532             raise socket.error
    533 
    534     def get_remote_proxy(self, oid):
    535         return RPCProxy(self, oid)
    536 
    537 class RPCProxy(object):
    538 
    539     __methods = None
    540     __attributes = None
    541 
    542     def __init__(self, sockio, oid):
    543         self.sockio = sockio
    544         self.oid = oid
    545 
    546     def __getattr__(self, name):
    547         if self.__methods is None:
    548             self.__getmethods()
    549         if self.__methods.get(name):
    550             return MethodProxy(self.sockio, self.oid, name)
    551         if self.__attributes is None:
    552             self.__getattributes()
    553         if name in self.__attributes:
    554             value = self.sockio.remotecall(self.oid, '__getattribute__',
    555                                            (name,), {})
    556             return value
    557         else:
    558             raise AttributeError, name
    559 
    560     def __getattributes(self):
    561         self.__attributes = self.sockio.remotecall(self.oid,
    562                                                 "__attributes__", (), {})
    563 
    564     def __getmethods(self):
    565         self.__methods = self.sockio.remotecall(self.oid,
    566                                                 "__methods__", (), {})
    567 
    568 def _getmethods(obj, methods):
    569     # Helper to get a list of methods from an object
    570     # Adds names to dictionary argument 'methods'
    571     for name in dir(obj):
    572         attr = getattr(obj, name)
    573         if hasattr(attr, '__call__'):
    574             methods[name] = 1
    575     if type(obj) == types.InstanceType:
    576         _getmethods(obj.__class__, methods)
    577     if type(obj) == types.ClassType:
    578         for super in obj.__bases__:
    579             _getmethods(super, methods)
    580 
    581 def _getattributes(obj, attributes):
    582     for name in dir(obj):
    583         attr = getattr(obj, name)
    584         if not hasattr(attr, '__call__'):
    585             attributes[name] = 1
    586 
    587 class MethodProxy(object):
    588 
    589     def __init__(self, sockio, oid, name):
    590         self.sockio = sockio
    591         self.oid = oid
    592         self.name = name
    593 
    594     def __call__(self, *args, **kwargs):
    595         value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
    596         return value
    597 
    598 
    599 # XXX KBK 09Sep03  We need a proper unit test for this module.  Previously
    600 #                  existing test code was removed at Rev 1.27 (r34098).
    601