Home | History | Annotate | Download | only in dns
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/dns/dns_test_util.h"
      6 
      7 #include <string>
      8 
      9 #include "base/big_endian.h"
     10 #include "base/bind.h"
     11 #include "base/memory/weak_ptr.h"
     12 #include "base/message_loop/message_loop.h"
     13 #include "base/sys_byteorder.h"
     14 #include "net/base/dns_util.h"
     15 #include "net/base/io_buffer.h"
     16 #include "net/base/net_errors.h"
     17 #include "net/dns/address_sorter.h"
     18 #include "net/dns/dns_query.h"
     19 #include "net/dns/dns_response.h"
     20 #include "net/dns/dns_transaction.h"
     21 #include "testing/gtest/include/gtest/gtest.h"
     22 
     23 namespace net {
     24 namespace {
     25 
     26 class MockAddressSorter : public AddressSorter {
     27  public:
     28   virtual ~MockAddressSorter() {}
     29   virtual void Sort(const AddressList& list,
     30                     const CallbackType& callback) const OVERRIDE {
     31     // Do nothing.
     32     callback.Run(true, list);
     33   }
     34 };
     35 
     36 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
     37 class MockTransaction : public DnsTransaction,
     38                         public base::SupportsWeakPtr<MockTransaction> {
     39  public:
     40   MockTransaction(const MockDnsClientRuleList& rules,
     41                   const std::string& hostname,
     42                   uint16 qtype,
     43                   const DnsTransactionFactory::CallbackType& callback)
     44       : result_(MockDnsClientRule::FAIL),
     45         hostname_(hostname),
     46         qtype_(qtype),
     47         callback_(callback),
     48         started_(false),
     49         delayed_(false) {
     50     // Find the relevant rule which matches |qtype| and prefix of |hostname|.
     51     for (size_t i = 0; i < rules.size(); ++i) {
     52       const std::string& prefix = rules[i].prefix;
     53       if ((rules[i].qtype == qtype) &&
     54           (hostname.size() >= prefix.size()) &&
     55           (hostname.compare(0, prefix.size(), prefix) == 0)) {
     56         result_ = rules[i].result;
     57         delayed_ = rules[i].delay;
     58         break;
     59       }
     60     }
     61   }
     62 
     63   virtual const std::string& GetHostname() const OVERRIDE {
     64     return hostname_;
     65   }
     66 
     67   virtual uint16 GetType() const OVERRIDE {
     68     return qtype_;
     69   }
     70 
     71   virtual void Start() OVERRIDE {
     72     EXPECT_FALSE(started_);
     73     started_ = true;
     74     if (delayed_)
     75       return;
     76     // Using WeakPtr to cleanly cancel when transaction is destroyed.
     77     base::MessageLoop::current()->PostTask(
     78         FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr()));
     79   }
     80 
     81   void FinishDelayedTransaction() {
     82     EXPECT_TRUE(delayed_);
     83     delayed_ = false;
     84     Finish();
     85   }
     86 
     87   bool delayed() const { return delayed_; }
     88 
     89  private:
     90   void Finish() {
     91     switch (result_) {
     92       case MockDnsClientRule::EMPTY:
     93       case MockDnsClientRule::OK: {
     94         std::string qname;
     95         DNSDomainFromDot(hostname_, &qname);
     96         DnsQuery query(0, qname, qtype_);
     97 
     98         DnsResponse response;
     99         char* buffer = response.io_buffer()->data();
    100         int nbytes = query.io_buffer()->size();
    101         memcpy(buffer, query.io_buffer()->data(), nbytes);
    102         dns_protocol::Header* header =
    103             reinterpret_cast<dns_protocol::Header*>(buffer);
    104         header->flags |= dns_protocol::kFlagResponse;
    105 
    106         if (MockDnsClientRule::OK == result_) {
    107           const uint16 kPointerToQueryName =
    108               static_cast<uint16>(0xc000 | sizeof(*header));
    109 
    110           const uint32 kTTL = 86400;  // One day.
    111 
    112           // Size of RDATA which is a IPv4 or IPv6 address.
    113           size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ?
    114                               net::kIPv4AddressSize : net::kIPv6AddressSize;
    115 
    116           // 12 is the sum of sizes of the compressed name reference, TYPE,
    117           // CLASS, TTL and RDLENGTH.
    118           size_t answer_size = 12 + rdata_size;
    119 
    120           // Write answer with loopback IP address.
    121           header->ancount = base::HostToNet16(1);
    122           base::BigEndianWriter writer(buffer + nbytes, answer_size);
    123           writer.WriteU16(kPointerToQueryName);
    124           writer.WriteU16(qtype_);
    125           writer.WriteU16(net::dns_protocol::kClassIN);
    126           writer.WriteU32(kTTL);
    127           writer.WriteU16(rdata_size);
    128           if (qtype_ == net::dns_protocol::kTypeA) {
    129             char kIPv4Loopback[] = { 0x7f, 0, 0, 1 };
    130             writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback));
    131           } else {
    132             char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0,
    133                                      0, 0, 0, 0, 0, 0, 0, 1 };
    134             writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback));
    135           }
    136           nbytes += answer_size;
    137         }
    138         EXPECT_TRUE(response.InitParse(nbytes, query));
    139         callback_.Run(this, OK, &response);
    140       } break;
    141       case MockDnsClientRule::FAIL:
    142         callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL);
    143         break;
    144       case MockDnsClientRule::TIMEOUT:
    145         callback_.Run(this, ERR_DNS_TIMED_OUT, NULL);
    146         break;
    147       default:
    148         NOTREACHED();
    149         break;
    150     }
    151   }
    152 
    153   MockDnsClientRule::Result result_;
    154   const std::string hostname_;
    155   const uint16 qtype_;
    156   DnsTransactionFactory::CallbackType callback_;
    157   bool started_;
    158   bool delayed_;
    159 };
    160 
    161 }  // namespace
    162 
    163 // A DnsTransactionFactory which creates MockTransaction.
    164 class MockTransactionFactory : public DnsTransactionFactory {
    165  public:
    166   explicit MockTransactionFactory(const MockDnsClientRuleList& rules)
    167       : rules_(rules) {}
    168 
    169   virtual ~MockTransactionFactory() {}
    170 
    171   virtual scoped_ptr<DnsTransaction> CreateTransaction(
    172       const std::string& hostname,
    173       uint16 qtype,
    174       const DnsTransactionFactory::CallbackType& callback,
    175       const BoundNetLog&) OVERRIDE {
    176     MockTransaction* transaction =
    177         new MockTransaction(rules_, hostname, qtype, callback);
    178     if (transaction->delayed())
    179       delayed_transactions_.push_back(transaction->AsWeakPtr());
    180     return scoped_ptr<DnsTransaction>(transaction);
    181   }
    182 
    183   void CompleteDelayedTransactions() {
    184     DelayedTransactionList old_delayed_transactions;
    185     old_delayed_transactions.swap(delayed_transactions_);
    186     for (DelayedTransactionList::iterator it = old_delayed_transactions.begin();
    187          it != old_delayed_transactions.end(); ++it) {
    188       if (it->get())
    189         (*it)->FinishDelayedTransaction();
    190     }
    191   }
    192 
    193  private:
    194   typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList;
    195 
    196   MockDnsClientRuleList rules_;
    197   DelayedTransactionList delayed_transactions_;
    198 };
    199 
    200 MockDnsClient::MockDnsClient(const DnsConfig& config,
    201                              const MockDnsClientRuleList& rules)
    202       : config_(config),
    203         factory_(new MockTransactionFactory(rules)),
    204         address_sorter_(new MockAddressSorter()) {
    205 }
    206 
    207 MockDnsClient::~MockDnsClient() {}
    208 
    209 void MockDnsClient::SetConfig(const DnsConfig& config) {
    210   config_ = config;
    211 }
    212 
    213 const DnsConfig* MockDnsClient::GetConfig() const {
    214   return config_.IsValid() ? &config_ : NULL;
    215 }
    216 
    217 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
    218   return config_.IsValid() ? factory_.get() : NULL;
    219 }
    220 
    221 AddressSorter* MockDnsClient::GetAddressSorter() {
    222   return address_sorter_.get();
    223 }
    224 
    225 void MockDnsClient::CompleteDelayedTransactions() {
    226   factory_->CompleteDelayedTransactions();
    227 }
    228 
    229 }  // namespace net
    230