Home | History | Annotate | Download | only in local_discovery
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/browser/local_discovery/service_discovery_client_mdns.h"
      6 
      7 #include "base/memory/scoped_vector.h"
      8 #include "base/metrics/histogram.h"
      9 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
     10 #include "content/public/browser/browser_thread.h"
     11 #include "net/dns/mdns_client.h"
     12 #include "net/udp/datagram_server_socket.h"
     13 
     14 namespace local_discovery {
     15 
     16 using content::BrowserThread;
     17 
     18 // Base class for objects returned by ServiceDiscoveryClient implementation.
     19 // Handles interaction of client code on UI thread end net code on mdns thread.
     20 class ServiceDiscoveryClientMdns::Proxy {
     21  public:
     22   typedef base::WeakPtr<Proxy> WeakPtr;
     23 
     24   explicit Proxy(ServiceDiscoveryClientMdns* client)
     25       : client_(client),
     26         weak_ptr_factory_(this) {
     27     DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
     28     client_->proxies_.AddObserver(this);
     29   }
     30 
     31   virtual ~Proxy() {
     32     DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
     33     client_->proxies_.RemoveObserver(this);
     34   }
     35 
     36   // Returns true if object is not yet shutdown.
     37   virtual bool IsValid() = 0;
     38 
     39   // Notifies proxies that mDNS layer is going to be destroyed.
     40   virtual void OnMdnsDestroy() = 0;
     41 
     42   // Notifies proxies that new mDNS instance is ready.
     43   virtual void OnNewMdnsReady() {
     44     DCHECK(!client_->need_dalay_mdns_tasks_);
     45     if (IsValid()) {
     46       for (size_t i = 0; i < delayed_tasks_.size(); ++i)
     47         client_->mdns_runner_->PostTask(FROM_HERE, delayed_tasks_[i]);
     48     }
     49     delayed_tasks_.clear();
     50   }
     51 
     52   // Runs callback using this method to abort callback if instance of |Proxy|
     53   // is deleted.
     54   void RunCallback(const base::Closure& callback) {
     55     DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
     56     callback.Run();
     57   }
     58 
     59  protected:
     60   void PostToMdnsThread(const base::Closure& task) {
     61     DCHECK(IsValid());
     62     // The first task on IO thread for each |mdns_| instance must be |InitMdns|.
     63     // |OnInterfaceListReady| could be delayed by |GetMDnsInterfacesToBind|
     64     // running on FILE thread, so |PostToMdnsThread| could be called to post
     65     // task for |mdns_| that is not initialized yet.
     66     if (!client_->need_dalay_mdns_tasks_) {
     67       client_->mdns_runner_->PostTask(FROM_HERE, task);
     68       return;
     69     }
     70     delayed_tasks_.push_back(task);
     71   }
     72 
     73   static bool PostToUIThread(const base::Closure& task) {
     74     return BrowserThread::PostTask(BrowserThread::UI, FROM_HERE, task);
     75   }
     76 
     77   ServiceDiscoveryClient* client() {
     78     return client_->client_.get();
     79   }
     80 
     81   WeakPtr GetWeakPtr() {
     82     return weak_ptr_factory_.GetWeakPtr();
     83   }
     84 
     85   template<class T>
     86   void DeleteOnMdnsThread(T* t) {
     87     if (!t)
     88       return;
     89     if (!client_->mdns_runner_->DeleteSoon(FROM_HERE, t))
     90       delete t;
     91   }
     92 
     93  private:
     94   scoped_refptr<ServiceDiscoveryClientMdns> client_;
     95   base::WeakPtrFactory<Proxy> weak_ptr_factory_;
     96   // Delayed |mdns_runner_| tasks.
     97   std::vector<base::Closure> delayed_tasks_;
     98   DISALLOW_COPY_AND_ASSIGN(Proxy);
     99 };
    100 
    101 namespace {
    102 
    103 const int kMaxRestartAttempts = 10;
    104 const int kRestartDelayOnNetworkChangeSeconds = 3;
    105 
    106 typedef base::Callback<void(bool)> MdnsInitCallback;
    107 
    108 class SocketFactory : public net::MDnsSocketFactory {
    109  public:
    110   explicit SocketFactory(const net::InterfaceIndexFamilyList& interfaces)
    111       : interfaces_(interfaces) {}
    112 
    113   // net::MDnsSocketFactory implementation:
    114   virtual void CreateSockets(
    115       ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE {
    116     for (size_t i = 0; i < interfaces_.size(); ++i) {
    117       DCHECK(interfaces_[i].second == net::ADDRESS_FAMILY_IPV4 ||
    118              interfaces_[i].second == net::ADDRESS_FAMILY_IPV6);
    119       scoped_ptr<net::DatagramServerSocket> socket(
    120           CreateAndBindMDnsSocket(interfaces_[i].second, interfaces_[i].first));
    121       if (socket)
    122         sockets->push_back(socket.release());
    123     }
    124   }
    125 
    126  private:
    127   net::InterfaceIndexFamilyList interfaces_;
    128 };
    129 
    130 void InitMdns(const MdnsInitCallback& on_initialized,
    131               const net::InterfaceIndexFamilyList& interfaces,
    132               net::MDnsClient* mdns) {
    133   SocketFactory socket_factory(interfaces);
    134   BrowserThread::PostTask(BrowserThread::UI, FROM_HERE,
    135                           base::Bind(on_initialized,
    136                                      mdns->StartListening(&socket_factory)));
    137 }
    138 
    139 template<class T>
    140 class ProxyBase : public ServiceDiscoveryClientMdns::Proxy, public T {
    141  public:
    142   typedef ProxyBase<T> Base;
    143 
    144   explicit ProxyBase(ServiceDiscoveryClientMdns* client)
    145       : Proxy(client) {
    146   }
    147 
    148   virtual ~ProxyBase() {
    149     DeleteOnMdnsThread(implementation_.release());
    150   }
    151 
    152   virtual bool IsValid() OVERRIDE {
    153     return !!implementation();
    154   }
    155 
    156   virtual void OnMdnsDestroy() OVERRIDE {
    157     DeleteOnMdnsThread(implementation_.release());
    158   };
    159 
    160  protected:
    161   void set_implementation(scoped_ptr<T> implementation) {
    162     implementation_ = implementation.Pass();
    163   }
    164 
    165   T* implementation()  const {
    166     return implementation_.get();
    167   }
    168 
    169  private:
    170   scoped_ptr<T> implementation_;
    171   DISALLOW_COPY_AND_ASSIGN(ProxyBase);
    172 };
    173 
    174 class ServiceWatcherProxy : public ProxyBase<ServiceWatcher> {
    175  public:
    176   ServiceWatcherProxy(ServiceDiscoveryClientMdns* client_mdns,
    177                       const std::string& service_type,
    178                       const ServiceWatcher::UpdatedCallback& callback)
    179       : ProxyBase(client_mdns),
    180         service_type_(service_type),
    181         callback_(callback) {
    182     // It's safe to call |CreateServiceWatcher| on UI thread, because
    183     // |MDnsClient| is not used there. It's simplify implementation.
    184     set_implementation(client()->CreateServiceWatcher(
    185         service_type,
    186         base::Bind(&ServiceWatcherProxy::OnCallback, GetWeakPtr(), callback)));
    187   }
    188 
    189   // ServiceWatcher methods.
    190   virtual void Start() OVERRIDE {
    191     if (implementation()) {
    192       PostToMdnsThread(base::Bind(&ServiceWatcher::Start,
    193                                   base::Unretained(implementation())));
    194     }
    195   }
    196 
    197   virtual void DiscoverNewServices(bool force_update) OVERRIDE {
    198     if (implementation()) {
    199       PostToMdnsThread(base::Bind(&ServiceWatcher::DiscoverNewServices,
    200                                   base::Unretained(implementation()),
    201                                   force_update));
    202     }
    203   }
    204 
    205   virtual void SetActivelyRefreshServices(
    206       bool actively_refresh_services) OVERRIDE {
    207     if (implementation()) {
    208       PostToMdnsThread(base::Bind(&ServiceWatcher::SetActivelyRefreshServices,
    209                                   base::Unretained(implementation()),
    210                                   actively_refresh_services));
    211     }
    212   }
    213 
    214   virtual std::string GetServiceType() const OVERRIDE {
    215     return service_type_;
    216   }
    217 
    218   virtual void OnNewMdnsReady() OVERRIDE {
    219     ProxyBase<ServiceWatcher>::OnNewMdnsReady();
    220     if (!implementation())
    221       callback_.Run(ServiceWatcher::UPDATE_INVALIDATED, "");
    222   }
    223 
    224  private:
    225   static void OnCallback(const WeakPtr& proxy,
    226                          const ServiceWatcher::UpdatedCallback& callback,
    227                          UpdateType a1,
    228                          const std::string& a2) {
    229     DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
    230     PostToUIThread(base::Bind(&Base::RunCallback, proxy,
    231                               base::Bind(callback, a1, a2)));
    232   }
    233   std::string service_type_;
    234   ServiceWatcher::UpdatedCallback callback_;
    235   DISALLOW_COPY_AND_ASSIGN(ServiceWatcherProxy);
    236 };
    237 
    238 class ServiceResolverProxy : public ProxyBase<ServiceResolver> {
    239  public:
    240   ServiceResolverProxy(ServiceDiscoveryClientMdns* client_mdns,
    241                        const std::string& service_name,
    242                        const ServiceResolver::ResolveCompleteCallback& callback)
    243       : ProxyBase(client_mdns),
    244         service_name_(service_name) {
    245     // It's safe to call |CreateServiceResolver| on UI thread, because
    246     // |MDnsClient| is not used there. It's simplify implementation.
    247     set_implementation(client()->CreateServiceResolver(
    248         service_name,
    249         base::Bind(&ServiceResolverProxy::OnCallback, GetWeakPtr(), callback)));
    250   }
    251 
    252   // ServiceResolver methods.
    253   virtual void StartResolving() OVERRIDE {
    254     if (implementation()) {
    255       PostToMdnsThread(base::Bind(&ServiceResolver::StartResolving,
    256                                   base::Unretained(implementation())));
    257     }
    258   };
    259 
    260   virtual std::string GetName() const OVERRIDE {
    261     return service_name_;
    262   }
    263 
    264  private:
    265   static void OnCallback(
    266       const WeakPtr& proxy,
    267       const ServiceResolver::ResolveCompleteCallback& callback,
    268       RequestStatus a1,
    269       const ServiceDescription& a2) {
    270     DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
    271     PostToUIThread(base::Bind(&Base::RunCallback, proxy,
    272                               base::Bind(callback, a1, a2)));
    273   }
    274 
    275   std::string service_name_;
    276   DISALLOW_COPY_AND_ASSIGN(ServiceResolverProxy);
    277 };
    278 
    279 class LocalDomainResolverProxy : public ProxyBase<LocalDomainResolver> {
    280  public:
    281   LocalDomainResolverProxy(
    282       ServiceDiscoveryClientMdns* client_mdns,
    283       const std::string& domain,
    284       net::AddressFamily address_family,
    285       const LocalDomainResolver::IPAddressCallback& callback)
    286       : ProxyBase(client_mdns) {
    287     // It's safe to call |CreateLocalDomainResolver| on UI thread, because
    288     // |MDnsClient| is not used there. It's simplify implementation.
    289     set_implementation(client()->CreateLocalDomainResolver(
    290         domain,
    291         address_family,
    292         base::Bind(
    293             &LocalDomainResolverProxy::OnCallback, GetWeakPtr(), callback)));
    294   }
    295 
    296   // LocalDomainResolver methods.
    297   virtual void Start() OVERRIDE {
    298     if (implementation()) {
    299       PostToMdnsThread(base::Bind(&LocalDomainResolver::Start,
    300                                   base::Unretained(implementation())));
    301     }
    302   };
    303 
    304  private:
    305   static void OnCallback(const WeakPtr& proxy,
    306                          const LocalDomainResolver::IPAddressCallback& callback,
    307                          bool a1,
    308                          const net::IPAddressNumber& a2,
    309                          const net::IPAddressNumber& a3) {
    310     DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI));
    311     PostToUIThread(base::Bind(&Base::RunCallback, proxy,
    312                               base::Bind(callback, a1, a2, a3)));
    313   }
    314 
    315   DISALLOW_COPY_AND_ASSIGN(LocalDomainResolverProxy);
    316 };
    317 
    318 }  // namespace
    319 
    320 ServiceDiscoveryClientMdns::ServiceDiscoveryClientMdns()
    321     : mdns_runner_(
    322           BrowserThread::GetMessageLoopProxyForThread(BrowserThread::IO)),
    323       restart_attempts_(0),
    324       need_dalay_mdns_tasks_(true),
    325       weak_ptr_factory_(this) {
    326   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    327   net::NetworkChangeNotifier::AddNetworkChangeObserver(this);
    328   StartNewClient();
    329 }
    330 
    331 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientMdns::CreateServiceWatcher(
    332     const std::string& service_type,
    333     const ServiceWatcher::UpdatedCallback& callback) {
    334   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    335   return scoped_ptr<ServiceWatcher>(
    336       new ServiceWatcherProxy(this, service_type, callback));
    337 }
    338 
    339 scoped_ptr<ServiceResolver> ServiceDiscoveryClientMdns::CreateServiceResolver(
    340     const std::string& service_name,
    341     const ServiceResolver::ResolveCompleteCallback& callback) {
    342   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    343   return scoped_ptr<ServiceResolver>(
    344       new ServiceResolverProxy(this, service_name, callback));
    345 }
    346 
    347 scoped_ptr<LocalDomainResolver>
    348 ServiceDiscoveryClientMdns::CreateLocalDomainResolver(
    349     const std::string& domain,
    350     net::AddressFamily address_family,
    351     const LocalDomainResolver::IPAddressCallback& callback) {
    352   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    353   return scoped_ptr<LocalDomainResolver>(
    354       new LocalDomainResolverProxy(this, domain, address_family, callback));
    355 }
    356 
    357 ServiceDiscoveryClientMdns::~ServiceDiscoveryClientMdns() {
    358   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    359   net::NetworkChangeNotifier::RemoveNetworkChangeObserver(this);
    360   DestroyMdns();
    361 }
    362 
    363 void ServiceDiscoveryClientMdns::OnNetworkChanged(
    364     net::NetworkChangeNotifier::ConnectionType type) {
    365   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    366   // Only network changes resets counter.
    367   restart_attempts_ = 0;
    368   ScheduleStartNewClient();
    369 }
    370 
    371 void ServiceDiscoveryClientMdns::ScheduleStartNewClient() {
    372   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    373   OnBeforeMdnsDestroy();
    374   if (restart_attempts_ < kMaxRestartAttempts) {
    375     base::MessageLoop::current()->PostDelayedTask(
    376         FROM_HERE,
    377         base::Bind(&ServiceDiscoveryClientMdns::StartNewClient,
    378                    weak_ptr_factory_.GetWeakPtr()),
    379         base::TimeDelta::FromSeconds(
    380             kRestartDelayOnNetworkChangeSeconds * (1 << restart_attempts_)));
    381   } else {
    382     ReportSuccess();
    383   }
    384 }
    385 
    386 void ServiceDiscoveryClientMdns::StartNewClient() {
    387   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    388   ++restart_attempts_;
    389   DestroyMdns();
    390   mdns_.reset(net::MDnsClient::CreateDefault().release());
    391   client_.reset(new ServiceDiscoveryClientImpl(mdns_.get()));
    392   BrowserThread::PostTaskAndReplyWithResult(
    393       BrowserThread::FILE,
    394       FROM_HERE,
    395       base::Bind(&net::GetMDnsInterfacesToBind),
    396       base::Bind(&ServiceDiscoveryClientMdns::OnInterfaceListReady,
    397                  weak_ptr_factory_.GetWeakPtr()));
    398 }
    399 
    400 void ServiceDiscoveryClientMdns::OnInterfaceListReady(
    401     const net::InterfaceIndexFamilyList& interfaces) {
    402   mdns_runner_->PostTask(
    403       FROM_HERE,
    404       base::Bind(&InitMdns,
    405                  base::Bind(&ServiceDiscoveryClientMdns::OnMdnsInitialized,
    406                             weak_ptr_factory_.GetWeakPtr()),
    407                  interfaces,
    408                  base::Unretained(mdns_.get())));
    409 }
    410 
    411 void ServiceDiscoveryClientMdns::OnMdnsInitialized(bool success) {
    412   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    413   if (!success) {
    414     ScheduleStartNewClient();
    415     return;
    416   }
    417   ReportSuccess();
    418 
    419   // Initialization is done, no need to delay tasks.
    420   need_dalay_mdns_tasks_ = false;
    421   FOR_EACH_OBSERVER(Proxy, proxies_, OnNewMdnsReady());
    422 }
    423 
    424 void ServiceDiscoveryClientMdns::ReportSuccess() {
    425   DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
    426   UMA_HISTOGRAM_COUNTS_100("LocalDiscovery.ClientRestartAttempts",
    427                            restart_attempts_);
    428 }
    429 
    430 void ServiceDiscoveryClientMdns::OnBeforeMdnsDestroy() {
    431   need_dalay_mdns_tasks_ = true;
    432   weak_ptr_factory_.InvalidateWeakPtrs();
    433   FOR_EACH_OBSERVER(Proxy, proxies_, OnMdnsDestroy());
    434 }
    435 
    436 void ServiceDiscoveryClientMdns::DestroyMdns() {
    437   OnBeforeMdnsDestroy();
    438   // After calling |Proxy::OnMdnsDestroy| all references to client_ and mdns_
    439   // should be destroyed.
    440   if (client_)
    441     mdns_runner_->DeleteSoon(FROM_HERE, client_.release());
    442   if (mdns_)
    443     mdns_runner_->DeleteSoon(FROM_HERE, mdns_.release());
    444 }
    445 
    446 }  // namespace local_discovery
    447