Home | History | Annotate | Download | only in local_discovery
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <utility>
      6 
      7 #include "base/logging.h"
      8 #include "base/memory/singleton.h"
      9 #include "base/message_loop/message_loop_proxy.h"
     10 #include "base/stl_util.h"
     11 #include "chrome/utility/local_discovery/service_discovery_client_impl.h"
     12 #include "net/dns/dns_protocol.h"
     13 #include "net/dns/record_rdata.h"
     14 
     15 namespace local_discovery {
     16 
     17 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
     18     net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
     19 }
     20 
     21 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
     22 }
     23 
     24 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher(
     25     const std::string& service_type,
     26     const ServiceWatcher::UpdatedCallback& callback) {
     27   return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl(
     28       service_type,  callback, mdns_client_));
     29 }
     30 
     31 scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver(
     32     const std::string& service_name,
     33     const ServiceResolver::ResolveCompleteCallback& callback) {
     34   return scoped_ptr<ServiceResolver>(new ServiceResolverImpl(
     35       service_name, callback, mdns_client_));
     36 }
     37 
     38 scoped_ptr<LocalDomainResolver>
     39 ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
     40       const std::string& domain,
     41       net::AddressFamily address_family,
     42       const LocalDomainResolver::IPAddressCallback& callback) {
     43   return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl(
     44       domain, address_family, callback, mdns_client_));
     45 }
     46 
     47 ServiceWatcherImpl::ServiceWatcherImpl(
     48     const std::string& service_type,
     49     const ServiceWatcher::UpdatedCallback& callback,
     50     net::MDnsClient* mdns_client)
     51     : service_type_(service_type), callback_(callback), started_(false),
     52       mdns_client_(mdns_client) {
     53 }
     54 
     55 void ServiceWatcherImpl::Start() {
     56   DCHECK(!started_);
     57   listener_ = mdns_client_->CreateListener(
     58       net::dns_protocol::kTypePTR, service_type_, this);
     59   started_ = listener_->Start();
     60   if (started_)
     61     ReadCachedServices();
     62 }
     63 
     64 ServiceWatcherImpl::~ServiceWatcherImpl() {
     65 }
     66 
     67 void ServiceWatcherImpl::DiscoverNewServices(bool force_update) {
     68   DCHECK(started_);
     69   if (force_update)
     70     services_.clear();
     71   CreateTransaction(true /*network*/, false /*cache*/, force_update,
     72                     &transaction_network_);
     73 }
     74 
     75 void ServiceWatcherImpl::ReadCachedServices() {
     76   DCHECK(started_);
     77   CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/,
     78                     &transaction_cache_);
     79 }
     80 
     81 bool ServiceWatcherImpl::CreateTransaction(
     82     bool network, bool cache, bool force_refresh,
     83     scoped_ptr<net::MDnsTransaction>* transaction) {
     84   int transaction_flags = 0;
     85   if (network)
     86     transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;
     87 
     88   if (cache)
     89     transaction_flags |= net::MDnsTransaction::QUERY_CACHE;
     90 
     91   // TODO(noamsml): Add flag for force_refresh when supported.
     92 
     93   if (transaction_flags) {
     94     *transaction = mdns_client_->CreateTransaction(
     95         net::dns_protocol::kTypePTR, service_type_, transaction_flags,
     96         base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
     97                    base::Unretained(this), transaction));
     98     return (*transaction)->Start();
     99   }
    100 
    101   return true;
    102 }
    103 
    104 std::string ServiceWatcherImpl::GetServiceType() const {
    105   return listener_->GetName();
    106 }
    107 
    108 void ServiceWatcherImpl::OnRecordUpdate(
    109     net::MDnsListener::UpdateType update,
    110     const net::RecordParsed* record) {
    111   DCHECK(started_);
    112   if (record->type() == net::dns_protocol::kTypePTR) {
    113     DCHECK(record->name() == GetServiceType());
    114     const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
    115 
    116     switch (update) {
    117       case net::MDnsListener::RECORD_ADDED:
    118         AddService(rdata->ptrdomain());
    119         break;
    120       case net::MDnsListener::RECORD_CHANGED:
    121         NOTREACHED();
    122         break;
    123       case net::MDnsListener::RECORD_REMOVED:
    124         RemoveService(rdata->ptrdomain());
    125         break;
    126     }
    127   } else {
    128     DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
    129            record->type() == net::dns_protocol::kTypeTXT);
    130     DCHECK(services_.find(record->name()) != services_.end());
    131 
    132     DeferUpdate(UPDATE_CHANGED, record->name());
    133   }
    134 }
    135 
    136 void ServiceWatcherImpl::OnCachePurged() {
    137   // Not yet implemented.
    138 }
    139 
    140 void ServiceWatcherImpl::OnTransactionResponse(
    141     scoped_ptr<net::MDnsTransaction>* transaction,
    142     net::MDnsTransaction::Result result,
    143     const net::RecordParsed* record) {
    144   DCHECK(started_);
    145   if (result == net::MDnsTransaction::RESULT_RECORD) {
    146     const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();
    147     DCHECK(rdata);
    148     AddService(rdata->ptrdomain());
    149   } else if (result == net::MDnsTransaction::RESULT_DONE) {
    150     transaction->reset();
    151   }
    152 
    153   // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
    154   // record for PTR records on any name.
    155 }
    156 
    157 ServiceWatcherImpl::ServiceListeners::ServiceListeners(
    158     const std::string& service_name,
    159     ServiceWatcherImpl* watcher,
    160     net::MDnsClient* mdns_client) : update_pending_(false) {
    161   srv_listener_ = mdns_client->CreateListener(
    162       net::dns_protocol::kTypeSRV, service_name, watcher);
    163   txt_listener_ = mdns_client->CreateListener(
    164       net::dns_protocol::kTypeTXT, service_name, watcher);
    165 }
    166 
    167 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
    168 }
    169 
    170 bool ServiceWatcherImpl::ServiceListeners::Start() {
    171   if (!srv_listener_->Start())
    172     return false;
    173   return txt_listener_->Start();
    174 }
    175 
    176 void ServiceWatcherImpl::AddService(const std::string& service) {
    177   DCHECK(started_);
    178   std::pair<ServiceListenersMap::iterator, bool> found = services_.insert(
    179       make_pair(service, linked_ptr<ServiceListeners>(NULL)));
    180   if (found.second) {  // Newly inserted.
    181     found.first->second = linked_ptr<ServiceListeners>(
    182         new ServiceListeners(service, this, mdns_client_));
    183     bool success = found.first->second->Start();
    184 
    185     DeferUpdate(UPDATE_ADDED, service);
    186 
    187     DCHECK(success);
    188   }
    189 }
    190 
    191 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
    192                                      const std::string& service_name) {
    193   ServiceListenersMap::iterator found = services_.find(service_name);
    194 
    195   if (found != services_.end() && !found->second->update_pending()) {
    196     found->second->set_update_pending(true);
    197     base::MessageLoop::current()->PostTask(
    198         FROM_HERE,
    199         base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(),
    200                    update_type, service_name));
    201   }
    202 }
    203 
    204 void ServiceWatcherImpl::DeliverDeferredUpdate(
    205     ServiceWatcher::UpdateType update_type, const std::string& service_name) {
    206   ServiceListenersMap::iterator found = services_.find(service_name);
    207 
    208   if (found != services_.end()) {
    209     found->second->set_update_pending(false);
    210     if (!callback_.is_null())
    211       callback_.Run(update_type, service_name);
    212   }
    213 }
    214 
    215 void ServiceWatcherImpl::RemoveService(const std::string& service) {
    216   DCHECK(started_);
    217   ServiceListenersMap::iterator found = services_.find(service);
    218   if (found != services_.end()) {
    219     services_.erase(found);
    220     if (!callback_.is_null())
    221       callback_.Run(UPDATE_REMOVED, service);
    222   }
    223 }
    224 
    225 void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
    226                                       unsigned rrtype) {
    227   // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
    228   // on any name.
    229 }
    230 
    231 ServiceResolverImpl::ServiceResolverImpl(
    232     const std::string& service_name,
    233     const ResolveCompleteCallback& callback,
    234     net::MDnsClient* mdns_client)
    235     : service_name_(service_name), callback_(callback),
    236       metadata_resolved_(false), address_resolved_(false),
    237       mdns_client_(mdns_client) {
    238   service_staging_.service_name = service_name_;
    239 }
    240 
    241 void ServiceResolverImpl::StartResolving() {
    242   address_resolved_ = false;
    243   metadata_resolved_ = false;
    244 
    245   if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
    246     ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
    247   }
    248 }
    249 
    250 ServiceResolverImpl::~ServiceResolverImpl() {
    251 }
    252 
    253 bool ServiceResolverImpl::CreateTxtTransaction() {
    254   txt_transaction_ = mdns_client_->CreateTransaction(
    255       net::dns_protocol::kTypeTXT, service_name_,
    256       net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
    257       net::MDnsTransaction::QUERY_NETWORK,
    258       base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
    259                  AsWeakPtr()));
    260   return txt_transaction_->Start();
    261 }
    262 
    263 // TODO(noamsml): quick-resolve for AAAA records.  Since A records tend to be in
    264 void ServiceResolverImpl::CreateATransaction() {
    265   a_transaction_ = mdns_client_->CreateTransaction(
    266       net::dns_protocol::kTypeA,
    267       service_staging_.address.host(),
    268       net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
    269       base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
    270                  AsWeakPtr()));
    271   a_transaction_->Start();
    272 }
    273 
    274 bool ServiceResolverImpl::CreateSrvTransaction() {
    275   srv_transaction_ = mdns_client_->CreateTransaction(
    276       net::dns_protocol::kTypeSRV, service_name_,
    277       net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
    278       net::MDnsTransaction::QUERY_NETWORK,
    279       base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
    280                  AsWeakPtr()));
    281   return srv_transaction_->Start();
    282 }
    283 
    284 std::string ServiceResolverImpl::GetName() const {
    285   return service_name_;
    286 }
    287 
    288 void ServiceResolverImpl::SrvRecordTransactionResponse(
    289     net::MDnsTransaction::Result status, const net::RecordParsed* record) {
    290   srv_transaction_.reset();
    291   if (status == net::MDnsTransaction::RESULT_RECORD) {
    292     DCHECK(record);
    293     service_staging_.address = RecordToAddress(record);
    294     service_staging_.last_seen = record->time_created();
    295     CreateATransaction();
    296   } else {
    297     ServiceNotFound(MDnsStatusToRequestStatus(status));
    298   }
    299 }
    300 
    301 void ServiceResolverImpl::TxtRecordTransactionResponse(
    302     net::MDnsTransaction::Result status, const net::RecordParsed* record) {
    303   txt_transaction_.reset();
    304   if (status == net::MDnsTransaction::RESULT_RECORD) {
    305     DCHECK(record);
    306     service_staging_.metadata = RecordToMetadata(record);
    307   } else {
    308     service_staging_.metadata = std::vector<std::string>();
    309   }
    310 
    311   metadata_resolved_ = true;
    312   AlertCallbackIfReady();
    313 }
    314 
    315 void ServiceResolverImpl::ARecordTransactionResponse(
    316     net::MDnsTransaction::Result status, const net::RecordParsed* record) {
    317   a_transaction_.reset();
    318 
    319   if (status == net::MDnsTransaction::RESULT_RECORD) {
    320     DCHECK(record);
    321     service_staging_.ip_address = RecordToIPAddress(record);
    322   } else {
    323     service_staging_.ip_address = net::IPAddressNumber();
    324   }
    325 
    326   address_resolved_ = true;
    327   AlertCallbackIfReady();
    328 }
    329 
    330 void ServiceResolverImpl::AlertCallbackIfReady() {
    331   if (metadata_resolved_ && address_resolved_) {
    332     txt_transaction_.reset();
    333     srv_transaction_.reset();
    334     a_transaction_.reset();
    335     if (!callback_.is_null())
    336       callback_.Run(STATUS_SUCCESS, service_staging_);
    337     service_staging_ = ServiceDescription();
    338   }
    339 }
    340 
    341 void ServiceResolverImpl::ServiceNotFound(
    342     ServiceResolver::RequestStatus status) {
    343   txt_transaction_.reset();
    344   srv_transaction_.reset();
    345   a_transaction_.reset();
    346   if (!callback_.is_null())
    347     callback_.Run(status, ServiceDescription());
    348 }
    349 
    350 ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
    351     net::MDnsTransaction::Result status) const {
    352   switch (status) {
    353     case net::MDnsTransaction::RESULT_RECORD:
    354       return ServiceResolver::STATUS_SUCCESS;
    355     case net::MDnsTransaction::RESULT_NO_RESULTS:
    356       return ServiceResolver::STATUS_REQUEST_TIMEOUT;
    357     case net::MDnsTransaction::RESULT_NSEC:
    358       return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
    359     case net::MDnsTransaction::RESULT_DONE:  // Pass through.
    360     default:
    361       NOTREACHED();
    362       return ServiceResolver::STATUS_REQUEST_TIMEOUT;
    363   }
    364 }
    365 
    366 const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
    367     const net::RecordParsed* record) const {
    368   DCHECK(record->type() == net::dns_protocol::kTypeTXT);
    369   const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>();
    370   DCHECK(txt_rdata);
    371   return txt_rdata->texts();
    372 }
    373 
    374 net::HostPortPair ServiceResolverImpl::RecordToAddress(
    375     const net::RecordParsed* record) const {
    376   DCHECK(record->type() == net::dns_protocol::kTypeSRV);
    377   const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
    378   DCHECK(srv_rdata);
    379   return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
    380 }
    381 
    382 const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress(
    383     const net::RecordParsed* record) const {
    384   DCHECK(record->type() == net::dns_protocol::kTypeA);
    385   const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>();
    386   DCHECK(a_rdata);
    387   return a_rdata->address();
    388 }
    389 
    390 LocalDomainResolverImpl::LocalDomainResolverImpl(
    391     const std::string& domain,
    392     net::AddressFamily address_family,
    393     const IPAddressCallback& callback,
    394     net::MDnsClient* mdns_client)
    395     : domain_(domain), address_family_(address_family), callback_(callback),
    396       transaction_failures_(0), mdns_client_(mdns_client) {
    397 }
    398 
    399 LocalDomainResolverImpl::~LocalDomainResolverImpl() {
    400 }
    401 
    402 void LocalDomainResolverImpl::Start() {
    403   if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
    404       address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
    405     transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
    406     transaction_a_->Start();
    407   }
    408 
    409   if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
    410       address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
    411     transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
    412     transaction_aaaa_->Start();
    413   }
    414 }
    415 
    416 scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction(
    417     uint16 type) {
    418   return mdns_client_->CreateTransaction(
    419       type, domain_, net::MDnsTransaction::SINGLE_RESULT |
    420                      net::MDnsTransaction::QUERY_CACHE |
    421                      net::MDnsTransaction::QUERY_NETWORK,
    422       base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
    423                  base::Unretained(this)));
    424 }
    425 
    426 void LocalDomainResolverImpl::OnTransactionComplete(
    427     net::MDnsTransaction::Result result, const net::RecordParsed* record) {
    428   if (result != net::MDnsTransaction::RESULT_RECORD &&
    429       address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
    430     transaction_failures_++;
    431 
    432     if (transaction_failures_ < 2) {
    433       return;
    434     }
    435   }
    436 
    437   transaction_a_.reset();
    438   transaction_aaaa_.reset();
    439 
    440   net::IPAddressNumber address;
    441   if (result == net::MDnsTransaction::RESULT_RECORD) {
    442     if (record->type() == net::dns_protocol::kTypeA) {
    443       const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>();
    444       address = rdata->address();
    445     } else {
    446       DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
    447       const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>();
    448       address = rdata->address();
    449     }
    450   }
    451 
    452   callback_.Run(result == net::MDnsTransaction::RESULT_RECORD, address);
    453 }
    454 
    455 }  // namespace local_discovery
    456