Home | History | Annotate | Download | only in dns
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/dns/dns_test_util.h"
      6 
      7 #include <string>
      8 
      9 #include "base/bind.h"
     10 #include "base/memory/weak_ptr.h"
     11 #include "base/message_loop/message_loop.h"
     12 #include "base/sys_byteorder.h"
     13 #include "net/base/big_endian.h"
     14 #include "net/base/dns_util.h"
     15 #include "net/base/io_buffer.h"
     16 #include "net/base/net_errors.h"
     17 #include "net/dns/address_sorter.h"
     18 #include "net/dns/dns_client.h"
     19 #include "net/dns/dns_config_service.h"
     20 #include "net/dns/dns_protocol.h"
     21 #include "net/dns/dns_query.h"
     22 #include "net/dns/dns_response.h"
     23 #include "net/dns/dns_transaction.h"
     24 #include "testing/gtest/include/gtest/gtest.h"
     25 
     26 namespace net {
     27 namespace {
     28 
     29 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
     30 class MockTransaction : public DnsTransaction,
     31                         public base::SupportsWeakPtr<MockTransaction> {
     32  public:
     33   MockTransaction(const MockDnsClientRuleList& rules,
     34                   const std::string& hostname,
     35                   uint16 qtype,
     36                   const DnsTransactionFactory::CallbackType& callback)
     37       : result_(MockDnsClientRule::FAIL),
     38         hostname_(hostname),
     39         qtype_(qtype),
     40         callback_(callback),
     41         started_(false) {
     42     // Find the relevant rule which matches |qtype| and prefix of |hostname|.
     43     for (size_t i = 0; i < rules.size(); ++i) {
     44       const std::string& prefix = rules[i].prefix;
     45       if ((rules[i].qtype == qtype) &&
     46           (hostname.size() >= prefix.size()) &&
     47           (hostname.compare(0, prefix.size(), prefix) == 0)) {
     48         result_ = rules[i].result;
     49         break;
     50       }
     51     }
     52   }
     53 
     54   virtual const std::string& GetHostname() const OVERRIDE {
     55     return hostname_;
     56   }
     57 
     58   virtual uint16 GetType() const OVERRIDE {
     59     return qtype_;
     60   }
     61 
     62   virtual void Start() OVERRIDE {
     63     EXPECT_FALSE(started_);
     64     started_ = true;
     65     // Using WeakPtr to cleanly cancel when transaction is destroyed.
     66     base::MessageLoop::current()->PostTask(
     67         FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr()));
     68   }
     69 
     70  private:
     71   void Finish() {
     72     switch (result_) {
     73       case MockDnsClientRule::EMPTY:
     74       case MockDnsClientRule::OK: {
     75         std::string qname;
     76         DNSDomainFromDot(hostname_, &qname);
     77         DnsQuery query(0, qname, qtype_);
     78 
     79         DnsResponse response;
     80         char* buffer = response.io_buffer()->data();
     81         int nbytes = query.io_buffer()->size();
     82         memcpy(buffer, query.io_buffer()->data(), nbytes);
     83         dns_protocol::Header* header =
     84             reinterpret_cast<dns_protocol::Header*>(buffer);
     85         header->flags |= dns_protocol::kFlagResponse;
     86 
     87         if (MockDnsClientRule::OK == result_) {
     88           const uint16 kPointerToQueryName =
     89               static_cast<uint16>(0xc000 | sizeof(*header));
     90 
     91           const uint32 kTTL = 86400;  // One day.
     92 
     93           // Size of RDATA which is a IPv4 or IPv6 address.
     94           size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ?
     95                               net::kIPv4AddressSize : net::kIPv6AddressSize;
     96 
     97           // 12 is the sum of sizes of the compressed name reference, TYPE,
     98           // CLASS, TTL and RDLENGTH.
     99           size_t answer_size = 12 + rdata_size;
    100 
    101           // Write answer with loopback IP address.
    102           header->ancount = base::HostToNet16(1);
    103           BigEndianWriter writer(buffer + nbytes, answer_size);
    104           writer.WriteU16(kPointerToQueryName);
    105           writer.WriteU16(qtype_);
    106           writer.WriteU16(net::dns_protocol::kClassIN);
    107           writer.WriteU32(kTTL);
    108           writer.WriteU16(rdata_size);
    109           if (qtype_ == net::dns_protocol::kTypeA) {
    110             char kIPv4Loopback[] = { 0x7f, 0, 0, 1 };
    111             writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback));
    112           } else {
    113             char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0,
    114                                      0, 0, 0, 0, 0, 0, 0, 1 };
    115             writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback));
    116           }
    117           nbytes += answer_size;
    118         }
    119         EXPECT_TRUE(response.InitParse(nbytes, query));
    120         callback_.Run(this, OK, &response);
    121       } break;
    122       case MockDnsClientRule::FAIL:
    123         callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL);
    124         break;
    125       case MockDnsClientRule::TIMEOUT:
    126         callback_.Run(this, ERR_DNS_TIMED_OUT, NULL);
    127         break;
    128       default:
    129         NOTREACHED();
    130         break;
    131     }
    132   }
    133 
    134   MockDnsClientRule::Result result_;
    135   const std::string hostname_;
    136   const uint16 qtype_;
    137   DnsTransactionFactory::CallbackType callback_;
    138   bool started_;
    139 };
    140 
    141 
    142 // A DnsTransactionFactory which creates MockTransaction.
    143 class MockTransactionFactory : public DnsTransactionFactory {
    144  public:
    145   explicit MockTransactionFactory(const MockDnsClientRuleList& rules)
    146       : rules_(rules) {}
    147   virtual ~MockTransactionFactory() {}
    148 
    149   virtual scoped_ptr<DnsTransaction> CreateTransaction(
    150       const std::string& hostname,
    151       uint16 qtype,
    152       const DnsTransactionFactory::CallbackType& callback,
    153       const BoundNetLog&) OVERRIDE {
    154     return scoped_ptr<DnsTransaction>(
    155         new MockTransaction(rules_, hostname, qtype, callback));
    156   }
    157 
    158  private:
    159   MockDnsClientRuleList rules_;
    160 };
    161 
    162 class MockAddressSorter : public AddressSorter {
    163  public:
    164   virtual ~MockAddressSorter() {}
    165   virtual void Sort(const AddressList& list,
    166                     const CallbackType& callback) const OVERRIDE {
    167     // Do nothing.
    168     callback.Run(true, list);
    169   }
    170 };
    171 
    172 // MockDnsClient provides MockTransactionFactory.
    173 class MockDnsClient : public DnsClient {
    174  public:
    175   MockDnsClient(const DnsConfig& config,
    176                 const MockDnsClientRuleList& rules)
    177       : config_(config), factory_(rules) {}
    178   virtual ~MockDnsClient() {}
    179 
    180   virtual void SetConfig(const DnsConfig& config) OVERRIDE {
    181     config_ = config;
    182   }
    183 
    184   virtual const DnsConfig* GetConfig() const OVERRIDE {
    185     return config_.IsValid() ? &config_ : NULL;
    186   }
    187 
    188   virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE {
    189     return config_.IsValid() ? &factory_ : NULL;
    190   }
    191 
    192   virtual AddressSorter* GetAddressSorter() OVERRIDE {
    193     return &address_sorter_;
    194   }
    195 
    196  private:
    197   DnsConfig config_;
    198   MockTransactionFactory factory_;
    199   MockAddressSorter address_sorter_;
    200 };
    201 
    202 }  // namespace
    203 
    204 // static
    205 scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config,
    206                                           const MockDnsClientRuleList& rules) {
    207   return scoped_ptr<DnsClient>(new MockDnsClient(config, rules));
    208 }
    209 
    210 }  // namespace net
    211