Home | History | Annotate | Download | only in webob
      1 # (c) 2005 Ian Bicking and contributors; written for Paste
      2 # (http://pythonpaste.org) Licensed under the MIT license:
      3 # http://www.opensource.org/licenses/mit-license.php
      4 """
      5 Gives a multi-value dictionary object (MultiDict) plus several wrappers
      6 """
      7 from collections import MutableMapping
      8 
      9 import binascii
     10 import warnings
     11 
     12 from webob.compat import (
     13     PY3,
     14     iteritems_,
     15     itervalues_,
     16     url_encode,
     17     )
     18 
     19 __all__ = ['MultiDict', 'NestedMultiDict', 'NoVars', 'GetDict']
     20 
     21 class MultiDict(MutableMapping):
     22     """
     23         An ordered dictionary that can have multiple values for each key.
     24         Adds the methods getall, getone, mixed and extend and add to the normal
     25         dictionary interface.
     26     """
     27 
     28     def __init__(self, *args, **kw):
     29         if len(args) > 1:
     30             raise TypeError("MultiDict can only be called with one positional "
     31                             "argument")
     32         if args:
     33             if hasattr(args[0], 'iteritems'):
     34                 items = list(args[0].iteritems())
     35             elif hasattr(args[0], 'items'):
     36                 items = list(args[0].items())
     37             else:
     38                 items = list(args[0])
     39             self._items = items
     40         else:
     41             self._items = []
     42         if kw:
     43             self._items.extend(kw.items())
     44 
     45     @classmethod
     46     def view_list(cls, lst):
     47         """
     48         Create a dict that is a view on the given list
     49         """
     50         if not isinstance(lst, list):
     51             raise TypeError(
     52                 "%s.view_list(obj) takes only actual list objects, not %r"
     53                 % (cls.__name__, lst))
     54         obj = cls()
     55         obj._items = lst
     56         return obj
     57 
     58     @classmethod
     59     def from_fieldstorage(cls, fs):
     60         """
     61         Create a dict from a cgi.FieldStorage instance
     62         """
     63         obj = cls()
     64         # fs.list can be None when there's nothing to parse
     65         for field in fs.list or ():
     66             charset = field.type_options.get('charset', 'utf8')
     67             transfer_encoding = field.headers.get('Content-Transfer-Encoding', None)
     68             supported_transfer_encoding = {
     69                 'base64' : binascii.a2b_base64,
     70                 'quoted-printable' : binascii.a2b_qp
     71                 }
     72             if PY3: # pragma: no cover
     73                 if charset == 'utf8':
     74                     decode = lambda b: b
     75                 else:
     76                     decode = lambda b: b.encode('utf8').decode(charset)
     77             else:
     78                 decode = lambda b: b.decode(charset)
     79             if field.filename:
     80                 field.filename = decode(field.filename)
     81                 obj.add(field.name, field)
     82             else:
     83                 value = field.value
     84                 if transfer_encoding in supported_transfer_encoding:
     85                     if PY3: # pragma: no cover
     86                         # binascii accepts bytes
     87                         value = value.encode('utf8')
     88                     value = supported_transfer_encoding[transfer_encoding](value)
     89                     if PY3: # pragma: no cover
     90                         # binascii returns bytes
     91                         value = value.decode('utf8')
     92                 obj.add(field.name, decode(value))
     93         return obj
     94 
     95     def __getitem__(self, key):
     96         for k, v in reversed(self._items):
     97             if k == key:
     98                 return v
     99         raise KeyError(key)
    100 
    101     def __setitem__(self, key, value):
    102         try:
    103             del self[key]
    104         except KeyError:
    105             pass
    106         self._items.append((key, value))
    107 
    108     def add(self, key, value):
    109         """
    110         Add the key and value, not overwriting any previous value.
    111         """
    112         self._items.append((key, value))
    113 
    114     def getall(self, key):
    115         """
    116         Return a list of all values matching the key (may be an empty list)
    117         """
    118         return [v for k, v in self._items if k == key]
    119 
    120     def getone(self, key):
    121         """
    122         Get one value matching the key, raising a KeyError if multiple
    123         values were found.
    124         """
    125         v = self.getall(key)
    126         if not v:
    127             raise KeyError('Key not found: %r' % key)
    128         if len(v) > 1:
    129             raise KeyError('Multiple values match %r: %r' % (key, v))
    130         return v[0]
    131 
    132     def mixed(self):
    133         """
    134         Returns a dictionary where the values are either single
    135         values, or a list of values when a key/value appears more than
    136         once in this dictionary.  This is similar to the kind of
    137         dictionary often used to represent the variables in a web
    138         request.
    139         """
    140         result = {}
    141         multi = {}
    142         for key, value in self.items():
    143             if key in result:
    144                 # We do this to not clobber any lists that are
    145                 # *actual* values in this dictionary:
    146                 if key in multi:
    147                     result[key].append(value)
    148                 else:
    149                     result[key] = [result[key], value]
    150                     multi[key] = None
    151             else:
    152                 result[key] = value
    153         return result
    154 
    155     def dict_of_lists(self):
    156         """
    157         Returns a dictionary where each key is associated with a list of values.
    158         """
    159         r = {}
    160         for key, val in self.items():
    161             r.setdefault(key, []).append(val)
    162         return r
    163 
    164     def __delitem__(self, key):
    165         items = self._items
    166         found = False
    167         for i in range(len(items)-1, -1, -1):
    168             if items[i][0] == key:
    169                 del items[i]
    170                 found = True
    171         if not found:
    172             raise KeyError(key)
    173 
    174     def __contains__(self, key):
    175         for k, v in self._items:
    176             if k == key:
    177                 return True
    178         return False
    179 
    180     has_key = __contains__
    181 
    182     def clear(self):
    183         del self._items[:]
    184 
    185     def copy(self):
    186         return self.__class__(self)
    187 
    188     def setdefault(self, key, default=None):
    189         for k, v in self._items:
    190             if key == k:
    191                 return v
    192         self._items.append((key, default))
    193         return default
    194 
    195     def pop(self, key, *args):
    196         if len(args) > 1:
    197             raise TypeError("pop expected at most 2 arguments, got %s"
    198                              % repr(1 + len(args)))
    199         for i in range(len(self._items)):
    200             if self._items[i][0] == key:
    201                 v = self._items[i][1]
    202                 del self._items[i]
    203                 return v
    204         if args:
    205             return args[0]
    206         else:
    207             raise KeyError(key)
    208 
    209     def popitem(self):
    210         return self._items.pop()
    211 
    212     def update(self, *args, **kw):
    213         if args:
    214             lst = args[0]
    215             if len(lst) != len(dict(lst)):
    216                 # this does not catch the cases where we overwrite existing
    217                 # keys, but those would produce too many warning
    218                 msg = ("Behavior of MultiDict.update() has changed "
    219                     "and overwrites duplicate keys. Consider using .extend()"
    220                 )
    221                 warnings.warn(msg, UserWarning, stacklevel=2)
    222         MutableMapping.update(self, *args, **kw)
    223 
    224     def extend(self, other=None, **kwargs):
    225         if other is None:
    226             pass
    227         elif hasattr(other, 'items'):
    228             self._items.extend(other.items())
    229         elif hasattr(other, 'keys'):
    230             for k in other.keys():
    231                 self._items.append((k, other[k]))
    232         else:
    233             for k, v in other:
    234                 self._items.append((k, v))
    235         if kwargs:
    236             self.update(kwargs)
    237 
    238     def __repr__(self):
    239         items = map('(%r, %r)'.__mod__, _hide_passwd(self.items()))
    240         return '%s([%s])' % (self.__class__.__name__, ', '.join(items))
    241 
    242     def __len__(self):
    243         return len(self._items)
    244 
    245     ##
    246     ## All the iteration:
    247     ##
    248 
    249     def iterkeys(self):
    250         for k, v in self._items:
    251             yield k
    252     if PY3: # pragma: no cover
    253         keys = iterkeys
    254     else:
    255         def keys(self):
    256             return [k for k, v in self._items]
    257 
    258     __iter__ = iterkeys
    259 
    260     def iteritems(self):
    261         return iter(self._items)
    262 
    263     if PY3: # pragma: no cover
    264         items = iteritems
    265     else:
    266         def items(self):
    267             return self._items[:]
    268 
    269     def itervalues(self):
    270         for k, v in self._items:
    271             yield v
    272 
    273     if PY3: # pragma: no cover
    274         values = itervalues
    275     else:
    276         def values(self):
    277             return [v for k, v in self._items]
    278 
    279 _dummy = object()
    280 
    281 class GetDict(MultiDict):
    282 #     def __init__(self, data, tracker, encoding, errors):
    283 #         d = lambda b: b.decode(encoding, errors)
    284 #         data = [(d(k), d(v)) for k,v in data]
    285     def __init__(self, data, env):
    286         self.env = env
    287         MultiDict.__init__(self, data)
    288     def on_change(self):
    289         e = lambda t: t.encode('utf8')
    290         data = [(e(k), e(v)) for k,v in self.items()]
    291         qs = url_encode(data)
    292         self.env['QUERY_STRING'] = qs
    293         self.env['webob._parsed_query_vars'] = (self, qs)
    294     def __setitem__(self, key, value):
    295         MultiDict.__setitem__(self, key, value)
    296         self.on_change()
    297     def add(self, key, value):
    298         MultiDict.add(self, key, value)
    299         self.on_change()
    300     def __delitem__(self, key):
    301         MultiDict.__delitem__(self, key)
    302         self.on_change()
    303     def clear(self):
    304         MultiDict.clear(self)
    305         self.on_change()
    306     def setdefault(self, key, default=None):
    307         result = MultiDict.setdefault(self, key, default)
    308         self.on_change()
    309         return result
    310     def pop(self, key, *args):
    311         result = MultiDict.pop(self, key, *args)
    312         self.on_change()
    313         return result
    314     def popitem(self):
    315         result = MultiDict.popitem(self)
    316         self.on_change()
    317         return result
    318     def update(self, *args, **kwargs):
    319         MultiDict.update(self, *args, **kwargs)
    320         self.on_change()
    321     def extend(self, *args, **kwargs):
    322         MultiDict.extend(self, *args, **kwargs)
    323         self.on_change()
    324     def __repr__(self):
    325         items = map('(%r, %r)'.__mod__, _hide_passwd(self.items()))
    326         # TODO: GET -> GetDict
    327         return 'GET([%s])' % (', '.join(items))
    328     def copy(self):
    329         # Copies shouldn't be tracked
    330         return MultiDict(self)
    331 
    332 class NestedMultiDict(MultiDict):
    333     """
    334     Wraps several MultiDict objects, treating it as one large MultiDict
    335     """
    336 
    337     def __init__(self, *dicts):
    338         self.dicts = dicts
    339 
    340     def __getitem__(self, key):
    341         for d in self.dicts:
    342             value = d.get(key, _dummy)
    343             if value is not _dummy:
    344                 return value
    345         raise KeyError(key)
    346 
    347     def _readonly(self, *args, **kw):
    348         raise KeyError("NestedMultiDict objects are read-only")
    349     __setitem__ = _readonly
    350     add = _readonly
    351     __delitem__ = _readonly
    352     clear = _readonly
    353     setdefault = _readonly
    354     pop = _readonly
    355     popitem = _readonly
    356     update = _readonly
    357 
    358     def getall(self, key):
    359         result = []
    360         for d in self.dicts:
    361             result.extend(d.getall(key))
    362         return result
    363 
    364     # Inherited:
    365     # getone
    366     # mixed
    367     # dict_of_lists
    368 
    369     def copy(self):
    370         return MultiDict(self)
    371 
    372     def __contains__(self, key):
    373         for d in self.dicts:
    374             if key in d:
    375                 return True
    376         return False
    377 
    378     has_key = __contains__
    379 
    380     def __len__(self):
    381         v = 0
    382         for d in self.dicts:
    383             v += len(d)
    384         return v
    385 
    386     def __nonzero__(self):
    387         for d in self.dicts:
    388             if d:
    389                 return True
    390         return False
    391 
    392     def iteritems(self):
    393         for d in self.dicts:
    394             for item in iteritems_(d):
    395                 yield item
    396     if PY3: # pragma: no cover
    397         items = iteritems
    398     else:
    399         def items(self):
    400             return list(self.iteritems())
    401 
    402     def itervalues(self):
    403         for d in self.dicts:
    404             for value in itervalues_(d):
    405                 yield value
    406     if PY3: # pragma: no cover
    407         values = itervalues
    408     else:
    409         def values(self):
    410             return list(self.itervalues())
    411 
    412     def __iter__(self):
    413         for d in self.dicts:
    414             for key in d:
    415                 yield key
    416 
    417     iterkeys = __iter__
    418 
    419     if PY3: # pragma: no cover
    420         keys = iterkeys
    421     else:
    422         def keys(self):
    423             return list(self.iterkeys())
    424 
    425 class NoVars(object):
    426     """
    427     Represents no variables; used when no variables
    428     are applicable.
    429 
    430     This is read-only
    431     """
    432 
    433     def __init__(self, reason=None):
    434         self.reason = reason or 'N/A'
    435 
    436     def __getitem__(self, key):
    437         raise KeyError("No key %r: %s" % (key, self.reason))
    438 
    439     def __setitem__(self, *args, **kw):
    440         raise KeyError("Cannot add variables: %s" % self.reason)
    441 
    442     add = __setitem__
    443     setdefault = __setitem__
    444     update = __setitem__
    445 
    446     def __delitem__(self, *args, **kw):
    447         raise KeyError("No keys to delete: %s" % self.reason)
    448     clear = __delitem__
    449     pop = __delitem__
    450     popitem = __delitem__
    451 
    452     def get(self, key, default=None):
    453         return default
    454 
    455     def getall(self, key):
    456         return []
    457 
    458     def getone(self, key):
    459         return self[key]
    460 
    461     def mixed(self):
    462         return {}
    463     dict_of_lists = mixed
    464 
    465     def __contains__(self, key):
    466         return False
    467     has_key = __contains__
    468 
    469     def copy(self):
    470         return self
    471 
    472     def __repr__(self):
    473         return '<%s: %s>' % (self.__class__.__name__,
    474                              self.reason)
    475 
    476     def __len__(self):
    477         return 0
    478 
    479     def __cmp__(self, other):
    480         return cmp({}, other)
    481 
    482     def iterkeys(self):
    483         return iter([])
    484 
    485     if PY3: # pragma: no cover
    486         keys = iterkeys
    487         items = iterkeys
    488         values = iterkeys
    489     else:
    490         def keys(self):
    491             return []
    492         items = keys
    493         values = keys
    494         itervalues = iterkeys
    495         iteritems = iterkeys
    496 
    497     __iter__ = iterkeys
    498 
    499 def _hide_passwd(items):
    500     for k, v in items:
    501         if ('password' in k
    502             or 'passwd' in k
    503             or 'pwd' in k
    504         ):
    505             yield k, '******'
    506         else:
    507             yield k, v
    508