Home | History | Annotate | Download | only in testserver
      1 #!/usr/bin/python2.4
      2 # Copyright (c) 2010 The Chromium Authors. All rights reserved.
      3 # Use of this source code is governed by a BSD-style license that can be
      4 # found in the LICENSE file.
      5 
      6 """Tests exercising the various classes in xmppserver.py."""
      7 
      8 import unittest
      9 
     10 import base64
     11 import xmppserver
     12 
     13 class XmlUtilsTest(unittest.TestCase):
     14 
     15   def testParseXml(self):
     16     xml_text = """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>"""
     17     xml = xmppserver.ParseXml(xml_text)
     18     self.assertEqual(xml.toxml(), xml_text)
     19 
     20   def testCloneXml(self):
     21     xml = xmppserver.ParseXml('<foo/>')
     22     xml_clone = xmppserver.CloneXml(xml)
     23     xml_clone.setAttribute('bar', 'baz')
     24     self.assertEqual(xml, xml)
     25     self.assertEqual(xml_clone, xml_clone)
     26     self.assertNotEqual(xml, xml_clone)
     27 
     28   def testCloneXmlUnlink(self):
     29     xml_text = '<foo/>'
     30     xml = xmppserver.ParseXml(xml_text)
     31     xml_clone = xmppserver.CloneXml(xml)
     32     xml.unlink()
     33     self.assertEqual(xml.parentNode, None)
     34     self.assertNotEqual(xml_clone.parentNode, None)
     35     self.assertEqual(xml_clone.toxml(), xml_text)
     36 
     37 class StanzaParserTest(unittest.TestCase):
     38 
     39   def setUp(self):
     40     self.stanzas = []
     41 
     42   def FeedStanza(self, stanza):
     43     # We can't append stanza directly because it is unlinked after
     44     # this callback.
     45     self.stanzas.append(stanza.toxml())
     46 
     47   def testBasic(self):
     48     parser = xmppserver.StanzaParser(self)
     49     parser.FeedString('<foo')
     50     self.assertEqual(len(self.stanzas), 0)
     51     parser.FeedString('/><bar></bar>')
     52     self.assertEqual(self.stanzas[0], '<foo/>')
     53     self.assertEqual(self.stanzas[1], '<bar/>')
     54 
     55   def testStream(self):
     56     parser = xmppserver.StanzaParser(self)
     57     parser.FeedString('<stream')
     58     self.assertEqual(len(self.stanzas), 0)
     59     parser.FeedString(':stream foo="bar" xmlns:stream="baz">')
     60     self.assertEqual(self.stanzas[0],
     61                      '<stream:stream foo="bar" xmlns:stream="baz"/>')
     62 
     63   def testNested(self):
     64     parser = xmppserver.StanzaParser(self)
     65     parser.FeedString('<foo')
     66     self.assertEqual(len(self.stanzas), 0)
     67     parser.FeedString(' bar="baz"')
     68     parser.FeedString('><baz/><blah>meh</blah></foo>')
     69     self.assertEqual(self.stanzas[0],
     70                      '<foo bar="baz"><baz/><blah>meh</blah></foo>')
     71 
     72 
     73 class JidTest(unittest.TestCase):
     74 
     75   def testBasic(self):
     76     jid = xmppserver.Jid('foo', 'bar.com')
     77     self.assertEqual(str(jid), 'foo (at] bar.com')
     78 
     79   def testResource(self):
     80     jid = xmppserver.Jid('foo', 'bar.com', 'resource')
     81     self.assertEqual(str(jid), 'foo (at] bar.com/resource')
     82 
     83   def testGetBareJid(self):
     84     jid = xmppserver.Jid('foo', 'bar.com', 'resource')
     85     self.assertEqual(str(jid.GetBareJid()), 'foo (at] bar.com')
     86 
     87 
     88 class IdGeneratorTest(unittest.TestCase):
     89 
     90   def testBasic(self):
     91     id_generator = xmppserver.IdGenerator('foo')
     92     for i in xrange(0, 100):
     93       self.assertEqual('foo.%d' % i, id_generator.GetNextId())
     94 
     95 
     96 class HandshakeTaskTest(unittest.TestCase):
     97 
     98   def setUp(self):
     99     self.data_received = 0
    100 
    101   def SendData(self, _):
    102     self.data_received += 1
    103 
    104   def SendStanza(self, _, unused=True):
    105     self.data_received += 1
    106 
    107   def HandshakeDone(self, jid):
    108     self.jid = jid
    109 
    110   def DoHandshake(self, resource_prefix, resource, username,
    111                   initial_stream_domain, auth_domain, auth_stream_domain):
    112     self.data_received = 0
    113     handshake_task = (
    114       xmppserver.HandshakeTask(self, resource_prefix))
    115     stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
    116     stream_xml.setAttribute('to', initial_stream_domain)
    117     self.assertEqual(self.data_received, 0)
    118     handshake_task.FeedStanza(stream_xml)
    119     self.assertEqual(self.data_received, 2)
    120 
    121     if auth_domain:
    122       username_domain = '%s@%s' % (username, auth_domain)
    123     else:
    124       username_domain = username
    125     auth_string = base64.b64encode('\0%s\0bar' % username_domain)
    126     auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
    127     handshake_task.FeedStanza(auth_xml)
    128     self.assertEqual(self.data_received, 3)
    129 
    130     stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
    131     stream_xml.setAttribute('to', auth_stream_domain)
    132     handshake_task.FeedStanza(stream_xml)
    133     self.assertEqual(self.data_received, 5)
    134 
    135     bind_xml = xmppserver.ParseXml(
    136       '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource)
    137     handshake_task.FeedStanza(bind_xml)
    138     self.assertEqual(self.data_received, 6)
    139 
    140     session_xml = xmppserver.ParseXml(
    141       '<iq type="set"><session></session></iq>')
    142     handshake_task.FeedStanza(session_xml)
    143     self.assertEqual(self.data_received, 7)
    144 
    145     self.assertEqual(self.jid.username, username)
    146     self.assertEqual(self.jid.domain,
    147                      auth_stream_domain or auth_domain or
    148                      initial_stream_domain)
    149     self.assertEqual(self.jid.resource,
    150                      '%s.%s' % (resource_prefix, resource))
    151 
    152   def testBasic(self):
    153     self.DoHandshake('resource_prefix', 'resource',
    154                      'foo', 'bar.com', 'baz.com', 'quux.com')
    155 
    156   def testDomainBehavior(self):
    157     self.DoHandshake('resource_prefix', 'resource',
    158                      'foo', 'bar.com', 'baz.com', 'quux.com')
    159     self.DoHandshake('resource_prefix', 'resource',
    160                      'foo', 'bar.com', 'baz.com', '')
    161     self.DoHandshake('resource_prefix', 'resource',
    162                      'foo', 'bar.com', '', '')
    163     self.DoHandshake('resource_prefix', 'resource',
    164                      'foo', '', '', '')
    165 
    166 
    167 class XmppConnectionTest(unittest.TestCase):
    168 
    169   def setUp(self):
    170     self.connections = set()
    171     self.data = []
    172 
    173   # socket-like methods.
    174   def fileno(self):
    175     return 0
    176 
    177   def setblocking(self, int):
    178     pass
    179 
    180   def getpeername(self):
    181     return ('', 0)
    182 
    183   def send(self, data):
    184     self.data.append(data)
    185     pass
    186 
    187   def close(self):
    188     pass
    189 
    190   # XmppConnection delegate methods.
    191   def OnXmppHandshakeDone(self, xmpp_connection):
    192     self.connections.add(xmpp_connection)
    193 
    194   def OnXmppConnectionClosed(self, xmpp_connection):
    195     self.connections.discard(xmpp_connection)
    196 
    197   def ForwardNotification(self, unused_xmpp_connection, notification_stanza):
    198     for connection in self.connections:
    199       connection.ForwardNotification(notification_stanza)
    200 
    201   def testBasic(self):
    202     socket_map = {}
    203     xmpp_connection = xmppserver.XmppConnection(
    204       self, socket_map, self, ('', 0))
    205     self.assertEqual(len(socket_map), 1)
    206     self.assertEqual(len(self.connections), 0)
    207     xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar'))
    208     self.assertEqual(len(socket_map), 1)
    209     self.assertEqual(len(self.connections), 1)
    210 
    211     # Test subscription request.
    212     self.assertEqual(len(self.data), 0)
    213     xmpp_connection.collect_incoming_data(
    214       '<iq><subscribe xmlns="google:push"></subscribe></iq>')
    215     self.assertEqual(len(self.data), 1)
    216 
    217     # Test acks.
    218     xmpp_connection.collect_incoming_data('<iq type="result"/>')
    219     self.assertEqual(len(self.data), 1)
    220 
    221     # Test notification.
    222     xmpp_connection.collect_incoming_data(
    223       '<message><push xmlns="google:push"/></message>')
    224     self.assertEqual(len(self.data), 2)
    225 
    226     # Test unexpected stanza.
    227     def SendUnexpectedStanza():
    228       xmpp_connection.collect_incoming_data('<foo/>')
    229     self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)
    230 
    231     # Test unexpected notifier command.
    232     def SendUnexpectedNotifierCommand():
    233       xmpp_connection.collect_incoming_data(
    234         '<iq><foo xmlns="google:notifier"/></iq>')
    235     self.assertRaises(xmppserver.UnexpectedXml,
    236                       SendUnexpectedNotifierCommand)
    237 
    238     # Test close
    239     xmpp_connection.close()
    240     self.assertEqual(len(socket_map), 0)
    241     self.assertEqual(len(self.connections), 0)
    242 
    243 class XmppServerTest(unittest.TestCase):
    244 
    245   # socket-like methods.
    246   def fileno(self):
    247     return 0
    248 
    249   def setblocking(self, int):
    250     pass
    251 
    252   def getpeername(self):
    253     return ('', 0)
    254 
    255   def close(self):
    256     pass
    257 
    258   def testBasic(self):
    259     class FakeXmppServer(xmppserver.XmppServer):
    260       def accept(self2):
    261         return (self, ('', 0))
    262 
    263     socket_map = {}
    264     self.assertEqual(len(socket_map), 0)
    265     xmpp_server = FakeXmppServer(socket_map, ('', 0))
    266     self.assertEqual(len(socket_map), 1)
    267     xmpp_server.handle_accept()
    268     self.assertEqual(len(socket_map), 2)
    269     xmpp_server.close()
    270     self.assertEqual(len(socket_map), 0)
    271 
    272 
    273 if __name__ == '__main__':
    274   unittest.main()
    275