Home | History | Annotate | Download | only in sax
      1 """\
      2 A library of useful helper classes to the SAX classes, for the
      3 convenience of application and driver writers.
      4 """
      5 
      6 import os, urlparse, urllib, types
      7 import io
      8 import sys
      9 import handler
     10 import xmlreader
     11 
     12 try:
     13     _StringTypes = [types.StringType, types.UnicodeType]
     14 except AttributeError:
     15     _StringTypes = [types.StringType]
     16 
     17 def __dict_replace(s, d):
     18     """Replace substrings of a string using a dictionary."""
     19     for key, value in d.items():
     20         s = s.replace(key, value)
     21     return s
     22 
     23 def escape(data, entities={}):
     24     """Escape &, <, and > in a string of data.
     25 
     26     You can escape other strings of data by passing a dictionary as
     27     the optional entities parameter.  The keys and values must all be
     28     strings; each key will be replaced with its corresponding value.
     29     """
     30 
     31     # must do ampersand first

     32     data = data.replace("&", "&amp;")
     33     data = data.replace(">", "&gt;")
     34     data = data.replace("<", "&lt;")
     35     if entities:
     36         data = __dict_replace(data, entities)
     37     return data
     38 
     39 def unescape(data, entities={}):
     40     """Unescape &amp;, &lt;, and &gt; in a string of data.
     41 
     42     You can unescape other strings of data by passing a dictionary as
     43     the optional entities parameter.  The keys and values must all be
     44     strings; each key will be replaced with its corresponding value.
     45     """
     46     data = data.replace("&lt;", "<")
     47     data = data.replace("&gt;", ">")
     48     if entities:
     49         data = __dict_replace(data, entities)
     50     # must do ampersand last

     51     return data.replace("&amp;", "&")
     52 
     53 def quoteattr(data, entities={}):
     54     """Escape and quote an attribute value.
     55 
     56     Escape &, <, and > in a string of data, then quote it for use as
     57     an attribute value.  The \" character will be escaped as well, if
     58     necessary.
     59 
     60     You can escape other strings of data by passing a dictionary as
     61     the optional entities parameter.  The keys and values must all be
     62     strings; each key will be replaced with its corresponding value.
     63     """
     64     entities = entities.copy()
     65     entities.update({'\n': '&#10;', '\r': '&#13;', '\t':'&#9;'})
     66     data = escape(data, entities)
     67     if '"' in data:
     68         if "'" in data:
     69             data = '"%s"' % data.replace('"', "&quot;")
     70         else:
     71             data = "'%s'" % data
     72     else:
     73         data = '"%s"' % data
     74     return data
     75 
     76 
     77 def _gettextwriter(out, encoding):
     78     if out is None:
     79         import sys
     80         out = sys.stdout
     81 
     82     if isinstance(out, io.RawIOBase):
     83         buffer = io.BufferedIOBase(out)
     84         # Keep the original file open when the TextIOWrapper is

     85         # destroyed

     86         buffer.close = lambda: None
     87     else:
     88         # This is to handle passed objects that aren't in the

     89         # IOBase hierarchy, but just have a write method

     90         buffer = io.BufferedIOBase()
     91         buffer.writable = lambda: True
     92         buffer.write = out.write
     93         try:
     94             # TextIOWrapper uses this methods to determine

     95             # if BOM (for UTF-16, etc) should be added

     96             buffer.seekable = out.seekable
     97             buffer.tell = out.tell
     98         except AttributeError:
     99             pass
    100     # wrap a binary writer with TextIOWrapper

    101     return _UnbufferedTextIOWrapper(buffer, encoding=encoding,
    102                                    errors='xmlcharrefreplace',
    103                                    newline='\n')
    104 
    105 
    106 class _UnbufferedTextIOWrapper(io.TextIOWrapper):
    107     def write(self, s):
    108         super(_UnbufferedTextIOWrapper, self).write(s)
    109         self.flush()
    110 
    111 
    112 class XMLGenerator(handler.ContentHandler):
    113 
    114     def __init__(self, out=None, encoding="iso-8859-1"):
    115         handler.ContentHandler.__init__(self)
    116         out = _gettextwriter(out, encoding)
    117         self._write = out.write
    118         self._flush = out.flush
    119         self._ns_contexts = [{}] # contains uri -> prefix dicts

    120         self._current_context = self._ns_contexts[-1]
    121         self._undeclared_ns_maps = []
    122         self._encoding = encoding
    123 
    124     def _qname(self, name):
    125         """Builds a qualified name from a (ns_url, localname) pair"""
    126         if name[0]:
    127             # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is

    128             # bound by definition to http://www.w3.org/XML/1998/namespace.  It

    129             # does not need to be declared and will not usually be found in

    130             # self._current_context.

    131             if 'http://www.w3.org/XML/1998/namespace' == name[0]:
    132                 return 'xml:' + name[1]
    133             # The name is in a non-empty namespace

    134             prefix = self._current_context[name[0]]
    135             if prefix:
    136                 # If it is not the default namespace, prepend the prefix

    137                 return prefix + ":" + name[1]
    138         # Return the unqualified name

    139         return name[1]
    140 
    141     # ContentHandler methods

    142 
    143     def startDocument(self):
    144         self._write(u'<?xml version="1.0" encoding="%s"?>\n' %
    145                         self._encoding)
    146 
    147     def endDocument(self):
    148         self._flush()
    149 
    150     def startPrefixMapping(self, prefix, uri):
    151         self._ns_contexts.append(self._current_context.copy())
    152         self._current_context[uri] = prefix
    153         self._undeclared_ns_maps.append((prefix, uri))
    154 
    155     def endPrefixMapping(self, prefix):
    156         self._current_context = self._ns_contexts[-1]
    157         del self._ns_contexts[-1]
    158 
    159     def startElement(self, name, attrs):
    160         self._write(u'<' + name)
    161         for (name, value) in attrs.items():
    162             self._write(u' %s=%s' % (name, quoteattr(value)))
    163         self._write(u'>')
    164 
    165     def endElement(self, name):
    166         self._write(u'</%s>' % name)
    167 
    168     def startElementNS(self, name, qname, attrs):
    169         self._write(u'<' + self._qname(name))
    170 
    171         for prefix, uri in self._undeclared_ns_maps:
    172             if prefix:
    173                 self._write(u' xmlns:%s="%s"' % (prefix, uri))
    174             else:
    175                 self._write(u' xmlns="%s"' % uri)
    176         self._undeclared_ns_maps = []
    177 
    178         for (name, value) in attrs.items():
    179             self._write(u' %s=%s' % (self._qname(name), quoteattr(value)))
    180         self._write(u'>')
    181 
    182     def endElementNS(self, name, qname):
    183         self._write(u'</%s>' % self._qname(name))
    184 
    185     def characters(self, content):
    186         if not isinstance(content, unicode):
    187             content = unicode(content, self._encoding)
    188         self._write(escape(content))
    189 
    190     def ignorableWhitespace(self, content):
    191         if not isinstance(content, unicode):
    192             content = unicode(content, self._encoding)
    193         self._write(content)
    194 
    195     def processingInstruction(self, target, data):
    196         self._write(u'<?%s %s?>' % (target, data))
    197 
    198 
    199 class XMLFilterBase(xmlreader.XMLReader):
    200     """This class is designed to sit between an XMLReader and the
    201     client application's event handlers.  By default, it does nothing
    202     but pass requests up to the reader and events on to the handlers
    203     unmodified, but subclasses can override specific methods to modify
    204     the event stream or the configuration requests as they pass
    205     through."""
    206 
    207     def __init__(self, parent = None):
    208         xmlreader.XMLReader.__init__(self)
    209         self._parent = parent
    210 
    211     # ErrorHandler methods

    212 
    213     def error(self, exception):
    214         self._err_handler.error(exception)
    215 
    216     def fatalError(self, exception):
    217         self._err_handler.fatalError(exception)
    218 
    219     def warning(self, exception):
    220         self._err_handler.warning(exception)
    221 
    222     # ContentHandler methods

    223 
    224     def setDocumentLocator(self, locator):
    225         self._cont_handler.setDocumentLocator(locator)
    226 
    227     def startDocument(self):
    228         self._cont_handler.startDocument()
    229 
    230     def endDocument(self):
    231         self._cont_handler.endDocument()
    232 
    233     def startPrefixMapping(self, prefix, uri):
    234         self._cont_handler.startPrefixMapping(prefix, uri)
    235 
    236     def endPrefixMapping(self, prefix):
    237         self._cont_handler.endPrefixMapping(prefix)
    238 
    239     def startElement(self, name, attrs):
    240         self._cont_handler.startElement(name, attrs)
    241 
    242     def endElement(self, name):
    243         self._cont_handler.endElement(name)
    244 
    245     def startElementNS(self, name, qname, attrs):
    246         self._cont_handler.startElementNS(name, qname, attrs)
    247 
    248     def endElementNS(self, name, qname):
    249         self._cont_handler.endElementNS(name, qname)
    250 
    251     def characters(self, content):
    252         self._cont_handler.characters(content)
    253 
    254     def ignorableWhitespace(self, chars):
    255         self._cont_handler.ignorableWhitespace(chars)
    256 
    257     def processingInstruction(self, target, data):
    258         self._cont_handler.processingInstruction(target, data)
    259 
    260     def skippedEntity(self, name):
    261         self._cont_handler.skippedEntity(name)
    262 
    263     # DTDHandler methods

    264 
    265     def notationDecl(self, name, publicId, systemId):
    266         self._dtd_handler.notationDecl(name, publicId, systemId)
    267 
    268     def unparsedEntityDecl(self, name, publicId, systemId, ndata):
    269         self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)
    270 
    271     # EntityResolver methods

    272 
    273     def resolveEntity(self, publicId, systemId):
    274         return self._ent_handler.resolveEntity(publicId, systemId)
    275 
    276     # XMLReader methods

    277 
    278     def parse(self, source):
    279         self._parent.setContentHandler(self)
    280         self._parent.setErrorHandler(self)
    281         self._parent.setEntityResolver(self)
    282         self._parent.setDTDHandler(self)
    283         self._parent.parse(source)
    284 
    285     def setLocale(self, locale):
    286         self._parent.setLocale(locale)
    287 
    288     def getFeature(self, name):
    289         return self._parent.getFeature(name)
    290 
    291     def setFeature(self, name, state):
    292         self._parent.setFeature(name, state)
    293 
    294     def getProperty(self, name):
    295         return self._parent.getProperty(name)
    296 
    297     def setProperty(self, name, value):
    298         self._parent.setProperty(name, value)
    299 
    300     # XMLFilter methods

    301 
    302     def getParent(self):
    303         return self._parent
    304 
    305     def setParent(self, parent):
    306         self._parent = parent
    307 
    308 # --- Utility functions

    309 
    310 def prepare_input_source(source, base = ""):
    311     """This function takes an InputSource and an optional base URL and
    312     returns a fully resolved InputSource object ready for reading."""
    313 
    314     if type(source) in _StringTypes:
    315         source = xmlreader.InputSource(source)
    316     elif hasattr(source, "read"):
    317         f = source
    318         source = xmlreader.InputSource()
    319         source.setByteStream(f)
    320         if hasattr(f, "name"):
    321             source.setSystemId(f.name)
    322 
    323     if source.getByteStream() is None:
    324         try:
    325             sysid = source.getSystemId()
    326             basehead = os.path.dirname(os.path.normpath(base))
    327             encoding = sys.getfilesystemencoding()
    328             if isinstance(sysid, unicode):
    329                 if not isinstance(basehead, unicode):
    330                     try:
    331                         basehead = basehead.decode(encoding)
    332                     except UnicodeDecodeError:
    333                         sysid = sysid.encode(encoding)
    334             else:
    335                 if isinstance(basehead, unicode):
    336                     try:
    337                         sysid = sysid.decode(encoding)
    338                     except UnicodeDecodeError:
    339                         basehead = basehead.encode(encoding)
    340             sysidfilename = os.path.join(basehead, sysid)
    341             isfile = os.path.isfile(sysidfilename)
    342         except UnicodeError:
    343             isfile = False
    344         if isfile:
    345             source.setSystemId(sysidfilename)
    346             f = open(sysidfilename, "rb")
    347         else:
    348             source.setSystemId(urlparse.urljoin(base, source.getSystemId()))
    349             f = urllib.urlopen(source.getSystemId())
    350 
    351         source.setByteStream(f)
    352 
    353     return source
    354