1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/dns/dns_test_util.h" 6 7 #include <string> 8 9 #include "base/bind.h" 10 #include "base/memory/weak_ptr.h" 11 #include "base/message_loop/message_loop.h" 12 #include "base/sys_byteorder.h" 13 #include "net/base/big_endian.h" 14 #include "net/base/dns_util.h" 15 #include "net/base/io_buffer.h" 16 #include "net/base/net_errors.h" 17 #include "net/dns/address_sorter.h" 18 #include "net/dns/dns_query.h" 19 #include "net/dns/dns_response.h" 20 #include "net/dns/dns_transaction.h" 21 #include "testing/gtest/include/gtest/gtest.h" 22 23 namespace net { 24 namespace { 25 26 class MockAddressSorter : public AddressSorter { 27 public: 28 virtual ~MockAddressSorter() {} 29 virtual void Sort(const AddressList& list, 30 const CallbackType& callback) const OVERRIDE { 31 // Do nothing. 32 callback.Run(true, list); 33 } 34 }; 35 36 // A DnsTransaction which uses MockDnsClientRuleList to determine the response. 37 class MockTransaction : public DnsTransaction, 38 public base::SupportsWeakPtr<MockTransaction> { 39 public: 40 MockTransaction(const MockDnsClientRuleList& rules, 41 const std::string& hostname, 42 uint16 qtype, 43 const DnsTransactionFactory::CallbackType& callback) 44 : result_(MockDnsClientRule::FAIL), 45 hostname_(hostname), 46 qtype_(qtype), 47 callback_(callback), 48 started_(false), 49 delayed_(false) { 50 // Find the relevant rule which matches |qtype| and prefix of |hostname|. 51 for (size_t i = 0; i < rules.size(); ++i) { 52 const std::string& prefix = rules[i].prefix; 53 if ((rules[i].qtype == qtype) && 54 (hostname.size() >= prefix.size()) && 55 (hostname.compare(0, prefix.size(), prefix) == 0)) { 56 result_ = rules[i].result; 57 delayed_ = rules[i].delay; 58 break; 59 } 60 } 61 } 62 63 virtual const std::string& GetHostname() const OVERRIDE { 64 return hostname_; 65 } 66 67 virtual uint16 GetType() const OVERRIDE { 68 return qtype_; 69 } 70 71 virtual void Start() OVERRIDE { 72 EXPECT_FALSE(started_); 73 started_ = true; 74 if (delayed_) 75 return; 76 // Using WeakPtr to cleanly cancel when transaction is destroyed. 77 base::MessageLoop::current()->PostTask( 78 FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr())); 79 } 80 81 void FinishDelayedTransaction() { 82 EXPECT_TRUE(delayed_); 83 delayed_ = false; 84 Finish(); 85 } 86 87 bool delayed() const { return delayed_; } 88 89 private: 90 void Finish() { 91 switch (result_) { 92 case MockDnsClientRule::EMPTY: 93 case MockDnsClientRule::OK: { 94 std::string qname; 95 DNSDomainFromDot(hostname_, &qname); 96 DnsQuery query(0, qname, qtype_); 97 98 DnsResponse response; 99 char* buffer = response.io_buffer()->data(); 100 int nbytes = query.io_buffer()->size(); 101 memcpy(buffer, query.io_buffer()->data(), nbytes); 102 dns_protocol::Header* header = 103 reinterpret_cast<dns_protocol::Header*>(buffer); 104 header->flags |= dns_protocol::kFlagResponse; 105 106 if (MockDnsClientRule::OK == result_) { 107 const uint16 kPointerToQueryName = 108 static_cast<uint16>(0xc000 | sizeof(*header)); 109 110 const uint32 kTTL = 86400; // One day. 111 112 // Size of RDATA which is a IPv4 or IPv6 address. 113 size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? 114 net::kIPv4AddressSize : net::kIPv6AddressSize; 115 116 // 12 is the sum of sizes of the compressed name reference, TYPE, 117 // CLASS, TTL and RDLENGTH. 118 size_t answer_size = 12 + rdata_size; 119 120 // Write answer with loopback IP address. 121 header->ancount = base::HostToNet16(1); 122 BigEndianWriter writer(buffer + nbytes, answer_size); 123 writer.WriteU16(kPointerToQueryName); 124 writer.WriteU16(qtype_); 125 writer.WriteU16(net::dns_protocol::kClassIN); 126 writer.WriteU32(kTTL); 127 writer.WriteU16(rdata_size); 128 if (qtype_ == net::dns_protocol::kTypeA) { 129 char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; 130 writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); 131 } else { 132 char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, 133 0, 0, 0, 0, 0, 0, 0, 1 }; 134 writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); 135 } 136 nbytes += answer_size; 137 } 138 EXPECT_TRUE(response.InitParse(nbytes, query)); 139 callback_.Run(this, OK, &response); 140 } break; 141 case MockDnsClientRule::FAIL: 142 callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); 143 break; 144 case MockDnsClientRule::TIMEOUT: 145 callback_.Run(this, ERR_DNS_TIMED_OUT, NULL); 146 break; 147 default: 148 NOTREACHED(); 149 break; 150 } 151 } 152 153 MockDnsClientRule::Result result_; 154 const std::string hostname_; 155 const uint16 qtype_; 156 DnsTransactionFactory::CallbackType callback_; 157 bool started_; 158 bool delayed_; 159 }; 160 161 } // namespace 162 163 // A DnsTransactionFactory which creates MockTransaction. 164 class MockTransactionFactory : public DnsTransactionFactory { 165 public: 166 explicit MockTransactionFactory(const MockDnsClientRuleList& rules) 167 : rules_(rules) {} 168 169 virtual ~MockTransactionFactory() {} 170 171 virtual scoped_ptr<DnsTransaction> CreateTransaction( 172 const std::string& hostname, 173 uint16 qtype, 174 const DnsTransactionFactory::CallbackType& callback, 175 const BoundNetLog&) OVERRIDE { 176 MockTransaction* transaction = 177 new MockTransaction(rules_, hostname, qtype, callback); 178 if (transaction->delayed()) 179 delayed_transactions_.push_back(transaction->AsWeakPtr()); 180 return scoped_ptr<DnsTransaction>(transaction); 181 } 182 183 void CompleteDelayedTransactions() { 184 DelayedTransactionList old_delayed_transactions; 185 old_delayed_transactions.swap(delayed_transactions_); 186 for (DelayedTransactionList::iterator it = old_delayed_transactions.begin(); 187 it != old_delayed_transactions.end(); ++it) { 188 if (it->get()) 189 (*it)->FinishDelayedTransaction(); 190 } 191 } 192 193 private: 194 typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; 195 196 MockDnsClientRuleList rules_; 197 DelayedTransactionList delayed_transactions_; 198 }; 199 200 MockDnsClient::MockDnsClient(const DnsConfig& config, 201 const MockDnsClientRuleList& rules) 202 : config_(config), 203 factory_(new MockTransactionFactory(rules)), 204 address_sorter_(new MockAddressSorter()) { 205 } 206 207 MockDnsClient::~MockDnsClient() {} 208 209 void MockDnsClient::SetConfig(const DnsConfig& config) { 210 config_ = config; 211 } 212 213 const DnsConfig* MockDnsClient::GetConfig() const { 214 return config_.IsValid() ? &config_ : NULL; 215 } 216 217 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() { 218 return config_.IsValid() ? factory_.get() : NULL; 219 } 220 221 AddressSorter* MockDnsClient::GetAddressSorter() { 222 return address_sorter_.get(); 223 } 224 225 void MockDnsClient::CompleteDelayedTransactions() { 226 factory_->CompleteDelayedTransactions(); 227 } 228 229 } // namespace net 230