Home | History | Annotate | Download | only in dns
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/dns/mock_host_resolver.h"
      6 
      7 #include <string>
      8 #include <vector>
      9 
     10 #include "base/bind.h"
     11 #include "base/memory/ref_counted.h"
     12 #include "base/message_loop/message_loop.h"
     13 #include "base/stl_util.h"
     14 #include "base/strings/string_split.h"
     15 #include "base/strings/string_util.h"
     16 #include "base/threading/platform_thread.h"
     17 #include "net/base/net_errors.h"
     18 #include "net/base/net_util.h"
     19 #include "net/base/test_completion_callback.h"
     20 #include "net/dns/host_cache.h"
     21 
     22 #if defined(OS_WIN)
     23 #include "net/base/winsock_init.h"
     24 #endif
     25 
     26 namespace net {
     27 
     28 namespace {
     29 
     30 // Cache size for the MockCachingHostResolver.
     31 const unsigned kMaxCacheEntries = 100;
     32 // TTL for the successful resolutions. Failures are not cached.
     33 const unsigned kCacheEntryTTLSeconds = 60;
     34 
     35 }  // namespace
     36 
     37 int ParseAddressList(const std::string& host_list,
     38                      const std::string& canonical_name,
     39                      AddressList* addrlist) {
     40   *addrlist = AddressList();
     41   std::vector<std::string> addresses;
     42   base::SplitString(host_list, ',', &addresses);
     43   addrlist->set_canonical_name(canonical_name);
     44   for (size_t index = 0; index < addresses.size(); ++index) {
     45     IPAddressNumber ip_number;
     46     if (!ParseIPLiteralToNumber(addresses[index], &ip_number)) {
     47       LOG(WARNING) << "Not a supported IP literal: " << addresses[index];
     48       return ERR_UNEXPECTED;
     49     }
     50     addrlist->push_back(IPEndPoint(ip_number, -1));
     51   }
     52   return OK;
     53 }
     54 
     55 struct MockHostResolverBase::Request {
     56   Request(const RequestInfo& req_info,
     57           AddressList* addr,
     58           const CompletionCallback& cb)
     59       : info(req_info), addresses(addr), callback(cb) {}
     60   RequestInfo info;
     61   AddressList* addresses;
     62   CompletionCallback callback;
     63 };
     64 
     65 MockHostResolverBase::~MockHostResolverBase() {
     66   STLDeleteValues(&requests_);
     67 }
     68 
     69 int MockHostResolverBase::Resolve(const RequestInfo& info,
     70                                   AddressList* addresses,
     71                                   const CompletionCallback& callback,
     72                                   RequestHandle* handle,
     73                                   const BoundNetLog& net_log) {
     74   DCHECK(CalledOnValidThread());
     75   num_resolve_++;
     76   size_t id = next_request_id_++;
     77   int rv = ResolveFromIPLiteralOrCache(info, addresses);
     78   if (rv != ERR_DNS_CACHE_MISS) {
     79     return rv;
     80   }
     81   if (synchronous_mode_) {
     82     return ResolveProc(id, info, addresses);
     83   }
     84   // Store the request for asynchronous resolution
     85   Request* req = new Request(info, addresses, callback);
     86   requests_[id] = req;
     87   if (handle)
     88     *handle = reinterpret_cast<RequestHandle>(id);
     89 
     90   if (!ondemand_mode_) {
     91     base::MessageLoop::current()->PostTask(
     92         FROM_HERE,
     93         base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
     94   }
     95 
     96   return ERR_IO_PENDING;
     97 }
     98 
     99 int MockHostResolverBase::ResolveFromCache(const RequestInfo& info,
    100                                            AddressList* addresses,
    101                                            const BoundNetLog& net_log) {
    102   num_resolve_from_cache_++;
    103   DCHECK(CalledOnValidThread());
    104   next_request_id_++;
    105   int rv = ResolveFromIPLiteralOrCache(info, addresses);
    106   return rv;
    107 }
    108 
    109 void MockHostResolverBase::CancelRequest(RequestHandle handle) {
    110   DCHECK(CalledOnValidThread());
    111   size_t id = reinterpret_cast<size_t>(handle);
    112   RequestMap::iterator it = requests_.find(id);
    113   if (it != requests_.end()) {
    114     scoped_ptr<Request> req(it->second);
    115     requests_.erase(it);
    116   } else {
    117     NOTREACHED() << "CancelRequest must NOT be called after request is "
    118         "complete or canceled.";
    119   }
    120 }
    121 
    122 HostCache* MockHostResolverBase::GetHostCache() {
    123   return cache_.get();
    124 }
    125 
    126 void MockHostResolverBase::ResolveAllPending() {
    127   DCHECK(CalledOnValidThread());
    128   DCHECK(ondemand_mode_);
    129   for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) {
    130     base::MessageLoop::current()->PostTask(
    131         FROM_HERE,
    132         base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first));
    133   }
    134 }
    135 
    136 // start id from 1 to distinguish from NULL RequestHandle
    137 MockHostResolverBase::MockHostResolverBase(bool use_caching)
    138     : synchronous_mode_(false),
    139       ondemand_mode_(false),
    140       next_request_id_(1),
    141       num_resolve_(0),
    142       num_resolve_from_cache_(0) {
    143   rules_ = CreateCatchAllHostResolverProc();
    144 
    145   if (use_caching) {
    146     cache_.reset(new HostCache(kMaxCacheEntries));
    147   }
    148 }
    149 
    150 int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info,
    151                                                       AddressList* addresses) {
    152   IPAddressNumber ip;
    153   if (ParseIPLiteralToNumber(info.hostname(), &ip)) {
    154     *addresses = AddressList::CreateFromIPAddress(ip, info.port());
    155     if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME)
    156       addresses->SetDefaultCanonicalName();
    157     return OK;
    158   }
    159   int rv = ERR_DNS_CACHE_MISS;
    160   if (cache_.get() && info.allow_cached_response()) {
    161     HostCache::Key key(info.hostname(),
    162                        info.address_family(),
    163                        info.host_resolver_flags());
    164     const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now());
    165     if (entry) {
    166       rv = entry->error;
    167       if (rv == OK)
    168         *addresses = AddressList::CopyWithPort(entry->addrlist, info.port());
    169     }
    170   }
    171   return rv;
    172 }
    173 
    174 int MockHostResolverBase::ResolveProc(size_t id,
    175                                       const RequestInfo& info,
    176                                       AddressList* addresses) {
    177   AddressList addr;
    178   int rv = rules_->Resolve(info.hostname(),
    179                            info.address_family(),
    180                            info.host_resolver_flags(),
    181                            &addr,
    182                            NULL);
    183   if (cache_.get()) {
    184     HostCache::Key key(info.hostname(),
    185                        info.address_family(),
    186                        info.host_resolver_flags());
    187     // Storing a failure with TTL 0 so that it overwrites previous value.
    188     base::TimeDelta ttl;
    189     if (rv == OK)
    190       ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds);
    191     cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl);
    192   }
    193   if (rv == OK)
    194     *addresses = AddressList::CopyWithPort(addr, info.port());
    195   return rv;
    196 }
    197 
    198 void MockHostResolverBase::ResolveNow(size_t id) {
    199   RequestMap::iterator it = requests_.find(id);
    200   if (it == requests_.end())
    201     return;  // was canceled
    202 
    203   scoped_ptr<Request> req(it->second);
    204   requests_.erase(it);
    205   int rv = ResolveProc(id, req->info, req->addresses);
    206   if (!req->callback.is_null())
    207     req->callback.Run(rv);
    208 }
    209 
    210 //-----------------------------------------------------------------------------
    211 
    212 struct RuleBasedHostResolverProc::Rule {
    213   enum ResolverType {
    214     kResolverTypeFail,
    215     kResolverTypeSystem,
    216     kResolverTypeIPLiteral,
    217   };
    218 
    219   ResolverType resolver_type;
    220   std::string host_pattern;
    221   AddressFamily address_family;
    222   HostResolverFlags host_resolver_flags;
    223   std::string replacement;
    224   std::string canonical_name;
    225   int latency_ms;  // In milliseconds.
    226 
    227   Rule(ResolverType resolver_type,
    228        const std::string& host_pattern,
    229        AddressFamily address_family,
    230        HostResolverFlags host_resolver_flags,
    231        const std::string& replacement,
    232        const std::string& canonical_name,
    233        int latency_ms)
    234       : resolver_type(resolver_type),
    235         host_pattern(host_pattern),
    236         address_family(address_family),
    237         host_resolver_flags(host_resolver_flags),
    238         replacement(replacement),
    239         canonical_name(canonical_name),
    240         latency_ms(latency_ms) {}
    241 };
    242 
    243 RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous)
    244     : HostResolverProc(previous) {
    245 }
    246 
    247 void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern,
    248                                         const std::string& replacement) {
    249   AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
    250                           replacement);
    251 }
    252 
    253 void RuleBasedHostResolverProc::AddRuleForAddressFamily(
    254     const std::string& host_pattern,
    255     AddressFamily address_family,
    256     const std::string& replacement) {
    257   DCHECK(!replacement.empty());
    258   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    259       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    260   Rule rule(Rule::kResolverTypeSystem,
    261             host_pattern,
    262             address_family,
    263             flags,
    264             replacement,
    265             std::string(),
    266             0);
    267   rules_.push_back(rule);
    268 }
    269 
    270 void RuleBasedHostResolverProc::AddIPLiteralRule(
    271     const std::string& host_pattern,
    272     const std::string& ip_literal,
    273     const std::string& canonical_name) {
    274   // Literals are always resolved to themselves by HostResolverImpl,
    275   // consequently we do not support remapping them.
    276   IPAddressNumber ip_number;
    277   DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number));
    278   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    279       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    280   if (!canonical_name.empty())
    281     flags |= HOST_RESOLVER_CANONNAME;
    282   Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
    283             ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name,
    284             0);
    285   rules_.push_back(rule);
    286 }
    287 
    288 void RuleBasedHostResolverProc::AddRuleWithLatency(
    289     const std::string& host_pattern,
    290     const std::string& replacement,
    291     int latency_ms) {
    292   DCHECK(!replacement.empty());
    293   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    294       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    295   Rule rule(Rule::kResolverTypeSystem,
    296             host_pattern,
    297             ADDRESS_FAMILY_UNSPECIFIED,
    298             flags,
    299             replacement,
    300             std::string(),
    301             latency_ms);
    302   rules_.push_back(rule);
    303 }
    304 
    305 void RuleBasedHostResolverProc::AllowDirectLookup(
    306     const std::string& host_pattern) {
    307   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    308       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    309   Rule rule(Rule::kResolverTypeSystem,
    310             host_pattern,
    311             ADDRESS_FAMILY_UNSPECIFIED,
    312             flags,
    313             std::string(),
    314             std::string(),
    315             0);
    316   rules_.push_back(rule);
    317 }
    318 
    319 void RuleBasedHostResolverProc::AddSimulatedFailure(
    320     const std::string& host_pattern) {
    321   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    322       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    323   Rule rule(Rule::kResolverTypeFail,
    324             host_pattern,
    325             ADDRESS_FAMILY_UNSPECIFIED,
    326             flags,
    327             std::string(),
    328             std::string(),
    329             0);
    330   rules_.push_back(rule);
    331 }
    332 
    333 void RuleBasedHostResolverProc::ClearRules() {
    334   rules_.clear();
    335 }
    336 
    337 int RuleBasedHostResolverProc::Resolve(const std::string& host,
    338                                        AddressFamily address_family,
    339                                        HostResolverFlags host_resolver_flags,
    340                                        AddressList* addrlist,
    341                                        int* os_error) {
    342   RuleList::iterator r;
    343   for (r = rules_.begin(); r != rules_.end(); ++r) {
    344     bool matches_address_family =
    345         r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
    346         r->address_family == address_family;
    347     // Flags match if all of the bitflags in host_resolver_flags are enabled
    348     // in the rule's host_resolver_flags. However, the rule may have additional
    349     // flags specified, in which case the flags should still be considered a
    350     // match.
    351     bool matches_flags = (r->host_resolver_flags & host_resolver_flags) ==
    352         host_resolver_flags;
    353     if (matches_flags && matches_address_family &&
    354         MatchPattern(host, r->host_pattern)) {
    355       if (r->latency_ms != 0) {
    356         base::PlatformThread::Sleep(
    357             base::TimeDelta::FromMilliseconds(r->latency_ms));
    358       }
    359 
    360       // Remap to a new host.
    361       const std::string& effective_host =
    362           r->replacement.empty() ? host : r->replacement;
    363 
    364       // Apply the resolving function to the remapped hostname.
    365       switch (r->resolver_type) {
    366         case Rule::kResolverTypeFail:
    367           return ERR_NAME_NOT_RESOLVED;
    368         case Rule::kResolverTypeSystem:
    369 #if defined(OS_WIN)
    370           net::EnsureWinsockInit();
    371 #endif
    372           return SystemHostResolverCall(effective_host,
    373                                         address_family,
    374                                         host_resolver_flags,
    375                                         addrlist, os_error);
    376         case Rule::kResolverTypeIPLiteral:
    377           return ParseAddressList(effective_host,
    378                                   r->canonical_name,
    379                                   addrlist);
    380         default:
    381           NOTREACHED();
    382           return ERR_UNEXPECTED;
    383       }
    384     }
    385   }
    386   return ResolveUsingPrevious(host, address_family,
    387                               host_resolver_flags, addrlist, os_error);
    388 }
    389 
    390 RuleBasedHostResolverProc::~RuleBasedHostResolverProc() {
    391 }
    392 
    393 RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() {
    394   RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL);
    395   catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
    396 
    397   // Next add a rules-based layer the use controls.
    398   return new RuleBasedHostResolverProc(catchall);
    399 }
    400 
    401 //-----------------------------------------------------------------------------
    402 
    403 int HangingHostResolver::Resolve(const RequestInfo& info,
    404                                  AddressList* addresses,
    405                                  const CompletionCallback& callback,
    406                                  RequestHandle* out_req,
    407                                  const BoundNetLog& net_log) {
    408   return ERR_IO_PENDING;
    409 }
    410 
    411 int HangingHostResolver::ResolveFromCache(const RequestInfo& info,
    412                                           AddressList* addresses,
    413                                           const BoundNetLog& net_log) {
    414   return ERR_DNS_CACHE_MISS;
    415 }
    416 
    417 //-----------------------------------------------------------------------------
    418 
    419 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {}
    420 
    421 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
    422     HostResolverProc* proc) {
    423   Init(proc);
    424 }
    425 
    426 ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
    427   HostResolverProc* old_proc =
    428       HostResolverProc::SetDefault(previous_proc_.get());
    429   // The lifetimes of multiple instances must be nested.
    430   CHECK_EQ(old_proc, current_proc_);
    431 }
    432 
    433 void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
    434   current_proc_ = proc;
    435   previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
    436   current_proc_->SetLastProc(previous_proc_.get());
    437 }
    438 
    439 }  // namespace net
    440