Home | History | Annotate | Download | only in dns
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/dns/mock_host_resolver.h"
      6 
      7 #include <string>
      8 #include <vector>
      9 
     10 #include "base/bind.h"
     11 #include "base/memory/ref_counted.h"
     12 #include "base/message_loop/message_loop.h"
     13 #include "base/stl_util.h"
     14 #include "base/strings/string_split.h"
     15 #include "base/strings/string_util.h"
     16 #include "base/threading/platform_thread.h"
     17 #include "net/base/net_errors.h"
     18 #include "net/base/net_util.h"
     19 #include "net/base/test_completion_callback.h"
     20 #include "net/dns/host_cache.h"
     21 
     22 #if defined(OS_WIN)
     23 #include "net/base/winsock_init.h"
     24 #endif
     25 
     26 namespace net {
     27 
     28 namespace {
     29 
     30 // Cache size for the MockCachingHostResolver.
     31 const unsigned kMaxCacheEntries = 100;
     32 // TTL for the successful resolutions. Failures are not cached.
     33 const unsigned kCacheEntryTTLSeconds = 60;
     34 
     35 }  // namespace
     36 
     37 int ParseAddressList(const std::string& host_list,
     38                      const std::string& canonical_name,
     39                      AddressList* addrlist) {
     40   *addrlist = AddressList();
     41   std::vector<std::string> addresses;
     42   base::SplitString(host_list, ',', &addresses);
     43   addrlist->set_canonical_name(canonical_name);
     44   for (size_t index = 0; index < addresses.size(); ++index) {
     45     IPAddressNumber ip_number;
     46     if (!ParseIPLiteralToNumber(addresses[index], &ip_number)) {
     47       LOG(WARNING) << "Not a supported IP literal: " << addresses[index];
     48       return ERR_UNEXPECTED;
     49     }
     50     addrlist->push_back(IPEndPoint(ip_number, -1));
     51   }
     52   return OK;
     53 }
     54 
     55 struct MockHostResolverBase::Request {
     56   Request(const RequestInfo& req_info,
     57           AddressList* addr,
     58           const CompletionCallback& cb)
     59       : info(req_info), addresses(addr), callback(cb) {}
     60   RequestInfo info;
     61   AddressList* addresses;
     62   CompletionCallback callback;
     63 };
     64 
     65 MockHostResolverBase::~MockHostResolverBase() {
     66   STLDeleteValues(&requests_);
     67 }
     68 
     69 int MockHostResolverBase::Resolve(const RequestInfo& info,
     70                                   RequestPriority priority,
     71                                   AddressList* addresses,
     72                                   const CompletionCallback& callback,
     73                                   RequestHandle* handle,
     74                                   const BoundNetLog& net_log) {
     75   DCHECK(CalledOnValidThread());
     76   last_request_priority_ = priority;
     77   num_resolve_++;
     78   size_t id = next_request_id_++;
     79   int rv = ResolveFromIPLiteralOrCache(info, addresses);
     80   if (rv != ERR_DNS_CACHE_MISS) {
     81     return rv;
     82   }
     83   if (synchronous_mode_) {
     84     return ResolveProc(id, info, addresses);
     85   }
     86   // Store the request for asynchronous resolution
     87   Request* req = new Request(info, addresses, callback);
     88   requests_[id] = req;
     89   if (handle)
     90     *handle = reinterpret_cast<RequestHandle>(id);
     91 
     92   if (!ondemand_mode_) {
     93     base::MessageLoop::current()->PostTask(
     94         FROM_HERE,
     95         base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
     96   }
     97 
     98   return ERR_IO_PENDING;
     99 }
    100 
    101 int MockHostResolverBase::ResolveFromCache(const RequestInfo& info,
    102                                            AddressList* addresses,
    103                                            const BoundNetLog& net_log) {
    104   num_resolve_from_cache_++;
    105   DCHECK(CalledOnValidThread());
    106   next_request_id_++;
    107   int rv = ResolveFromIPLiteralOrCache(info, addresses);
    108   return rv;
    109 }
    110 
    111 void MockHostResolverBase::CancelRequest(RequestHandle handle) {
    112   DCHECK(CalledOnValidThread());
    113   size_t id = reinterpret_cast<size_t>(handle);
    114   RequestMap::iterator it = requests_.find(id);
    115   if (it != requests_.end()) {
    116     scoped_ptr<Request> req(it->second);
    117     requests_.erase(it);
    118   } else {
    119     NOTREACHED() << "CancelRequest must NOT be called after request is "
    120         "complete or canceled.";
    121   }
    122 }
    123 
    124 HostCache* MockHostResolverBase::GetHostCache() {
    125   return cache_.get();
    126 }
    127 
    128 void MockHostResolverBase::ResolveAllPending() {
    129   DCHECK(CalledOnValidThread());
    130   DCHECK(ondemand_mode_);
    131   for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) {
    132     base::MessageLoop::current()->PostTask(
    133         FROM_HERE,
    134         base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first));
    135   }
    136 }
    137 
    138 // start id from 1 to distinguish from NULL RequestHandle
    139 MockHostResolverBase::MockHostResolverBase(bool use_caching)
    140     : last_request_priority_(DEFAULT_PRIORITY),
    141       synchronous_mode_(false),
    142       ondemand_mode_(false),
    143       next_request_id_(1),
    144       num_resolve_(0),
    145       num_resolve_from_cache_(0) {
    146   rules_ = CreateCatchAllHostResolverProc();
    147 
    148   if (use_caching) {
    149     cache_.reset(new HostCache(kMaxCacheEntries));
    150   }
    151 }
    152 
    153 int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info,
    154                                                       AddressList* addresses) {
    155   IPAddressNumber ip;
    156   if (ParseIPLiteralToNumber(info.hostname(), &ip)) {
    157     *addresses = AddressList::CreateFromIPAddress(ip, info.port());
    158     if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME)
    159       addresses->SetDefaultCanonicalName();
    160     return OK;
    161   }
    162   int rv = ERR_DNS_CACHE_MISS;
    163   if (cache_.get() && info.allow_cached_response()) {
    164     HostCache::Key key(info.hostname(),
    165                        info.address_family(),
    166                        info.host_resolver_flags());
    167     const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now());
    168     if (entry) {
    169       rv = entry->error;
    170       if (rv == OK)
    171         *addresses = AddressList::CopyWithPort(entry->addrlist, info.port());
    172     }
    173   }
    174   return rv;
    175 }
    176 
    177 int MockHostResolverBase::ResolveProc(size_t id,
    178                                       const RequestInfo& info,
    179                                       AddressList* addresses) {
    180   AddressList addr;
    181   int rv = rules_->Resolve(info.hostname(),
    182                            info.address_family(),
    183                            info.host_resolver_flags(),
    184                            &addr,
    185                            NULL);
    186   if (cache_.get()) {
    187     HostCache::Key key(info.hostname(),
    188                        info.address_family(),
    189                        info.host_resolver_flags());
    190     // Storing a failure with TTL 0 so that it overwrites previous value.
    191     base::TimeDelta ttl;
    192     if (rv == OK)
    193       ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds);
    194     cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl);
    195   }
    196   if (rv == OK)
    197     *addresses = AddressList::CopyWithPort(addr, info.port());
    198   return rv;
    199 }
    200 
    201 void MockHostResolverBase::ResolveNow(size_t id) {
    202   RequestMap::iterator it = requests_.find(id);
    203   if (it == requests_.end())
    204     return;  // was canceled
    205 
    206   scoped_ptr<Request> req(it->second);
    207   requests_.erase(it);
    208   int rv = ResolveProc(id, req->info, req->addresses);
    209   if (!req->callback.is_null())
    210     req->callback.Run(rv);
    211 }
    212 
    213 //-----------------------------------------------------------------------------
    214 
    215 struct RuleBasedHostResolverProc::Rule {
    216   enum ResolverType {
    217     kResolverTypeFail,
    218     kResolverTypeSystem,
    219     kResolverTypeIPLiteral,
    220   };
    221 
    222   ResolverType resolver_type;
    223   std::string host_pattern;
    224   AddressFamily address_family;
    225   HostResolverFlags host_resolver_flags;
    226   std::string replacement;
    227   std::string canonical_name;
    228   int latency_ms;  // In milliseconds.
    229 
    230   Rule(ResolverType resolver_type,
    231        const std::string& host_pattern,
    232        AddressFamily address_family,
    233        HostResolverFlags host_resolver_flags,
    234        const std::string& replacement,
    235        const std::string& canonical_name,
    236        int latency_ms)
    237       : resolver_type(resolver_type),
    238         host_pattern(host_pattern),
    239         address_family(address_family),
    240         host_resolver_flags(host_resolver_flags),
    241         replacement(replacement),
    242         canonical_name(canonical_name),
    243         latency_ms(latency_ms) {}
    244 };
    245 
    246 RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous)
    247     : HostResolverProc(previous) {
    248 }
    249 
    250 void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern,
    251                                         const std::string& replacement) {
    252   AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
    253                           replacement);
    254 }
    255 
    256 void RuleBasedHostResolverProc::AddRuleForAddressFamily(
    257     const std::string& host_pattern,
    258     AddressFamily address_family,
    259     const std::string& replacement) {
    260   DCHECK(!replacement.empty());
    261   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    262       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    263   Rule rule(Rule::kResolverTypeSystem,
    264             host_pattern,
    265             address_family,
    266             flags,
    267             replacement,
    268             std::string(),
    269             0);
    270   rules_.push_back(rule);
    271 }
    272 
    273 void RuleBasedHostResolverProc::AddIPLiteralRule(
    274     const std::string& host_pattern,
    275     const std::string& ip_literal,
    276     const std::string& canonical_name) {
    277   // Literals are always resolved to themselves by HostResolverImpl,
    278   // consequently we do not support remapping them.
    279   IPAddressNumber ip_number;
    280   DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number));
    281   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    282       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    283   if (!canonical_name.empty())
    284     flags |= HOST_RESOLVER_CANONNAME;
    285   Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
    286             ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name,
    287             0);
    288   rules_.push_back(rule);
    289 }
    290 
    291 void RuleBasedHostResolverProc::AddRuleWithLatency(
    292     const std::string& host_pattern,
    293     const std::string& replacement,
    294     int latency_ms) {
    295   DCHECK(!replacement.empty());
    296   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    297       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    298   Rule rule(Rule::kResolverTypeSystem,
    299             host_pattern,
    300             ADDRESS_FAMILY_UNSPECIFIED,
    301             flags,
    302             replacement,
    303             std::string(),
    304             latency_ms);
    305   rules_.push_back(rule);
    306 }
    307 
    308 void RuleBasedHostResolverProc::AllowDirectLookup(
    309     const std::string& host_pattern) {
    310   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    311       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    312   Rule rule(Rule::kResolverTypeSystem,
    313             host_pattern,
    314             ADDRESS_FAMILY_UNSPECIFIED,
    315             flags,
    316             std::string(),
    317             std::string(),
    318             0);
    319   rules_.push_back(rule);
    320 }
    321 
    322 void RuleBasedHostResolverProc::AddSimulatedFailure(
    323     const std::string& host_pattern) {
    324   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    325       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    326   Rule rule(Rule::kResolverTypeFail,
    327             host_pattern,
    328             ADDRESS_FAMILY_UNSPECIFIED,
    329             flags,
    330             std::string(),
    331             std::string(),
    332             0);
    333   rules_.push_back(rule);
    334 }
    335 
    336 void RuleBasedHostResolverProc::ClearRules() {
    337   rules_.clear();
    338 }
    339 
    340 int RuleBasedHostResolverProc::Resolve(const std::string& host,
    341                                        AddressFamily address_family,
    342                                        HostResolverFlags host_resolver_flags,
    343                                        AddressList* addrlist,
    344                                        int* os_error) {
    345   RuleList::iterator r;
    346   for (r = rules_.begin(); r != rules_.end(); ++r) {
    347     bool matches_address_family =
    348         r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
    349         r->address_family == address_family;
    350     // Ignore HOST_RESOLVER_SYSTEM_ONLY, since it should have no impact on
    351     // whether a rule matches.
    352     HostResolverFlags flags = host_resolver_flags & ~HOST_RESOLVER_SYSTEM_ONLY;
    353     // Flags match if all of the bitflags in host_resolver_flags are enabled
    354     // in the rule's host_resolver_flags. However, the rule may have additional
    355     // flags specified, in which case the flags should still be considered a
    356     // match.
    357     bool matches_flags = (r->host_resolver_flags & flags) == flags;
    358     if (matches_flags && matches_address_family &&
    359         MatchPattern(host, r->host_pattern)) {
    360       if (r->latency_ms != 0) {
    361         base::PlatformThread::Sleep(
    362             base::TimeDelta::FromMilliseconds(r->latency_ms));
    363       }
    364 
    365       // Remap to a new host.
    366       const std::string& effective_host =
    367           r->replacement.empty() ? host : r->replacement;
    368 
    369       // Apply the resolving function to the remapped hostname.
    370       switch (r->resolver_type) {
    371         case Rule::kResolverTypeFail:
    372           return ERR_NAME_NOT_RESOLVED;
    373         case Rule::kResolverTypeSystem:
    374 #if defined(OS_WIN)
    375           net::EnsureWinsockInit();
    376 #endif
    377           return SystemHostResolverCall(effective_host,
    378                                         address_family,
    379                                         host_resolver_flags,
    380                                         addrlist, os_error);
    381         case Rule::kResolverTypeIPLiteral:
    382           return ParseAddressList(effective_host,
    383                                   r->canonical_name,
    384                                   addrlist);
    385         default:
    386           NOTREACHED();
    387           return ERR_UNEXPECTED;
    388       }
    389     }
    390   }
    391   return ResolveUsingPrevious(host, address_family,
    392                               host_resolver_flags, addrlist, os_error);
    393 }
    394 
    395 RuleBasedHostResolverProc::~RuleBasedHostResolverProc() {
    396 }
    397 
    398 RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() {
    399   RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL);
    400   catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
    401 
    402   // Next add a rules-based layer the use controls.
    403   return new RuleBasedHostResolverProc(catchall);
    404 }
    405 
    406 //-----------------------------------------------------------------------------
    407 
    408 int HangingHostResolver::Resolve(const RequestInfo& info,
    409                                  RequestPriority priority,
    410                                  AddressList* addresses,
    411                                  const CompletionCallback& callback,
    412                                  RequestHandle* out_req,
    413                                  const BoundNetLog& net_log) {
    414   return ERR_IO_PENDING;
    415 }
    416 
    417 int HangingHostResolver::ResolveFromCache(const RequestInfo& info,
    418                                           AddressList* addresses,
    419                                           const BoundNetLog& net_log) {
    420   return ERR_DNS_CACHE_MISS;
    421 }
    422 
    423 //-----------------------------------------------------------------------------
    424 
    425 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {}
    426 
    427 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
    428     HostResolverProc* proc) {
    429   Init(proc);
    430 }
    431 
    432 ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
    433   HostResolverProc* old_proc =
    434       HostResolverProc::SetDefault(previous_proc_.get());
    435   // The lifetimes of multiple instances must be nested.
    436   CHECK_EQ(old_proc, current_proc_);
    437 }
    438 
    439 void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
    440   current_proc_ = proc;
    441   previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
    442   current_proc_->SetLastProc(previous_proc_.get());
    443 }
    444 
    445 }  // namespace net
    446