1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/dns/mock_host_resolver.h" 6 7 #include <string> 8 #include <vector> 9 10 #include "base/bind.h" 11 #include "base/memory/ref_counted.h" 12 #include "base/message_loop/message_loop.h" 13 #include "base/stl_util.h" 14 #include "base/strings/string_split.h" 15 #include "base/strings/string_util.h" 16 #include "base/threading/platform_thread.h" 17 #include "net/base/ip_endpoint.h" 18 #include "net/base/net_errors.h" 19 #include "net/base/net_util.h" 20 #include "net/base/test_completion_callback.h" 21 #include "net/dns/host_cache.h" 22 23 #if defined(OS_WIN) 24 #include "net/base/winsock_init.h" 25 #endif 26 27 namespace net { 28 29 namespace { 30 31 // Cache size for the MockCachingHostResolver. 32 const unsigned kMaxCacheEntries = 100; 33 // TTL for the successful resolutions. Failures are not cached. 34 const unsigned kCacheEntryTTLSeconds = 60; 35 36 } // namespace 37 38 int ParseAddressList(const std::string& host_list, 39 const std::string& canonical_name, 40 AddressList* addrlist) { 41 *addrlist = AddressList(); 42 std::vector<std::string> addresses; 43 base::SplitString(host_list, ',', &addresses); 44 addrlist->set_canonical_name(canonical_name); 45 for (size_t index = 0; index < addresses.size(); ++index) { 46 IPAddressNumber ip_number; 47 if (!ParseIPLiteralToNumber(addresses[index], &ip_number)) { 48 LOG(WARNING) << "Not a supported IP literal: " << addresses[index]; 49 return ERR_UNEXPECTED; 50 } 51 addrlist->push_back(IPEndPoint(ip_number, -1)); 52 } 53 return OK; 54 } 55 56 struct MockHostResolverBase::Request { 57 Request(const RequestInfo& req_info, 58 AddressList* addr, 59 const CompletionCallback& cb) 60 : info(req_info), addresses(addr), callback(cb) {} 61 RequestInfo info; 62 AddressList* addresses; 63 CompletionCallback callback; 64 }; 65 66 MockHostResolverBase::~MockHostResolverBase() { 67 STLDeleteValues(&requests_); 68 } 69 70 int MockHostResolverBase::Resolve(const RequestInfo& info, 71 RequestPriority priority, 72 AddressList* addresses, 73 const CompletionCallback& callback, 74 RequestHandle* handle, 75 const BoundNetLog& net_log) { 76 DCHECK(CalledOnValidThread()); 77 last_request_priority_ = priority; 78 num_resolve_++; 79 size_t id = next_request_id_++; 80 int rv = ResolveFromIPLiteralOrCache(info, addresses); 81 if (rv != ERR_DNS_CACHE_MISS) { 82 return rv; 83 } 84 if (synchronous_mode_) { 85 return ResolveProc(id, info, addresses); 86 } 87 // Store the request for asynchronous resolution 88 Request* req = new Request(info, addresses, callback); 89 requests_[id] = req; 90 if (handle) 91 *handle = reinterpret_cast<RequestHandle>(id); 92 93 if (!ondemand_mode_) { 94 base::MessageLoop::current()->PostTask( 95 FROM_HERE, 96 base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id)); 97 } 98 99 return ERR_IO_PENDING; 100 } 101 102 int MockHostResolverBase::ResolveFromCache(const RequestInfo& info, 103 AddressList* addresses, 104 const BoundNetLog& net_log) { 105 num_resolve_from_cache_++; 106 DCHECK(CalledOnValidThread()); 107 next_request_id_++; 108 int rv = ResolveFromIPLiteralOrCache(info, addresses); 109 return rv; 110 } 111 112 void MockHostResolverBase::CancelRequest(RequestHandle handle) { 113 DCHECK(CalledOnValidThread()); 114 size_t id = reinterpret_cast<size_t>(handle); 115 RequestMap::iterator it = requests_.find(id); 116 if (it != requests_.end()) { 117 scoped_ptr<Request> req(it->second); 118 requests_.erase(it); 119 } else { 120 NOTREACHED() << "CancelRequest must NOT be called after request is " 121 "complete or canceled."; 122 } 123 } 124 125 HostCache* MockHostResolverBase::GetHostCache() { 126 return cache_.get(); 127 } 128 129 void MockHostResolverBase::ResolveAllPending() { 130 DCHECK(CalledOnValidThread()); 131 DCHECK(ondemand_mode_); 132 for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) { 133 base::MessageLoop::current()->PostTask( 134 FROM_HERE, 135 base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first)); 136 } 137 } 138 139 // start id from 1 to distinguish from NULL RequestHandle 140 MockHostResolverBase::MockHostResolverBase(bool use_caching) 141 : last_request_priority_(DEFAULT_PRIORITY), 142 synchronous_mode_(false), 143 ondemand_mode_(false), 144 next_request_id_(1), 145 num_resolve_(0), 146 num_resolve_from_cache_(0) { 147 rules_ = CreateCatchAllHostResolverProc(); 148 149 if (use_caching) { 150 cache_.reset(new HostCache(kMaxCacheEntries)); 151 } 152 } 153 154 int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info, 155 AddressList* addresses) { 156 IPAddressNumber ip; 157 if (ParseIPLiteralToNumber(info.hostname(), &ip)) { 158 // This matches the behavior HostResolverImpl. 159 if (info.address_family() != ADDRESS_FAMILY_UNSPECIFIED && 160 info.address_family() != GetAddressFamily(ip)) { 161 return ERR_NAME_NOT_RESOLVED; 162 } 163 164 *addresses = AddressList::CreateFromIPAddress(ip, info.port()); 165 if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME) 166 addresses->SetDefaultCanonicalName(); 167 return OK; 168 } 169 int rv = ERR_DNS_CACHE_MISS; 170 if (cache_.get() && info.allow_cached_response()) { 171 HostCache::Key key(info.hostname(), 172 info.address_family(), 173 info.host_resolver_flags()); 174 const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now()); 175 if (entry) { 176 rv = entry->error; 177 if (rv == OK) 178 *addresses = AddressList::CopyWithPort(entry->addrlist, info.port()); 179 } 180 } 181 return rv; 182 } 183 184 int MockHostResolverBase::ResolveProc(size_t id, 185 const RequestInfo& info, 186 AddressList* addresses) { 187 AddressList addr; 188 int rv = rules_->Resolve(info.hostname(), 189 info.address_family(), 190 info.host_resolver_flags(), 191 &addr, 192 NULL); 193 if (cache_.get()) { 194 HostCache::Key key(info.hostname(), 195 info.address_family(), 196 info.host_resolver_flags()); 197 // Storing a failure with TTL 0 so that it overwrites previous value. 198 base::TimeDelta ttl; 199 if (rv == OK) 200 ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds); 201 cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl); 202 } 203 if (rv == OK) 204 *addresses = AddressList::CopyWithPort(addr, info.port()); 205 return rv; 206 } 207 208 void MockHostResolverBase::ResolveNow(size_t id) { 209 RequestMap::iterator it = requests_.find(id); 210 if (it == requests_.end()) 211 return; // was canceled 212 213 scoped_ptr<Request> req(it->second); 214 requests_.erase(it); 215 int rv = ResolveProc(id, req->info, req->addresses); 216 if (!req->callback.is_null()) 217 req->callback.Run(rv); 218 } 219 220 //----------------------------------------------------------------------------- 221 222 struct RuleBasedHostResolverProc::Rule { 223 enum ResolverType { 224 kResolverTypeFail, 225 kResolverTypeSystem, 226 kResolverTypeIPLiteral, 227 }; 228 229 ResolverType resolver_type; 230 std::string host_pattern; 231 AddressFamily address_family; 232 HostResolverFlags host_resolver_flags; 233 std::string replacement; 234 std::string canonical_name; 235 int latency_ms; // In milliseconds. 236 237 Rule(ResolverType resolver_type, 238 const std::string& host_pattern, 239 AddressFamily address_family, 240 HostResolverFlags host_resolver_flags, 241 const std::string& replacement, 242 const std::string& canonical_name, 243 int latency_ms) 244 : resolver_type(resolver_type), 245 host_pattern(host_pattern), 246 address_family(address_family), 247 host_resolver_flags(host_resolver_flags), 248 replacement(replacement), 249 canonical_name(canonical_name), 250 latency_ms(latency_ms) {} 251 }; 252 253 RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous) 254 : HostResolverProc(previous) { 255 } 256 257 void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern, 258 const std::string& replacement) { 259 AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED, 260 replacement); 261 } 262 263 void RuleBasedHostResolverProc::AddRuleForAddressFamily( 264 const std::string& host_pattern, 265 AddressFamily address_family, 266 const std::string& replacement) { 267 DCHECK(!replacement.empty()); 268 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | 269 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; 270 Rule rule(Rule::kResolverTypeSystem, 271 host_pattern, 272 address_family, 273 flags, 274 replacement, 275 std::string(), 276 0); 277 rules_.push_back(rule); 278 } 279 280 void RuleBasedHostResolverProc::AddIPLiteralRule( 281 const std::string& host_pattern, 282 const std::string& ip_literal, 283 const std::string& canonical_name) { 284 // Literals are always resolved to themselves by HostResolverImpl, 285 // consequently we do not support remapping them. 286 IPAddressNumber ip_number; 287 DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number)); 288 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | 289 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; 290 if (!canonical_name.empty()) 291 flags |= HOST_RESOLVER_CANONNAME; 292 Rule rule(Rule::kResolverTypeIPLiteral, host_pattern, 293 ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name, 294 0); 295 rules_.push_back(rule); 296 } 297 298 void RuleBasedHostResolverProc::AddRuleWithLatency( 299 const std::string& host_pattern, 300 const std::string& replacement, 301 int latency_ms) { 302 DCHECK(!replacement.empty()); 303 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | 304 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; 305 Rule rule(Rule::kResolverTypeSystem, 306 host_pattern, 307 ADDRESS_FAMILY_UNSPECIFIED, 308 flags, 309 replacement, 310 std::string(), 311 latency_ms); 312 rules_.push_back(rule); 313 } 314 315 void RuleBasedHostResolverProc::AllowDirectLookup( 316 const std::string& host_pattern) { 317 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | 318 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; 319 Rule rule(Rule::kResolverTypeSystem, 320 host_pattern, 321 ADDRESS_FAMILY_UNSPECIFIED, 322 flags, 323 std::string(), 324 std::string(), 325 0); 326 rules_.push_back(rule); 327 } 328 329 void RuleBasedHostResolverProc::AddSimulatedFailure( 330 const std::string& host_pattern) { 331 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | 332 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; 333 Rule rule(Rule::kResolverTypeFail, 334 host_pattern, 335 ADDRESS_FAMILY_UNSPECIFIED, 336 flags, 337 std::string(), 338 std::string(), 339 0); 340 rules_.push_back(rule); 341 } 342 343 void RuleBasedHostResolverProc::ClearRules() { 344 rules_.clear(); 345 } 346 347 int RuleBasedHostResolverProc::Resolve(const std::string& host, 348 AddressFamily address_family, 349 HostResolverFlags host_resolver_flags, 350 AddressList* addrlist, 351 int* os_error) { 352 RuleList::iterator r; 353 for (r = rules_.begin(); r != rules_.end(); ++r) { 354 bool matches_address_family = 355 r->address_family == ADDRESS_FAMILY_UNSPECIFIED || 356 r->address_family == address_family; 357 // Ignore HOST_RESOLVER_SYSTEM_ONLY, since it should have no impact on 358 // whether a rule matches. 359 HostResolverFlags flags = host_resolver_flags & ~HOST_RESOLVER_SYSTEM_ONLY; 360 // Flags match if all of the bitflags in host_resolver_flags are enabled 361 // in the rule's host_resolver_flags. However, the rule may have additional 362 // flags specified, in which case the flags should still be considered a 363 // match. 364 bool matches_flags = (r->host_resolver_flags & flags) == flags; 365 if (matches_flags && matches_address_family && 366 MatchPattern(host, r->host_pattern)) { 367 if (r->latency_ms != 0) { 368 base::PlatformThread::Sleep( 369 base::TimeDelta::FromMilliseconds(r->latency_ms)); 370 } 371 372 // Remap to a new host. 373 const std::string& effective_host = 374 r->replacement.empty() ? host : r->replacement; 375 376 // Apply the resolving function to the remapped hostname. 377 switch (r->resolver_type) { 378 case Rule::kResolverTypeFail: 379 return ERR_NAME_NOT_RESOLVED; 380 case Rule::kResolverTypeSystem: 381 #if defined(OS_WIN) 382 net::EnsureWinsockInit(); 383 #endif 384 return SystemHostResolverCall(effective_host, 385 address_family, 386 host_resolver_flags, 387 addrlist, os_error); 388 case Rule::kResolverTypeIPLiteral: 389 return ParseAddressList(effective_host, 390 r->canonical_name, 391 addrlist); 392 default: 393 NOTREACHED(); 394 return ERR_UNEXPECTED; 395 } 396 } 397 } 398 return ResolveUsingPrevious(host, address_family, 399 host_resolver_flags, addrlist, os_error); 400 } 401 402 RuleBasedHostResolverProc::~RuleBasedHostResolverProc() { 403 } 404 405 RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() { 406 RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL); 407 catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost"); 408 409 // Next add a rules-based layer the use controls. 410 return new RuleBasedHostResolverProc(catchall); 411 } 412 413 //----------------------------------------------------------------------------- 414 415 int HangingHostResolver::Resolve(const RequestInfo& info, 416 RequestPriority priority, 417 AddressList* addresses, 418 const CompletionCallback& callback, 419 RequestHandle* out_req, 420 const BoundNetLog& net_log) { 421 return ERR_IO_PENDING; 422 } 423 424 int HangingHostResolver::ResolveFromCache(const RequestInfo& info, 425 AddressList* addresses, 426 const BoundNetLog& net_log) { 427 return ERR_DNS_CACHE_MISS; 428 } 429 430 //----------------------------------------------------------------------------- 431 432 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {} 433 434 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc( 435 HostResolverProc* proc) { 436 Init(proc); 437 } 438 439 ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() { 440 HostResolverProc* old_proc = 441 HostResolverProc::SetDefault(previous_proc_.get()); 442 // The lifetimes of multiple instances must be nested. 443 CHECK_EQ(old_proc, current_proc_.get()); 444 } 445 446 void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) { 447 current_proc_ = proc; 448 previous_proc_ = HostResolverProc::SetDefault(current_proc_.get()); 449 current_proc_->SetLastProc(previous_proc_.get()); 450 } 451 452 } // namespace net 453