Home | History | Annotate | Download | only in webpagereplay
      1 #!/usr/bin/env python
      2 # Copyright 2010 Google Inc. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #      http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 
     16 import daemonserver
     17 import errno
     18 import logging
     19 import socket
     20 import SocketServer
     21 import threading
     22 import time
     23 
     24 from third_party.dns import flags
     25 from third_party.dns import message
     26 from third_party.dns import rcode
     27 from third_party.dns import resolver
     28 from third_party.dns import rdatatype
     29 from third_party import ipaddr
     30 
     31 
     32 
     33 class DnsProxyException(Exception):
     34   pass
     35 
     36 
     37 class RealDnsLookup(object):
     38   def __init__(self, name_servers):
     39     if '127.0.0.1' in name_servers:
     40       raise DnsProxyException(
     41           'Invalid nameserver: 127.0.0.1 (causes an infinte loop)')
     42     self.resolver = resolver.get_default_resolver()
     43     self.resolver.nameservers = name_servers
     44     self.dns_cache_lock = threading.Lock()
     45     self.dns_cache = {}
     46 
     47   @staticmethod
     48   def _IsIPAddress(hostname):
     49     try:
     50       socket.inet_aton(hostname)
     51       return True
     52     except socket.error:
     53       return False
     54 
     55   def __call__(self, hostname, rdtype=rdatatype.A):
     56     """Return real IP for a host.
     57 
     58     Args:
     59       host: a hostname ending with a period (e.g. "www.google.com.")
     60       rdtype: the query type (1 for 'A', 28 for 'AAAA')
     61     Returns:
     62       the IP address as a string (e.g. "192.168.25.2")
     63     """
     64     if self._IsIPAddress(hostname):
     65       return hostname
     66     self.dns_cache_lock.acquire()
     67     ip = self.dns_cache.get(hostname)
     68     self.dns_cache_lock.release()
     69     if ip:
     70       return ip
     71     try:
     72       answers = self.resolver.query(hostname, rdtype)
     73     except resolver.NXDOMAIN:
     74       return None
     75     except resolver.NoNameservers:
     76       logging.debug('_real_dns_lookup(%s) -> No nameserver.',
     77                     hostname)
     78       return None
     79     except (resolver.NoAnswer, resolver.Timeout) as ex:
     80       logging.debug('_real_dns_lookup(%s) -> None (%s)',
     81                     hostname, ex.__class__.__name__)
     82       return None
     83     if answers:
     84       ip = str(answers[0])
     85     self.dns_cache_lock.acquire()
     86     self.dns_cache[hostname] = ip
     87     self.dns_cache_lock.release()
     88     return ip
     89 
     90   def ClearCache(self):
     91     """Clear the dns cache."""
     92     self.dns_cache_lock.acquire()
     93     self.dns_cache.clear()
     94     self.dns_cache_lock.release()
     95 
     96 
     97 class ReplayDnsLookup(object):
     98   """Resolve DNS requests to replay host."""
     99   def __init__(self, replay_ip, filters=None):
    100     self.replay_ip = replay_ip
    101     self.filters = filters or []
    102 
    103   def __call__(self, hostname):
    104     ip = self.replay_ip
    105     for f in self.filters:
    106       ip = f(hostname, default_ip=ip)
    107     return ip
    108 
    109 
    110 class PrivateIpFilter(object):
    111   """Resolve private hosts to their real IPs and others to the Web proxy IP.
    112 
    113   Hosts in the given http_archive will resolve to the Web proxy IP without
    114   checking the real IP.
    115 
    116   This only supports IPv4 lookups.
    117   """
    118   def __init__(self, real_dns_lookup, http_archive):
    119     """Initialize PrivateIpDnsLookup.
    120 
    121     Args:
    122       real_dns_lookup: a function that resolves a host to an IP.
    123       http_archive: an instance of a HttpArchive
    124         Hosts is in the archive will always resolve to the web_proxy_ip
    125     """
    126     self.real_dns_lookup = real_dns_lookup
    127     self.http_archive = http_archive
    128     self.InitializeArchiveHosts()
    129 
    130   def __call__(self, host, default_ip):
    131     """Return real IPv4 for private hosts and Web proxy IP otherwise.
    132 
    133     Args:
    134       host: a hostname ending with a period (e.g. "www.google.com.")
    135     Returns:
    136       IP address as a string or None (if lookup fails)
    137     """
    138     ip = default_ip
    139     if host not in self.archive_hosts:
    140       real_ip = self.real_dns_lookup(host)
    141       if real_ip:
    142         if ipaddr.IPAddress(real_ip).is_private:
    143           ip = real_ip
    144       else:
    145         ip = None
    146     return ip
    147 
    148   def InitializeArchiveHosts(self):
    149     """Recompute the archive_hosts from the http_archive."""
    150     self.archive_hosts = set('%s.' % req.host.split(':')[0]
    151                              for req in self.http_archive)
    152 
    153 
    154 class DelayFilter(object):
    155   """Add a delay to replayed lookups."""
    156 
    157   def __init__(self, is_record_mode, delay_ms):
    158     self.is_record_mode = is_record_mode
    159     self.delay_ms = int(delay_ms)
    160 
    161   def __call__(self, host, default_ip):
    162     if not self.is_record_mode:
    163       time.sleep(self.delay_ms * 1000.0)
    164     return default_ip
    165 
    166   def SetRecordMode(self):
    167     self.is_record_mode = True
    168 
    169   def SetReplayMode(self):
    170     self.is_record_mode = False
    171 
    172 
    173 class UdpDnsHandler(SocketServer.DatagramRequestHandler):
    174   """Resolve DNS queries to localhost.
    175 
    176   Possible alternative implementation:
    177   http://howl.play-bow.org/pipermail/dnspython-users/2010-February/000119.html
    178   """
    179 
    180   STANDARD_QUERY_OPERATION_CODE = 0
    181 
    182   def handle(self):
    183     """Handle a DNS query.
    184 
    185     IPv6 requests (with rdtype AAAA) receive mismatched IPv4 responses
    186     (with rdtype A). To properly support IPv6, the http proxy would
    187     need both types of addresses. By default, Windows XP does not
    188     support IPv6.
    189     """
    190     self.data = self.rfile.read()
    191     self.transaction_id = self.data[0]
    192     self.flags = self.data[1]
    193     self.qa_counts = self.data[4:6]
    194     self.domain = ''
    195     operation_code = (ord(self.data[2]) >> 3) & 15
    196     if operation_code == self.STANDARD_QUERY_OPERATION_CODE:
    197       self.wire_domain = self.data[12:]
    198       self.domain = self._domain(self.wire_domain)
    199     else:
    200       logging.debug("DNS request with non-zero operation code: %s",
    201                     operation_code)
    202     ip = self.server.dns_lookup(self.domain)
    203     if ip is None:
    204       logging.debug('dnsproxy: %s -> NXDOMAIN', self.domain)
    205       response = self.get_dns_no_such_name_response()
    206     else:
    207       if ip == self.server.server_address[0]:
    208         logging.debug('dnsproxy: %s -> %s (replay web proxy)', self.domain, ip)
    209       else:
    210         logging.debug('dnsproxy: %s -> %s', self.domain, ip)
    211       response = self.get_dns_response(ip)
    212     self.wfile.write(response)
    213 
    214   @classmethod
    215   def _domain(cls, wire_domain):
    216     domain = ''
    217     index = 0
    218     length = ord(wire_domain[index])
    219     while length:
    220       domain += wire_domain[index + 1:index + length + 1] + '.'
    221       index += length + 1
    222       length = ord(wire_domain[index])
    223     return domain
    224 
    225   def get_dns_response(self, ip):
    226     packet = ''
    227     if self.domain:
    228       packet = (
    229           self.transaction_id +
    230           self.flags +
    231           '\x81\x80' +        # standard query response, no error
    232           self.qa_counts * 2 + '\x00\x00\x00\x00' +  # Q&A counts
    233           self.wire_domain +
    234           '\xc0\x0c'          # pointer to domain name
    235           '\x00\x01'          # resource record type ("A" host address)
    236           '\x00\x01'          # class of the data
    237           '\x00\x00\x00\x3c'  # ttl (seconds)
    238           '\x00\x04' +        # resource data length (4 bytes for ip)
    239           socket.inet_aton(ip)
    240           )
    241     return packet
    242 
    243   def get_dns_no_such_name_response(self):
    244     query_message = message.from_wire(self.data)
    245     response_message = message.make_response(query_message)
    246     response_message.flags |= flags.AA | flags.RA
    247     response_message.set_rcode(rcode.NXDOMAIN)
    248     return response_message.to_wire()
    249 
    250 
    251 class DnsProxyServer(SocketServer.ThreadingUDPServer,
    252                      daemonserver.DaemonServer):
    253   # Increase the request queue size. The default value, 5, is set in
    254   # SocketServer.TCPServer (the parent of BaseHTTPServer.HTTPServer).
    255   # Since we're intercepting many domains through this single server,
    256   # it is quite possible to get more than 5 concurrent requests.
    257   request_queue_size = 256
    258 
    259   # Allow sockets to be reused. See
    260   # http://svn.python.org/projects/python/trunk/Lib/SocketServer.py for more
    261   # details.
    262   allow_reuse_address = True
    263 
    264   # Don't prevent python from exiting when there is thread activity.
    265   daemon_threads = True
    266 
    267   def __init__(self, host='', port=53, dns_lookup=None):
    268     """Initialize DnsProxyServer.
    269 
    270     Args:
    271       host: a host string (name or IP) to bind the dns proxy and to which
    272         DNS requests will be resolved.
    273       port: an integer port on which to bind the proxy.
    274       dns_lookup: a list of filters to apply to lookup.
    275     """
    276     try:
    277       SocketServer.ThreadingUDPServer.__init__(
    278           self, (host, port), UdpDnsHandler)
    279     except socket.error, (error_number, msg):
    280       if error_number == errno.EACCES:
    281         raise DnsProxyException(
    282             'Unable to bind DNS server on (%s:%s)' % (host, port))
    283       raise
    284     self.dns_lookup = dns_lookup or (lambda host: self.server_address[0])
    285     self.server_port = self.server_address[1]
    286     logging.warning('DNS server started on %s:%d', self.server_address[0],
    287                                                    self.server_address[1])
    288 
    289   def cleanup(self):
    290     try:
    291       self.shutdown()
    292       self.server_close()
    293     except KeyboardInterrupt, e:
    294       pass
    295     logging.info('Stopped DNS server')
    296