Home | History | Annotate | Download | only in local_discovery
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <utility>
      6 
      7 #include "base/logging.h"
      8 #include "base/memory/singleton.h"
      9 #include "base/message_loop/message_loop_proxy.h"
     10 #include "base/stl_util.h"
     11 #include "chrome/utility/local_discovery/service_discovery_client_impl.h"
     12 #include "net/dns/dns_protocol.h"
     13 #include "net/dns/record_rdata.h"
     14 
     15 namespace local_discovery {
     16 
     17 namespace {
     18 // TODO(noamsml): Make this configurable through the LocalDomainResolver
     19 // interface.
     20 const int kLocalDomainSecondAddressTimeoutMs = 100;
     21 
     22 const int kInitialRequeryTimeSeconds = 1;
     23 const int kMaxRequeryTimeSeconds = 2; // Time for last requery
     24 }
     25 
     26 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
     27     net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
     28 }
     29 
     30 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
     31 }
     32 
     33 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher(
     34     const std::string& service_type,
     35     const ServiceWatcher::UpdatedCallback& callback) {
     36   return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl(
     37       service_type, callback, mdns_client_));
     38 }
     39 
     40 scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver(
     41     const std::string& service_name,
     42     const ServiceResolver::ResolveCompleteCallback& callback) {
     43   return scoped_ptr<ServiceResolver>(new ServiceResolverImpl(
     44       service_name, callback, mdns_client_));
     45 }
     46 
     47 scoped_ptr<LocalDomainResolver>
     48 ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
     49       const std::string& domain,
     50       net::AddressFamily address_family,
     51       const LocalDomainResolver::IPAddressCallback& callback) {
     52   return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl(
     53       domain, address_family, callback, mdns_client_));
     54 }
     55 
     56 ServiceWatcherImpl::ServiceWatcherImpl(
     57     const std::string& service_type,
     58     const ServiceWatcher::UpdatedCallback& callback,
     59     net::MDnsClient* mdns_client)
     60     : service_type_(service_type), callback_(callback), started_(false),
     61       mdns_client_(mdns_client) {
     62 }
     63 
     64 void ServiceWatcherImpl::Start() {
     65   DCHECK(!started_);
     66   listener_ = mdns_client_->CreateListener(
     67       net::dns_protocol::kTypePTR, service_type_, this);
     68   started_ = listener_->Start();
     69   if (started_)
     70     ReadCachedServices();
     71 }
     72 
     73 ServiceWatcherImpl::~ServiceWatcherImpl() {
     74 }
     75 
     76 void ServiceWatcherImpl::DiscoverNewServices(bool force_update) {
     77   DCHECK(started_);
     78   if (force_update)
     79     services_.clear();
     80   SendQuery(kInitialRequeryTimeSeconds, force_update);
     81 }
     82 
     83 void ServiceWatcherImpl::ReadCachedServices() {
     84   DCHECK(started_);
     85   CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
     86                     &transaction_cache_);
     87 }
     88 
     89 bool ServiceWatcherImpl::CreateTransaction(
     90     bool network, bool cache, bool force_refresh,
     91     scoped_ptr<net::MDnsTransaction>* transaction) {
     92   int transaction_flags = 0;
     93   if (network)
     94     transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;
     95 
     96   if (cache)
     97     transaction_flags |= net::MDnsTransaction::QUERY_CACHE;
     98 
     99   // TODO(noamsml): Add flag for force_refresh when supported.
    100 
    101   if (transaction_flags) {
    102     *transaction = mdns_client_->CreateTransaction(
    103         net::dns_protocol::kTypePTR, service_type_, transaction_flags,
    104         base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
    105                    base::Unretained(this), transaction));
    106     return (*transaction)->Start();
    107   }
    108 
    109   return true;
    110 }
    111 
    112 std::string ServiceWatcherImpl::GetServiceType() const {
    113   return listener_->GetName();
    114 }
    115 
    116 void ServiceWatcherImpl::OnRecordUpdate(
    117     net::MDnsListener::UpdateType update,
    118     const net::RecordParsed* record) {
    119   DCHECK(started_);
    120   if (record->type() == net::dns_protocol::kTypePTR) {
    121     DCHECK(record->name() == GetServiceType());
    122     const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
    123 
    124     switch (update) {
    125       case net::MDnsListener::RECORD_ADDED:
    126         AddService(rdata->ptrdomain());
    127         break;
    128       case net::MDnsListener::RECORD_CHANGED:
    129         NOTREACHED();
    130         break;
    131       case net::MDnsListener::RECORD_REMOVED:
    132         RemoveService(rdata->ptrdomain());
    133         break;
    134     }
    135   } else {
    136     DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
    137            record->type() == net::dns_protocol::kTypeTXT);
    138     DCHECK(services_.find(record->name()) != services_.end());
    139 
    140     DeferUpdate(UPDATE_CHANGED, record->name());
    141   }
    142 }
    143 
    144 void ServiceWatcherImpl::OnCachePurged() {
    145   // Not yet implemented.
    146 }
    147 
    148 void ServiceWatcherImpl::OnTransactionResponse(
    149     scoped_ptr<net::MDnsTransaction>* transaction,
    150     net::MDnsTransaction::Result result,
    151     const net::RecordParsed* record) {
    152   DCHECK(started_);
    153   if (result == net::MDnsTransaction::RESULT_RECORD) {
    154     const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
    155     DCHECK(rdata);
    156     AddService(rdata->ptrdomain());
    157   } else if (result == net::MDnsTransaction::RESULT_DONE) {
    158     transaction->reset();
    159   }
    160 
    161   // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
    162   // record for PTR records on any name.
    163 }
    164 
    165 ServiceWatcherImpl::ServiceListeners::ServiceListeners(
    166     const std::string& service_name,
    167     ServiceWatcherImpl* watcher,
    168     net::MDnsClient* mdns_client) : update_pending_(false) {
    169   srv_listener_ = mdns_client->CreateListener(
    170       net::dns_protocol::kTypeSRV, service_name, watcher);
    171   txt_listener_ = mdns_client->CreateListener(
    172       net::dns_protocol::kTypeTXT, service_name, watcher);
    173 }
    174 
    175 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
    176 }
    177 
    178 bool ServiceWatcherImpl::ServiceListeners::Start() {
    179   if (!srv_listener_->Start())
    180     return false;
    181   return txt_listener_->Start();
    182 }
    183 
    184 void ServiceWatcherImpl::AddService(const std::string& service) {
    185   DCHECK(started_);
    186   std::pair<ServiceListenersMap::iterator, bool> found = services_.insert(
    187       make_pair(service, linked_ptr<ServiceListeners>(NULL)));
    188   if (found.second) {  // Newly inserted.
    189     found.first->second = linked_ptr<ServiceListeners>(
    190         new ServiceListeners(service, this, mdns_client_));
    191     bool success = found.first->second->Start();
    192 
    193     DeferUpdate(UPDATE_ADDED, service);
    194 
    195     DCHECK(success);
    196   }
    197 }
    198 
    199 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
    200                                      const std::string& service_name) {
    201   ServiceListenersMap::iterator found = services_.find(service_name);
    202 
    203   if (found != services_.end() && !found->second->update_pending()) {
    204     found->second->set_update_pending(true);
    205     base::MessageLoop::current()->PostTask(
    206         FROM_HERE,
    207         base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(),
    208                    update_type, service_name));
    209   }
    210 }
    211 
    212 void ServiceWatcherImpl::DeliverDeferredUpdate(
    213     ServiceWatcher::UpdateType update_type, const std::string& service_name) {
    214   ServiceListenersMap::iterator found = services_.find(service_name);
    215 
    216   if (found != services_.end()) {
    217     found->second->set_update_pending(false);
    218     if (!callback_.is_null())
    219       callback_.Run(update_type, service_name);
    220   }
    221 }
    222 
    223 void ServiceWatcherImpl::RemoveService(const std::string& service) {
    224   DCHECK(started_);
    225   ServiceListenersMap::iterator found = services_.find(service);
    226   if (found != services_.end()) {
    227     services_.erase(found);
    228     if (!callback_.is_null())
    229       callback_.Run(UPDATE_REMOVED, service);
    230   }
    231 }
    232 
    233 void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
    234                                       unsigned rrtype) {
    235   // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
    236   // on any name.
    237 }
    238 
    239 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) {
    240   if (timeout_seconds <= kMaxRequeryTimeSeconds) {
    241     base::MessageLoop::current()->PostDelayedTask(
    242         FROM_HERE,
    243         base::Bind(&ServiceWatcherImpl::SendQuery,
    244                    AsWeakPtr(),
    245                    timeout_seconds * 2 /*next_timeout_seconds*/,
    246                    false /*force_update*/),
    247         base::TimeDelta::FromSeconds(timeout_seconds));
    248   }
    249 }
    250 
    251 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds,
    252                                    bool force_update) {
    253   CreateTransaction(true /*network*/, false /*cache*/, force_update,
    254                     &transaction_network_);
    255   ScheduleQuery(next_timeout_seconds);
    256 }
    257 
    258 ServiceResolverImpl::ServiceResolverImpl(
    259     const std::string& service_name,
    260     const ResolveCompleteCallback& callback,
    261     net::MDnsClient* mdns_client)
    262     : service_name_(service_name), callback_(callback),
    263       metadata_resolved_(false), address_resolved_(false),
    264       mdns_client_(mdns_client) {
    265 }
    266 
    267 void ServiceResolverImpl::StartResolving() {
    268   address_resolved_ = false;
    269   metadata_resolved_ = false;
    270   service_staging_ = ServiceDescription();
    271   service_staging_.service_name = service_name_;
    272 
    273   if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
    274     ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
    275   }
    276 }
    277 
    278 ServiceResolverImpl::~ServiceResolverImpl() {
    279 }
    280 
    281 bool ServiceResolverImpl::CreateTxtTransaction() {
    282   txt_transaction_ = mdns_client_->CreateTransaction(
    283       net::dns_protocol::kTypeTXT, service_name_,
    284       net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
    285       net::MDnsTransaction::QUERY_NETWORK,
    286       base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
    287                  AsWeakPtr()));
    288   return txt_transaction_->Start();
    289 }
    290 
    291 // TODO(noamsml): quick-resolve for AAAA records.  Since A records tend to be in
    292 void ServiceResolverImpl::CreateATransaction() {
    293   a_transaction_ = mdns_client_->CreateTransaction(
    294       net::dns_protocol::kTypeA,
    295       service_staging_.address.host(),
    296       net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
    297       base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
    298                  AsWeakPtr()));
    299   a_transaction_->Start();
    300 }
    301 
    302 bool ServiceResolverImpl::CreateSrvTransaction() {
    303   srv_transaction_ = mdns_client_->CreateTransaction(
    304       net::dns_protocol::kTypeSRV, service_name_,
    305       net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
    306       net::MDnsTransaction::QUERY_NETWORK,
    307       base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
    308                  AsWeakPtr()));
    309   return srv_transaction_->Start();
    310 }
    311 
    312 std::string ServiceResolverImpl::GetName() const {
    313   return service_name_;
    314 }
    315 
    316 void ServiceResolverImpl::SrvRecordTransactionResponse(
    317     net::MDnsTransaction::Result status, const net::RecordParsed* record) {
    318   srv_transaction_.reset();
    319   if (status == net::MDnsTransaction::RESULT_RECORD) {
    320     DCHECK(record);
    321     service_staging_.address = RecordToAddress(record);
    322     service_staging_.last_seen = record->time_created();
    323     CreateATransaction();
    324   } else {
    325     ServiceNotFound(MDnsStatusToRequestStatus(status));
    326   }
    327 }
    328 
    329 void ServiceResolverImpl::TxtRecordTransactionResponse(
    330     net::MDnsTransaction::Result status, const net::RecordParsed* record) {
    331   txt_transaction_.reset();
    332   if (status == net::MDnsTransaction::RESULT_RECORD) {
    333     DCHECK(record);
    334     service_staging_.metadata = RecordToMetadata(record);
    335   } else {
    336     service_staging_.metadata = std::vector<std::string>();
    337   }
    338 
    339   metadata_resolved_ = true;
    340   AlertCallbackIfReady();
    341 }
    342 
    343 void ServiceResolverImpl::ARecordTransactionResponse(
    344     net::MDnsTransaction::Result status, const net::RecordParsed* record) {
    345   a_transaction_.reset();
    346 
    347   if (status == net::MDnsTransaction::RESULT_RECORD) {
    348     DCHECK(record);
    349     service_staging_.ip_address = RecordToIPAddress(record);
    350   } else {
    351     service_staging_.ip_address = net::IPAddressNumber();
    352   }
    353 
    354   address_resolved_ = true;
    355   AlertCallbackIfReady();
    356 }
    357 
    358 void ServiceResolverImpl::AlertCallbackIfReady() {
    359   if (metadata_resolved_ && address_resolved_) {
    360     txt_transaction_.reset();
    361     srv_transaction_.reset();
    362     a_transaction_.reset();
    363     if (!callback_.is_null())
    364       callback_.Run(STATUS_SUCCESS, service_staging_);
    365   }
    366 }
    367 
    368 void ServiceResolverImpl::ServiceNotFound(
    369     ServiceResolver::RequestStatus status) {
    370   txt_transaction_.reset();
    371   srv_transaction_.reset();
    372   a_transaction_.reset();
    373   if (!callback_.is_null())
    374     callback_.Run(status, ServiceDescription());
    375 }
    376 
    377 ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
    378     net::MDnsTransaction::Result status) const {
    379   switch (status) {
    380     case net::MDnsTransaction::RESULT_RECORD:
    381       return ServiceResolver::STATUS_SUCCESS;
    382     case net::MDnsTransaction::RESULT_NO_RESULTS:
    383       return ServiceResolver::STATUS_REQUEST_TIMEOUT;
    384     case net::MDnsTransaction::RESULT_NSEC:
    385       return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
    386     case net::MDnsTransaction::RESULT_DONE:  // Pass through.
    387     default:
    388       NOTREACHED();
    389       return ServiceResolver::STATUS_REQUEST_TIMEOUT;
    390   }
    391 }
    392 
    393 const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
    394     const net::RecordParsed* record) const {
    395   DCHECK(record->type() == net::dns_protocol::kTypeTXT);
    396   const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>();
    397   DCHECK(txt_rdata);
    398   return txt_rdata->texts();
    399 }
    400 
    401 net::HostPortPair ServiceResolverImpl::RecordToAddress(
    402     const net::RecordParsed* record) const {
    403   DCHECK(record->type() == net::dns_protocol::kTypeSRV);
    404   const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
    405   DCHECK(srv_rdata);
    406   return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
    407 }
    408 
    409 const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress(
    410     const net::RecordParsed* record) const {
    411   DCHECK(record->type() == net::dns_protocol::kTypeA);
    412   const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>();
    413   DCHECK(a_rdata);
    414   return a_rdata->address();
    415 }
    416 
    417 LocalDomainResolverImpl::LocalDomainResolverImpl(
    418     const std::string& domain,
    419     net::AddressFamily address_family,
    420     const IPAddressCallback& callback,
    421     net::MDnsClient* mdns_client)
    422     : domain_(domain), address_family_(address_family), callback_(callback),
    423       transactions_finished_(0), mdns_client_(mdns_client) {
    424 }
    425 
    426 LocalDomainResolverImpl::~LocalDomainResolverImpl() {
    427   timeout_callback_.Cancel();
    428 }
    429 
    430 void LocalDomainResolverImpl::Start() {
    431   if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
    432       address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
    433     transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
    434     transaction_a_->Start();
    435   }
    436 
    437   if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
    438       address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
    439     transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
    440     transaction_aaaa_->Start();
    441   }
    442 }
    443 
    444 scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction(
    445     uint16 type) {
    446   return mdns_client_->CreateTransaction(
    447       type, domain_, net::MDnsTransaction::SINGLE_RESULT |
    448                      net::MDnsTransaction::QUERY_CACHE |
    449                      net::MDnsTransaction::QUERY_NETWORK,
    450       base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
    451                  base::Unretained(this)));
    452 }
    453 
    454 void LocalDomainResolverImpl::OnTransactionComplete(
    455     net::MDnsTransaction::Result result, const net::RecordParsed* record) {
    456   transactions_finished_++;
    457 
    458   if (result == net::MDnsTransaction::RESULT_RECORD) {
    459     if (record->type() == net::dns_protocol::kTypeA) {
    460       const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>();
    461       address_ipv4_ = rdata->address();
    462     } else {
    463       DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
    464       const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>();
    465       address_ipv6_ = rdata->address();
    466     }
    467   }
    468 
    469   if (transactions_finished_ == 1 &&
    470       address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
    471     timeout_callback_.Reset(base::Bind(
    472         &LocalDomainResolverImpl::SendResolvedAddresses,
    473         base::Unretained(this)));
    474 
    475     base::MessageLoop::current()->PostDelayedTask(
    476         FROM_HERE,
    477         timeout_callback_.callback(),
    478         base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs));
    479   } else if (transactions_finished_ == 2
    480       || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) {
    481     SendResolvedAddresses();
    482   }
    483 }
    484 
    485 bool LocalDomainResolverImpl::IsSuccess() {
    486   return !address_ipv4_.empty() || !address_ipv6_.empty();
    487 }
    488 
    489 void LocalDomainResolverImpl::SendResolvedAddresses() {
    490   transaction_a_.reset();
    491   transaction_aaaa_.reset();
    492   timeout_callback_.Cancel();
    493   callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_);
    494 }
    495 
    496 }  // namespace local_discovery
    497