Home | History | Annotate | Download | only in testserver
      1 # Copyright 2013 The Chromium Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 """A bare-bones and non-compliant XMPP server.
      6 
      7 Just enough of the protocol is implemented to get it to work with
      8 Chrome's sync notification system.
      9 """
     10 
     11 import asynchat
     12 import asyncore
     13 import base64
     14 import re
     15 import socket
     16 from xml.dom import minidom
     17 
     18 # pychecker complains about the use of fileno(), which is implemented
     19 # by asyncore by forwarding to an internal object via __getattr__.
     20 __pychecker__ = 'no-classattr'
     21 
     22 
     23 class Error(Exception):
     24   """Error class for this module."""
     25   pass
     26 
     27 
     28 class UnexpectedXml(Error):
     29   """Raised when an unexpected XML element has been encountered."""
     30 
     31   def __init__(self, xml_element):
     32     xml_text = xml_element.toxml()
     33     Error.__init__(self, 'Unexpected XML element', xml_text)
     34 
     35 
     36 def ParseXml(xml_string):
     37   """Parses the given string as XML and returns a minidom element
     38   object.
     39   """
     40   dom = minidom.parseString(xml_string)
     41 
     42   # minidom handles xmlns specially, but there's a bug where it sets
     43   # the attribute value to None, which causes toxml() or toprettyxml()
     44   # to break.
     45   def FixMinidomXmlnsBug(xml_element):
     46     if xml_element.getAttribute('xmlns') is None:
     47       xml_element.setAttribute('xmlns', '')
     48 
     49   def ApplyToAllDescendantElements(xml_element, fn):
     50     fn(xml_element)
     51     for node in xml_element.childNodes:
     52       if node.nodeType == node.ELEMENT_NODE:
     53         ApplyToAllDescendantElements(node, fn)
     54 
     55   root = dom.documentElement
     56   ApplyToAllDescendantElements(root, FixMinidomXmlnsBug)
     57   return root
     58 
     59 
     60 def CloneXml(xml):
     61   """Returns a deep copy of the given XML element.
     62 
     63   Args:
     64     xml: The XML element, which should be something returned from
     65          ParseXml() (i.e., a root element).
     66   """
     67   return xml.ownerDocument.cloneNode(True).documentElement
     68 
     69 
     70 class StanzaParser(object):
     71   """A hacky incremental XML parser.
     72 
     73   StanzaParser consumes data incrementally via FeedString() and feeds
     74   its delegate complete parsed stanzas (i.e., XML documents) via
     75   FeedStanza().  Any stanzas passed to FeedStanza() are unlinked after
     76   the callback is done.
     77 
     78   Use like so:
     79 
     80   class MyClass(object):
     81     ...
     82     def __init__(self, ...):
     83       ...
     84       self._parser = StanzaParser(self)
     85       ...
     86 
     87     def SomeFunction(self, ...):
     88       ...
     89       self._parser.FeedString(some_data)
     90       ...
     91 
     92     def FeedStanza(self, stanza):
     93       ...
     94       print stanza.toprettyxml()
     95       ...
     96   """
     97 
     98   # NOTE(akalin): The following regexps are naive, but necessary since
     99   # none of the existing Python 2.4/2.5 XML libraries support
    100   # incremental parsing.  This works well enough for our purposes.
    101   #
    102   # The regexps below assume that any present XML element starts at
    103   # the beginning of the string, but there may be trailing whitespace.
    104 
    105   # Matches an opening stream tag (e.g., '<stream:stream foo="bar">')
    106   # (assumes that the stream XML namespace is defined in the tag).
    107   _stream_re = re.compile(r'^(<stream:stream [^>]*>)\s*')
    108 
    109   # Matches an empty element tag (e.g., '<foo bar="baz"/>').
    110   _empty_element_re = re.compile(r'^(<[^>]*/>)\s*')
    111 
    112   # Matches a non-empty element (e.g., '<foo bar="baz">quux</foo>').
    113   # Does *not* handle nested elements.
    114   _non_empty_element_re = re.compile(r'^(<([^ >]*)[^>]*>.*?</\2>)\s*')
    115 
    116   # The closing tag for a stream tag.  We have to insert this
    117   # ourselves since all XML stanzas are children of the stream tag,
    118   # which is never closed until the connection is closed.
    119   _stream_suffix = '</stream:stream>'
    120 
    121   def __init__(self, delegate):
    122     self._buffer = ''
    123     self._delegate = delegate
    124 
    125   def FeedString(self, data):
    126     """Consumes the given string data, possibly feeding one or more
    127     stanzas to the delegate.
    128     """
    129     self._buffer += data
    130     while (self._ProcessBuffer(self._stream_re, self._stream_suffix) or
    131            self._ProcessBuffer(self._empty_element_re) or
    132            self._ProcessBuffer(self._non_empty_element_re)):
    133       pass
    134 
    135   def _ProcessBuffer(self, regexp, xml_suffix=''):
    136     """If the buffer matches the given regexp, removes the match from
    137     the buffer, appends the given suffix, parses it, and feeds it to
    138     the delegate.
    139 
    140     Returns:
    141       Whether or not the buffer matched the given regexp.
    142     """
    143     results = regexp.match(self._buffer)
    144     if not results:
    145       return False
    146     xml_text = self._buffer[:results.end()] + xml_suffix
    147     self._buffer = self._buffer[results.end():]
    148     stanza = ParseXml(xml_text)
    149     self._delegate.FeedStanza(stanza)
    150     # Needed because stanza may have cycles.
    151     stanza.unlink()
    152     return True
    153 
    154 
    155 class Jid(object):
    156   """Simple struct for an XMPP jid (essentially an e-mail address with
    157   an optional resource string).
    158   """
    159 
    160   def __init__(self, username, domain, resource=''):
    161     self.username = username
    162     self.domain = domain
    163     self.resource = resource
    164 
    165   def __str__(self):
    166     jid_str = "%s@%s" % (self.username, self.domain)
    167     if self.resource:
    168       jid_str += '/' + self.resource
    169     return jid_str
    170 
    171   def GetBareJid(self):
    172     return Jid(self.username, self.domain)
    173 
    174 
    175 class IdGenerator(object):
    176   """Simple class to generate unique IDs for XMPP messages."""
    177 
    178   def __init__(self, prefix):
    179     self._prefix = prefix
    180     self._id = 0
    181 
    182   def GetNextId(self):
    183     next_id = "%s.%s" % (self._prefix, self._id)
    184     self._id += 1
    185     return next_id
    186 
    187 
    188 class HandshakeTask(object):
    189   """Class to handle the initial handshake with a connected XMPP
    190   client.
    191   """
    192 
    193   # The handshake states in order.
    194   (_INITIAL_STREAM_NEEDED,
    195    _AUTH_NEEDED,
    196    _AUTH_STREAM_NEEDED,
    197    _BIND_NEEDED,
    198    _SESSION_NEEDED,
    199    _FINISHED) = range(6)
    200 
    201   # Used when in the _INITIAL_STREAM_NEEDED and _AUTH_STREAM_NEEDED
    202   # states.  Not an XML object as it's only the opening tag.
    203   #
    204   # The from and id attributes are filled in later.
    205   _STREAM_DATA = (
    206     '<stream:stream from="%s" id="%s" '
    207     'version="1.0" xmlns:stream="http://etherx.jabber.org/streams" '
    208     'xmlns="jabber:client">')
    209 
    210   # Used when in the _INITIAL_STREAM_NEEDED state.
    211   _AUTH_STANZA = ParseXml(
    212     '<stream:features xmlns:stream="http://etherx.jabber.org/streams">'
    213     '  <mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl">'
    214     '    <mechanism>PLAIN</mechanism>'
    215     '    <mechanism>X-GOOGLE-TOKEN</mechanism>'
    216     '    <mechanism>X-OAUTH2</mechanism>'
    217     '  </mechanisms>'
    218     '</stream:features>')
    219 
    220   # Used when in the _AUTH_NEEDED state.
    221   _AUTH_SUCCESS_STANZA = ParseXml(
    222     '<success xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>')
    223 
    224   # Used when in the _AUTH_NEEDED state.
    225   _AUTH_FAILURE_STANZA = ParseXml(
    226     '<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>')
    227 
    228   # Used when in the _AUTH_STREAM_NEEDED state.
    229   _BIND_STANZA = ParseXml(
    230     '<stream:features xmlns:stream="http://etherx.jabber.org/streams">'
    231     '  <bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"/>'
    232     '  <session xmlns="urn:ietf:params:xml:ns:xmpp-session"/>'
    233     '</stream:features>')
    234 
    235   # Used when in the _BIND_NEEDED state.
    236   #
    237   # The id and jid attributes are filled in later.
    238   _BIND_RESULT_STANZA = ParseXml(
    239     '<iq id="" type="result">'
    240     '  <bind xmlns="urn:ietf:params:xml:ns:xmpp-bind">'
    241     '    <jid/>'
    242     '  </bind>'
    243     '</iq>')
    244 
    245   # Used when in the _SESSION_NEEDED state.
    246   #
    247   # The id attribute is filled in later.
    248   _IQ_RESPONSE_STANZA = ParseXml('<iq id="" type="result"/>')
    249 
    250   def __init__(self, connection, resource_prefix, authenticated):
    251     self._connection = connection
    252     self._id_generator = IdGenerator(resource_prefix)
    253     self._username = ''
    254     self._domain = ''
    255     self._jid = None
    256     self._authenticated = authenticated
    257     self._resource_prefix = resource_prefix
    258     self._state = self._INITIAL_STREAM_NEEDED
    259 
    260   def FeedStanza(self, stanza):
    261     """Inspects the given stanza and changes the handshake state if needed.
    262 
    263     Called when a stanza is received from the client.  Inspects the
    264     stanza to make sure it has the expected attributes given the
    265     current state, advances the state if needed, and sends a reply to
    266     the client if needed.
    267     """
    268     def ExpectStanza(stanza, name):
    269       if stanza.tagName != name:
    270         raise UnexpectedXml(stanza)
    271 
    272     def ExpectIq(stanza, type, name):
    273       ExpectStanza(stanza, 'iq')
    274       if (stanza.getAttribute('type') != type or
    275           stanza.firstChild.tagName != name):
    276         raise UnexpectedXml(stanza)
    277 
    278     def GetStanzaId(stanza):
    279       return stanza.getAttribute('id')
    280 
    281     def HandleStream(stanza):
    282       ExpectStanza(stanza, 'stream:stream')
    283       domain = stanza.getAttribute('to')
    284       if domain:
    285         self._domain = domain
    286       SendStreamData()
    287 
    288     def SendStreamData():
    289       next_id = self._id_generator.GetNextId()
    290       stream_data = self._STREAM_DATA % (self._domain, next_id)
    291       self._connection.SendData(stream_data)
    292 
    293     def GetUserDomain(stanza):
    294       encoded_username_password = stanza.firstChild.data
    295       username_password = base64.b64decode(encoded_username_password)
    296       (_, username_domain, _) = username_password.split('\0')
    297       # The domain may be omitted.
    298       #
    299       # If we were using python 2.5, we'd be able to do:
    300       #
    301       #   username, _, domain = username_domain.partition('@')
    302       #   if not domain:
    303       #     domain = self._domain
    304       at_pos = username_domain.find('@')
    305       if at_pos != -1:
    306         username = username_domain[:at_pos]
    307         domain = username_domain[at_pos+1:]
    308       else:
    309         username = username_domain
    310         domain = self._domain
    311       return (username, domain)
    312 
    313     def Finish():
    314       self._state = self._FINISHED
    315       self._connection.HandshakeDone(self._jid)
    316 
    317     if self._state == self._INITIAL_STREAM_NEEDED:
    318       HandleStream(stanza)
    319       self._connection.SendStanza(self._AUTH_STANZA, False)
    320       self._state = self._AUTH_NEEDED
    321 
    322     elif self._state == self._AUTH_NEEDED:
    323       ExpectStanza(stanza, 'auth')
    324       (self._username, self._domain) = GetUserDomain(stanza)
    325       if self._authenticated:
    326         self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False)
    327         self._state = self._AUTH_STREAM_NEEDED
    328       else:
    329         self._connection.SendStanza(self._AUTH_FAILURE_STANZA, False)
    330         Finish()
    331 
    332     elif self._state == self._AUTH_STREAM_NEEDED:
    333       HandleStream(stanza)
    334       self._connection.SendStanza(self._BIND_STANZA, False)
    335       self._state = self._BIND_NEEDED
    336 
    337     elif self._state == self._BIND_NEEDED:
    338       ExpectIq(stanza, 'set', 'bind')
    339       stanza_id = GetStanzaId(stanza)
    340       resource_element = stanza.getElementsByTagName('resource')[0]
    341       resource = resource_element.firstChild.data
    342       full_resource = '%s.%s' % (self._resource_prefix, resource)
    343       response = CloneXml(self._BIND_RESULT_STANZA)
    344       response.setAttribute('id', stanza_id)
    345       self._jid = Jid(self._username, self._domain, full_resource)
    346       jid_text = response.parentNode.createTextNode(str(self._jid))
    347       response.getElementsByTagName('jid')[0].appendChild(jid_text)
    348       self._connection.SendStanza(response)
    349       self._state = self._SESSION_NEEDED
    350 
    351     elif self._state == self._SESSION_NEEDED:
    352       ExpectIq(stanza, 'set', 'session')
    353       stanza_id = GetStanzaId(stanza)
    354       xml = CloneXml(self._IQ_RESPONSE_STANZA)
    355       xml.setAttribute('id', stanza_id)
    356       self._connection.SendStanza(xml)
    357       Finish()
    358 
    359 
    360 def AddrString(addr):
    361   return '%s:%d' % addr
    362 
    363 
    364 class XmppConnection(asynchat.async_chat):
    365   """A single XMPP client connection.
    366 
    367   This class handles the connection to a single XMPP client (via a
    368   socket).  It does the XMPP handshake and also implements the (old)
    369   Google notification protocol.
    370   """
    371 
    372   # Used for acknowledgements to the client.
    373   #
    374   # The from and id attributes are filled in later.
    375   _IQ_RESPONSE_STANZA = ParseXml('<iq from="" id="" type="result"/>')
    376 
    377   def __init__(self, sock, socket_map, delegate, addr, authenticated):
    378     """Starts up the xmpp connection.
    379 
    380     Args:
    381       sock: The socket to the client.
    382       socket_map: A map from sockets to their owning objects.
    383       delegate: The delegate, which is notified when the XMPP
    384         handshake is successful, when the connection is closed, and
    385         when a notification has to be broadcast.
    386       addr: The host/port of the client.
    387     """
    388     # We do this because in versions of python < 2.6,
    389     # async_chat.__init__ doesn't take a map argument nor pass it to
    390     # dispatcher.__init__.  We rely on the fact that
    391     # async_chat.__init__ calls dispatcher.__init__ as the last thing
    392     # it does, and that calling dispatcher.__init__ with socket=None
    393     # and map=None is essentially a no-op.
    394     asynchat.async_chat.__init__(self)
    395     asyncore.dispatcher.__init__(self, sock, socket_map)
    396 
    397     self.set_terminator(None)
    398 
    399     self._delegate = delegate
    400     self._parser = StanzaParser(self)
    401     self._jid = None
    402 
    403     self._addr = addr
    404     addr_str = AddrString(self._addr)
    405     self._handshake_task = HandshakeTask(self, addr_str, authenticated)
    406     print 'Starting connection to %s' % self
    407 
    408   def __str__(self):
    409     if self._jid:
    410       return str(self._jid)
    411     else:
    412       return AddrString(self._addr)
    413 
    414   # async_chat implementation.
    415 
    416   def collect_incoming_data(self, data):
    417     self._parser.FeedString(data)
    418 
    419   # This is only here to make pychecker happy.
    420   def found_terminator(self):
    421     asynchat.async_chat.found_terminator(self)
    422 
    423   def close(self):
    424     print "Closing connection to %s" % self
    425     self._delegate.OnXmppConnectionClosed(self)
    426     asynchat.async_chat.close(self)
    427 
    428   # Called by self._parser.FeedString().
    429   def FeedStanza(self, stanza):
    430     if self._handshake_task:
    431       self._handshake_task.FeedStanza(stanza)
    432     elif stanza.tagName == 'iq' and stanza.getAttribute('type') == 'result':
    433       # Ignore all client acks.
    434       pass
    435     elif (stanza.firstChild and
    436           stanza.firstChild.namespaceURI == 'google:push'):
    437       self._HandlePushCommand(stanza)
    438     else:
    439       raise UnexpectedXml(stanza)
    440 
    441   # Called by self._handshake_task.
    442   def HandshakeDone(self, jid):
    443     if jid:
    444       self._jid = jid
    445       self._handshake_task = None
    446       self._delegate.OnXmppHandshakeDone(self)
    447       print "Handshake done for %s" % self
    448     else:
    449       print "Handshake failed for %s" % self
    450       self.close()
    451 
    452   def _HandlePushCommand(self, stanza):
    453     if stanza.tagName == 'iq' and stanza.firstChild.tagName == 'subscribe':
    454       # Subscription request.
    455       self._SendIqResponseStanza(stanza)
    456     elif stanza.tagName == 'message' and stanza.firstChild.tagName == 'push':
    457       # Send notification request.
    458       self._delegate.ForwardNotification(self, stanza)
    459     else:
    460       raise UnexpectedXml(command_xml)
    461 
    462   def _SendIqResponseStanza(self, iq):
    463     stanza = CloneXml(self._IQ_RESPONSE_STANZA)
    464     stanza.setAttribute('from', str(self._jid.GetBareJid()))
    465     stanza.setAttribute('id', iq.getAttribute('id'))
    466     self.SendStanza(stanza)
    467 
    468   def SendStanza(self, stanza, unlink=True):
    469     """Sends a stanza to the client.
    470 
    471     Args:
    472       stanza: The stanza to send.
    473       unlink: Whether to unlink stanza after sending it. (Pass in
    474       False if stanza is a constant.)
    475     """
    476     self.SendData(stanza.toxml())
    477     if unlink:
    478       stanza.unlink()
    479 
    480   def SendData(self, data):
    481     """Sends raw data to the client.
    482     """
    483     # We explicitly encode to ascii as that is what the client expects
    484     # (some minidom library functions return unicode strings).
    485     self.push(data.encode('ascii'))
    486 
    487   def ForwardNotification(self, notification_stanza):
    488     """Forwards a notification to the client."""
    489     notification_stanza.setAttribute('from', str(self._jid.GetBareJid()))
    490     notification_stanza.setAttribute('to', str(self._jid))
    491     self.SendStanza(notification_stanza, False)
    492 
    493 
    494 class XmppServer(asyncore.dispatcher):
    495   """The main XMPP server class.
    496 
    497   The XMPP server starts accepting connections on the given address
    498   and spawns off XmppConnection objects for each one.
    499 
    500   Use like so:
    501 
    502     socket_map = {}
    503     xmpp_server = xmppserver.XmppServer(socket_map, ('127.0.0.1', 5222))
    504     asyncore.loop(30.0, False, socket_map)
    505   """
    506 
    507   # Used when sending a notification.
    508   _NOTIFICATION_STANZA = ParseXml(
    509     '<message>'
    510     '  <push xmlns="google:push">'
    511     '    <data/>'
    512     '  </push>'
    513     '</message>')
    514 
    515   def __init__(self, socket_map, addr):
    516     asyncore.dispatcher.__init__(self, None, socket_map)
    517     self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
    518     self.set_reuse_addr()
    519     self.bind(addr)
    520     self.listen(5)
    521     self._socket_map = socket_map
    522     self._connections = set()
    523     self._handshake_done_connections = set()
    524     self._notifications_enabled = True
    525     self._authenticated = True
    526 
    527   def handle_accept(self):
    528     (sock, addr) = self.accept()
    529     xmpp_connection = XmppConnection(
    530       sock, self._socket_map, self, addr, self._authenticated)
    531     self._connections.add(xmpp_connection)
    532     # Return the new XmppConnection for testing.
    533     return xmpp_connection
    534 
    535   def close(self):
    536     # A copy is necessary since calling close on each connection
    537     # removes it from self._connections.
    538     for connection in self._connections.copy():
    539       connection.close()
    540     asyncore.dispatcher.close(self)
    541 
    542   def EnableNotifications(self):
    543     self._notifications_enabled = True
    544 
    545   def DisableNotifications(self):
    546     self._notifications_enabled = False
    547 
    548   def MakeNotification(self, channel, data):
    549     """Makes a notification from the given channel and encoded data.
    550 
    551     Args:
    552       channel: The channel on which to send the notification.
    553       data: The notification payload.
    554     """
    555     notification_stanza = CloneXml(self._NOTIFICATION_STANZA)
    556     push_element = notification_stanza.getElementsByTagName('push')[0]
    557     push_element.setAttribute('channel', channel)
    558     data_element = push_element.getElementsByTagName('data')[0]
    559     encoded_data = base64.b64encode(data)
    560     data_text = notification_stanza.parentNode.createTextNode(encoded_data)
    561     data_element.appendChild(data_text)
    562     return notification_stanza
    563 
    564   def SendNotification(self, channel, data):
    565     """Sends a notification to all connections.
    566 
    567     Args:
    568       channel: The channel on which to send the notification.
    569       data: The notification payload.
    570     """
    571     notification_stanza = self.MakeNotification(channel, data)
    572     self.ForwardNotification(None, notification_stanza)
    573     notification_stanza.unlink()
    574 
    575   def SetAuthenticated(self, auth_valid):
    576     self._authenticated = auth_valid
    577 
    578     # We check authentication only when establishing new connections.  We close
    579     # all existing connections here to make sure previously connected clients
    580     # pick up on the change.  It's a hack, but it works well enough for our
    581     # purposes.
    582     if not self._authenticated:
    583       for connection in self._handshake_done_connections:
    584         connection.close()
    585 
    586   def GetAuthenticated(self):
    587     return self._authenticated
    588 
    589   # XmppConnection delegate methods.
    590   def OnXmppHandshakeDone(self, xmpp_connection):
    591     self._handshake_done_connections.add(xmpp_connection)
    592 
    593   def OnXmppConnectionClosed(self, xmpp_connection):
    594     self._connections.discard(xmpp_connection)
    595     self._handshake_done_connections.discard(xmpp_connection)
    596 
    597   def ForwardNotification(self, unused_xmpp_connection, notification_stanza):
    598     if self._notifications_enabled:
    599       for connection in self._handshake_done_connections:
    600         print 'Sending notification to %s' % connection
    601         connection.ForwardNotification(notification_stanza)
    602     else:
    603       print 'Notifications disabled; dropping notification'
    604