Home | History | Annotate | Download | only in dns
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/dns/mock_host_resolver.h"
      6 
      7 #include <string>
      8 #include <vector>
      9 
     10 #include "base/bind.h"
     11 #include "base/memory/ref_counted.h"
     12 #include "base/message_loop/message_loop.h"
     13 #include "base/stl_util.h"
     14 #include "base/strings/string_split.h"
     15 #include "base/strings/string_util.h"
     16 #include "base/threading/platform_thread.h"
     17 #include "net/base/ip_endpoint.h"
     18 #include "net/base/net_errors.h"
     19 #include "net/base/net_util.h"
     20 #include "net/base/test_completion_callback.h"
     21 #include "net/dns/host_cache.h"
     22 
     23 #if defined(OS_WIN)
     24 #include "net/base/winsock_init.h"
     25 #endif
     26 
     27 namespace net {
     28 
     29 namespace {
     30 
     31 // Cache size for the MockCachingHostResolver.
     32 const unsigned kMaxCacheEntries = 100;
     33 // TTL for the successful resolutions. Failures are not cached.
     34 const unsigned kCacheEntryTTLSeconds = 60;
     35 
     36 }  // namespace
     37 
     38 int ParseAddressList(const std::string& host_list,
     39                      const std::string& canonical_name,
     40                      AddressList* addrlist) {
     41   *addrlist = AddressList();
     42   std::vector<std::string> addresses;
     43   base::SplitString(host_list, ',', &addresses);
     44   addrlist->set_canonical_name(canonical_name);
     45   for (size_t index = 0; index < addresses.size(); ++index) {
     46     IPAddressNumber ip_number;
     47     if (!ParseIPLiteralToNumber(addresses[index], &ip_number)) {
     48       LOG(WARNING) << "Not a supported IP literal: " << addresses[index];
     49       return ERR_UNEXPECTED;
     50     }
     51     addrlist->push_back(IPEndPoint(ip_number, -1));
     52   }
     53   return OK;
     54 }
     55 
     56 struct MockHostResolverBase::Request {
     57   Request(const RequestInfo& req_info,
     58           AddressList* addr,
     59           const CompletionCallback& cb)
     60       : info(req_info), addresses(addr), callback(cb) {}
     61   RequestInfo info;
     62   AddressList* addresses;
     63   CompletionCallback callback;
     64 };
     65 
     66 MockHostResolverBase::~MockHostResolverBase() {
     67   STLDeleteValues(&requests_);
     68 }
     69 
     70 int MockHostResolverBase::Resolve(const RequestInfo& info,
     71                                   RequestPriority priority,
     72                                   AddressList* addresses,
     73                                   const CompletionCallback& callback,
     74                                   RequestHandle* handle,
     75                                   const BoundNetLog& net_log) {
     76   DCHECK(CalledOnValidThread());
     77   last_request_priority_ = priority;
     78   num_resolve_++;
     79   size_t id = next_request_id_++;
     80   int rv = ResolveFromIPLiteralOrCache(info, addresses);
     81   if (rv != ERR_DNS_CACHE_MISS) {
     82     return rv;
     83   }
     84   if (synchronous_mode_) {
     85     return ResolveProc(id, info, addresses);
     86   }
     87   // Store the request for asynchronous resolution
     88   Request* req = new Request(info, addresses, callback);
     89   requests_[id] = req;
     90   if (handle)
     91     *handle = reinterpret_cast<RequestHandle>(id);
     92 
     93   if (!ondemand_mode_) {
     94     base::MessageLoop::current()->PostTask(
     95         FROM_HERE,
     96         base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
     97   }
     98 
     99   return ERR_IO_PENDING;
    100 }
    101 
    102 int MockHostResolverBase::ResolveFromCache(const RequestInfo& info,
    103                                            AddressList* addresses,
    104                                            const BoundNetLog& net_log) {
    105   num_resolve_from_cache_++;
    106   DCHECK(CalledOnValidThread());
    107   next_request_id_++;
    108   int rv = ResolveFromIPLiteralOrCache(info, addresses);
    109   return rv;
    110 }
    111 
    112 void MockHostResolverBase::CancelRequest(RequestHandle handle) {
    113   DCHECK(CalledOnValidThread());
    114   size_t id = reinterpret_cast<size_t>(handle);
    115   RequestMap::iterator it = requests_.find(id);
    116   if (it != requests_.end()) {
    117     scoped_ptr<Request> req(it->second);
    118     requests_.erase(it);
    119   } else {
    120     NOTREACHED() << "CancelRequest must NOT be called after request is "
    121         "complete or canceled.";
    122   }
    123 }
    124 
    125 HostCache* MockHostResolverBase::GetHostCache() {
    126   return cache_.get();
    127 }
    128 
    129 void MockHostResolverBase::ResolveAllPending() {
    130   DCHECK(CalledOnValidThread());
    131   DCHECK(ondemand_mode_);
    132   for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) {
    133     base::MessageLoop::current()->PostTask(
    134         FROM_HERE,
    135         base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first));
    136   }
    137 }
    138 
    139 // start id from 1 to distinguish from NULL RequestHandle
    140 MockHostResolverBase::MockHostResolverBase(bool use_caching)
    141     : last_request_priority_(DEFAULT_PRIORITY),
    142       synchronous_mode_(false),
    143       ondemand_mode_(false),
    144       next_request_id_(1),
    145       num_resolve_(0),
    146       num_resolve_from_cache_(0) {
    147   rules_ = CreateCatchAllHostResolverProc();
    148 
    149   if (use_caching) {
    150     cache_.reset(new HostCache(kMaxCacheEntries));
    151   }
    152 }
    153 
    154 int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info,
    155                                                       AddressList* addresses) {
    156   IPAddressNumber ip;
    157   if (ParseIPLiteralToNumber(info.hostname(), &ip)) {
    158     // This matches the behavior HostResolverImpl.
    159     if (info.address_family() != ADDRESS_FAMILY_UNSPECIFIED &&
    160         info.address_family() != GetAddressFamily(ip)) {
    161       return ERR_NAME_NOT_RESOLVED;
    162     }
    163 
    164     *addresses = AddressList::CreateFromIPAddress(ip, info.port());
    165     if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME)
    166       addresses->SetDefaultCanonicalName();
    167     return OK;
    168   }
    169   int rv = ERR_DNS_CACHE_MISS;
    170   if (cache_.get() && info.allow_cached_response()) {
    171     HostCache::Key key(info.hostname(),
    172                        info.address_family(),
    173                        info.host_resolver_flags());
    174     const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now());
    175     if (entry) {
    176       rv = entry->error;
    177       if (rv == OK)
    178         *addresses = AddressList::CopyWithPort(entry->addrlist, info.port());
    179     }
    180   }
    181   return rv;
    182 }
    183 
    184 int MockHostResolverBase::ResolveProc(size_t id,
    185                                       const RequestInfo& info,
    186                                       AddressList* addresses) {
    187   AddressList addr;
    188   int rv = rules_->Resolve(info.hostname(),
    189                            info.address_family(),
    190                            info.host_resolver_flags(),
    191                            &addr,
    192                            NULL);
    193   if (cache_.get()) {
    194     HostCache::Key key(info.hostname(),
    195                        info.address_family(),
    196                        info.host_resolver_flags());
    197     // Storing a failure with TTL 0 so that it overwrites previous value.
    198     base::TimeDelta ttl;
    199     if (rv == OK)
    200       ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds);
    201     cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl);
    202   }
    203   if (rv == OK)
    204     *addresses = AddressList::CopyWithPort(addr, info.port());
    205   return rv;
    206 }
    207 
    208 void MockHostResolverBase::ResolveNow(size_t id) {
    209   RequestMap::iterator it = requests_.find(id);
    210   if (it == requests_.end())
    211     return;  // was canceled
    212 
    213   scoped_ptr<Request> req(it->second);
    214   requests_.erase(it);
    215   int rv = ResolveProc(id, req->info, req->addresses);
    216   if (!req->callback.is_null())
    217     req->callback.Run(rv);
    218 }
    219 
    220 //-----------------------------------------------------------------------------
    221 
    222 struct RuleBasedHostResolverProc::Rule {
    223   enum ResolverType {
    224     kResolverTypeFail,
    225     kResolverTypeSystem,
    226     kResolverTypeIPLiteral,
    227   };
    228 
    229   ResolverType resolver_type;
    230   std::string host_pattern;
    231   AddressFamily address_family;
    232   HostResolverFlags host_resolver_flags;
    233   std::string replacement;
    234   std::string canonical_name;
    235   int latency_ms;  // In milliseconds.
    236 
    237   Rule(ResolverType resolver_type,
    238        const std::string& host_pattern,
    239        AddressFamily address_family,
    240        HostResolverFlags host_resolver_flags,
    241        const std::string& replacement,
    242        const std::string& canonical_name,
    243        int latency_ms)
    244       : resolver_type(resolver_type),
    245         host_pattern(host_pattern),
    246         address_family(address_family),
    247         host_resolver_flags(host_resolver_flags),
    248         replacement(replacement),
    249         canonical_name(canonical_name),
    250         latency_ms(latency_ms) {}
    251 };
    252 
    253 RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous)
    254     : HostResolverProc(previous) {
    255 }
    256 
    257 void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern,
    258                                         const std::string& replacement) {
    259   AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
    260                           replacement);
    261 }
    262 
    263 void RuleBasedHostResolverProc::AddRuleForAddressFamily(
    264     const std::string& host_pattern,
    265     AddressFamily address_family,
    266     const std::string& replacement) {
    267   DCHECK(!replacement.empty());
    268   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    269       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    270   Rule rule(Rule::kResolverTypeSystem,
    271             host_pattern,
    272             address_family,
    273             flags,
    274             replacement,
    275             std::string(),
    276             0);
    277   rules_.push_back(rule);
    278 }
    279 
    280 void RuleBasedHostResolverProc::AddIPLiteralRule(
    281     const std::string& host_pattern,
    282     const std::string& ip_literal,
    283     const std::string& canonical_name) {
    284   // Literals are always resolved to themselves by HostResolverImpl,
    285   // consequently we do not support remapping them.
    286   IPAddressNumber ip_number;
    287   DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number));
    288   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    289       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    290   if (!canonical_name.empty())
    291     flags |= HOST_RESOLVER_CANONNAME;
    292   Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
    293             ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name,
    294             0);
    295   rules_.push_back(rule);
    296 }
    297 
    298 void RuleBasedHostResolverProc::AddRuleWithLatency(
    299     const std::string& host_pattern,
    300     const std::string& replacement,
    301     int latency_ms) {
    302   DCHECK(!replacement.empty());
    303   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    304       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    305   Rule rule(Rule::kResolverTypeSystem,
    306             host_pattern,
    307             ADDRESS_FAMILY_UNSPECIFIED,
    308             flags,
    309             replacement,
    310             std::string(),
    311             latency_ms);
    312   rules_.push_back(rule);
    313 }
    314 
    315 void RuleBasedHostResolverProc::AllowDirectLookup(
    316     const std::string& host_pattern) {
    317   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    318       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    319   Rule rule(Rule::kResolverTypeSystem,
    320             host_pattern,
    321             ADDRESS_FAMILY_UNSPECIFIED,
    322             flags,
    323             std::string(),
    324             std::string(),
    325             0);
    326   rules_.push_back(rule);
    327 }
    328 
    329 void RuleBasedHostResolverProc::AddSimulatedFailure(
    330     const std::string& host_pattern) {
    331   HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
    332       HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    333   Rule rule(Rule::kResolverTypeFail,
    334             host_pattern,
    335             ADDRESS_FAMILY_UNSPECIFIED,
    336             flags,
    337             std::string(),
    338             std::string(),
    339             0);
    340   rules_.push_back(rule);
    341 }
    342 
    343 void RuleBasedHostResolverProc::ClearRules() {
    344   rules_.clear();
    345 }
    346 
    347 int RuleBasedHostResolverProc::Resolve(const std::string& host,
    348                                        AddressFamily address_family,
    349                                        HostResolverFlags host_resolver_flags,
    350                                        AddressList* addrlist,
    351                                        int* os_error) {
    352   RuleList::iterator r;
    353   for (r = rules_.begin(); r != rules_.end(); ++r) {
    354     bool matches_address_family =
    355         r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
    356         r->address_family == address_family;
    357     // Ignore HOST_RESOLVER_SYSTEM_ONLY, since it should have no impact on
    358     // whether a rule matches.
    359     HostResolverFlags flags = host_resolver_flags & ~HOST_RESOLVER_SYSTEM_ONLY;
    360     // Flags match if all of the bitflags in host_resolver_flags are enabled
    361     // in the rule's host_resolver_flags. However, the rule may have additional
    362     // flags specified, in which case the flags should still be considered a
    363     // match.
    364     bool matches_flags = (r->host_resolver_flags & flags) == flags;
    365     if (matches_flags && matches_address_family &&
    366         MatchPattern(host, r->host_pattern)) {
    367       if (r->latency_ms != 0) {
    368         base::PlatformThread::Sleep(
    369             base::TimeDelta::FromMilliseconds(r->latency_ms));
    370       }
    371 
    372       // Remap to a new host.
    373       const std::string& effective_host =
    374           r->replacement.empty() ? host : r->replacement;
    375 
    376       // Apply the resolving function to the remapped hostname.
    377       switch (r->resolver_type) {
    378         case Rule::kResolverTypeFail:
    379           return ERR_NAME_NOT_RESOLVED;
    380         case Rule::kResolverTypeSystem:
    381 #if defined(OS_WIN)
    382           net::EnsureWinsockInit();
    383 #endif
    384           return SystemHostResolverCall(effective_host,
    385                                         address_family,
    386                                         host_resolver_flags,
    387                                         addrlist, os_error);
    388         case Rule::kResolverTypeIPLiteral:
    389           return ParseAddressList(effective_host,
    390                                   r->canonical_name,
    391                                   addrlist);
    392         default:
    393           NOTREACHED();
    394           return ERR_UNEXPECTED;
    395       }
    396     }
    397   }
    398   return ResolveUsingPrevious(host, address_family,
    399                               host_resolver_flags, addrlist, os_error);
    400 }
    401 
    402 RuleBasedHostResolverProc::~RuleBasedHostResolverProc() {
    403 }
    404 
    405 RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() {
    406   RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL);
    407   catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
    408 
    409   // Next add a rules-based layer the use controls.
    410   return new RuleBasedHostResolverProc(catchall);
    411 }
    412 
    413 //-----------------------------------------------------------------------------
    414 
    415 int HangingHostResolver::Resolve(const RequestInfo& info,
    416                                  RequestPriority priority,
    417                                  AddressList* addresses,
    418                                  const CompletionCallback& callback,
    419                                  RequestHandle* out_req,
    420                                  const BoundNetLog& net_log) {
    421   return ERR_IO_PENDING;
    422 }
    423 
    424 int HangingHostResolver::ResolveFromCache(const RequestInfo& info,
    425                                           AddressList* addresses,
    426                                           const BoundNetLog& net_log) {
    427   return ERR_DNS_CACHE_MISS;
    428 }
    429 
    430 //-----------------------------------------------------------------------------
    431 
    432 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {}
    433 
    434 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
    435     HostResolverProc* proc) {
    436   Init(proc);
    437 }
    438 
    439 ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
    440   HostResolverProc* old_proc =
    441       HostResolverProc::SetDefault(previous_proc_.get());
    442   // The lifetimes of multiple instances must be nested.
    443   CHECK_EQ(old_proc, current_proc_.get());
    444 }
    445 
    446 void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
    447   current_proc_ = proc;
    448   previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
    449   current_proc_->SetLastProc(previous_proc_.get());
    450 }
    451 
    452 }  // namespace net
    453