Home | History | Annotate | Download | only in tests
      1 from __future__ import absolute_import, division, unicode_literals
      2 
      3 import os
      4 import sys
      5 import codecs
      6 import glob
      7 import xml.sax.handler
      8 
      9 base_path = os.path.split(__file__)[0]
     10 
     11 test_dir = os.path.join(base_path, 'testdata')
     12 sys.path.insert(0, os.path.abspath(os.path.join(base_path,
     13                                                 os.path.pardir,
     14                                                 os.path.pardir)))
     15 
     16 from html5lib import treebuilders
     17 del base_path
     18 
     19 # Build a dict of avaliable trees
     20 treeTypes = {"DOM": treebuilders.getTreeBuilder("dom")}
     21 
     22 # Try whatever etree implementations are avaliable from a list that are
     23 #"supposed" to work
     24 try:
     25     import xml.etree.ElementTree as ElementTree
     26     treeTypes['ElementTree'] = treebuilders.getTreeBuilder("etree", ElementTree, fullTree=True)
     27 except ImportError:
     28     try:
     29         import elementtree.ElementTree as ElementTree
     30         treeTypes['ElementTree'] = treebuilders.getTreeBuilder("etree", ElementTree, fullTree=True)
     31     except ImportError:
     32         pass
     33 
     34 try:
     35     import xml.etree.cElementTree as cElementTree
     36     treeTypes['cElementTree'] = treebuilders.getTreeBuilder("etree", cElementTree, fullTree=True)
     37 except ImportError:
     38     try:
     39         import cElementTree
     40         treeTypes['cElementTree'] = treebuilders.getTreeBuilder("etree", cElementTree, fullTree=True)
     41     except ImportError:
     42         pass
     43 
     44 try:
     45     import lxml.etree as lxml  # flake8: noqa
     46 except ImportError:
     47     pass
     48 else:
     49     treeTypes['lxml'] = treebuilders.getTreeBuilder("lxml")
     50 
     51 
     52 def get_data_files(subdirectory, files='*.dat'):
     53     return glob.glob(os.path.join(test_dir, subdirectory, files))
     54 
     55 
     56 class DefaultDict(dict):
     57     def __init__(self, default, *args, **kwargs):
     58         self.default = default
     59         dict.__init__(self, *args, **kwargs)
     60 
     61     def __getitem__(self, key):
     62         return dict.get(self, key, self.default)
     63 
     64 
     65 class TestData(object):
     66     def __init__(self, filename, newTestHeading="data", encoding="utf8"):
     67         if encoding is None:
     68             self.f = open(filename, mode="rb")
     69         else:
     70             self.f = codecs.open(filename, encoding=encoding)
     71         self.encoding = encoding
     72         self.newTestHeading = newTestHeading
     73 
     74     def __del__(self):
     75         self.f.close()
     76 
     77     def __iter__(self):
     78         data = DefaultDict(None)
     79         key = None
     80         for line in self.f:
     81             heading = self.isSectionHeading(line)
     82             if heading:
     83                 if data and heading == self.newTestHeading:
     84                     # Remove trailing newline
     85                     data[key] = data[key][:-1]
     86                     yield self.normaliseOutput(data)
     87                     data = DefaultDict(None)
     88                 key = heading
     89                 data[key] = "" if self.encoding else b""
     90             elif key is not None:
     91                 data[key] += line
     92         if data:
     93             yield self.normaliseOutput(data)
     94 
     95     def isSectionHeading(self, line):
     96         """If the current heading is a test section heading return the heading,
     97         otherwise return False"""
     98         # print(line)
     99         if line.startswith("#" if self.encoding else b"#"):
    100             return line[1:].strip()
    101         else:
    102             return False
    103 
    104     def normaliseOutput(self, data):
    105         # Remove trailing newlines
    106         for key, value in data.items():
    107             if value.endswith("\n" if self.encoding else b"\n"):
    108                 data[key] = value[:-1]
    109         return data
    110 
    111 
    112 def convert(stripChars):
    113     def convertData(data):
    114         """convert the output of str(document) to the format used in the testcases"""
    115         data = data.split("\n")
    116         rv = []
    117         for line in data:
    118             if line.startswith("|"):
    119                 rv.append(line[stripChars:])
    120             else:
    121                 rv.append(line)
    122         return "\n".join(rv)
    123     return convertData
    124 
    125 convertExpected = convert(2)
    126 
    127 
    128 def errorMessage(input, expected, actual):
    129     msg = ("Input:\n%s\nExpected:\n%s\nRecieved\n%s\n" %
    130            (repr(input), repr(expected), repr(actual)))
    131     if sys.version_info.major == 2:
    132         msg = msg.encode("ascii", "backslashreplace")
    133     return msg
    134 
    135 
    136 class TracingSaxHandler(xml.sax.handler.ContentHandler):
    137     def __init__(self):
    138         xml.sax.handler.ContentHandler.__init__(self)
    139         self.visited = []
    140 
    141     def startDocument(self):
    142         self.visited.append('startDocument')
    143 
    144     def endDocument(self):
    145         self.visited.append('endDocument')
    146 
    147     def startPrefixMapping(self, prefix, uri):
    148         # These are ignored as their order is not guaranteed
    149         pass
    150 
    151     def endPrefixMapping(self, prefix):
    152         # These are ignored as their order is not guaranteed
    153         pass
    154 
    155     def startElement(self, name, attrs):
    156         self.visited.append(('startElement', name, attrs))
    157 
    158     def endElement(self, name):
    159         self.visited.append(('endElement', name))
    160 
    161     def startElementNS(self, name, qname, attrs):
    162         self.visited.append(('startElementNS', name, qname, dict(attrs)))
    163 
    164     def endElementNS(self, name, qname):
    165         self.visited.append(('endElementNS', name, qname))
    166 
    167     def characters(self, content):
    168         self.visited.append(('characters', content))
    169 
    170     def ignorableWhitespace(self, whitespace):
    171         self.visited.append(('ignorableWhitespace', whitespace))
    172 
    173     def processingInstruction(self, target, data):
    174         self.visited.append(('processingInstruction', target, data))
    175 
    176     def skippedEntity(self, name):
    177         self.visited.append(('skippedEntity', name))
    178