Home | History | Annotate | Download | only in sax
      1 """\
      2 A library of useful helper classes to the SAX classes, for the
      3 convenience of application and driver writers.
      4 """
      5 
      6 import os, urllib.parse, urllib.request
      7 import io
      8 import codecs
      9 from . import handler
     10 from . import xmlreader
     11 
     12 def __dict_replace(s, d):
     13     """Replace substrings of a string using a dictionary."""
     14     for key, value in d.items():
     15         s = s.replace(key, value)
     16     return s
     17 
     18 def escape(data, entities={}):
     19     """Escape &, <, and > in a string of data.
     20 
     21     You can escape other strings of data by passing a dictionary as
     22     the optional entities parameter.  The keys and values must all be
     23     strings; each key will be replaced with its corresponding value.
     24     """
     25 
     26     # must do ampersand first
     27     data = data.replace("&", "&amp;")
     28     data = data.replace(">", "&gt;")
     29     data = data.replace("<", "&lt;")
     30     if entities:
     31         data = __dict_replace(data, entities)
     32     return data
     33 
     34 def unescape(data, entities={}):
     35     """Unescape &amp;, &lt;, and &gt; in a string of data.
     36 
     37     You can unescape other strings of data by passing a dictionary as
     38     the optional entities parameter.  The keys and values must all be
     39     strings; each key will be replaced with its corresponding value.
     40     """
     41     data = data.replace("&lt;", "<")
     42     data = data.replace("&gt;", ">")
     43     if entities:
     44         data = __dict_replace(data, entities)
     45     # must do ampersand last
     46     return data.replace("&amp;", "&")
     47 
     48 def quoteattr(data, entities={}):
     49     """Escape and quote an attribute value.
     50 
     51     Escape &, <, and > in a string of data, then quote it for use as
     52     an attribute value.  The \" character will be escaped as well, if
     53     necessary.
     54 
     55     You can escape other strings of data by passing a dictionary as
     56     the optional entities parameter.  The keys and values must all be
     57     strings; each key will be replaced with its corresponding value.
     58     """
     59     entities = entities.copy()
     60     entities.update({'\n': '&#10;', '\r': '&#13;', '\t':'&#9;'})
     61     data = escape(data, entities)
     62     if '"' in data:
     63         if "'" in data:
     64             data = '"%s"' % data.replace('"', "&quot;")
     65         else:
     66             data = "'%s'" % data
     67     else:
     68         data = '"%s"' % data
     69     return data
     70 
     71 
     72 def _gettextwriter(out, encoding):
     73     if out is None:
     74         import sys
     75         return sys.stdout
     76 
     77     if isinstance(out, io.TextIOBase):
     78         # use a text writer as is
     79         return out
     80 
     81     if isinstance(out, (codecs.StreamWriter, codecs.StreamReaderWriter)):
     82         # use a codecs stream writer as is
     83         return out
     84 
     85     # wrap a binary writer with TextIOWrapper
     86     if isinstance(out, io.RawIOBase):
     87         # Keep the original file open when the TextIOWrapper is
     88         # destroyed
     89         class _wrapper:
     90             __class__ = out.__class__
     91             def __getattr__(self, name):
     92                 return getattr(out, name)
     93         buffer = _wrapper()
     94         buffer.close = lambda: None
     95     else:
     96         # This is to handle passed objects that aren't in the
     97         # IOBase hierarchy, but just have a write method
     98         buffer = io.BufferedIOBase()
     99         buffer.writable = lambda: True
    100         buffer.write = out.write
    101         try:
    102             # TextIOWrapper uses this methods to determine
    103             # if BOM (for UTF-16, etc) should be added
    104             buffer.seekable = out.seekable
    105             buffer.tell = out.tell
    106         except AttributeError:
    107             pass
    108     return io.TextIOWrapper(buffer, encoding=encoding,
    109                             errors='xmlcharrefreplace',
    110                             newline='\n',
    111                             write_through=True)
    112 
    113 class XMLGenerator(handler.ContentHandler):
    114 
    115     def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
    116         handler.ContentHandler.__init__(self)
    117         out = _gettextwriter(out, encoding)
    118         self._write = out.write
    119         self._flush = out.flush
    120         self._ns_contexts = [{}] # contains uri -> prefix dicts
    121         self._current_context = self._ns_contexts[-1]
    122         self._undeclared_ns_maps = []
    123         self._encoding = encoding
    124         self._short_empty_elements = short_empty_elements
    125         self._pending_start_element = False
    126 
    127     def _qname(self, name):
    128         """Builds a qualified name from a (ns_url, localname) pair"""
    129         if name[0]:
    130             # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is
    131             # bound by definition to http://www.w3.org/XML/1998/namespace.  It
    132             # does not need to be declared and will not usually be found in
    133             # self._current_context.
    134             if 'http://www.w3.org/XML/1998/namespace' == name[0]:
    135                 return 'xml:' + name[1]
    136             # The name is in a non-empty namespace
    137             prefix = self._current_context[name[0]]
    138             if prefix:
    139                 # If it is not the default namespace, prepend the prefix
    140                 return prefix + ":" + name[1]
    141         # Return the unqualified name
    142         return name[1]
    143 
    144     def _finish_pending_start_element(self,endElement=False):
    145         if self._pending_start_element:
    146             self._write('>')
    147             self._pending_start_element = False
    148 
    149     # ContentHandler methods
    150 
    151     def startDocument(self):
    152         self._write('<?xml version="1.0" encoding="%s"?>\n' %
    153                         self._encoding)
    154 
    155     def endDocument(self):
    156         self._flush()
    157 
    158     def startPrefixMapping(self, prefix, uri):
    159         self._ns_contexts.append(self._current_context.copy())
    160         self._current_context[uri] = prefix
    161         self._undeclared_ns_maps.append((prefix, uri))
    162 
    163     def endPrefixMapping(self, prefix):
    164         self._current_context = self._ns_contexts[-1]
    165         del self._ns_contexts[-1]
    166 
    167     def startElement(self, name, attrs):
    168         self._finish_pending_start_element()
    169         self._write('<' + name)
    170         for (name, value) in attrs.items():
    171             self._write(' %s=%s' % (name, quoteattr(value)))
    172         if self._short_empty_elements:
    173             self._pending_start_element = True
    174         else:
    175             self._write(">")
    176 
    177     def endElement(self, name):
    178         if self._pending_start_element:
    179             self._write('/>')
    180             self._pending_start_element = False
    181         else:
    182             self._write('</%s>' % name)
    183 
    184     def startElementNS(self, name, qname, attrs):
    185         self._finish_pending_start_element()
    186         self._write('<' + self._qname(name))
    187 
    188         for prefix, uri in self._undeclared_ns_maps:
    189             if prefix:
    190                 self._write(' xmlns:%s="%s"' % (prefix, uri))
    191             else:
    192                 self._write(' xmlns="%s"' % uri)
    193         self._undeclared_ns_maps = []
    194 
    195         for (name, value) in attrs.items():
    196             self._write(' %s=%s' % (self._qname(name), quoteattr(value)))
    197         if self._short_empty_elements:
    198             self._pending_start_element = True
    199         else:
    200             self._write(">")
    201 
    202     def endElementNS(self, name, qname):
    203         if self._pending_start_element:
    204             self._write('/>')
    205             self._pending_start_element = False
    206         else:
    207             self._write('</%s>' % self._qname(name))
    208 
    209     def characters(self, content):
    210         if content:
    211             self._finish_pending_start_element()
    212             if not isinstance(content, str):
    213                 content = str(content, self._encoding)
    214             self._write(escape(content))
    215 
    216     def ignorableWhitespace(self, content):
    217         if content:
    218             self._finish_pending_start_element()
    219             if not isinstance(content, str):
    220                 content = str(content, self._encoding)
    221             self._write(content)
    222 
    223     def processingInstruction(self, target, data):
    224         self._finish_pending_start_element()
    225         self._write('<?%s %s?>' % (target, data))
    226 
    227 
    228 class XMLFilterBase(xmlreader.XMLReader):
    229     """This class is designed to sit between an XMLReader and the
    230     client application's event handlers.  By default, it does nothing
    231     but pass requests up to the reader and events on to the handlers
    232     unmodified, but subclasses can override specific methods to modify
    233     the event stream or the configuration requests as they pass
    234     through."""
    235 
    236     def __init__(self, parent = None):
    237         xmlreader.XMLReader.__init__(self)
    238         self._parent = parent
    239 
    240     # ErrorHandler methods
    241 
    242     def error(self, exception):
    243         self._err_handler.error(exception)
    244 
    245     def fatalError(self, exception):
    246         self._err_handler.fatalError(exception)
    247 
    248     def warning(self, exception):
    249         self._err_handler.warning(exception)
    250 
    251     # ContentHandler methods
    252 
    253     def setDocumentLocator(self, locator):
    254         self._cont_handler.setDocumentLocator(locator)
    255 
    256     def startDocument(self):
    257         self._cont_handler.startDocument()
    258 
    259     def endDocument(self):
    260         self._cont_handler.endDocument()
    261 
    262     def startPrefixMapping(self, prefix, uri):
    263         self._cont_handler.startPrefixMapping(prefix, uri)
    264 
    265     def endPrefixMapping(self, prefix):
    266         self._cont_handler.endPrefixMapping(prefix)
    267 
    268     def startElement(self, name, attrs):
    269         self._cont_handler.startElement(name, attrs)
    270 
    271     def endElement(self, name):
    272         self._cont_handler.endElement(name)
    273 
    274     def startElementNS(self, name, qname, attrs):
    275         self._cont_handler.startElementNS(name, qname, attrs)
    276 
    277     def endElementNS(self, name, qname):
    278         self._cont_handler.endElementNS(name, qname)
    279 
    280     def characters(self, content):
    281         self._cont_handler.characters(content)
    282 
    283     def ignorableWhitespace(self, chars):
    284         self._cont_handler.ignorableWhitespace(chars)
    285 
    286     def processingInstruction(self, target, data):
    287         self._cont_handler.processingInstruction(target, data)
    288 
    289     def skippedEntity(self, name):
    290         self._cont_handler.skippedEntity(name)
    291 
    292     # DTDHandler methods
    293 
    294     def notationDecl(self, name, publicId, systemId):
    295         self._dtd_handler.notationDecl(name, publicId, systemId)
    296 
    297     def unparsedEntityDecl(self, name, publicId, systemId, ndata):
    298         self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)
    299 
    300     # EntityResolver methods
    301 
    302     def resolveEntity(self, publicId, systemId):
    303         return self._ent_handler.resolveEntity(publicId, systemId)
    304 
    305     # XMLReader methods
    306 
    307     def parse(self, source):
    308         self._parent.setContentHandler(self)
    309         self._parent.setErrorHandler(self)
    310         self._parent.setEntityResolver(self)
    311         self._parent.setDTDHandler(self)
    312         self._parent.parse(source)
    313 
    314     def setLocale(self, locale):
    315         self._parent.setLocale(locale)
    316 
    317     def getFeature(self, name):
    318         return self._parent.getFeature(name)
    319 
    320     def setFeature(self, name, state):
    321         self._parent.setFeature(name, state)
    322 
    323     def getProperty(self, name):
    324         return self._parent.getProperty(name)
    325 
    326     def setProperty(self, name, value):
    327         self._parent.setProperty(name, value)
    328 
    329     # XMLFilter methods
    330 
    331     def getParent(self):
    332         return self._parent
    333 
    334     def setParent(self, parent):
    335         self._parent = parent
    336 
    337 # --- Utility functions
    338 
    339 def prepare_input_source(source, base=""):
    340     """This function takes an InputSource and an optional base URL and
    341     returns a fully resolved InputSource object ready for reading."""
    342 
    343     if isinstance(source, str):
    344         source = xmlreader.InputSource(source)
    345     elif hasattr(source, "read"):
    346         f = source
    347         source = xmlreader.InputSource()
    348         if isinstance(f.read(0), str):
    349             source.setCharacterStream(f)
    350         else:
    351             source.setByteStream(f)
    352         if hasattr(f, "name") and isinstance(f.name, str):
    353             source.setSystemId(f.name)
    354 
    355     if source.getCharacterStream() is None and source.getByteStream() is None:
    356         sysid = source.getSystemId()
    357         basehead = os.path.dirname(os.path.normpath(base))
    358         sysidfilename = os.path.join(basehead, sysid)
    359         if os.path.isfile(sysidfilename):
    360             source.setSystemId(sysidfilename)
    361             f = open(sysidfilename, "rb")
    362         else:
    363             source.setSystemId(urllib.parse.urljoin(base, sysid))
    364             f = urllib.request.urlopen(source.getSystemId())
    365 
    366         source.setByteStream(f)
    367 
    368     return source
    369