Home | History | Annotate | Download | only in test
      1 # regression test for SAX 2.0            -*- coding: utf-8 -*-
      2 # $Id$
      3 
      4 from xml.sax import make_parser, ContentHandler, \
      5                     SAXException, SAXReaderNotAvailable, SAXParseException
      6 try:
      7     make_parser()
      8 except SAXReaderNotAvailable:
      9     # don't try to test this module if we cannot create a parser
     10     raise ImportError("no XML parsers available")
     11 from xml.sax.saxutils import XMLGenerator, escape, unescape, quoteattr, \
     12                              XMLFilterBase
     13 from xml.sax.expatreader import create_parser
     14 from xml.sax.handler import feature_namespaces
     15 from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
     16 from cStringIO import StringIO
     17 import io
     18 import os.path
     19 import shutil
     20 import test.test_support as support
     21 from test.test_support import findfile, run_unittest
     22 import unittest
     23 
     24 TEST_XMLFILE = findfile("test.xml", subdir="xmltestdata")
     25 TEST_XMLFILE_OUT = findfile("test.xml.out", subdir="xmltestdata")
     26 
     27 supports_unicode_filenames = True
     28 if not os.path.supports_unicode_filenames:
     29     try:
     30         support.TESTFN_UNICODE.encode(support.TESTFN_ENCODING)
     31     except (AttributeError, UnicodeError, TypeError):
     32         # Either the file system encoding is None, or the file name
     33         # cannot be encoded in the file system encoding.
     34         supports_unicode_filenames = False
     35 requires_unicode_filenames = unittest.skipUnless(
     36         supports_unicode_filenames,
     37         'Requires unicode filenames support')
     38 
     39 ns_uri = "http://www.python.org/xml-ns/saxtest/"
     40 
     41 class XmlTestBase(unittest.TestCase):
     42     def verify_empty_attrs(self, attrs):
     43         self.assertRaises(KeyError, attrs.getValue, "attr")
     44         self.assertRaises(KeyError, attrs.getValueByQName, "attr")
     45         self.assertRaises(KeyError, attrs.getNameByQName, "attr")
     46         self.assertRaises(KeyError, attrs.getQNameByName, "attr")
     47         self.assertRaises(KeyError, attrs.__getitem__, "attr")
     48         self.assertEqual(attrs.getLength(), 0)
     49         self.assertEqual(attrs.getNames(), [])
     50         self.assertEqual(attrs.getQNames(), [])
     51         self.assertEqual(len(attrs), 0)
     52         self.assertFalse(attrs.has_key("attr"))
     53         self.assertEqual(attrs.keys(), [])
     54         self.assertEqual(attrs.get("attrs"), None)
     55         self.assertEqual(attrs.get("attrs", 25), 25)
     56         self.assertEqual(attrs.items(), [])
     57         self.assertEqual(attrs.values(), [])
     58 
     59     def verify_empty_nsattrs(self, attrs):
     60         self.assertRaises(KeyError, attrs.getValue, (ns_uri, "attr"))
     61         self.assertRaises(KeyError, attrs.getValueByQName, "ns:attr")
     62         self.assertRaises(KeyError, attrs.getNameByQName, "ns:attr")
     63         self.assertRaises(KeyError, attrs.getQNameByName, (ns_uri, "attr"))
     64         self.assertRaises(KeyError, attrs.__getitem__, (ns_uri, "attr"))
     65         self.assertEqual(attrs.getLength(), 0)
     66         self.assertEqual(attrs.getNames(), [])
     67         self.assertEqual(attrs.getQNames(), [])
     68         self.assertEqual(len(attrs), 0)
     69         self.assertFalse(attrs.has_key((ns_uri, "attr")))
     70         self.assertEqual(attrs.keys(), [])
     71         self.assertEqual(attrs.get((ns_uri, "attr")), None)
     72         self.assertEqual(attrs.get((ns_uri, "attr"), 25), 25)
     73         self.assertEqual(attrs.items(), [])
     74         self.assertEqual(attrs.values(), [])
     75 
     76     def verify_attrs_wattr(self, attrs):
     77         self.assertEqual(attrs.getLength(), 1)
     78         self.assertEqual(attrs.getNames(), ["attr"])
     79         self.assertEqual(attrs.getQNames(), ["attr"])
     80         self.assertEqual(len(attrs), 1)
     81         self.assertTrue(attrs.has_key("attr"))
     82         self.assertEqual(attrs.keys(), ["attr"])
     83         self.assertEqual(attrs.get("attr"), "val")
     84         self.assertEqual(attrs.get("attr", 25), "val")
     85         self.assertEqual(attrs.items(), [("attr", "val")])
     86         self.assertEqual(attrs.values(), ["val"])
     87         self.assertEqual(attrs.getValue("attr"), "val")
     88         self.assertEqual(attrs.getValueByQName("attr"), "val")
     89         self.assertEqual(attrs.getNameByQName("attr"), "attr")
     90         self.assertEqual(attrs["attr"], "val")
     91         self.assertEqual(attrs.getQNameByName("attr"), "attr")
     92 
     93 class MakeParserTest(unittest.TestCase):
     94     def test_make_parser2(self):
     95         # Creating parsers several times in a row should succeed.
     96         # Testing this because there have been failures of this kind
     97         # before.
     98         from xml.sax import make_parser
     99         p = make_parser()
    100         from xml.sax import make_parser
    101         p = make_parser()
    102         from xml.sax import make_parser
    103         p = make_parser()
    104         from xml.sax import make_parser
    105         p = make_parser()
    106         from xml.sax import make_parser
    107         p = make_parser()
    108         from xml.sax import make_parser
    109         p = make_parser()
    110 
    111 
    112 # ===========================================================================
    113 #
    114 #   saxutils tests
    115 #
    116 # ===========================================================================
    117 
    118 class SaxutilsTest(unittest.TestCase):
    119     # ===== escape
    120     def test_escape_basic(self):
    121         self.assertEqual(escape("Donald Duck & Co"), "Donald Duck & Co")
    122 
    123     def test_escape_all(self):
    124         self.assertEqual(escape("<Donald Duck & Co>"),
    125                          "&lt;Donald Duck &amp; Co&gt;")
    126 
    127     def test_escape_extra(self):
    128         self.assertEqual(escape("Hei p deg", {"" : "&aring;"}),
    129                          "Hei p&aring; deg")
    130 
    131     # ===== unescape
    132     def test_unescape_basic(self):
    133         self.assertEqual(unescape("Donald Duck &amp; Co"), "Donald Duck & Co")
    134 
    135     def test_unescape_all(self):
    136         self.assertEqual(unescape("&lt;Donald Duck &amp; Co&gt;"),
    137                          "<Donald Duck & Co>")
    138 
    139     def test_unescape_extra(self):
    140         self.assertEqual(unescape("Hei p deg", {"" : "&aring;"}),
    141                          "Hei p&aring; deg")
    142 
    143     def test_unescape_amp_extra(self):
    144         self.assertEqual(unescape("&amp;foo;", {"&foo;": "splat"}), "&foo;")
    145 
    146     # ===== quoteattr
    147     def test_quoteattr_basic(self):
    148         self.assertEqual(quoteattr("Donald Duck & Co"),
    149                          '"Donald Duck &amp; Co"')
    150 
    151     def test_single_quoteattr(self):
    152         self.assertEqual(quoteattr('Includes "double" quotes'),
    153                          '\'Includes "double" quotes\'')
    154 
    155     def test_double_quoteattr(self):
    156         self.assertEqual(quoteattr("Includes 'single' quotes"),
    157                          "\"Includes 'single' quotes\"")
    158 
    159     def test_single_double_quoteattr(self):
    160         self.assertEqual(quoteattr("Includes 'single' and \"double\" quotes"),
    161                          "\"Includes 'single' and &quot;double&quot; quotes\"")
    162 
    163     # ===== make_parser
    164     def test_make_parser(self):
    165         # Creating a parser should succeed - it should fall back
    166         # to the expatreader
    167         p = make_parser(['xml.parsers.no_such_parser'])
    168 
    169 
    170 # ===== XMLGenerator
    171 
    172 start = '<?xml version="1.0" encoding="iso-8859-1"?>\n'
    173 
    174 class XmlgenTest:
    175     def test_xmlgen_basic(self):
    176         result = self.ioclass()
    177         gen = XMLGenerator(result)
    178         gen.startDocument()
    179         gen.startElement("doc", {})
    180         gen.endElement("doc")
    181         gen.endDocument()
    182 
    183         self.assertEqual(result.getvalue(), start + "<doc></doc>")
    184 
    185     def test_xmlgen_content(self):
    186         result = self.ioclass()
    187         gen = XMLGenerator(result)
    188 
    189         gen.startDocument()
    190         gen.startElement("doc", {})
    191         gen.characters("huhei")
    192         gen.endElement("doc")
    193         gen.endDocument()
    194 
    195         self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
    196 
    197     def test_xmlgen_pi(self):
    198         result = self.ioclass()
    199         gen = XMLGenerator(result)
    200 
    201         gen.startDocument()
    202         gen.processingInstruction("test", "data")
    203         gen.startElement("doc", {})
    204         gen.endElement("doc")
    205         gen.endDocument()
    206 
    207         self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>")
    208 
    209     def test_xmlgen_content_escape(self):
    210         result = self.ioclass()
    211         gen = XMLGenerator(result)
    212 
    213         gen.startDocument()
    214         gen.startElement("doc", {})
    215         gen.characters("<huhei&")
    216         gen.endElement("doc")
    217         gen.endDocument()
    218 
    219         self.assertEqual(result.getvalue(),
    220             start + "<doc>&lt;huhei&amp;</doc>")
    221 
    222     def test_xmlgen_attr_escape(self):
    223         result = self.ioclass()
    224         gen = XMLGenerator(result)
    225 
    226         gen.startDocument()
    227         gen.startElement("doc", {"a": '"'})
    228         gen.startElement("e", {"a": "'"})
    229         gen.endElement("e")
    230         gen.startElement("e", {"a": "'\""})
    231         gen.endElement("e")
    232         gen.startElement("e", {"a": "\n\r\t"})
    233         gen.endElement("e")
    234         gen.endElement("doc")
    235         gen.endDocument()
    236 
    237         self.assertEqual(result.getvalue(), start +
    238             ("<doc a='\"'><e a=\"'\"></e>"
    239              "<e a=\"'&quot;\"></e>"
    240              "<e a=\"&#10;&#13;&#9;\"></e></doc>"))
    241 
    242     def test_xmlgen_encoding(self):
    243         encodings = ('iso-8859-15', 'utf-8',
    244                      'utf-16be', 'utf-16le',
    245                      'utf-32be', 'utf-32le')
    246         for encoding in encodings:
    247             result = self.ioclass()
    248             gen = XMLGenerator(result, encoding=encoding)
    249 
    250             gen.startDocument()
    251             gen.startElement("doc", {"a": u'\u20ac'})
    252             gen.characters(u"\u20ac")
    253             gen.endElement("doc")
    254             gen.endDocument()
    255 
    256             self.assertEqual(result.getvalue(), (
    257                 u'<?xml version="1.0" encoding="%s"?>\n'
    258                 u'<doc a="\u20ac">\u20ac</doc>' % encoding
    259                 ).encode(encoding, 'xmlcharrefreplace'))
    260 
    261     def test_xmlgen_unencodable(self):
    262         result = self.ioclass()
    263         gen = XMLGenerator(result, encoding='ascii')
    264 
    265         gen.startDocument()
    266         gen.startElement("doc", {"a": u'\u20ac'})
    267         gen.characters(u"\u20ac")
    268         gen.endElement("doc")
    269         gen.endDocument()
    270 
    271         self.assertEqual(result.getvalue(),
    272                 '<?xml version="1.0" encoding="ascii"?>\n'
    273                 '<doc a="&#8364;">&#8364;</doc>')
    274 
    275     def test_xmlgen_ignorable(self):
    276         result = self.ioclass()
    277         gen = XMLGenerator(result)
    278 
    279         gen.startDocument()
    280         gen.startElement("doc", {})
    281         gen.ignorableWhitespace(" ")
    282         gen.endElement("doc")
    283         gen.endDocument()
    284 
    285         self.assertEqual(result.getvalue(), start + "<doc> </doc>")
    286 
    287     def test_xmlgen_ns(self):
    288         result = self.ioclass()
    289         gen = XMLGenerator(result)
    290 
    291         gen.startDocument()
    292         gen.startPrefixMapping("ns1", ns_uri)
    293         gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
    294         # add an unqualified name
    295         gen.startElementNS((None, "udoc"), None, {})
    296         gen.endElementNS((None, "udoc"), None)
    297         gen.endElementNS((ns_uri, "doc"), "ns1:doc")
    298         gen.endPrefixMapping("ns1")
    299         gen.endDocument()
    300 
    301         self.assertEqual(result.getvalue(), start + \
    302            ('<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
    303                                          ns_uri))
    304 
    305     def test_1463026_1(self):
    306         result = self.ioclass()
    307         gen = XMLGenerator(result)
    308 
    309         gen.startDocument()
    310         gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
    311         gen.endElementNS((None, 'a'), 'a')
    312         gen.endDocument()
    313 
    314         self.assertEqual(result.getvalue(), start+'<a b="c"></a>')
    315 
    316     def test_1463026_2(self):
    317         result = self.ioclass()
    318         gen = XMLGenerator(result)
    319 
    320         gen.startDocument()
    321         gen.startPrefixMapping(None, 'qux')
    322         gen.startElementNS(('qux', 'a'), 'a', {})
    323         gen.endElementNS(('qux', 'a'), 'a')
    324         gen.endPrefixMapping(None)
    325         gen.endDocument()
    326 
    327         self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>')
    328 
    329     def test_1463026_3(self):
    330         result = self.ioclass()
    331         gen = XMLGenerator(result)
    332 
    333         gen.startDocument()
    334         gen.startPrefixMapping('my', 'qux')
    335         gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
    336         gen.endElementNS(('qux', 'a'), 'a')
    337         gen.endPrefixMapping('my')
    338         gen.endDocument()
    339 
    340         self.assertEqual(result.getvalue(),
    341             start+'<my:a xmlns:my="qux" b="c"></my:a>')
    342 
    343     def test_5027_1(self):
    344         # The xml prefix (as in xml:lang below) is reserved and bound by
    345         # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
    346         # a bug whereby a KeyError is raised because this namespace is missing
    347         # from a dictionary.
    348         #
    349         # This test demonstrates the bug by parsing a document.
    350         test_xml = StringIO(
    351             '<?xml version="1.0"?>'
    352             '<a:g1 xmlns:a="http://example.com/ns">'
    353              '<a:g2 xml:lang="en">Hello</a:g2>'
    354             '</a:g1>')
    355 
    356         parser = make_parser()
    357         parser.setFeature(feature_namespaces, True)
    358         result = self.ioclass()
    359         gen = XMLGenerator(result)
    360         parser.setContentHandler(gen)
    361         parser.parse(test_xml)
    362 
    363         self.assertEqual(result.getvalue(),
    364                          start + (
    365                          '<a:g1 xmlns:a="http://example.com/ns">'
    366                           '<a:g2 xml:lang="en">Hello</a:g2>'
    367                          '</a:g1>'))
    368 
    369     def test_5027_2(self):
    370         # The xml prefix (as in xml:lang below) is reserved and bound by
    371         # definition to http://www.w3.org/XML/1998/namespace.  XMLGenerator had
    372         # a bug whereby a KeyError is raised because this namespace is missing
    373         # from a dictionary.
    374         #
    375         # This test demonstrates the bug by direct manipulation of the
    376         # XMLGenerator.
    377         result = self.ioclass()
    378         gen = XMLGenerator(result)
    379 
    380         gen.startDocument()
    381         gen.startPrefixMapping('a', 'http://example.com/ns')
    382         gen.startElementNS(('http://example.com/ns', 'g1'), 'g1', {})
    383         lang_attr = {('http://www.w3.org/XML/1998/namespace', 'lang'): 'en'}
    384         gen.startElementNS(('http://example.com/ns', 'g2'), 'g2', lang_attr)
    385         gen.characters('Hello')
    386         gen.endElementNS(('http://example.com/ns', 'g2'), 'g2')
    387         gen.endElementNS(('http://example.com/ns', 'g1'), 'g1')
    388         gen.endPrefixMapping('a')
    389         gen.endDocument()
    390 
    391         self.assertEqual(result.getvalue(),
    392                          start + (
    393                          '<a:g1 xmlns:a="http://example.com/ns">'
    394                           '<a:g2 xml:lang="en">Hello</a:g2>'
    395                          '</a:g1>'))
    396 
    397     def test_no_close_file(self):
    398         result = self.ioclass()
    399         def func(out):
    400             gen = XMLGenerator(out)
    401             gen.startDocument()
    402             gen.startElement("doc", {})
    403         func(result)
    404         self.assertFalse(result.closed)
    405 
    406     def test_xmlgen_fragment(self):
    407         result = self.ioclass()
    408         gen = XMLGenerator(result)
    409 
    410         # Don't call gen.startDocument()
    411         gen.startElement("foo", {"a": "1.0"})
    412         gen.characters("Hello")
    413         gen.endElement("foo")
    414         gen.startElement("bar", {"b": "2.0"})
    415         gen.endElement("bar")
    416         # Don't call gen.endDocument()
    417 
    418         self.assertEqual(result.getvalue(),
    419                          '<foo a="1.0">Hello</foo><bar b="2.0"></bar>')
    420 
    421 class StringXmlgenTest(XmlgenTest, unittest.TestCase):
    422     ioclass = StringIO
    423 
    424 class BytesIOXmlgenTest(XmlgenTest, unittest.TestCase):
    425     ioclass = io.BytesIO
    426 
    427 class WriterXmlgenTest(XmlgenTest, unittest.TestCase):
    428     class ioclass(list):
    429         write = list.append
    430         closed = False
    431 
    432         def getvalue(self):
    433             return b''.join(self)
    434 
    435 
    436 class XMLFilterBaseTest(unittest.TestCase):
    437     def test_filter_basic(self):
    438         result = StringIO()
    439         gen = XMLGenerator(result)
    440         filter = XMLFilterBase()
    441         filter.setContentHandler(gen)
    442 
    443         filter.startDocument()
    444         filter.startElement("doc", {})
    445         filter.characters("content")
    446         filter.ignorableWhitespace(" ")
    447         filter.endElement("doc")
    448         filter.endDocument()
    449 
    450         self.assertEqual(result.getvalue(), start + "<doc>content </doc>")
    451 
    452 # ===========================================================================
    453 #
    454 #   expatreader tests
    455 #
    456 # ===========================================================================
    457 
    458 xml_test_out = open(TEST_XMLFILE_OUT).read()
    459 
    460 class ExpatReaderTest(XmlTestBase):
    461 
    462     # ===== XMLReader support
    463 
    464     def test_expat_file(self):
    465         parser = create_parser()
    466         result = StringIO()
    467         xmlgen = XMLGenerator(result)
    468 
    469         parser.setContentHandler(xmlgen)
    470         parser.parse(open(TEST_XMLFILE))
    471 
    472         self.assertEqual(result.getvalue(), xml_test_out)
    473 
    474     @requires_unicode_filenames
    475     def test_expat_file_unicode(self):
    476         fname = support.TESTFN_UNICODE
    477         shutil.copyfile(TEST_XMLFILE, fname)
    478         self.addCleanup(support.unlink, fname)
    479 
    480         parser = create_parser()
    481         result = StringIO()
    482         xmlgen = XMLGenerator(result)
    483 
    484         parser.setContentHandler(xmlgen)
    485         parser.parse(open(fname))
    486 
    487         self.assertEqual(result.getvalue(), xml_test_out)
    488 
    489     # ===== DTDHandler support
    490 
    491     class TestDTDHandler:
    492 
    493         def __init__(self):
    494             self._notations = []
    495             self._entities  = []
    496 
    497         def notationDecl(self, name, publicId, systemId):
    498             self._notations.append((name, publicId, systemId))
    499 
    500         def unparsedEntityDecl(self, name, publicId, systemId, ndata):
    501             self._entities.append((name, publicId, systemId, ndata))
    502 
    503     def test_expat_dtdhandler(self):
    504         parser = create_parser()
    505         handler = self.TestDTDHandler()
    506         parser.setDTDHandler(handler)
    507 
    508         parser.feed('<!DOCTYPE doc [\n')
    509         parser.feed('  <!ENTITY img SYSTEM "expat.gif" NDATA GIF>\n')
    510         parser.feed('  <!NOTATION GIF PUBLIC "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN">\n')
    511         parser.feed(']>\n')
    512         parser.feed('<doc></doc>')
    513         parser.close()
    514 
    515         self.assertEqual(handler._notations,
    516             [("GIF", "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN", None)])
    517         self.assertEqual(handler._entities, [("img", None, "expat.gif", "GIF")])
    518 
    519     # ===== EntityResolver support
    520 
    521     class TestEntityResolver:
    522 
    523         def resolveEntity(self, publicId, systemId):
    524             inpsrc = InputSource()
    525             inpsrc.setByteStream(StringIO("<entity/>"))
    526             return inpsrc
    527 
    528     def test_expat_entityresolver(self):
    529         parser = create_parser()
    530         parser.setEntityResolver(self.TestEntityResolver())
    531         result = StringIO()
    532         parser.setContentHandler(XMLGenerator(result))
    533 
    534         parser.feed('<!DOCTYPE doc [\n')
    535         parser.feed('  <!ENTITY test SYSTEM "whatever">\n')
    536         parser.feed(']>\n')
    537         parser.feed('<doc>&test;</doc>')
    538         parser.close()
    539 
    540         self.assertEqual(result.getvalue(), start +
    541                          "<doc><entity></entity></doc>")
    542 
    543     # ===== Attributes support
    544 
    545     class AttrGatherer(ContentHandler):
    546 
    547         def startElement(self, name, attrs):
    548             self._attrs = attrs
    549 
    550         def startElementNS(self, name, qname, attrs):
    551             self._attrs = attrs
    552 
    553     def test_expat_attrs_empty(self):
    554         parser = create_parser()
    555         gather = self.AttrGatherer()
    556         parser.setContentHandler(gather)
    557 
    558         parser.feed("<doc/>")
    559         parser.close()
    560 
    561         self.verify_empty_attrs(gather._attrs)
    562 
    563     def test_expat_attrs_wattr(self):
    564         parser = create_parser()
    565         gather = self.AttrGatherer()
    566         parser.setContentHandler(gather)
    567 
    568         parser.feed("<doc attr='val'/>")
    569         parser.close()
    570 
    571         self.verify_attrs_wattr(gather._attrs)
    572 
    573     def test_expat_nsattrs_empty(self):
    574         parser = create_parser(1)
    575         gather = self.AttrGatherer()
    576         parser.setContentHandler(gather)
    577 
    578         parser.feed("<doc/>")
    579         parser.close()
    580 
    581         self.verify_empty_nsattrs(gather._attrs)
    582 
    583     def test_expat_nsattrs_wattr(self):
    584         parser = create_parser(1)
    585         gather = self.AttrGatherer()
    586         parser.setContentHandler(gather)
    587 
    588         parser.feed("<doc xmlns:ns='%s' ns:attr='val'/>" % ns_uri)
    589         parser.close()
    590 
    591         attrs = gather._attrs
    592 
    593         self.assertEqual(attrs.getLength(), 1)
    594         self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
    595         self.assertTrue((attrs.getQNames() == [] or
    596                          attrs.getQNames() == ["ns:attr"]))
    597         self.assertEqual(len(attrs), 1)
    598         self.assertTrue(attrs.has_key((ns_uri, "attr")))
    599         self.assertEqual(attrs.get((ns_uri, "attr")), "val")
    600         self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
    601         self.assertEqual(attrs.items(), [((ns_uri, "attr"), "val")])
    602         self.assertEqual(attrs.values(), ["val"])
    603         self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
    604         self.assertEqual(attrs[(ns_uri, "attr")], "val")
    605 
    606     # ===== InputSource support
    607 
    608     def test_expat_inpsource_filename(self):
    609         parser = create_parser()
    610         result = StringIO()
    611         xmlgen = XMLGenerator(result)
    612 
    613         parser.setContentHandler(xmlgen)
    614         parser.parse(TEST_XMLFILE)
    615 
    616         self.assertEqual(result.getvalue(), xml_test_out)
    617 
    618     def test_expat_inpsource_sysid(self):
    619         parser = create_parser()
    620         result = StringIO()
    621         xmlgen = XMLGenerator(result)
    622 
    623         parser.setContentHandler(xmlgen)
    624         parser.parse(InputSource(TEST_XMLFILE))
    625 
    626         self.assertEqual(result.getvalue(), xml_test_out)
    627 
    628     @requires_unicode_filenames
    629     def test_expat_inpsource_sysid_unicode(self):
    630         fname = support.TESTFN_UNICODE
    631         shutil.copyfile(TEST_XMLFILE, fname)
    632         self.addCleanup(support.unlink, fname)
    633 
    634         parser = create_parser()
    635         result = StringIO()
    636         xmlgen = XMLGenerator(result)
    637 
    638         parser.setContentHandler(xmlgen)
    639         parser.parse(InputSource(fname))
    640 
    641         self.assertEqual(result.getvalue(), xml_test_out)
    642 
    643     def test_expat_inpsource_stream(self):
    644         parser = create_parser()
    645         result = StringIO()
    646         xmlgen = XMLGenerator(result)
    647 
    648         parser.setContentHandler(xmlgen)
    649         inpsrc = InputSource()
    650         inpsrc.setByteStream(open(TEST_XMLFILE))
    651         parser.parse(inpsrc)
    652 
    653         self.assertEqual(result.getvalue(), xml_test_out)
    654 
    655     # ===== IncrementalParser support
    656 
    657     def test_expat_incremental(self):
    658         result = StringIO()
    659         xmlgen = XMLGenerator(result)
    660         parser = create_parser()
    661         parser.setContentHandler(xmlgen)
    662 
    663         parser.feed("<doc>")
    664         parser.feed("</doc>")
    665         parser.close()
    666 
    667         self.assertEqual(result.getvalue(), start + "<doc></doc>")
    668 
    669     def test_expat_incremental_reset(self):
    670         result = StringIO()
    671         xmlgen = XMLGenerator(result)
    672         parser = create_parser()
    673         parser.setContentHandler(xmlgen)
    674 
    675         parser.feed("<doc>")
    676         parser.feed("text")
    677 
    678         result = StringIO()
    679         xmlgen = XMLGenerator(result)
    680         parser.setContentHandler(xmlgen)
    681         parser.reset()
    682 
    683         parser.feed("<doc>")
    684         parser.feed("text")
    685         parser.feed("</doc>")
    686         parser.close()
    687 
    688         self.assertEqual(result.getvalue(), start + "<doc>text</doc>")
    689 
    690     # ===== Locator support
    691 
    692     def test_expat_locator_noinfo(self):
    693         result = StringIO()
    694         xmlgen = XMLGenerator(result)
    695         parser = create_parser()
    696         parser.setContentHandler(xmlgen)
    697 
    698         parser.feed("<doc>")
    699         parser.feed("</doc>")
    700         parser.close()
    701 
    702         self.assertEqual(parser.getSystemId(), None)
    703         self.assertEqual(parser.getPublicId(), None)
    704         self.assertEqual(parser.getLineNumber(), 1)
    705 
    706     def test_expat_locator_withinfo(self):
    707         result = StringIO()
    708         xmlgen = XMLGenerator(result)
    709         parser = create_parser()
    710         parser.setContentHandler(xmlgen)
    711         parser.parse(TEST_XMLFILE)
    712 
    713         self.assertEqual(parser.getSystemId(), TEST_XMLFILE)
    714         self.assertEqual(parser.getPublicId(), None)
    715 
    716     @requires_unicode_filenames
    717     def test_expat_locator_withinfo_unicode(self):
    718         fname = support.TESTFN_UNICODE
    719         shutil.copyfile(TEST_XMLFILE, fname)
    720         self.addCleanup(support.unlink, fname)
    721 
    722         result = StringIO()
    723         xmlgen = XMLGenerator(result)
    724         parser = create_parser()
    725         parser.setContentHandler(xmlgen)
    726         parser.parse(fname)
    727 
    728         self.assertEqual(parser.getSystemId(), fname)
    729         self.assertEqual(parser.getPublicId(), None)
    730 
    731 
    732 # ===========================================================================
    733 #
    734 #   error reporting
    735 #
    736 # ===========================================================================
    737 
    738 class ErrorReportingTest(unittest.TestCase):
    739     def test_expat_inpsource_location(self):
    740         parser = create_parser()
    741         parser.setContentHandler(ContentHandler()) # do nothing
    742         source = InputSource()
    743         source.setByteStream(StringIO("<foo bar foobar>"))   #ill-formed
    744         name = "a file name"
    745         source.setSystemId(name)
    746         try:
    747             parser.parse(source)
    748             self.fail()
    749         except SAXException, e:
    750             self.assertEqual(e.getSystemId(), name)
    751 
    752     def test_expat_incomplete(self):
    753         parser = create_parser()
    754         parser.setContentHandler(ContentHandler()) # do nothing
    755         self.assertRaises(SAXParseException, parser.parse, StringIO("<foo>"))
    756 
    757     def test_sax_parse_exception_str(self):
    758         # pass various values from a locator to the SAXParseException to
    759         # make sure that the __str__() doesn't fall apart when None is
    760         # passed instead of an integer line and column number
    761         #
    762         # use "normal" values for the locator:
    763         str(SAXParseException("message", None,
    764                               self.DummyLocator(1, 1)))
    765         # use None for the line number:
    766         str(SAXParseException("message", None,
    767                               self.DummyLocator(None, 1)))
    768         # use None for the column number:
    769         str(SAXParseException("message", None,
    770                               self.DummyLocator(1, None)))
    771         # use None for both:
    772         str(SAXParseException("message", None,
    773                               self.DummyLocator(None, None)))
    774 
    775     class DummyLocator:
    776         def __init__(self, lineno, colno):
    777             self._lineno = lineno
    778             self._colno = colno
    779 
    780         def getPublicId(self):
    781             return "pubid"
    782 
    783         def getSystemId(self):
    784             return "sysid"
    785 
    786         def getLineNumber(self):
    787             return self._lineno
    788 
    789         def getColumnNumber(self):
    790             return self._colno
    791 
    792 # ===========================================================================
    793 #
    794 #   xmlreader tests
    795 #
    796 # ===========================================================================
    797 
    798 class XmlReaderTest(XmlTestBase):
    799 
    800     # ===== AttributesImpl
    801     def test_attrs_empty(self):
    802         self.verify_empty_attrs(AttributesImpl({}))
    803 
    804     def test_attrs_wattr(self):
    805         self.verify_attrs_wattr(AttributesImpl({"attr" : "val"}))
    806 
    807     def test_nsattrs_empty(self):
    808         self.verify_empty_nsattrs(AttributesNSImpl({}, {}))
    809 
    810     def test_nsattrs_wattr(self):
    811         attrs = AttributesNSImpl({(ns_uri, "attr") : "val"},
    812                                  {(ns_uri, "attr") : "ns:attr"})
    813 
    814         self.assertEqual(attrs.getLength(), 1)
    815         self.assertEqual(attrs.getNames(), [(ns_uri, "attr")])
    816         self.assertEqual(attrs.getQNames(), ["ns:attr"])
    817         self.assertEqual(len(attrs), 1)
    818         self.assertTrue(attrs.has_key((ns_uri, "attr")))
    819         self.assertEqual(attrs.keys(), [(ns_uri, "attr")])
    820         self.assertEqual(attrs.get((ns_uri, "attr")), "val")
    821         self.assertEqual(attrs.get((ns_uri, "attr"), 25), "val")
    822         self.assertEqual(attrs.items(), [((ns_uri, "attr"), "val")])
    823         self.assertEqual(attrs.values(), ["val"])
    824         self.assertEqual(attrs.getValue((ns_uri, "attr")), "val")
    825         self.assertEqual(attrs.getValueByQName("ns:attr"), "val")
    826         self.assertEqual(attrs.getNameByQName("ns:attr"), (ns_uri, "attr"))
    827         self.assertEqual(attrs[(ns_uri, "attr")], "val")
    828         self.assertEqual(attrs.getQNameByName((ns_uri, "attr")), "ns:attr")
    829 
    830 
    831     # During the development of Python 2.5, an attempt to move the "xml"
    832     # package implementation to a new package ("xmlcore") proved painful.
    833     # The goal of this change was to allow applications to be able to
    834     # obtain and rely on behavior in the standard library implementation
    835     # of the XML support without needing to be concerned about the
    836     # availability of the PyXML implementation.
    837     #
    838     # While the existing import hackery in Lib/xml/__init__.py can cause
    839     # PyXML's _xmlpus package to supplant the "xml" package, that only
    840     # works because either implementation uses the "xml" package name for
    841     # imports.
    842     #
    843     # The move resulted in a number of problems related to the fact that
    844     # the import machinery's "package context" is based on the name that's
    845     # being imported rather than the __name__ of the actual package
    846     # containment; it wasn't possible for the "xml" package to be replaced
    847     # by a simple module that indirected imports to the "xmlcore" package.
    848     #
    849     # The following two tests exercised bugs that were introduced in that
    850     # attempt.  Keeping these tests around will help detect problems with
    851     # other attempts to provide reliable access to the standard library's
    852     # implementation of the XML support.
    853 
    854     def test_sf_1511497(self):
    855         # Bug report: http://www.python.org/sf/1511497
    856         import sys
    857         old_modules = sys.modules.copy()
    858         for modname in sys.modules.keys():
    859             if modname.startswith("xml."):
    860                 del sys.modules[modname]
    861         try:
    862             import xml.sax.expatreader
    863             module = xml.sax.expatreader
    864             self.assertEqual(module.__name__, "xml.sax.expatreader")
    865         finally:
    866             sys.modules.update(old_modules)
    867 
    868     def test_sf_1513611(self):
    869         # Bug report: http://www.python.org/sf/1513611
    870         sio = StringIO("invalid")
    871         parser = make_parser()
    872         from xml.sax import SAXParseException
    873         self.assertRaises(SAXParseException, parser.parse, sio)
    874 
    875 
    876 def test_main():
    877     run_unittest(MakeParserTest,
    878                  SaxutilsTest,
    879                  StringXmlgenTest,
    880                  BytesIOXmlgenTest,
    881                  WriterXmlgenTest,
    882                  ExpatReaderTest,
    883                  ErrorReportingTest,
    884                  XmlReaderTest)
    885 
    886 if __name__ == "__main__":
    887     test_main()
    888