Home | History | Annotate | Download | only in testserver
      1 #!/usr/bin/env python
      2 # Copyright 2013 The Chromium Authors. All rights reserved.
      3 # Use of this source code is governed by a BSD-style license that can be
      4 # found in the LICENSE file.
      5 
      6 """Tests exercising the various classes in xmppserver.py."""
      7 
      8 import unittest
      9 
     10 import base64
     11 import xmppserver
     12 
     13 class XmlUtilsTest(unittest.TestCase):
     14 
     15   def testParseXml(self):
     16     xml_text = """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>"""
     17     xml = xmppserver.ParseXml(xml_text)
     18     self.assertEqual(xml.toxml(), xml_text)
     19 
     20   def testCloneXml(self):
     21     xml = xmppserver.ParseXml('<foo/>')
     22     xml_clone = xmppserver.CloneXml(xml)
     23     xml_clone.setAttribute('bar', 'baz')
     24     self.assertEqual(xml, xml)
     25     self.assertEqual(xml_clone, xml_clone)
     26     self.assertNotEqual(xml, xml_clone)
     27 
     28   def testCloneXmlUnlink(self):
     29     xml_text = '<foo/>'
     30     xml = xmppserver.ParseXml(xml_text)
     31     xml_clone = xmppserver.CloneXml(xml)
     32     xml.unlink()
     33     self.assertEqual(xml.parentNode, None)
     34     self.assertNotEqual(xml_clone.parentNode, None)
     35     self.assertEqual(xml_clone.toxml(), xml_text)
     36 
     37 class StanzaParserTest(unittest.TestCase):
     38 
     39   def setUp(self):
     40     self.stanzas = []
     41 
     42   def FeedStanza(self, stanza):
     43     # We can't append stanza directly because it is unlinked after
     44     # this callback.
     45     self.stanzas.append(stanza.toxml())
     46 
     47   def testBasic(self):
     48     parser = xmppserver.StanzaParser(self)
     49     parser.FeedString('<foo')
     50     self.assertEqual(len(self.stanzas), 0)
     51     parser.FeedString('/><bar></bar>')
     52     self.assertEqual(self.stanzas[0], '<foo/>')
     53     self.assertEqual(self.stanzas[1], '<bar/>')
     54 
     55   def testStream(self):
     56     parser = xmppserver.StanzaParser(self)
     57     parser.FeedString('<stream')
     58     self.assertEqual(len(self.stanzas), 0)
     59     parser.FeedString(':stream foo="bar" xmlns:stream="baz">')
     60     self.assertEqual(self.stanzas[0],
     61                      '<stream:stream foo="bar" xmlns:stream="baz"/>')
     62 
     63   def testNested(self):
     64     parser = xmppserver.StanzaParser(self)
     65     parser.FeedString('<foo')
     66     self.assertEqual(len(self.stanzas), 0)
     67     parser.FeedString(' bar="baz"')
     68     parser.FeedString('><baz/><blah>meh</blah></foo>')
     69     self.assertEqual(self.stanzas[0],
     70                      '<foo bar="baz"><baz/><blah>meh</blah></foo>')
     71 
     72 
     73 class JidTest(unittest.TestCase):
     74 
     75   def testBasic(self):
     76     jid = xmppserver.Jid('foo', 'bar.com')
     77     self.assertEqual(str(jid), 'foo (at] bar.com')
     78 
     79   def testResource(self):
     80     jid = xmppserver.Jid('foo', 'bar.com', 'resource')
     81     self.assertEqual(str(jid), 'foo (at] bar.com/resource')
     82 
     83   def testGetBareJid(self):
     84     jid = xmppserver.Jid('foo', 'bar.com', 'resource')
     85     self.assertEqual(str(jid.GetBareJid()), 'foo (at] bar.com')
     86 
     87 
     88 class IdGeneratorTest(unittest.TestCase):
     89 
     90   def testBasic(self):
     91     id_generator = xmppserver.IdGenerator('foo')
     92     for i in xrange(0, 100):
     93       self.assertEqual('foo.%d' % i, id_generator.GetNextId())
     94 
     95 
     96 class HandshakeTaskTest(unittest.TestCase):
     97 
     98   def setUp(self):
     99     self.Reset()
    100 
    101   def Reset(self):
    102     self.data_received = 0
    103     self.handshake_done = False
    104     self.jid = None
    105 
    106   def SendData(self, _):
    107     self.data_received += 1
    108 
    109   def SendStanza(self, _, unused=True):
    110     self.data_received += 1
    111 
    112   def HandshakeDone(self, jid):
    113     self.handshake_done = True
    114     self.jid = jid
    115 
    116   def DoHandshake(self, resource_prefix, resource, username,
    117                   initial_stream_domain, auth_domain, auth_stream_domain):
    118     self.Reset()
    119     handshake_task = (
    120       xmppserver.HandshakeTask(self, resource_prefix, True))
    121     stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
    122     stream_xml.setAttribute('to', initial_stream_domain)
    123     self.assertEqual(self.data_received, 0)
    124     handshake_task.FeedStanza(stream_xml)
    125     self.assertEqual(self.data_received, 2)
    126 
    127     if auth_domain:
    128       username_domain = '%s@%s' % (username, auth_domain)
    129     else:
    130       username_domain = username
    131     auth_string = base64.b64encode('\0%s\0bar' % username_domain)
    132     auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
    133     handshake_task.FeedStanza(auth_xml)
    134     self.assertEqual(self.data_received, 3)
    135 
    136     stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
    137     stream_xml.setAttribute('to', auth_stream_domain)
    138     handshake_task.FeedStanza(stream_xml)
    139     self.assertEqual(self.data_received, 5)
    140 
    141     bind_xml = xmppserver.ParseXml(
    142       '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource)
    143     handshake_task.FeedStanza(bind_xml)
    144     self.assertEqual(self.data_received, 6)
    145 
    146     self.assertFalse(self.handshake_done)
    147 
    148     session_xml = xmppserver.ParseXml(
    149       '<iq type="set"><session></session></iq>')
    150     handshake_task.FeedStanza(session_xml)
    151     self.assertEqual(self.data_received, 7)
    152 
    153     self.assertTrue(self.handshake_done)
    154 
    155     self.assertEqual(self.jid.username, username)
    156     self.assertEqual(self.jid.domain,
    157                      auth_stream_domain or auth_domain or
    158                      initial_stream_domain)
    159     self.assertEqual(self.jid.resource,
    160                      '%s.%s' % (resource_prefix, resource))
    161 
    162     handshake_task.FeedStanza('<ignored/>')
    163     self.assertEqual(self.data_received, 7)
    164 
    165   def DoHandshakeUnauthenticated(self, resource_prefix, resource, username,
    166                                  initial_stream_domain):
    167     self.Reset()
    168     handshake_task = (
    169       xmppserver.HandshakeTask(self, resource_prefix, False))
    170     stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
    171     stream_xml.setAttribute('to', initial_stream_domain)
    172     self.assertEqual(self.data_received, 0)
    173     handshake_task.FeedStanza(stream_xml)
    174     self.assertEqual(self.data_received, 2)
    175 
    176     self.assertFalse(self.handshake_done)
    177 
    178     auth_string = base64.b64encode('\0%s\0bar' % username)
    179     auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
    180     handshake_task.FeedStanza(auth_xml)
    181     self.assertEqual(self.data_received, 3)
    182 
    183     self.assertTrue(self.handshake_done)
    184 
    185     self.assertEqual(self.jid, None)
    186 
    187     handshake_task.FeedStanza('<ignored/>')
    188     self.assertEqual(self.data_received, 3)
    189 
    190   def testBasic(self):
    191     self.DoHandshake('resource_prefix', 'resource',
    192                      'foo', 'bar.com', 'baz.com', 'quux.com')
    193 
    194   def testDomainBehavior(self):
    195     self.DoHandshake('resource_prefix', 'resource',
    196                      'foo', 'bar.com', 'baz.com', 'quux.com')
    197     self.DoHandshake('resource_prefix', 'resource',
    198                      'foo', 'bar.com', 'baz.com', '')
    199     self.DoHandshake('resource_prefix', 'resource',
    200                      'foo', 'bar.com', '', '')
    201     self.DoHandshake('resource_prefix', 'resource',
    202                      'foo', '', '', '')
    203 
    204   def testBasicUnauthenticated(self):
    205     self.DoHandshakeUnauthenticated('resource_prefix', 'resource',
    206                                     'foo', 'bar.com')
    207 
    208 
    209 class FakeSocket(object):
    210   """A fake socket object used for testing.
    211   """
    212 
    213   def __init__(self):
    214     self._sent_data = []
    215 
    216   def GetSentData(self):
    217     return self._sent_data
    218 
    219   # socket-like methods.
    220   def fileno(self):
    221     return 0
    222 
    223   def setblocking(self, int):
    224     pass
    225 
    226   def getpeername(self):
    227     return ('', 0)
    228 
    229   def send(self, data):
    230     self._sent_data.append(data)
    231     pass
    232 
    233   def close(self):
    234     pass
    235 
    236 
    237 class XmppConnectionTest(unittest.TestCase):
    238 
    239   def setUp(self):
    240     self.connections = set()
    241     self.fake_socket = FakeSocket()
    242 
    243   # XmppConnection delegate methods.
    244   def OnXmppHandshakeDone(self, xmpp_connection):
    245     self.connections.add(xmpp_connection)
    246 
    247   def OnXmppConnectionClosed(self, xmpp_connection):
    248     self.connections.discard(xmpp_connection)
    249 
    250   def ForwardNotification(self, unused_xmpp_connection, notification_stanza):
    251     for connection in self.connections:
    252       connection.ForwardNotification(notification_stanza)
    253 
    254   def testBasic(self):
    255     socket_map = {}
    256     xmpp_connection = xmppserver.XmppConnection(
    257       self.fake_socket, socket_map, self, ('', 0), True)
    258     self.assertEqual(len(socket_map), 1)
    259     self.assertEqual(len(self.connections), 0)
    260     xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar'))
    261     self.assertEqual(len(socket_map), 1)
    262     self.assertEqual(len(self.connections), 1)
    263 
    264     sent_data = self.fake_socket.GetSentData()
    265 
    266     # Test subscription request.
    267     self.assertEqual(len(sent_data), 0)
    268     xmpp_connection.collect_incoming_data(
    269       '<iq><subscribe xmlns="google:push"></subscribe></iq>')
    270     self.assertEqual(len(sent_data), 1)
    271 
    272     # Test acks.
    273     xmpp_connection.collect_incoming_data('<iq type="result"/>')
    274     self.assertEqual(len(sent_data), 1)
    275 
    276     # Test notification.
    277     xmpp_connection.collect_incoming_data(
    278       '<message><push xmlns="google:push"/></message>')
    279     self.assertEqual(len(sent_data), 2)
    280 
    281     # Test unexpected stanza.
    282     def SendUnexpectedStanza():
    283       xmpp_connection.collect_incoming_data('<foo/>')
    284     self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)
    285 
    286     # Test unexpected notifier command.
    287     def SendUnexpectedNotifierCommand():
    288       xmpp_connection.collect_incoming_data(
    289         '<iq><foo xmlns="google:notifier"/></iq>')
    290     self.assertRaises(xmppserver.UnexpectedXml,
    291                       SendUnexpectedNotifierCommand)
    292 
    293     # Test close.
    294     xmpp_connection.close()
    295     self.assertEqual(len(socket_map), 0)
    296     self.assertEqual(len(self.connections), 0)
    297 
    298   def testBasicUnauthenticated(self):
    299     socket_map = {}
    300     xmpp_connection = xmppserver.XmppConnection(
    301       self.fake_socket, socket_map, self, ('', 0), False)
    302     self.assertEqual(len(socket_map), 1)
    303     self.assertEqual(len(self.connections), 0)
    304     xmpp_connection.HandshakeDone(None)
    305     self.assertEqual(len(socket_map), 0)
    306     self.assertEqual(len(self.connections), 0)
    307 
    308     # Test unexpected stanza.
    309     def SendUnexpectedStanza():
    310       xmpp_connection.collect_incoming_data('<foo/>')
    311     self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)
    312 
    313     # Test redundant close.
    314     xmpp_connection.close()
    315     self.assertEqual(len(socket_map), 0)
    316     self.assertEqual(len(self.connections), 0)
    317 
    318 
    319 class FakeXmppServer(xmppserver.XmppServer):
    320   """A fake XMPP server object used for testing.
    321   """
    322 
    323   def __init__(self):
    324     self._socket_map = {}
    325     self._fake_sockets = set()
    326     self._next_jid_suffix = 1
    327     xmppserver.XmppServer.__init__(self, self._socket_map, ('', 0))
    328 
    329   def GetSocketMap(self):
    330     return self._socket_map
    331 
    332   def GetFakeSockets(self):
    333     return self._fake_sockets
    334 
    335   def AddHandshakeCompletedConnection(self):
    336     """Creates a new XMPP connection and completes its handshake.
    337     """
    338     xmpp_connection = self.handle_accept()
    339     jid = xmppserver.Jid('user%s' % self._next_jid_suffix, 'domain.com')
    340     self._next_jid_suffix += 1
    341     xmpp_connection.HandshakeDone(jid)
    342 
    343   # XmppServer overrides.
    344   def accept(self):
    345     fake_socket = FakeSocket()
    346     self._fake_sockets.add(fake_socket)
    347     return (fake_socket, ('', 0))
    348 
    349   def close(self):
    350     self._fake_sockets.clear()
    351     xmppserver.XmppServer.close(self)
    352 
    353 
    354 class XmppServerTest(unittest.TestCase):
    355 
    356   def setUp(self):
    357     self.xmpp_server = FakeXmppServer()
    358 
    359   def AssertSentDataLength(self, expected_length):
    360     for fake_socket in self.xmpp_server.GetFakeSockets():
    361       self.assertEqual(len(fake_socket.GetSentData()), expected_length)
    362 
    363   def testBasic(self):
    364     socket_map = self.xmpp_server.GetSocketMap()
    365     self.assertEqual(len(socket_map), 1)
    366     self.xmpp_server.AddHandshakeCompletedConnection()
    367     self.assertEqual(len(socket_map), 2)
    368     self.xmpp_server.close()
    369     self.assertEqual(len(socket_map), 0)
    370 
    371   def testMakeNotification(self):
    372     notification = self.xmpp_server.MakeNotification('channel', 'data')
    373     expected_xml = (
    374       '<message>'
    375       '  <push channel="channel" xmlns="google:push">'
    376       '    <data>%s</data>'
    377       '  </push>'
    378       '</message>' % base64.b64encode('data'))
    379     self.assertEqual(notification.toxml(), expected_xml)
    380 
    381   def testSendNotification(self):
    382     # Add a few connections.
    383     for _ in xrange(0, 7):
    384       self.xmpp_server.AddHandshakeCompletedConnection()
    385 
    386     self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 7)
    387 
    388     self.AssertSentDataLength(0)
    389     self.xmpp_server.SendNotification('channel', 'data')
    390     self.AssertSentDataLength(1)
    391 
    392   def testEnableDisableNotifications(self):
    393     # Add a few connections.
    394     for _ in xrange(0, 5):
    395       self.xmpp_server.AddHandshakeCompletedConnection()
    396 
    397     self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 5)
    398 
    399     self.AssertSentDataLength(0)
    400     self.xmpp_server.SendNotification('channel', 'data')
    401     self.AssertSentDataLength(1)
    402 
    403     self.xmpp_server.EnableNotifications()
    404     self.xmpp_server.SendNotification('channel', 'data')
    405     self.AssertSentDataLength(2)
    406 
    407     self.xmpp_server.DisableNotifications()
    408     self.xmpp_server.SendNotification('channel', 'data')
    409     self.AssertSentDataLength(2)
    410 
    411     self.xmpp_server.DisableNotifications()
    412     self.xmpp_server.SendNotification('channel', 'data')
    413     self.AssertSentDataLength(2)
    414 
    415     self.xmpp_server.EnableNotifications()
    416     self.xmpp_server.SendNotification('channel', 'data')
    417     self.AssertSentDataLength(3)
    418 
    419 
    420 if __name__ == '__main__':
    421   unittest.main()
    422