1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/browser/local_discovery/service_discovery_client_mdns.h" 6 7 #include "base/memory/scoped_vector.h" 8 #include "base/metrics/histogram.h" 9 #include "chrome/common/local_discovery/service_discovery_client_impl.h" 10 #include "content/public/browser/browser_thread.h" 11 #include "net/dns/mdns_client.h" 12 #include "net/udp/datagram_server_socket.h" 13 14 namespace local_discovery { 15 16 using content::BrowserThread; 17 18 // Base class for objects returned by ServiceDiscoveryClient implementation. 19 // Handles interaction of client code on UI thread end net code on mdns thread. 20 class ServiceDiscoveryClientMdns::Proxy { 21 public: 22 typedef base::WeakPtr<Proxy> WeakPtr; 23 24 explicit Proxy(ServiceDiscoveryClientMdns* client) 25 : client_(client), 26 weak_ptr_factory_(this) { 27 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 28 client_->proxies_.AddObserver(this); 29 } 30 31 virtual ~Proxy() { 32 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 33 client_->proxies_.RemoveObserver(this); 34 } 35 36 // Notify proxies that mDNS layer is going to be destroyed. 37 virtual void OnMdnsDestroy() = 0; 38 39 // Notify proxies that new mDNS instance is ready. 40 virtual void OnNewMdnsReady() {} 41 42 // Run callback using this method to abort callback if instance of |Proxy| 43 // is deleted. 44 void RunCallback(const base::Closure& callback) { 45 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 46 callback.Run(); 47 } 48 49 protected: 50 bool PostToMdnsThread(const base::Closure& task) { 51 return client_->PostToMdnsThread(task); 52 } 53 54 static bool PostToUIThread(const base::Closure& task) { 55 return BrowserThread::PostTask(BrowserThread::UI, FROM_HERE, task); 56 } 57 58 ServiceDiscoveryClient* client() { 59 return client_->client_.get(); 60 } 61 62 WeakPtr GetWeakPtr() { 63 return weak_ptr_factory_.GetWeakPtr(); 64 } 65 66 template<class T> 67 void DeleteOnMdnsThread(T* t) { 68 if (!t) 69 return; 70 if (!client_->mdns_runner_->DeleteSoon(FROM_HERE, t)) 71 delete t; 72 } 73 74 private: 75 scoped_refptr<ServiceDiscoveryClientMdns> client_; 76 base::WeakPtrFactory<Proxy> weak_ptr_factory_; 77 78 DISALLOW_COPY_AND_ASSIGN(Proxy); 79 }; 80 81 namespace { 82 83 const size_t kMaxDelayedTasks = 10000; 84 const int kMaxRestartAttempts = 10; 85 const int kRestartDelayOnNetworkChangeSeconds = 3; 86 87 typedef base::Callback<void(bool)> MdnsInitCallback; 88 89 class SocketFactory : public net::MDnsSocketFactory { 90 public: 91 explicit SocketFactory(const net::InterfaceIndexFamilyList& interfaces) 92 : interfaces_(interfaces) {} 93 94 // net::MDnsSocketFactory implementation: 95 virtual void CreateSockets( 96 ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE { 97 for (size_t i = 0; i < interfaces_.size(); ++i) { 98 DCHECK(interfaces_[i].second == net::ADDRESS_FAMILY_IPV4 || 99 interfaces_[i].second == net::ADDRESS_FAMILY_IPV6); 100 scoped_ptr<net::DatagramServerSocket> socket( 101 CreateAndBindMDnsSocket(interfaces_[i].second, interfaces_[i].first)); 102 if (socket) 103 sockets->push_back(socket.release()); 104 } 105 } 106 107 private: 108 net::InterfaceIndexFamilyList interfaces_; 109 }; 110 111 void InitMdns(const MdnsInitCallback& on_initialized, 112 const net::InterfaceIndexFamilyList& interfaces, 113 net::MDnsClient* mdns) { 114 SocketFactory socket_factory(interfaces); 115 BrowserThread::PostTask(BrowserThread::UI, FROM_HERE, 116 base::Bind(on_initialized, 117 mdns->StartListening(&socket_factory))); 118 } 119 120 template<class T> 121 class ProxyBase : public ServiceDiscoveryClientMdns::Proxy, public T { 122 public: 123 typedef base::WeakPtr<Proxy> WeakPtr; 124 typedef ProxyBase<T> Base; 125 126 explicit ProxyBase(ServiceDiscoveryClientMdns* client) 127 : Proxy(client) { 128 } 129 130 virtual ~ProxyBase() { 131 DeleteOnMdnsThread(implementation_.release()); 132 } 133 134 virtual void OnMdnsDestroy() OVERRIDE { 135 DeleteOnMdnsThread(implementation_.release()); 136 }; 137 138 protected: 139 void set_implementation(scoped_ptr<T> implementation) { 140 implementation_ = implementation.Pass(); 141 } 142 143 T* implementation() const { 144 return implementation_.get(); 145 } 146 147 private: 148 scoped_ptr<T> implementation_; 149 DISALLOW_COPY_AND_ASSIGN(ProxyBase); 150 }; 151 152 class ServiceWatcherProxy : public ProxyBase<ServiceWatcher> { 153 public: 154 ServiceWatcherProxy(ServiceDiscoveryClientMdns* client_mdns, 155 const std::string& service_type, 156 const ServiceWatcher::UpdatedCallback& callback) 157 : ProxyBase(client_mdns), 158 service_type_(service_type), 159 callback_(callback) { 160 // It's safe to call |CreateServiceWatcher| on UI thread, because 161 // |MDnsClient| is not used there. It's simplify implementation. 162 set_implementation(client()->CreateServiceWatcher( 163 service_type, 164 base::Bind(&ServiceWatcherProxy::OnCallback, GetWeakPtr(), callback))); 165 } 166 167 // ServiceWatcher methods. 168 virtual void Start() OVERRIDE { 169 if (implementation()) 170 PostToMdnsThread(base::Bind(&ServiceWatcher::Start, 171 base::Unretained(implementation()))); 172 } 173 174 virtual void DiscoverNewServices(bool force_update) OVERRIDE { 175 if (implementation()) 176 PostToMdnsThread(base::Bind(&ServiceWatcher::DiscoverNewServices, 177 base::Unretained(implementation()), 178 force_update)); 179 } 180 181 virtual void SetActivelyRefreshServices( 182 bool actively_refresh_services) OVERRIDE { 183 if (implementation()) 184 PostToMdnsThread(base::Bind(&ServiceWatcher::SetActivelyRefreshServices, 185 base::Unretained(implementation()), 186 actively_refresh_services)); 187 } 188 189 virtual std::string GetServiceType() const OVERRIDE { 190 return service_type_; 191 } 192 193 virtual void OnNewMdnsReady() OVERRIDE { 194 if (!implementation()) 195 callback_.Run(ServiceWatcher::UPDATE_INVALIDATED, ""); 196 } 197 198 private: 199 static void OnCallback(const WeakPtr& proxy, 200 const ServiceWatcher::UpdatedCallback& callback, 201 UpdateType a1, 202 const std::string& a2) { 203 DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI)); 204 PostToUIThread(base::Bind(&Base::RunCallback, proxy, 205 base::Bind(callback, a1, a2))); 206 } 207 std::string service_type_; 208 ServiceWatcher::UpdatedCallback callback_; 209 DISALLOW_COPY_AND_ASSIGN(ServiceWatcherProxy); 210 }; 211 212 class ServiceResolverProxy : public ProxyBase<ServiceResolver> { 213 public: 214 ServiceResolverProxy(ServiceDiscoveryClientMdns* client_mdns, 215 const std::string& service_name, 216 const ServiceResolver::ResolveCompleteCallback& callback) 217 : ProxyBase(client_mdns), 218 service_name_(service_name) { 219 // It's safe to call |CreateServiceResolver| on UI thread, because 220 // |MDnsClient| is not used there. It's simplify implementation. 221 set_implementation(client()->CreateServiceResolver( 222 service_name, 223 base::Bind(&ServiceResolverProxy::OnCallback, GetWeakPtr(), callback))); 224 } 225 226 // ServiceResolver methods. 227 virtual void StartResolving() OVERRIDE { 228 if (implementation()) 229 PostToMdnsThread(base::Bind(&ServiceResolver::StartResolving, 230 base::Unretained(implementation()))); 231 }; 232 233 virtual std::string GetName() const OVERRIDE { 234 return service_name_; 235 } 236 237 private: 238 static void OnCallback( 239 const WeakPtr& proxy, 240 const ServiceResolver::ResolveCompleteCallback& callback, 241 RequestStatus a1, 242 const ServiceDescription& a2) { 243 DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI)); 244 PostToUIThread(base::Bind(&Base::RunCallback, proxy, 245 base::Bind(callback, a1, a2))); 246 } 247 248 std::string service_name_; 249 DISALLOW_COPY_AND_ASSIGN(ServiceResolverProxy); 250 }; 251 252 class LocalDomainResolverProxy : public ProxyBase<LocalDomainResolver> { 253 public: 254 LocalDomainResolverProxy( 255 ServiceDiscoveryClientMdns* client_mdns, 256 const std::string& domain, 257 net::AddressFamily address_family, 258 const LocalDomainResolver::IPAddressCallback& callback) 259 : ProxyBase(client_mdns) { 260 // It's safe to call |CreateLocalDomainResolver| on UI thread, because 261 // |MDnsClient| is not used there. It's simplify implementation. 262 set_implementation(client()->CreateLocalDomainResolver( 263 domain, 264 address_family, 265 base::Bind( 266 &LocalDomainResolverProxy::OnCallback, GetWeakPtr(), callback))); 267 } 268 269 // LocalDomainResolver methods. 270 virtual void Start() OVERRIDE { 271 if (implementation()) 272 PostToMdnsThread(base::Bind(&LocalDomainResolver::Start, 273 base::Unretained(implementation()))); 274 }; 275 276 private: 277 static void OnCallback(const WeakPtr& proxy, 278 const LocalDomainResolver::IPAddressCallback& callback, 279 bool a1, 280 const net::IPAddressNumber& a2, 281 const net::IPAddressNumber& a3) { 282 DCHECK(!BrowserThread::CurrentlyOn(BrowserThread::UI)); 283 PostToUIThread(base::Bind(&Base::RunCallback, proxy, 284 base::Bind(callback, a1, a2, a3))); 285 } 286 287 DISALLOW_COPY_AND_ASSIGN(LocalDomainResolverProxy); 288 }; 289 290 } // namespace 291 292 ServiceDiscoveryClientMdns::ServiceDiscoveryClientMdns() 293 : mdns_runner_( 294 BrowserThread::GetMessageLoopProxyForThread(BrowserThread::IO)), 295 restart_attempts_(0), 296 need_dalay_mdns_tasks_(true), 297 weak_ptr_factory_(this) { 298 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 299 net::NetworkChangeNotifier::AddNetworkChangeObserver(this); 300 StartNewClient(); 301 } 302 303 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientMdns::CreateServiceWatcher( 304 const std::string& service_type, 305 const ServiceWatcher::UpdatedCallback& callback) { 306 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 307 return scoped_ptr<ServiceWatcher>( 308 new ServiceWatcherProxy(this, service_type, callback)); 309 } 310 311 scoped_ptr<ServiceResolver> ServiceDiscoveryClientMdns::CreateServiceResolver( 312 const std::string& service_name, 313 const ServiceResolver::ResolveCompleteCallback& callback) { 314 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 315 return scoped_ptr<ServiceResolver>( 316 new ServiceResolverProxy(this, service_name, callback)); 317 } 318 319 scoped_ptr<LocalDomainResolver> 320 ServiceDiscoveryClientMdns::CreateLocalDomainResolver( 321 const std::string& domain, 322 net::AddressFamily address_family, 323 const LocalDomainResolver::IPAddressCallback& callback) { 324 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 325 return scoped_ptr<LocalDomainResolver>( 326 new LocalDomainResolverProxy(this, domain, address_family, callback)); 327 } 328 329 ServiceDiscoveryClientMdns::~ServiceDiscoveryClientMdns() { 330 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 331 net::NetworkChangeNotifier::RemoveNetworkChangeObserver(this); 332 DestroyMdns(); 333 } 334 335 void ServiceDiscoveryClientMdns::OnNetworkChanged( 336 net::NetworkChangeNotifier::ConnectionType type) { 337 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 338 // Only network changes resets counter. 339 restart_attempts_ = 0; 340 ScheduleStartNewClient(); 341 } 342 343 void ServiceDiscoveryClientMdns::ScheduleStartNewClient() { 344 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 345 OnBeforeMdnsDestroy(); 346 if (restart_attempts_ < kMaxRestartAttempts) { 347 base::MessageLoop::current()->PostDelayedTask( 348 FROM_HERE, 349 base::Bind(&ServiceDiscoveryClientMdns::StartNewClient, 350 weak_ptr_factory_.GetWeakPtr()), 351 base::TimeDelta::FromSeconds( 352 kRestartDelayOnNetworkChangeSeconds * (1 << restart_attempts_))); 353 } else { 354 ReportSuccess(); 355 } 356 } 357 358 void ServiceDiscoveryClientMdns::StartNewClient() { 359 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 360 ++restart_attempts_; 361 DestroyMdns(); 362 mdns_.reset(net::MDnsClient::CreateDefault().release()); 363 client_.reset(new ServiceDiscoveryClientImpl(mdns_.get())); 364 BrowserThread::PostTaskAndReplyWithResult( 365 BrowserThread::FILE, 366 FROM_HERE, 367 base::Bind(&net::GetMDnsInterfacesToBind), 368 base::Bind(&ServiceDiscoveryClientMdns::OnInterfaceListReady, 369 weak_ptr_factory_.GetWeakPtr())); 370 } 371 372 void ServiceDiscoveryClientMdns::OnInterfaceListReady( 373 const net::InterfaceIndexFamilyList& interfaces) { 374 mdns_runner_->PostTask( 375 FROM_HERE, 376 base::Bind(&InitMdns, 377 base::Bind(&ServiceDiscoveryClientMdns::OnMdnsInitialized, 378 weak_ptr_factory_.GetWeakPtr()), 379 interfaces, 380 base::Unretained(mdns_.get()))); 381 } 382 383 void ServiceDiscoveryClientMdns::OnMdnsInitialized(bool success) { 384 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 385 if (!success) { 386 ScheduleStartNewClient(); 387 return; 388 } 389 ReportSuccess(); 390 391 // Initialization is done, no need to delay tasks. 392 need_dalay_mdns_tasks_ = false; 393 for (size_t i = 0; i < delayed_tasks_.size(); ++i) 394 mdns_runner_->PostTask(FROM_HERE, delayed_tasks_[i]); 395 delayed_tasks_.clear(); 396 397 FOR_EACH_OBSERVER(Proxy, proxies_, OnNewMdnsReady()); 398 } 399 400 void ServiceDiscoveryClientMdns::ReportSuccess() { 401 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 402 UMA_HISTOGRAM_COUNTS_100("LocalDiscovery.ClientRestartAttempts", 403 restart_attempts_); 404 } 405 406 void ServiceDiscoveryClientMdns::OnBeforeMdnsDestroy() { 407 need_dalay_mdns_tasks_ = true; 408 delayed_tasks_.clear(); 409 weak_ptr_factory_.InvalidateWeakPtrs(); 410 FOR_EACH_OBSERVER(Proxy, proxies_, OnMdnsDestroy()); 411 } 412 413 void ServiceDiscoveryClientMdns::DestroyMdns() { 414 OnBeforeMdnsDestroy(); 415 // After calling |Proxy::OnMdnsDestroy| all references to client_ and mdns_ 416 // should be destroyed. 417 if (client_) 418 mdns_runner_->DeleteSoon(FROM_HERE, client_.release()); 419 if (mdns_) 420 mdns_runner_->DeleteSoon(FROM_HERE, mdns_.release()); 421 } 422 423 bool ServiceDiscoveryClientMdns::PostToMdnsThread(const base::Closure& task) { 424 // The first task on IO thread for each |mdns_| instance must be |InitMdns|. 425 // |OnInterfaceListReady| could be delayed by |GetMDnsInterfacesToBind| 426 // running on FILE thread, so |PostToMdnsThread| could be called to post 427 // task for |mdns_| that is not initialized yet. 428 if (!need_dalay_mdns_tasks_) 429 return mdns_runner_->PostTask(FROM_HERE, task); 430 if (kMaxDelayedTasks > delayed_tasks_.size()) 431 delayed_tasks_.push_back(task); 432 return true; 433 } 434 435 } // namespace local_discovery 436