Home | History | Annotate | Download | only in dns
      1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/dns/mdns_cache.h"
      6 
      7 #include <algorithm>
      8 #include <utility>
      9 
     10 #include "base/stl_util.h"
     11 #include "base/strings/string_number_conversions.h"
     12 #include "net/dns/dns_protocol.h"
     13 #include "net/dns/record_parsed.h"
     14 #include "net/dns/record_rdata.h"
     15 
     16 // TODO(noamsml): Recursive CNAME closure (backwards and forwards).
     17 
     18 namespace net {
     19 
     20 // The effective TTL given to records with a nominal zero TTL.
     21 // Allows time for hosts to send updated records, as detailed in RFC 6762
     22 // Section 10.1.
     23 static const unsigned kZeroTTLSeconds = 1;
     24 
     25 MDnsCache::Key::Key(unsigned type, const std::string& name,
     26                     const std::string& optional)
     27     : type_(type), name_(name), optional_(optional) {
     28 }
     29 
     30 MDnsCache::Key::Key(
     31     const MDnsCache::Key& other)
     32     : type_(other.type_), name_(other.name_), optional_(other.optional_) {
     33 }
     34 
     35 
     36 MDnsCache::Key& MDnsCache::Key::operator=(
     37     const MDnsCache::Key& other) {
     38   type_ = other.type_;
     39   name_ = other.name_;
     40   optional_ = other.optional_;
     41   return *this;
     42 }
     43 
     44 MDnsCache::Key::~Key() {
     45 }
     46 
     47 bool MDnsCache::Key::operator<(const MDnsCache::Key& key) const {
     48   if (name_ != key.name_)
     49     return name_ < key.name_;
     50 
     51   if (type_ != key.type_)
     52     return type_ < key.type_;
     53 
     54   if (optional_ != key.optional_)
     55     return optional_ < key.optional_;
     56   return false;  // keys are equal
     57 }
     58 
     59 bool MDnsCache::Key::operator==(const MDnsCache::Key& key) const {
     60   return type_ == key.type_ && name_ == key.name_ && optional_ == key.optional_;
     61 }
     62 
     63 // static
     64 MDnsCache::Key MDnsCache::Key::CreateFor(const RecordParsed* record) {
     65   return Key(record->type(),
     66              record->name(),
     67              GetOptionalFieldForRecord(record));
     68 }
     69 
     70 
     71 MDnsCache::MDnsCache() {
     72 }
     73 
     74 MDnsCache::~MDnsCache() {
     75   Clear();
     76 }
     77 
     78 void MDnsCache::Clear() {
     79   next_expiration_ = base::Time();
     80   STLDeleteValues(&mdns_cache_);
     81 }
     82 
     83 const RecordParsed* MDnsCache::LookupKey(const Key& key) {
     84   RecordMap::iterator found = mdns_cache_.find(key);
     85   if (found != mdns_cache_.end()) {
     86     return found->second;
     87   }
     88   return NULL;
     89 }
     90 
     91 MDnsCache::UpdateType MDnsCache::UpdateDnsRecord(
     92     scoped_ptr<const RecordParsed> record) {
     93   Key cache_key = Key::CreateFor(record.get());
     94 
     95   // Ignore "goodbye" packets for records not in cache.
     96   if (record->ttl() == 0 && mdns_cache_.find(cache_key) == mdns_cache_.end())
     97     return NoChange;
     98 
     99   base::Time new_expiration = GetEffectiveExpiration(record.get());
    100   if (next_expiration_ != base::Time())
    101     new_expiration = std::min(new_expiration, next_expiration_);
    102 
    103   std::pair<RecordMap::iterator, bool> insert_result =
    104       mdns_cache_.insert(std::make_pair(cache_key, (const RecordParsed*)NULL));
    105   UpdateType type = NoChange;
    106   if (insert_result.second) {
    107     type = RecordAdded;
    108   } else {
    109     const RecordParsed* other_record = insert_result.first->second;
    110 
    111     if (record->ttl() != 0 && !record->IsEqual(other_record, true)) {
    112       type = RecordChanged;
    113     }
    114     delete other_record;
    115   }
    116 
    117   insert_result.first->second = record.release();
    118   next_expiration_ = new_expiration;
    119   return type;
    120 }
    121 
    122 void MDnsCache::CleanupRecords(
    123     base::Time now,
    124     const RecordRemovedCallback& record_removed_callback) {
    125   base::Time next_expiration;
    126 
    127   // We are guaranteed that |next_expiration_| will be at or before the next
    128   // expiration. This allows clients to eagrely call CleanupRecords with
    129   // impunity.
    130   if (now < next_expiration_) return;
    131 
    132   for (RecordMap::iterator i = mdns_cache_.begin();
    133        i != mdns_cache_.end(); ) {
    134     base::Time expiration = GetEffectiveExpiration(i->second);
    135     if (now >= expiration) {
    136       record_removed_callback.Run(i->second);
    137       delete i->second;
    138       mdns_cache_.erase(i++);
    139     } else {
    140       if (next_expiration == base::Time() ||  expiration < next_expiration) {
    141         next_expiration = expiration;
    142       }
    143       ++i;
    144     }
    145   }
    146 
    147   next_expiration_ = next_expiration;
    148 }
    149 
    150 void MDnsCache::FindDnsRecords(unsigned type,
    151                                const std::string& name,
    152                                std::vector<const RecordParsed*>* results,
    153                                base::Time now) const {
    154   DCHECK(results);
    155   results->clear();
    156 
    157   RecordMap::const_iterator i = mdns_cache_.lower_bound(Key(type, name, ""));
    158   for (; i != mdns_cache_.end(); ++i) {
    159     if (i->first.name() != name ||
    160         (type != 0 && i->first.type() != type)) {
    161       break;
    162     }
    163 
    164     const RecordParsed* record = i->second;
    165 
    166     // Records are deleted only upon request.
    167     if (now >= GetEffectiveExpiration(record)) continue;
    168 
    169     results->push_back(record);
    170   }
    171 }
    172 
    173 scoped_ptr<const RecordParsed> MDnsCache::RemoveRecord(
    174     const RecordParsed* record) {
    175   Key key = Key::CreateFor(record);
    176   RecordMap::iterator found = mdns_cache_.find(key);
    177 
    178   if (found != mdns_cache_.end() && found->second == record) {
    179     mdns_cache_.erase(key);
    180     return scoped_ptr<const RecordParsed>(record);
    181   }
    182 
    183   return scoped_ptr<const RecordParsed>();
    184 }
    185 
    186 // static
    187 std::string MDnsCache::GetOptionalFieldForRecord(
    188     const RecordParsed* record) {
    189   switch (record->type()) {
    190     case PtrRecordRdata::kType: {
    191       const PtrRecordRdata* rdata = record->rdata<PtrRecordRdata>();
    192       return rdata->ptrdomain();
    193     }
    194     default:  // Most records are considered unique for our purposes
    195       return "";
    196   }
    197 }
    198 
    199 // static
    200 base::Time MDnsCache::GetEffectiveExpiration(const RecordParsed* record) {
    201   base::TimeDelta ttl;
    202 
    203   if (record->ttl()) {
    204     ttl = base::TimeDelta::FromSeconds(record->ttl());
    205   } else {
    206     ttl = base::TimeDelta::FromSeconds(kZeroTTLSeconds);
    207   }
    208 
    209   return record->time_created() + ttl;
    210 }
    211 
    212 }  // namespace net
    213