1 #!/usr/bin/env python2.7 2 # Copyright 2015 gRPC authors. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 16 """Starts a local DNS server for use in tests""" 17 18 import argparse 19 import sys 20 import yaml 21 import signal 22 import os 23 import threading 24 import time 25 26 import twisted 27 import twisted.internet 28 import twisted.internet.reactor 29 import twisted.internet.threads 30 import twisted.internet.defer 31 import twisted.internet.protocol 32 import twisted.names 33 import twisted.names.client 34 import twisted.names.dns 35 import twisted.names.server 36 from twisted.names import client, server, common, authority, dns 37 import argparse 38 import platform 39 40 _SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp' # missing end '.' for twisted syntax 41 _SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123' 42 43 class NoFileAuthority(authority.FileAuthority): 44 def __init__(self, soa, records): 45 # skip FileAuthority 46 common.ResolverBase.__init__(self) 47 self.soa = soa 48 self.records = records 49 50 def start_local_dns_server(args): 51 all_records = {} 52 def _push_record(name, r): 53 print('pushing record: |%s|' % name) 54 if all_records.get(name) is not None: 55 all_records[name].append(r) 56 return 57 all_records[name] = [r] 58 59 def _maybe_split_up_txt_data(name, txt_data, r_ttl): 60 start = 0 61 txt_data_list = [] 62 while len(txt_data[start:]) > 0: 63 next_read = len(txt_data[start:]) 64 if next_read > 255: 65 next_read = 255 66 txt_data_list.append(txt_data[start:start+next_read]) 67 start += next_read 68 _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl)) 69 70 with open(args.records_config_path) as config: 71 test_records_config = yaml.load(config) 72 common_zone_name = test_records_config['resolver_tests_common_zone_name'] 73 for group in test_records_config['resolver_component_tests']: 74 for name in group['records'].keys(): 75 for record in group['records'][name]: 76 r_type = record['type'] 77 r_data = record['data'] 78 r_ttl = int(record['TTL']) 79 record_full_name = '%s.%s' % (name, common_zone_name) 80 assert record_full_name[-1] == '.' 81 record_full_name = record_full_name[:-1] 82 if r_type == 'A': 83 _push_record(record_full_name, dns.Record_A(r_data, ttl=r_ttl)) 84 if r_type == 'AAAA': 85 _push_record(record_full_name, dns.Record_AAAA(r_data, ttl=r_ttl)) 86 if r_type == 'SRV': 87 p, w, port, target = r_data.split(' ') 88 p = int(p) 89 w = int(w) 90 port = int(port) 91 target_full_name = '%s.%s' % (target, common_zone_name) 92 r_data = '%s %s %s %s' % (p, w, port, target_full_name) 93 _push_record(record_full_name, dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl)) 94 if r_type == 'TXT': 95 _maybe_split_up_txt_data(record_full_name, r_data, r_ttl) 96 # Server health check record 97 _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME, dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0)) 98 soa_record = dns.Record_SOA(mname = common_zone_name) 99 test_domain_com = NoFileAuthority( 100 soa = (common_zone_name, soa_record), 101 records = all_records, 102 ) 103 server = twisted.names.server.DNSServerFactory( 104 authorities=[test_domain_com], verbose=2) 105 server.noisy = 2 106 twisted.internet.reactor.listenTCP(args.port, server) 107 dns_proto = twisted.names.dns.DNSDatagramProtocol(server) 108 dns_proto.noisy = 2 109 twisted.internet.reactor.listenUDP(args.port, dns_proto) 110 print('starting local dns server on 127.0.0.1:%s' % args.port) 111 print('starting twisted.internet.reactor') 112 twisted.internet.reactor.suggestThreadPoolSize(1) 113 twisted.internet.reactor.run() 114 115 def _quit_on_signal(signum, _frame): 116 print('Received SIGNAL %d. Quitting with exit code 0' % signum) 117 twisted.internet.reactor.stop() 118 sys.stdout.flush() 119 sys.exit(0) 120 121 def flush_stdout_loop(): 122 num_timeouts_so_far = 0 123 sleep_time = 1 124 # Prevent zombies. Tests that use this server are short-lived. 125 max_timeouts = 60 * 2 126 while num_timeouts_so_far < max_timeouts: 127 sys.stdout.flush() 128 time.sleep(sleep_time) 129 num_timeouts_so_far += 1 130 print('Process timeout reached, or cancelled. Exitting 0.') 131 os.kill(os.getpid(), signal.SIGTERM) 132 133 def main(): 134 argp = argparse.ArgumentParser(description='Local DNS Server for resolver tests') 135 argp.add_argument('-p', '--port', default=None, type=int, 136 help='Port for DNS server to listen on for TCP and UDP.') 137 argp.add_argument('-r', '--records_config_path', default=None, type=str, 138 help=('Directory of resolver_test_record_groups.yaml file. ' 139 'Defauls to path needed when the test is invoked as part of run_tests.py.')) 140 args = argp.parse_args() 141 signal.signal(signal.SIGTERM, _quit_on_signal) 142 signal.signal(signal.SIGINT, _quit_on_signal) 143 output_flush_thread = threading.Thread(target=flush_stdout_loop) 144 output_flush_thread.setDaemon(True) 145 output_flush_thread.start() 146 start_local_dns_server(args) 147 148 if __name__ == '__main__': 149 main() 150