Home | History | Annotate | Download | only in netprotos
      1 # Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import collections
      6 import dpkt
      7 import logging
      8 import socket
      9 import time
     10 
     11 
     12 DnsRecord = collections.namedtuple('DnsResult', ['rrname', 'rrtype', 'data', 'ts'])
     13 
     14 MDNS_IP_ADDR = '224.0.0.251'
     15 MDNS_PORT = 5353
     16 
     17 # Value to | to a class value to signal cache flush.
     18 DNS_CACHE_FLUSH = 0x8000
     19 
     20 # When considering SRV records, clients are supposed to unilaterally prefer
     21 # numerically lower priorities, then pick probabilistically by weight.
     22 # See RFC2782.
     23 # An arbitrary number that will fit in 16 bits.
     24 DEFAULT_PRIORITY = 500
     25 # An arbitrary number that will fit in 16 bits.
     26 DEFAULT_WEIGHT = 500
     27 
     28 def _RR_equals(rra, rrb):
     29     """Returns whether the two dpkt.dns.DNS.RR objects are equal."""
     30     # Compare all the members present in either object and on any RR object.
     31     keys = set(rra.__dict__.keys() + rrb.__dict__.keys() +
     32                dpkt.dns.DNS.RR.__slots__)
     33     # On RR objects, rdata is packed based on the other members and the final
     34     # packed string depends on other RR and Q elements on the same DNS/mDNS
     35     # packet.
     36     keys.discard('rdata')
     37     for key in keys:
     38         if hasattr(rra, key) != hasattr(rrb, key):
     39             return False
     40         if not hasattr(rra, key):
     41             continue
     42         if key == 'cls':
     43           # cls attribute should be masked for the cache flush bit.
     44           if (getattr(rra, key) & ~DNS_CACHE_FLUSH !=
     45                 getattr(rrb, key) & ~DNS_CACHE_FLUSH):
     46               return False
     47         else:
     48           if getattr(rra, key) != getattr(rrb, key):
     49               return False
     50     return True
     51 
     52 
     53 class ZeroconfDaemon(object):
     54     """Implements a simulated Zeroconf daemon running on the given host.
     55 
     56     This class implements part of the Multicast DNS RFC 6762 able to simulate
     57     a host exposing services or consuming services over mDNS.
     58     """
     59     def __init__(self, host, hostname, domain='local'):
     60         """Initializes the ZeroconfDameon running on the given host.
     61 
     62         For the purposes of the Zeroconf implementation, a host must have a
     63         hostname and a domain that defaults to 'local'. The ZeroconfDaemon will
     64         by default advertise the host address it is running on, which is
     65         required by some services.
     66 
     67         @param host: The Host instance where this daemon runs on.
     68         @param hostname: A string representing the hostname
     69         """
     70         self._host = host
     71         self._hostname = hostname
     72         self._domain = domain
     73         self._response_ttl = 60 # Default TTL in seconds.
     74 
     75         self._a_records = {} # Local A records.
     76         self._srv_records = {} # Local SRV records.
     77         self._ptr_records = {} # Local PTR records.
     78         self._txt_records = {} # Local TXT records.
     79 
     80         # dict() of name --> (dict() of type --> (dict() of data --> timeout))
     81         # For example: _peer_records['somehost.local'][dpkt.dns.DNS_A] \
     82         #     ['192.168.0.1'] = time.time() + 3600
     83         self._peer_records = {}
     84 
     85         # Register the host address locally.
     86         self.register_A(self.full_hostname, host.ip_addr)
     87 
     88         # Attend all the traffic to the mDNS port (unicast, multicast or
     89         # broadcast).
     90         self._sock = host.socket(socket.AF_INET, socket.SOCK_DGRAM)
     91         self._sock.listen(MDNS_IP_ADDR, MDNS_PORT, self._mdns_request)
     92 
     93         # Observer list for new responses.
     94         self._answer_callbacks = []
     95 
     96 
     97     def __del__(self):
     98         self._sock.close()
     99 
    100 
    101     @property
    102     def host(self):
    103         """The Host object where this daemon is running."""
    104         return self._host
    105 
    106 
    107     @property
    108     def hostname(self):
    109         """The hostname part within a domain."""
    110         return self._hostname
    111 
    112 
    113     @property
    114     def domain(self):
    115         """The domain where the given hostname is running."""
    116         return self._domain
    117 
    118 
    119     @property
    120     def full_hostname(self):
    121         """The full hostname designation including host and domain name."""
    122         return self._hostname + '.' + self._domain
    123 
    124 
    125     def _mdns_request(self, data, addr, port):
    126         """Handles a mDNS multicast packet.
    127 
    128         This method will generate and send a mDNS response to any query
    129         for which it has new authoritative information. Called by the Simulator
    130         as a callback for every mDNS received packet.
    131 
    132         @param data: The string contained on the UDP message.
    133         @param addr: The address where the message comes from.
    134         @param port: The port number where the message comes from.
    135         """
    136         # Parse the mDNS request using dpkt's DNS module.
    137         mdns = dpkt.dns.DNS(data)
    138         if mdns.op == 0x0000: # Query
    139             QUERY_HANDLERS = {
    140                 dpkt.dns.DNS_A: self._process_A,
    141                 dpkt.dns.DNS_PTR: self._process_PTR,
    142                 dpkt.dns.DNS_TXT: self._process_TXT,
    143                 dpkt.dns.DNS_SRV: self._process_SRV,
    144             }
    145 
    146             answers = []
    147             for q in mdns.qd: # Query entries
    148                 if q.type in QUERY_HANDLERS:
    149                     answers += QUERY_HANDLERS[q.type](q)
    150                 elif q.type == dpkt.dns.DNS_ANY:
    151                     # Special type matching any known type.
    152                     for _, handler in QUERY_HANDLERS.iteritems():
    153                         answers += handler(q)
    154             # Remove all the already known answers from the list.
    155             answers = [ans for ans in answers if not any(True
    156                 for known_ans in mdns.an if _RR_equals(known_ans, ans))]
    157 
    158             self._send_answers(answers)
    159 
    160         # Always process the received authoritative answers.
    161         answers = mdns.ns
    162 
    163         # Process the answers for response packets.
    164         if mdns.op == 0x8400: # Standard response
    165             answers.extend(mdns.an)
    166 
    167         if answers:
    168             cur_time = time.time()
    169             new_answers = []
    170             for rr in answers: # Answers RRs
    171                 # dpkt decodes the information on different fields depending on
    172                 # the response type.
    173                 if rr.type == dpkt.dns.DNS_A:
    174                     data = socket.inet_ntoa(rr.ip)
    175                 elif rr.type == dpkt.dns.DNS_PTR:
    176                     data = rr.ptrname
    177                 elif rr.type == dpkt.dns.DNS_TXT:
    178                     data = tuple(rr.text) # Convert the list to a hashable tuple
    179                 elif rr.type == dpkt.dns.DNS_SRV:
    180                     data = rr.srvname, rr.priority, rr.weight, rr.port
    181                 else:
    182                     continue # Ignore unsupported records.
    183                 if not rr.name in self._peer_records:
    184                     self._peer_records[rr.name] = {}
    185                 # Start a new cache or clear the existing if required.
    186                 if not rr.type in self._peer_records[rr.name] or (
    187                         rr.cls & DNS_CACHE_FLUSH):
    188                     self._peer_records[rr.name][rr.type] = {}
    189 
    190                 new_answers.append((rr.type, rr.name, data))
    191                 cached_ans = self._peer_records[rr.name][rr.type]
    192                 rr_timeout = cur_time + rr.ttl
    193                 # Update the answer timeout if already cached.
    194                 if data in cached_ans:
    195                     cached_ans[data] = max(cached_ans[data], rr_timeout)
    196                 else:
    197                     cached_ans[data] = rr_timeout
    198             if new_answers:
    199                 for cbk in self._answer_callbacks:
    200                     cbk(new_answers)
    201 
    202 
    203     def clear_cache(self):
    204         """Discards all the cached records."""
    205         self._peer_records = {}
    206 
    207 
    208     def _send_answers(self, answers):
    209         """Send a mDNS reply with the provided answers.
    210 
    211         This method uses the undelying Host to send an IP packet with a mDNS
    212         response containing the list of answers of the type dpkt.dns.DNS.RR.
    213         If the list is empty, no packet is sent.
    214 
    215         @param answers: The list of answers to send.
    216         """
    217         if not answers:
    218             return
    219         logging.debug('Sending response with answers: %r.', answers)
    220         resp_dns = dpkt.dns.DNS(
    221             op = dpkt.dns.DNS_AA, # Authoritative Answer.
    222             rcode = dpkt.dns.DNS_RCODE_NOERR,
    223             an = answers)
    224         # This property modifies the "op" field:
    225         resp_dns.qr = dpkt.dns.DNS_R, # Response.
    226         self._sock.send(str(resp_dns), MDNS_IP_ADDR, MDNS_PORT)
    227 
    228 
    229     ### RFC 2782 - RR for specifying the location of services (DNS SRV).
    230     def register_SRV(self, service, proto, priority, weight, port):
    231         """Publishes the SRV specified record.
    232 
    233         A SRV record defines a service on a port of a host with given properties
    234         like priority and weight. The service has a name of the form
    235         "service.proto.domain". The target host, this is, the host where the
    236         announced service is running on is set to the host where this zeroconf
    237         daemon is running, "hostname.domain".
    238 
    239         @param service: A string with the service name.
    240         @param proto: A string with the protocol name, for example "_tcp".
    241         @param priority: The service priority number as defined by RFC2782.
    242         @param weight: The service weight number as defined by RFC2782.
    243         @param port: The port number where the service is running on.
    244         """
    245         srvname = service + '.' + proto + '.' + self._domain
    246         self._srv_records[srvname] = priority, weight, port
    247 
    248 
    249     def _process_SRV(self, q):
    250         """Process a SRV query provided in |q|.
    251 
    252         @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_SRV.
    253         @return: A list of dns.DNS.RR responses to the provided query that can
    254         be empty.
    255         """
    256         if not q.name in self._srv_records:
    257             return []
    258         priority, weight, port = self._srv_records[q.name]
    259         full_hostname = self._hostname + '.' + self._domain
    260         ans = dpkt.dns.DNS.RR(
    261             type = dpkt.dns.DNS_SRV,
    262             cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH,
    263             ttl = self._response_ttl,
    264             name = q.name,
    265             srvname = full_hostname,
    266             priority = priority,
    267             weight = weight,
    268             port = port)
    269         # The target host (srvname) requires to send an A record with its IP
    270         # address. We do this as if a query for it was sent.
    271         a_qry = dpkt.dns.DNS.Q(name=full_hostname, type=dpkt.dns.DNS_A)
    272         return [ans] + self._process_A(a_qry)
    273 
    274 
    275     ### RFC 1035 - 3.4.1, Domains Names - A (IPv4 address).
    276     def register_A(self, hostname, ip_addr):
    277         """Registers an Address record (A) pointing to the given IP addres.
    278 
    279         Records registered with method are assumed authoritative.
    280 
    281         @param hostname: The full host name, for example, "somehost.local".
    282         @param ip_addr: The IPv4 address of the host, for example, "192.0.1.1".
    283         """
    284         if not hostname in self._a_records:
    285             self._a_records[hostname] = []
    286         self._a_records[hostname].append(socket.inet_aton(ip_addr))
    287 
    288 
    289     def _process_A(self, q):
    290         """Process an A query provided in |q|.
    291 
    292         @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_A.
    293         @return: A list of dns.DNS.RR responses to the provided query that can
    294         be empty.
    295         """
    296         if not q.name in self._a_records:
    297             return []
    298         answers = []
    299         for ip_addr in self._a_records[q.name]:
    300             answers.append(dpkt.dns.DNS.RR(
    301                 type = dpkt.dns.DNS_A,
    302                 cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH,
    303                 ttl = self._response_ttl,
    304                 name = q.name,
    305                 ip = ip_addr))
    306         return answers
    307 
    308 
    309     ### RFC 1035 - 3.3.12, Domain names - PTR (domain name pointer).
    310     def register_PTR(self, domain, destination):
    311         """Register a domain pointer record.
    312 
    313         A domain pointer record is simply a pointer to a hostname on the domain.
    314 
    315         @param domain: A domain name that can include a proto name, for
    316         example, "_workstation._tcp.local".
    317         @param destination: The hostname inside the given domain, for example,
    318         "my-desktop".
    319         """
    320         if not domain in self._ptr_records:
    321             self._ptr_records[domain] = []
    322         self._ptr_records[domain].append(destination)
    323 
    324 
    325     def _process_PTR(self, q):
    326         """Process a PTR query provided in |q|.
    327 
    328         @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_PTR.
    329         @return: A list of dns.DNS.RR responses to the provided query that can
    330         be empty.
    331         """
    332         if not q.name in self._ptr_records:
    333             return []
    334         answers = []
    335         for dest in self._ptr_records[q.name]:
    336             answers.append(dpkt.dns.DNS.RR(
    337                 type = dpkt.dns.DNS_PTR,
    338                 cls = dpkt.dns.DNS_IN, # Don't cache flush for PTR records.
    339                 ttl = self._response_ttl,
    340                 name = q.name,
    341                 ptrname = dest + '.' + q.name))
    342         return answers
    343 
    344 
    345     ### RFC 1035 - 3.3.14, Domain names - TXT (descriptive text).
    346     def register_TXT(self, domain, txt_list, announce=False):
    347         """Register a TXT record on a domain with given list of strings.
    348 
    349         A TXT record can hold any list of text entries whos format depends on
    350         the domain. This method replaces any previous TXT record previously
    351         registered for the given domain.
    352 
    353         @param domain: A domain name that normally can include a proto name and
    354         a service or host name.
    355         @param txt_list: A list of strings.
    356         @param announce: If True, the method will also announce the changes
    357         on the network.
    358         """
    359         self._txt_records[domain] = txt_list
    360         if announce:
    361             self._send_answers(self._process_TXT(dpkt.dns.DNS.Q(name=domain)))
    362 
    363 
    364     def _process_TXT(self, q):
    365         """Process a TXT query provided in |q|.
    366 
    367         @param q: The dns.DNS.Q query object with type dpkt.dns.DNS_TXT.
    368         @return: A list of dns.DNS.RR responses to the provided query that can
    369         be empty.
    370         """
    371         if not q.name in self._txt_records:
    372             return []
    373         text_list = self._txt_records[q.name]
    374         answer = dpkt.dns.DNS.RR(
    375             type = dpkt.dns.DNS_TXT,
    376             cls = dpkt.dns.DNS_IN | DNS_CACHE_FLUSH,
    377             ttl = self._response_ttl,
    378             name = q.name,
    379             text = text_list)
    380         return [answer]
    381 
    382 
    383     def register_service(self, unique_prefix, service_type,
    384                          protocol, port, txt_list):
    385         """Register a service in the Avahi style.
    386 
    387         Avahi exposes a convenient set of methods for manipulating "services"
    388         which are a trio of PTR, SRV, and TXT records.  This is a similar
    389         helper method for our daemon.
    390 
    391         @param unique_prefix: string unique prefix of service (part of the
    392                               canonical name).
    393         @param service_type: string type of service (e.g. '_privet').
    394         @param protocol: string protocol to use for service (e.g. '_tcp').
    395         @param port: IP port of service (e.g. 53).
    396         @param txt_list: list of txt records (e.g. ['vers=1.0', 'foo']).
    397         """
    398         service_name = '.'.join([unique_prefix, service_type])
    399         fq_service_name = '.'.join([service_name, protocol, self._domain])
    400         logging.debug('Registering service=%s on port=%d with txt records=%r',
    401                       fq_service_name, port, txt_list)
    402         self.register_SRV(
    403                 service_name, protocol, DEFAULT_PRIORITY, DEFAULT_WEIGHT, port)
    404         self.register_PTR('.'.join([service_type, protocol, self._domain]),
    405                           unique_prefix)
    406         self.register_TXT(fq_service_name, txt_list)
    407 
    408 
    409     def cached_results(self, rrname, rrtype, timestamp=None):
    410         """Return all the cached results for the requested rrname and rrtype.
    411 
    412         This method is used to request all the received mDNS answers present
    413         on the cache that were valid at the provided timestamp or later.
    414         Answers received before this timestamp whose TTL isn't long enough to
    415         make them valid at the timestamp aren't returned. On the other hand,
    416         answers received *after* the provided timestamp will always be
    417         considered, even if they weren't known at the provided timestamp point.
    418         A timestamp of None will return them all.
    419 
    420         This method allows to retrieve "volatile" answers with a TTL of zero.
    421         According to the RFC, these answers should be only considered for the
    422         "ongoing" request. To do this, call this method after a few seconds (the
    423         request timeout) after calling the send_request() method, passing to
    424         this method the returned timestamp.
    425 
    426         @param rrname: The requested domain name.
    427         @param rrtype: The DNS record type. For example, dpkt.dns.DNS_TXT.
    428         @param timestamp: The request timestamp. See description.
    429         @return: The list of matching records of the form (rrname, rrtype, data,
    430                  timeout).
    431         """
    432         if timestamp is None:
    433             timestamp = 0
    434         if not rrname in self._peer_records:
    435             return []
    436         if not rrtype in self._peer_records[rrname]:
    437             return []
    438         res = []
    439         for data, data_ts in self._peer_records[rrname][rrtype].iteritems():
    440             if data_ts >= timestamp:
    441                 res.append(DnsRecord(rrname, rrtype, data, data_ts))
    442         return res
    443 
    444 
    445     def send_request(self, queries):
    446         """Sends a request for the provided rrname and rrtype.
    447 
    448         All the known and valid answers for this request will be included in the
    449         non authoritative list of known answers together with the request. This
    450         is recommended by the RFC and avoid unnecessary responses.
    451 
    452         @param queries: A list of pairs (rrname, rrtype) where rrname is the
    453         domain name you are requesting for and the rrtype is the DNS record
    454         type. For example, ('somehost.local', dpkt.dns.DNS_ANY).
    455         @return: The timestamp where this request is sent. See cached_results().
    456         """
    457         queries = [dpkt.dns.DNS.Q(name=rrname, type=rrtype)
    458                 for rrname, rrtype in queries]
    459         # TODO(deymo): Inlcude the already known answers on the request.
    460         answers = []
    461         mdns = dpkt.dns.DNS(
    462             op = dpkt.dns.DNS_QUERY,
    463             qd = queries,
    464             an = answers)
    465         self._sock.send(str(mdns), MDNS_IP_ADDR, MDNS_PORT)
    466         return time.time()
    467 
    468 
    469     def add_answer_observer(self, callback):
    470         """Adds the callback to the list of observers for new answers.
    471 
    472         @param callback: A callable object accepting a list of tuples (rrname,
    473         rrtype, data) where rrname is the domain name, rrtype the DNS record
    474         type and data is the information associated with the answers, similar to
    475         what cached_results() returns.
    476         """
    477         self._answer_callbacks.append(callback)
    478