Home | History | Annotate | Download | only in dns
      1 # Copyright (C) 2001-2007, 2009, 2010 Nominum, Inc.
      2 #
      3 # Permission to use, copy, modify, and distribute this software and its
      4 # documentation for any purpose with or without fee is hereby granted,
      5 # provided that the above copyright notice and this permission notice
      6 # appear in all copies.
      7 #
      8 # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
      9 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
     10 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
     11 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
     12 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
     13 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
     14 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
     15 
     16 """DNS Names.
     17 
     18 @var root: The DNS root name.
     19 @type root: dns.name.Name object
     20 @var empty: The empty DNS name.
     21 @type empty: dns.name.Name object
     22 """
     23 
     24 import cStringIO
     25 import struct
     26 import sys
     27 
     28 if sys.hexversion >= 0x02030000:
     29     import encodings.idna
     30 
     31 import dns.exception
     32 
     33 NAMERELN_NONE = 0
     34 NAMERELN_SUPERDOMAIN = 1
     35 NAMERELN_SUBDOMAIN = 2
     36 NAMERELN_EQUAL = 3
     37 NAMERELN_COMMONANCESTOR = 4
     38 
     39 class EmptyLabel(dns.exception.SyntaxError):
     40     """Raised if a label is empty."""
     41     pass
     42 
     43 class BadEscape(dns.exception.SyntaxError):
     44     """Raised if an escaped code in a text format name is invalid."""
     45     pass
     46 
     47 class BadPointer(dns.exception.FormError):
     48     """Raised if a compression pointer points forward instead of backward."""
     49     pass
     50 
     51 class BadLabelType(dns.exception.FormError):
     52     """Raised if the label type of a wire format name is unknown."""
     53     pass
     54 
     55 class NeedAbsoluteNameOrOrigin(dns.exception.DNSException):
     56     """Raised if an attempt is made to convert a non-absolute name to
     57     wire when there is also a non-absolute (or missing) origin."""
     58     pass
     59 
     60 class NameTooLong(dns.exception.FormError):
     61     """Raised if a name is > 255 octets long."""
     62     pass
     63 
     64 class LabelTooLong(dns.exception.SyntaxError):
     65     """Raised if a label is > 63 octets long."""
     66     pass
     67 
     68 class AbsoluteConcatenation(dns.exception.DNSException):
     69     """Raised if an attempt is made to append anything other than the
     70     empty name to an absolute name."""
     71     pass
     72 
     73 class NoParent(dns.exception.DNSException):
     74     """Raised if an attempt is made to get the parent of the root name
     75     or the empty name."""
     76     pass
     77 
     78 _escaped = {
     79     '"' : True,
     80     '(' : True,
     81     ')' : True,
     82     '.' : True,
     83     ';' : True,
     84     '\\' : True,
     85     '@' : True,
     86     '$' : True
     87     }
     88 
     89 def _escapify(label):
     90     """Escape the characters in label which need it.
     91     @returns: the escaped string
     92     @rtype: string"""
     93     text = ''
     94     for c in label:
     95         if c in _escaped:
     96             text += '\\' + c
     97         elif ord(c) > 0x20 and ord(c) < 0x7F:
     98             text += c
     99         else:
    100             text += '\\%03d' % ord(c)
    101     return text
    102 
    103 def _validate_labels(labels):
    104     """Check for empty labels in the middle of a label sequence,
    105     labels that are too long, and for too many labels.
    106     @raises NameTooLong: the name as a whole is too long
    107     @raises LabelTooLong: an individual label is too long
    108     @raises EmptyLabel: a label is empty (i.e. the root label) and appears
    109     in a position other than the end of the label sequence"""
    110 
    111     l = len(labels)
    112     total = 0
    113     i = -1
    114     j = 0
    115     for label in labels:
    116         ll = len(label)
    117         total += ll + 1
    118         if ll > 63:
    119             raise LabelTooLong
    120         if i < 0 and label == '':
    121             i = j
    122         j += 1
    123     if total > 255:
    124         raise NameTooLong
    125     if i >= 0 and i != l - 1:
    126         raise EmptyLabel
    127 
    128 class Name(object):
    129     """A DNS name.
    130 
    131     The dns.name.Name class represents a DNS name as a tuple of labels.
    132     Instances of the class are immutable.
    133 
    134     @ivar labels: The tuple of labels in the name. Each label is a string of
    135     up to 63 octets."""
    136 
    137     __slots__ = ['labels']
    138 
    139     def __init__(self, labels):
    140         """Initialize a domain name from a list of labels.
    141         @param labels: the labels
    142         @type labels: any iterable whose values are strings
    143         """
    144 
    145         super(Name, self).__setattr__('labels', tuple(labels))
    146         _validate_labels(self.labels)
    147 
    148     def __setattr__(self, name, value):
    149         raise TypeError("object doesn't support attribute assignment")
    150 
    151     def is_absolute(self):
    152         """Is the most significant label of this name the root label?
    153         @rtype: bool
    154         """
    155 
    156         return len(self.labels) > 0 and self.labels[-1] == ''
    157 
    158     def is_wild(self):
    159         """Is this name wild?  (I.e. Is the least significant label '*'?)
    160         @rtype: bool
    161         """
    162 
    163         return len(self.labels) > 0 and self.labels[0] == '*'
    164 
    165     def __hash__(self):
    166         """Return a case-insensitive hash of the name.
    167         @rtype: int
    168         """
    169 
    170         h = 0L
    171         for label in self.labels:
    172             for c in label:
    173                 h += ( h << 3 ) + ord(c.lower())
    174         return int(h % sys.maxint)
    175 
    176     def fullcompare(self, other):
    177         """Compare two names, returning a 3-tuple (relation, order, nlabels).
    178 
    179         I{relation} describes the relation ship between the names,
    180         and is one of: dns.name.NAMERELN_NONE,
    181         dns.name.NAMERELN_SUPERDOMAIN, dns.name.NAMERELN_SUBDOMAIN,
    182         dns.name.NAMERELN_EQUAL, or dns.name.NAMERELN_COMMONANCESTOR
    183 
    184         I{order} is < 0 if self < other, > 0 if self > other, and ==
    185         0 if self == other.  A relative name is always less than an
    186         absolute name.  If both names have the same relativity, then
    187         the DNSSEC order relation is used to order them.
    188 
    189         I{nlabels} is the number of significant labels that the two names
    190         have in common.
    191         """
    192 
    193         sabs = self.is_absolute()
    194         oabs = other.is_absolute()
    195         if sabs != oabs:
    196             if sabs:
    197                 return (NAMERELN_NONE, 1, 0)
    198             else:
    199                 return (NAMERELN_NONE, -1, 0)
    200         l1 = len(self.labels)
    201         l2 = len(other.labels)
    202         ldiff = l1 - l2
    203         if ldiff < 0:
    204             l = l1
    205         else:
    206             l = l2
    207 
    208         order = 0
    209         nlabels = 0
    210         namereln = NAMERELN_NONE
    211         while l > 0:
    212             l -= 1
    213             l1 -= 1
    214             l2 -= 1
    215             label1 = self.labels[l1].lower()
    216             label2 = other.labels[l2].lower()
    217             if label1 < label2:
    218                 order = -1
    219                 if nlabels > 0:
    220                     namereln = NAMERELN_COMMONANCESTOR
    221                 return (namereln, order, nlabels)
    222             elif label1 > label2:
    223                 order = 1
    224                 if nlabels > 0:
    225                     namereln = NAMERELN_COMMONANCESTOR
    226                 return (namereln, order, nlabels)
    227             nlabels += 1
    228         order = ldiff
    229         if ldiff < 0:
    230             namereln = NAMERELN_SUPERDOMAIN
    231         elif ldiff > 0:
    232             namereln = NAMERELN_SUBDOMAIN
    233         else:
    234             namereln = NAMERELN_EQUAL
    235         return (namereln, order, nlabels)
    236 
    237     def is_subdomain(self, other):
    238         """Is self a subdomain of other?
    239 
    240         The notion of subdomain includes equality.
    241         @rtype: bool
    242         """
    243 
    244         (nr, o, nl) = self.fullcompare(other)
    245         if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL:
    246             return True
    247         return False
    248 
    249     def is_superdomain(self, other):
    250         """Is self a superdomain of other?
    251 
    252         The notion of subdomain includes equality.
    253         @rtype: bool
    254         """
    255 
    256         (nr, o, nl) = self.fullcompare(other)
    257         if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL:
    258             return True
    259         return False
    260 
    261     def canonicalize(self):
    262         """Return a name which is equal to the current name, but is in
    263         DNSSEC canonical form.
    264         @rtype: dns.name.Name object
    265         """
    266 
    267         return Name([x.lower() for x in self.labels])
    268 
    269     def __eq__(self, other):
    270         if isinstance(other, Name):
    271             return self.fullcompare(other)[1] == 0
    272         else:
    273             return False
    274 
    275     def __ne__(self, other):
    276         if isinstance(other, Name):
    277             return self.fullcompare(other)[1] != 0
    278         else:
    279             return True
    280 
    281     def __lt__(self, other):
    282         if isinstance(other, Name):
    283             return self.fullcompare(other)[1] < 0
    284         else:
    285             return NotImplemented
    286 
    287     def __le__(self, other):
    288         if isinstance(other, Name):
    289             return self.fullcompare(other)[1] <= 0
    290         else:
    291             return NotImplemented
    292 
    293     def __ge__(self, other):
    294         if isinstance(other, Name):
    295             return self.fullcompare(other)[1] >= 0
    296         else:
    297             return NotImplemented
    298 
    299     def __gt__(self, other):
    300         if isinstance(other, Name):
    301             return self.fullcompare(other)[1] > 0
    302         else:
    303             return NotImplemented
    304 
    305     def __repr__(self):
    306         return '<DNS name ' + self.__str__() + '>'
    307 
    308     def __str__(self):
    309         return self.to_text(False)
    310 
    311     def to_text(self, omit_final_dot = False):
    312         """Convert name to text format.
    313         @param omit_final_dot: If True, don't emit the final dot (denoting the
    314         root label) for absolute names.  The default is False.
    315         @rtype: string
    316         """
    317 
    318         if len(self.labels) == 0:
    319             return '@'
    320         if len(self.labels) == 1 and self.labels[0] == '':
    321             return '.'
    322         if omit_final_dot and self.is_absolute():
    323             l = self.labels[:-1]
    324         else:
    325             l = self.labels
    326         s = '.'.join(map(_escapify, l))
    327         return s
    328 
    329     def to_unicode(self, omit_final_dot = False):
    330         """Convert name to Unicode text format.
    331 
    332         IDN ACE lables are converted to Unicode.
    333 
    334         @param omit_final_dot: If True, don't emit the final dot (denoting the
    335         root label) for absolute names.  The default is False.
    336         @rtype: string
    337         """
    338 
    339         if len(self.labels) == 0:
    340             return u'@'
    341         if len(self.labels) == 1 and self.labels[0] == '':
    342             return u'.'
    343         if omit_final_dot and self.is_absolute():
    344             l = self.labels[:-1]
    345         else:
    346             l = self.labels
    347         s = u'.'.join([encodings.idna.ToUnicode(_escapify(x)) for x in l])
    348         return s
    349 
    350     def to_digestable(self, origin=None):
    351         """Convert name to a format suitable for digesting in hashes.
    352 
    353         The name is canonicalized and converted to uncompressed wire format.
    354 
    355         @param origin: If the name is relative and origin is not None, then
    356         origin will be appended to it.
    357         @type origin: dns.name.Name object
    358         @raises NeedAbsoluteNameOrOrigin: All names in wire format are
    359         absolute.  If self is a relative name, then an origin must be supplied;
    360         if it is missing, then this exception is raised
    361         @rtype: string
    362         """
    363 
    364         if not self.is_absolute():
    365             if origin is None or not origin.is_absolute():
    366                 raise NeedAbsoluteNameOrOrigin
    367             labels = list(self.labels)
    368             labels.extend(list(origin.labels))
    369         else:
    370             labels = self.labels
    371         dlabels = ["%s%s" % (chr(len(x)), x.lower()) for x in labels]
    372         return ''.join(dlabels)
    373 
    374     def to_wire(self, file = None, compress = None, origin = None):
    375         """Convert name to wire format, possibly compressing it.
    376 
    377         @param file: the file where the name is emitted (typically
    378         a cStringIO file).  If None, a string containing the wire name
    379         will be returned.
    380         @type file: file or None
    381         @param compress: The compression table.  If None (the default) names
    382         will not be compressed.
    383         @type compress: dict
    384         @param origin: If the name is relative and origin is not None, then
    385         origin will be appended to it.
    386         @type origin: dns.name.Name object
    387         @raises NeedAbsoluteNameOrOrigin: All names in wire format are
    388         absolute.  If self is a relative name, then an origin must be supplied;
    389         if it is missing, then this exception is raised
    390         """
    391 
    392         if file is None:
    393             file = cStringIO.StringIO()
    394             want_return = True
    395         else:
    396             want_return = False
    397 
    398         if not self.is_absolute():
    399             if origin is None or not origin.is_absolute():
    400                 raise NeedAbsoluteNameOrOrigin
    401             labels = list(self.labels)
    402             labels.extend(list(origin.labels))
    403         else:
    404             labels = self.labels
    405         i = 0
    406         for label in labels:
    407             n = Name(labels[i:])
    408             i += 1
    409             if not compress is None:
    410                 pos = compress.get(n)
    411             else:
    412                 pos = None
    413             if not pos is None:
    414                 value = 0xc000 + pos
    415                 s = struct.pack('!H', value)
    416                 file.write(s)
    417                 break
    418             else:
    419                 if not compress is None and len(n) > 1:
    420                     pos = file.tell()
    421                     if pos < 0xc000:
    422                         compress[n] = pos
    423                 l = len(label)
    424                 file.write(chr(l))
    425                 if l > 0:
    426                     file.write(label)
    427         if want_return:
    428             return file.getvalue()
    429 
    430     def __len__(self):
    431         """The length of the name (in labels).
    432         @rtype: int
    433         """
    434 
    435         return len(self.labels)
    436 
    437     def __getitem__(self, index):
    438         return self.labels[index]
    439 
    440     def __getslice__(self, start, stop):
    441         return self.labels[start:stop]
    442 
    443     def __add__(self, other):
    444         return self.concatenate(other)
    445 
    446     def __sub__(self, other):
    447         return self.relativize(other)
    448 
    449     def split(self, depth):
    450         """Split a name into a prefix and suffix at depth.
    451 
    452         @param depth: the number of labels in the suffix
    453         @type depth: int
    454         @raises ValueError: the depth was not >= 0 and <= the length of the
    455         name.
    456         @returns: the tuple (prefix, suffix)
    457         @rtype: tuple
    458         """
    459 
    460         l = len(self.labels)
    461         if depth == 0:
    462             return (self, dns.name.empty)
    463         elif depth == l:
    464             return (dns.name.empty, self)
    465         elif depth < 0 or depth > l:
    466             raise ValueError('depth must be >= 0 and <= the length of the name')
    467         return (Name(self[: -depth]), Name(self[-depth :]))
    468 
    469     def concatenate(self, other):
    470         """Return a new name which is the concatenation of self and other.
    471         @rtype: dns.name.Name object
    472         @raises AbsoluteConcatenation: self is absolute and other is
    473         not the empty name
    474         """
    475 
    476         if self.is_absolute() and len(other) > 0:
    477             raise AbsoluteConcatenation
    478         labels = list(self.labels)
    479         labels.extend(list(other.labels))
    480         return Name(labels)
    481 
    482     def relativize(self, origin):
    483         """If self is a subdomain of origin, return a new name which is self
    484         relative to origin.  Otherwise return self.
    485         @rtype: dns.name.Name object
    486         """
    487 
    488         if not origin is None and self.is_subdomain(origin):
    489             return Name(self[: -len(origin)])
    490         else:
    491             return self
    492 
    493     def derelativize(self, origin):
    494         """If self is a relative name, return a new name which is the
    495         concatenation of self and origin.  Otherwise return self.
    496         @rtype: dns.name.Name object
    497         """
    498 
    499         if not self.is_absolute():
    500             return self.concatenate(origin)
    501         else:
    502             return self
    503 
    504     def choose_relativity(self, origin=None, relativize=True):
    505         """Return a name with the relativity desired by the caller.  If
    506         origin is None, then self is returned.  Otherwise, if
    507         relativize is true the name is relativized, and if relativize is
    508         false the name is derelativized.
    509         @rtype: dns.name.Name object
    510         """
    511 
    512         if origin:
    513             if relativize:
    514                 return self.relativize(origin)
    515             else:
    516                 return self.derelativize(origin)
    517         else:
    518             return self
    519 
    520     def parent(self):
    521         """Return the parent of the name.
    522         @rtype: dns.name.Name object
    523         @raises NoParent: the name is either the root name or the empty name,
    524         and thus has no parent.
    525         """
    526         if self == root or self == empty:
    527             raise NoParent
    528         return Name(self.labels[1:])
    529 
    530 root = Name([''])
    531 empty = Name([])
    532 
    533 def from_unicode(text, origin = root):
    534     """Convert unicode text into a Name object.
    535 
    536     Lables are encoded in IDN ACE form.
    537 
    538     @rtype: dns.name.Name object
    539     """
    540 
    541     if not isinstance(text, unicode):
    542         raise ValueError("input to from_unicode() must be a unicode string")
    543     if not (origin is None or isinstance(origin, Name)):
    544         raise ValueError("origin must be a Name or None")
    545     labels = []
    546     label = u''
    547     escaping = False
    548     edigits = 0
    549     total = 0
    550     if text == u'@':
    551         text = u''
    552     if text:
    553         if text == u'.':
    554             return Name([''])	# no Unicode "u" on this constant!
    555         for c in text:
    556             if escaping:
    557                 if edigits == 0:
    558                     if c.isdigit():
    559                         total = int(c)
    560                         edigits += 1
    561                     else:
    562                         label += c
    563                         escaping = False
    564                 else:
    565                     if not c.isdigit():
    566                         raise BadEscape
    567                     total *= 10
    568                     total += int(c)
    569                     edigits += 1
    570                     if edigits == 3:
    571                         escaping = False
    572                         label += chr(total)
    573             elif c == u'.' or c == u'\u3002' or \
    574                  c == u'\uff0e' or c == u'\uff61':
    575                 if len(label) == 0:
    576                     raise EmptyLabel
    577                 labels.append(encodings.idna.ToASCII(label))
    578                 label = u''
    579             elif c == u'\\':
    580                 escaping = True
    581                 edigits = 0
    582                 total = 0
    583             else:
    584                 label += c
    585         if escaping:
    586             raise BadEscape
    587         if len(label) > 0:
    588             labels.append(encodings.idna.ToASCII(label))
    589         else:
    590             labels.append('')
    591     if (len(labels) == 0 or labels[-1] != '') and not origin is None:
    592         labels.extend(list(origin.labels))
    593     return Name(labels)
    594 
    595 def from_text(text, origin = root):
    596     """Convert text into a Name object.
    597     @rtype: dns.name.Name object
    598     """
    599 
    600     if not isinstance(text, str):
    601         if isinstance(text, unicode) and sys.hexversion >= 0x02030000:
    602             return from_unicode(text, origin)
    603         else:
    604             raise ValueError("input to from_text() must be a string")
    605     if not (origin is None or isinstance(origin, Name)):
    606         raise ValueError("origin must be a Name or None")
    607     labels = []
    608     label = ''
    609     escaping = False
    610     edigits = 0
    611     total = 0
    612     if text == '@':
    613         text = ''
    614     if text:
    615         if text == '.':
    616             return Name([''])
    617         for c in text:
    618             if escaping:
    619                 if edigits == 0:
    620                     if c.isdigit():
    621                         total = int(c)
    622                         edigits += 1
    623                     else:
    624                         label += c
    625                         escaping = False
    626                 else:
    627                     if not c.isdigit():
    628                         raise BadEscape
    629                     total *= 10
    630                     total += int(c)
    631                     edigits += 1
    632                     if edigits == 3:
    633                         escaping = False
    634                         label += chr(total)
    635             elif c == '.':
    636                 if len(label) == 0:
    637                     raise EmptyLabel
    638                 labels.append(label)
    639                 label = ''
    640             elif c == '\\':
    641                 escaping = True
    642                 edigits = 0
    643                 total = 0
    644             else:
    645                 label += c
    646         if escaping:
    647             raise BadEscape
    648         if len(label) > 0:
    649             labels.append(label)
    650         else:
    651             labels.append('')
    652     if (len(labels) == 0 or labels[-1] != '') and not origin is None:
    653         labels.extend(list(origin.labels))
    654     return Name(labels)
    655 
    656 def from_wire(message, current):
    657     """Convert possibly compressed wire format into a Name.
    658     @param message: the entire DNS message
    659     @type message: string
    660     @param current: the offset of the beginning of the name from the start
    661     of the message
    662     @type current: int
    663     @raises dns.name.BadPointer: a compression pointer did not point backwards
    664     in the message
    665     @raises dns.name.BadLabelType: an invalid label type was encountered.
    666     @returns: a tuple consisting of the name that was read and the number
    667     of bytes of the wire format message which were consumed reading it
    668     @rtype: (dns.name.Name object, int) tuple
    669     """
    670 
    671     if not isinstance(message, str):
    672         raise ValueError("input to from_wire() must be a byte string")
    673     labels = []
    674     biggest_pointer = current
    675     hops = 0
    676     count = ord(message[current])
    677     current += 1
    678     cused = 1
    679     while count != 0:
    680         if count < 64:
    681             labels.append(message[current : current + count])
    682             current += count
    683             if hops == 0:
    684                 cused += count
    685         elif count >= 192:
    686             current = (count & 0x3f) * 256 + ord(message[current])
    687             if hops == 0:
    688                 cused += 1
    689             if current >= biggest_pointer:
    690                 raise BadPointer
    691             biggest_pointer = current
    692             hops += 1
    693         else:
    694             raise BadLabelType
    695         count = ord(message[current])
    696         current += 1
    697         if hops == 0:
    698             cused += 1
    699     labels.append('')
    700     return (Name(labels), cused)
    701