Home | History | Annotate | Download | only in utils
      1 #!/usr/bin/env python2.7
      2 # Copyright 2015 gRPC authors.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 
     16 """Starts a local DNS server for use in tests"""
     17 
     18 import argparse
     19 import sys
     20 import yaml
     21 import signal
     22 import os
     23 import threading
     24 import time
     25 
     26 import twisted
     27 import twisted.internet
     28 import twisted.internet.reactor
     29 import twisted.internet.threads
     30 import twisted.internet.defer
     31 import twisted.internet.protocol
     32 import twisted.names
     33 import twisted.names.client
     34 import twisted.names.dns
     35 import twisted.names.server
     36 from twisted.names import client, server, common, authority, dns
     37 import argparse
     38 import platform
     39 
     40 _SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp' # missing end '.' for twisted syntax
     41 _SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123'
     42 
     43 class NoFileAuthority(authority.FileAuthority):
     44   def __init__(self, soa, records):
     45     # skip FileAuthority
     46     common.ResolverBase.__init__(self)
     47     self.soa = soa
     48     self.records = records
     49 
     50 def start_local_dns_server(args):
     51   all_records = {}
     52   def _push_record(name, r):
     53     print('pushing record: |%s|' % name)
     54     if all_records.get(name) is not None:
     55       all_records[name].append(r)
     56       return
     57     all_records[name] = [r]
     58 
     59   def _maybe_split_up_txt_data(name, txt_data, r_ttl):
     60     start = 0
     61     txt_data_list = []
     62     while len(txt_data[start:]) > 0:
     63       next_read = len(txt_data[start:])
     64       if next_read > 255:
     65         next_read = 255
     66       txt_data_list.append(txt_data[start:start+next_read])
     67       start += next_read
     68     _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl))
     69 
     70   with open(args.records_config_path) as config:
     71     test_records_config = yaml.load(config)
     72   common_zone_name = test_records_config['resolver_tests_common_zone_name']
     73   for group in test_records_config['resolver_component_tests']:
     74     for name in group['records'].keys():
     75       for record in group['records'][name]:
     76         r_type = record['type']
     77         r_data = record['data']
     78         r_ttl = int(record['TTL'])
     79         record_full_name = '%s.%s' % (name, common_zone_name)
     80         assert record_full_name[-1] == '.'
     81         record_full_name = record_full_name[:-1]
     82         if r_type == 'A':
     83           _push_record(record_full_name, dns.Record_A(r_data, ttl=r_ttl))
     84         if r_type == 'AAAA':
     85           _push_record(record_full_name, dns.Record_AAAA(r_data, ttl=r_ttl))
     86         if r_type == 'SRV':
     87           p, w, port, target = r_data.split(' ')
     88           p = int(p)
     89           w = int(w)
     90           port = int(port)
     91           target_full_name = '%s.%s' % (target, common_zone_name)
     92           r_data = '%s %s %s %s' % (p, w, port, target_full_name)
     93           _push_record(record_full_name, dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl))
     94         if r_type == 'TXT':
     95           _maybe_split_up_txt_data(record_full_name, r_data, r_ttl)
     96   # Server health check record
     97   _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME, dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0))
     98   soa_record = dns.Record_SOA(mname = common_zone_name)
     99   test_domain_com = NoFileAuthority(
    100     soa = (common_zone_name, soa_record),
    101     records = all_records,
    102   )
    103   server = twisted.names.server.DNSServerFactory(
    104       authorities=[test_domain_com], verbose=2)
    105   server.noisy = 2
    106   twisted.internet.reactor.listenTCP(args.port, server)
    107   dns_proto = twisted.names.dns.DNSDatagramProtocol(server)
    108   dns_proto.noisy = 2
    109   twisted.internet.reactor.listenUDP(args.port, dns_proto)
    110   print('starting local dns server on 127.0.0.1:%s' % args.port)
    111   print('starting twisted.internet.reactor')
    112   twisted.internet.reactor.suggestThreadPoolSize(1)
    113   twisted.internet.reactor.run()
    114 
    115 def _quit_on_signal(signum, _frame):
    116   print('Received SIGNAL %d. Quitting with exit code 0' % signum)
    117   twisted.internet.reactor.stop()
    118   sys.stdout.flush()
    119   sys.exit(0)
    120 
    121 def flush_stdout_loop():
    122   num_timeouts_so_far = 0
    123   sleep_time = 1
    124   # Prevent zombies. Tests that use this server are short-lived.
    125   max_timeouts = 60 * 2
    126   while num_timeouts_so_far < max_timeouts:
    127     sys.stdout.flush()
    128     time.sleep(sleep_time)
    129     num_timeouts_so_far += 1
    130   print('Process timeout reached, or cancelled. Exitting 0.')
    131   os.kill(os.getpid(), signal.SIGTERM)
    132 
    133 def main():
    134   argp = argparse.ArgumentParser(description='Local DNS Server for resolver tests')
    135   argp.add_argument('-p', '--port', default=None, type=int,
    136                     help='Port for DNS server to listen on for TCP and UDP.')
    137   argp.add_argument('-r', '--records_config_path', default=None, type=str,
    138                     help=('Directory of resolver_test_record_groups.yaml file. '
    139                           'Defauls to path needed when the test is invoked as part of run_tests.py.'))
    140   args = argp.parse_args()
    141   signal.signal(signal.SIGTERM, _quit_on_signal)
    142   signal.signal(signal.SIGINT, _quit_on_signal)
    143   output_flush_thread = threading.Thread(target=flush_stdout_loop)
    144   output_flush_thread.setDaemon(True)
    145   output_flush_thread.start()
    146   start_local_dns_server(args)
    147 
    148 if __name__ == '__main__':
    149   main()
    150