1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/utility/local_discovery/service_discovery_message_handler.h" 6 7 #include <algorithm> 8 9 #include "base/lazy_instance.h" 10 #include "chrome/common/local_discovery/local_discovery_messages.h" 11 #include "chrome/common/local_discovery/service_discovery_client_impl.h" 12 #include "content/public/utility/utility_thread.h" 13 #include "net/socket/socket_descriptor.h" 14 #include "net/udp/datagram_server_socket.h" 15 16 namespace local_discovery { 17 18 namespace { 19 20 void ClosePlatformSocket(net::SocketDescriptor socket); 21 22 // Sets socket factory used by |net::CreatePlatformSocket|. Implemetation 23 // keeps single socket that will be returned to the first call to 24 // |net::CreatePlatformSocket| during object lifetime. 25 class ScopedSocketFactory : public net::PlatformSocketFactory { 26 public: 27 explicit ScopedSocketFactory(net::SocketDescriptor socket) : socket_(socket) { 28 net::PlatformSocketFactory::SetInstance(this); 29 } 30 31 virtual ~ScopedSocketFactory() { 32 net::PlatformSocketFactory::SetInstance(NULL); 33 ClosePlatformSocket(socket_); 34 socket_ = net::kInvalidSocket; 35 } 36 37 virtual net::SocketDescriptor CreateSocket(int family, int type, 38 int protocol) OVERRIDE { 39 DCHECK_EQ(type, SOCK_DGRAM); 40 DCHECK(family == AF_INET || family == AF_INET6); 41 net::SocketDescriptor result = net::kInvalidSocket; 42 std::swap(result, socket_); 43 return result; 44 } 45 46 private: 47 net::SocketDescriptor socket_; 48 DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory); 49 }; 50 51 struct SocketInfo { 52 SocketInfo(net::SocketDescriptor socket, 53 net::AddressFamily address_family, 54 uint32 interface_index) 55 : socket(socket), 56 address_family(address_family), 57 interface_index(interface_index) { 58 } 59 net::SocketDescriptor socket; 60 net::AddressFamily address_family; 61 uint32 interface_index; 62 }; 63 64 // Returns list of sockets preallocated before. 65 class PreCreatedMDnsSocketFactory : public net::MDnsSocketFactory { 66 public: 67 PreCreatedMDnsSocketFactory() {} 68 virtual ~PreCreatedMDnsSocketFactory() { 69 // Not empty if process exits too fast, before starting mDns code. If 70 // happened, destructors may crash accessing destroyed global objects. 71 sockets_.weak_clear(); 72 } 73 74 // net::MDnsSocketFactory implementation: 75 virtual void CreateSockets( 76 ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE { 77 sockets->swap(sockets_); 78 Reset(); 79 } 80 81 void AddSocket(const SocketInfo& socket_info) { 82 // Takes ownership of socket_info.socket; 83 ScopedSocketFactory platform_factory(socket_info.socket); 84 scoped_ptr<net::DatagramServerSocket> socket( 85 net::CreateAndBindMDnsSocket(socket_info.address_family, 86 socket_info.interface_index)); 87 if (socket) { 88 socket->DetachFromThread(); 89 sockets_.push_back(socket.release()); 90 } 91 } 92 93 void Reset() { 94 sockets_.clear(); 95 } 96 97 private: 98 ScopedVector<net::DatagramServerSocket> sockets_; 99 100 DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory); 101 }; 102 103 base::LazyInstance<PreCreatedMDnsSocketFactory> 104 g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER; 105 106 #if defined(OS_WIN) 107 108 void ClosePlatformSocket(net::SocketDescriptor socket) { 109 ::closesocket(socket); 110 } 111 112 void StaticInitializeSocketFactory() { 113 net::InterfaceIndexFamilyList interfaces(net::GetMDnsInterfacesToBind()); 114 for (size_t i = 0; i < interfaces.size(); ++i) { 115 DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 || 116 interfaces[i].second == net::ADDRESS_FAMILY_IPV6); 117 net::SocketDescriptor descriptor = 118 net::CreatePlatformSocket( 119 net::ConvertAddressFamily(interfaces[i].second), SOCK_DGRAM, 120 IPPROTO_UDP); 121 g_local_discovery_socket_factory.Get().AddSocket( 122 SocketInfo(descriptor, interfaces[i].second, interfaces[i].first)); 123 } 124 } 125 126 #else // OS_WIN 127 128 void ClosePlatformSocket(net::SocketDescriptor socket) { 129 ::close(socket); 130 } 131 132 void StaticInitializeSocketFactory() { 133 } 134 135 #endif // OS_WIN 136 137 void SendHostMessageOnUtilityThread(IPC::Message* msg) { 138 content::UtilityThread::Get()->Send(msg); 139 } 140 141 std::string WatcherUpdateToString(ServiceWatcher::UpdateType update) { 142 switch (update) { 143 case ServiceWatcher::UPDATE_ADDED: 144 return "UPDATE_ADDED"; 145 case ServiceWatcher::UPDATE_CHANGED: 146 return "UPDATE_CHANGED"; 147 case ServiceWatcher::UPDATE_REMOVED: 148 return "UPDATE_REMOVED"; 149 case ServiceWatcher::UPDATE_INVALIDATED: 150 return "UPDATE_INVALIDATED"; 151 } 152 return "Unknown Update"; 153 } 154 155 std::string ResolverStatusToString(ServiceResolver::RequestStatus status) { 156 switch (status) { 157 case ServiceResolver::STATUS_SUCCESS: 158 return "STATUS_SUCESS"; 159 case ServiceResolver::STATUS_REQUEST_TIMEOUT: 160 return "STATUS_REQUEST_TIMEOUT"; 161 case ServiceResolver::STATUS_KNOWN_NONEXISTENT: 162 return "STATUS_KNOWN_NONEXISTENT"; 163 } 164 return "Unknown Status"; 165 } 166 167 } // namespace 168 169 ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() { 170 } 171 172 ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() { 173 DCHECK(!discovery_thread_); 174 } 175 176 void ServiceDiscoveryMessageHandler::PreSandboxStartup() { 177 StaticInitializeSocketFactory(); 178 } 179 180 void ServiceDiscoveryMessageHandler::InitializeMdns() { 181 if (service_discovery_client_ || mdns_client_) 182 return; 183 184 mdns_client_ = net::MDnsClient::CreateDefault(); 185 bool result = 186 mdns_client_->StartListening(g_local_discovery_socket_factory.Pointer()); 187 // Close unused sockets. 188 g_local_discovery_socket_factory.Get().Reset(); 189 if (!result) { 190 VLOG(1) << "Failed to start MDnsClient"; 191 Send(new LocalDiscoveryHostMsg_Error()); 192 return; 193 } 194 195 service_discovery_client_.reset( 196 new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get())); 197 } 198 199 bool ServiceDiscoveryMessageHandler::InitializeThread() { 200 if (discovery_task_runner_.get()) 201 return true; 202 if (discovery_thread_) 203 return false; 204 utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy(); 205 discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread")); 206 base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0); 207 if (discovery_thread_->StartWithOptions(thread_options)) { 208 discovery_task_runner_ = discovery_thread_->message_loop_proxy(); 209 discovery_task_runner_->PostTask(FROM_HERE, 210 base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns, 211 base::Unretained(this))); 212 } 213 return discovery_task_runner_.get() != NULL; 214 } 215 216 bool ServiceDiscoveryMessageHandler::OnMessageReceived( 217 const IPC::Message& message) { 218 bool handled = true; 219 IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message) 220 #if defined(OS_POSIX) 221 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets, OnSetSockets) 222 #endif // OS_POSIX 223 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher) 224 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices) 225 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetActivelyRefreshServices, 226 OnSetActivelyRefreshServices) 227 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher) 228 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService) 229 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver) 230 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain, 231 OnResolveLocalDomain) 232 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver, 233 OnDestroyLocalDomainResolver) 234 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery, 235 ShutdownLocalDiscovery) 236 IPC_MESSAGE_UNHANDLED(handled = false) 237 IPC_END_MESSAGE_MAP() 238 return handled; 239 } 240 241 void ServiceDiscoveryMessageHandler::PostTask( 242 const tracked_objects::Location& from_here, 243 const base::Closure& task) { 244 if (!InitializeThread()) 245 return; 246 discovery_task_runner_->PostTask(from_here, task); 247 } 248 249 #if defined(OS_POSIX) 250 void ServiceDiscoveryMessageHandler::OnSetSockets( 251 const std::vector<LocalDiscoveryMsg_SocketInfo>& sockets) { 252 for (size_t i = 0; i < sockets.size(); ++i) { 253 g_local_discovery_socket_factory.Get().AddSocket( 254 SocketInfo(sockets[i].descriptor.fd, sockets[i].address_family, 255 sockets[i].interface_index)); 256 } 257 } 258 #endif // OS_POSIX 259 260 void ServiceDiscoveryMessageHandler::OnStartWatcher( 261 uint64 id, 262 const std::string& service_type) { 263 PostTask(FROM_HERE, 264 base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher, 265 base::Unretained(this), id, service_type)); 266 } 267 268 void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id, 269 bool force_update) { 270 PostTask(FROM_HERE, 271 base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices, 272 base::Unretained(this), id, force_update)); 273 } 274 275 void ServiceDiscoveryMessageHandler::OnSetActivelyRefreshServices( 276 uint64 id, bool actively_refresh_services) { 277 PostTask(FROM_HERE, 278 base::Bind( 279 &ServiceDiscoveryMessageHandler::SetActivelyRefreshServices, 280 base::Unretained(this), id, actively_refresh_services)); 281 } 282 283 void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) { 284 PostTask(FROM_HERE, 285 base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher, 286 base::Unretained(this), id)); 287 } 288 289 void ServiceDiscoveryMessageHandler::OnResolveService( 290 uint64 id, 291 const std::string& service_name) { 292 PostTask(FROM_HERE, 293 base::Bind(&ServiceDiscoveryMessageHandler::ResolveService, 294 base::Unretained(this), id, service_name)); 295 } 296 297 void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) { 298 PostTask(FROM_HERE, 299 base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver, 300 base::Unretained(this), id)); 301 } 302 303 void ServiceDiscoveryMessageHandler::OnResolveLocalDomain( 304 uint64 id, const std::string& domain, 305 net::AddressFamily address_family) { 306 PostTask(FROM_HERE, 307 base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain, 308 base::Unretained(this), id, domain, address_family)); 309 } 310 311 void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) { 312 PostTask(FROM_HERE, 313 base::Bind( 314 &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver, 315 base::Unretained(this), id)); 316 } 317 318 void ServiceDiscoveryMessageHandler::StartWatcher( 319 uint64 id, 320 const std::string& service_type) { 321 VLOG(1) << "StartWatcher, id=" << id << ", type=" << service_type; 322 if (!service_discovery_client_) 323 return; 324 DCHECK(!ContainsKey(service_watchers_, id)); 325 scoped_ptr<ServiceWatcher> watcher( 326 service_discovery_client_->CreateServiceWatcher( 327 service_type, 328 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated, 329 base::Unretained(this), id))); 330 watcher->Start(); 331 service_watchers_[id].reset(watcher.release()); 332 } 333 334 void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id, 335 bool force_update) { 336 VLOG(1) << "DiscoverServices, id=" << id; 337 if (!service_discovery_client_) 338 return; 339 DCHECK(ContainsKey(service_watchers_, id)); 340 service_watchers_[id]->DiscoverNewServices(force_update); 341 } 342 343 void ServiceDiscoveryMessageHandler::SetActivelyRefreshServices( 344 uint64 id, 345 bool actively_refresh_services) { 346 VLOG(1) << "ActivelyRefreshServices, id=" << id; 347 if (!service_discovery_client_) 348 return; 349 DCHECK(ContainsKey(service_watchers_, id)); 350 service_watchers_[id]->SetActivelyRefreshServices(actively_refresh_services); 351 } 352 353 void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) { 354 VLOG(1) << "DestoryWatcher, id=" << id; 355 if (!service_discovery_client_) 356 return; 357 service_watchers_.erase(id); 358 } 359 360 void ServiceDiscoveryMessageHandler::ResolveService( 361 uint64 id, 362 const std::string& service_name) { 363 VLOG(1) << "ResolveService, id=" << id << ", name=" << service_name; 364 if (!service_discovery_client_) 365 return; 366 DCHECK(!ContainsKey(service_resolvers_, id)); 367 scoped_ptr<ServiceResolver> resolver( 368 service_discovery_client_->CreateServiceResolver( 369 service_name, 370 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved, 371 base::Unretained(this), id))); 372 resolver->StartResolving(); 373 service_resolvers_[id].reset(resolver.release()); 374 } 375 376 void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) { 377 VLOG(1) << "DestroyResolver, id=" << id; 378 if (!service_discovery_client_) 379 return; 380 service_resolvers_.erase(id); 381 } 382 383 void ServiceDiscoveryMessageHandler::ResolveLocalDomain( 384 uint64 id, 385 const std::string& domain, 386 net::AddressFamily address_family) { 387 VLOG(1) << "ResolveLocalDomain, id=" << id << ", domain=" << domain; 388 if (!service_discovery_client_) 389 return; 390 DCHECK(!ContainsKey(local_domain_resolvers_, id)); 391 scoped_ptr<LocalDomainResolver> resolver( 392 service_discovery_client_->CreateLocalDomainResolver( 393 domain, address_family, 394 base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved, 395 base::Unretained(this), id))); 396 resolver->Start(); 397 local_domain_resolvers_[id].reset(resolver.release()); 398 } 399 400 void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) { 401 VLOG(1) << "DestroyLocalDomainResolver, id=" << id; 402 if (!service_discovery_client_) 403 return; 404 local_domain_resolvers_.erase(id); 405 } 406 407 void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() { 408 if (!discovery_task_runner_.get()) 409 return; 410 411 discovery_task_runner_->PostTask( 412 FROM_HERE, 413 base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread, 414 base::Unretained(this))); 415 416 // This will wait for message loop to drain, so ShutdownOnIOThread will 417 // definitely be called. 418 discovery_thread_.reset(); 419 } 420 421 void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() { 422 VLOG(1) << "ShutdownLocalDiscovery"; 423 service_watchers_.clear(); 424 service_resolvers_.clear(); 425 local_domain_resolvers_.clear(); 426 service_discovery_client_.reset(); 427 mdns_client_.reset(); 428 } 429 430 void ServiceDiscoveryMessageHandler::OnServiceUpdated( 431 uint64 id, 432 ServiceWatcher::UpdateType update, 433 const std::string& name) { 434 VLOG(1) << "OnServiceUpdated, id=" << id 435 << ", status=" << WatcherUpdateToString(update) << ", name=" << name; 436 DCHECK(service_discovery_client_); 437 438 Send(new LocalDiscoveryHostMsg_WatcherCallback(id, update, name)); 439 } 440 441 void ServiceDiscoveryMessageHandler::OnServiceResolved( 442 uint64 id, 443 ServiceResolver::RequestStatus status, 444 const ServiceDescription& description) { 445 VLOG(1) << "OnServiceResolved, id=" << id 446 << ", status=" << ResolverStatusToString(status) 447 << ", name=" << description.service_name; 448 449 DCHECK(service_discovery_client_); 450 Send(new LocalDiscoveryHostMsg_ResolverCallback(id, status, description)); 451 } 452 453 void ServiceDiscoveryMessageHandler::OnLocalDomainResolved( 454 uint64 id, 455 bool success, 456 const net::IPAddressNumber& address_ipv4, 457 const net::IPAddressNumber& address_ipv6) { 458 VLOG(1) << "OnLocalDomainResolved, id=" << id 459 << ", IPv4=" << (address_ipv4.empty() ? "" : 460 net::IPAddressToString(address_ipv4)) 461 << ", IPv6=" << (address_ipv6.empty() ? "" : 462 net::IPAddressToString(address_ipv6)); 463 464 DCHECK(service_discovery_client_); 465 Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback( 466 id, success, address_ipv4, address_ipv6)); 467 } 468 469 void ServiceDiscoveryMessageHandler::Send(IPC::Message* msg) { 470 utility_task_runner_->PostTask(FROM_HERE, 471 base::Bind(&SendHostMessageOnUtilityThread, 472 msg)); 473 } 474 475 } // namespace local_discovery 476