1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/dns/dns_test_util.h" 6 7 #include <string> 8 9 #include "base/bind.h" 10 #include "base/memory/weak_ptr.h" 11 #include "base/message_loop/message_loop.h" 12 #include "base/sys_byteorder.h" 13 #include "net/base/big_endian.h" 14 #include "net/base/dns_util.h" 15 #include "net/base/io_buffer.h" 16 #include "net/base/net_errors.h" 17 #include "net/dns/address_sorter.h" 18 #include "net/dns/dns_client.h" 19 #include "net/dns/dns_config_service.h" 20 #include "net/dns/dns_protocol.h" 21 #include "net/dns/dns_query.h" 22 #include "net/dns/dns_response.h" 23 #include "net/dns/dns_transaction.h" 24 #include "testing/gtest/include/gtest/gtest.h" 25 26 namespace net { 27 namespace { 28 29 // A DnsTransaction which uses MockDnsClientRuleList to determine the response. 30 class MockTransaction : public DnsTransaction, 31 public base::SupportsWeakPtr<MockTransaction> { 32 public: 33 MockTransaction(const MockDnsClientRuleList& rules, 34 const std::string& hostname, 35 uint16 qtype, 36 const DnsTransactionFactory::CallbackType& callback) 37 : result_(MockDnsClientRule::FAIL), 38 hostname_(hostname), 39 qtype_(qtype), 40 callback_(callback), 41 started_(false) { 42 // Find the relevant rule which matches |qtype| and prefix of |hostname|. 43 for (size_t i = 0; i < rules.size(); ++i) { 44 const std::string& prefix = rules[i].prefix; 45 if ((rules[i].qtype == qtype) && 46 (hostname.size() >= prefix.size()) && 47 (hostname.compare(0, prefix.size(), prefix) == 0)) { 48 result_ = rules[i].result; 49 break; 50 } 51 } 52 } 53 54 virtual const std::string& GetHostname() const OVERRIDE { 55 return hostname_; 56 } 57 58 virtual uint16 GetType() const OVERRIDE { 59 return qtype_; 60 } 61 62 virtual void Start() OVERRIDE { 63 EXPECT_FALSE(started_); 64 started_ = true; 65 // Using WeakPtr to cleanly cancel when transaction is destroyed. 66 base::MessageLoop::current()->PostTask( 67 FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr())); 68 } 69 70 private: 71 void Finish() { 72 switch (result_) { 73 case MockDnsClientRule::EMPTY: 74 case MockDnsClientRule::OK: { 75 std::string qname; 76 DNSDomainFromDot(hostname_, &qname); 77 DnsQuery query(0, qname, qtype_); 78 79 DnsResponse response; 80 char* buffer = response.io_buffer()->data(); 81 int nbytes = query.io_buffer()->size(); 82 memcpy(buffer, query.io_buffer()->data(), nbytes); 83 dns_protocol::Header* header = 84 reinterpret_cast<dns_protocol::Header*>(buffer); 85 header->flags |= dns_protocol::kFlagResponse; 86 87 if (MockDnsClientRule::OK == result_) { 88 const uint16 kPointerToQueryName = 89 static_cast<uint16>(0xc000 | sizeof(*header)); 90 91 const uint32 kTTL = 86400; // One day. 92 93 // Size of RDATA which is a IPv4 or IPv6 address. 94 size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? 95 net::kIPv4AddressSize : net::kIPv6AddressSize; 96 97 // 12 is the sum of sizes of the compressed name reference, TYPE, 98 // CLASS, TTL and RDLENGTH. 99 size_t answer_size = 12 + rdata_size; 100 101 // Write answer with loopback IP address. 102 header->ancount = base::HostToNet16(1); 103 BigEndianWriter writer(buffer + nbytes, answer_size); 104 writer.WriteU16(kPointerToQueryName); 105 writer.WriteU16(qtype_); 106 writer.WriteU16(net::dns_protocol::kClassIN); 107 writer.WriteU32(kTTL); 108 writer.WriteU16(rdata_size); 109 if (qtype_ == net::dns_protocol::kTypeA) { 110 char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; 111 writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); 112 } else { 113 char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, 114 0, 0, 0, 0, 0, 0, 0, 1 }; 115 writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); 116 } 117 nbytes += answer_size; 118 } 119 EXPECT_TRUE(response.InitParse(nbytes, query)); 120 callback_.Run(this, OK, &response); 121 } break; 122 case MockDnsClientRule::FAIL: 123 callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); 124 break; 125 case MockDnsClientRule::TIMEOUT: 126 callback_.Run(this, ERR_DNS_TIMED_OUT, NULL); 127 break; 128 default: 129 NOTREACHED(); 130 break; 131 } 132 } 133 134 MockDnsClientRule::Result result_; 135 const std::string hostname_; 136 const uint16 qtype_; 137 DnsTransactionFactory::CallbackType callback_; 138 bool started_; 139 }; 140 141 142 // A DnsTransactionFactory which creates MockTransaction. 143 class MockTransactionFactory : public DnsTransactionFactory { 144 public: 145 explicit MockTransactionFactory(const MockDnsClientRuleList& rules) 146 : rules_(rules) {} 147 virtual ~MockTransactionFactory() {} 148 149 virtual scoped_ptr<DnsTransaction> CreateTransaction( 150 const std::string& hostname, 151 uint16 qtype, 152 const DnsTransactionFactory::CallbackType& callback, 153 const BoundNetLog&) OVERRIDE { 154 return scoped_ptr<DnsTransaction>( 155 new MockTransaction(rules_, hostname, qtype, callback)); 156 } 157 158 private: 159 MockDnsClientRuleList rules_; 160 }; 161 162 class MockAddressSorter : public AddressSorter { 163 public: 164 virtual ~MockAddressSorter() {} 165 virtual void Sort(const AddressList& list, 166 const CallbackType& callback) const OVERRIDE { 167 // Do nothing. 168 callback.Run(true, list); 169 } 170 }; 171 172 // MockDnsClient provides MockTransactionFactory. 173 class MockDnsClient : public DnsClient { 174 public: 175 MockDnsClient(const DnsConfig& config, 176 const MockDnsClientRuleList& rules) 177 : config_(config), factory_(rules) {} 178 virtual ~MockDnsClient() {} 179 180 virtual void SetConfig(const DnsConfig& config) OVERRIDE { 181 config_ = config; 182 } 183 184 virtual const DnsConfig* GetConfig() const OVERRIDE { 185 return config_.IsValid() ? &config_ : NULL; 186 } 187 188 virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE { 189 return config_.IsValid() ? &factory_ : NULL; 190 } 191 192 virtual AddressSorter* GetAddressSorter() OVERRIDE { 193 return &address_sorter_; 194 } 195 196 private: 197 DnsConfig config_; 198 MockTransactionFactory factory_; 199 MockAddressSorter address_sorter_; 200 }; 201 202 } // namespace 203 204 // static 205 scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, 206 const MockDnsClientRuleList& rules) { 207 return scoped_ptr<DnsClient>(new MockDnsClient(config, rules)); 208 } 209 210 } // namespace net 211