Home | History | Annotate | Download | only in local_discovery
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/utility/local_discovery/service_discovery_message_handler.h"
      6 
      7 #include <algorithm>
      8 
      9 #include "base/lazy_instance.h"
     10 #include "chrome/common/local_discovery/local_discovery_messages.h"
     11 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
     12 #include "content/public/utility/utility_thread.h"
     13 #include "net/socket/socket_descriptor.h"
     14 #include "net/udp/datagram_server_socket.h"
     15 
     16 namespace local_discovery {
     17 
     18 namespace {
     19 
     20 void ClosePlatformSocket(net::SocketDescriptor socket);
     21 
     22 // Sets socket factory used by |net::CreatePlatformSocket|. Implemetation
     23 // keeps single socket that will be returned to the first call to
     24 // |net::CreatePlatformSocket| during object lifetime.
     25 class ScopedSocketFactory : public net::PlatformSocketFactory {
     26  public:
     27   explicit ScopedSocketFactory(net::SocketDescriptor socket) : socket_(socket) {
     28     net::PlatformSocketFactory::SetInstance(this);
     29   }
     30 
     31   virtual ~ScopedSocketFactory() {
     32     net::PlatformSocketFactory::SetInstance(NULL);
     33     ClosePlatformSocket(socket_);
     34     socket_ = net::kInvalidSocket;
     35   }
     36 
     37   virtual net::SocketDescriptor CreateSocket(int family, int type,
     38                                              int protocol) OVERRIDE {
     39     DCHECK_EQ(type, SOCK_DGRAM);
     40     DCHECK(family == AF_INET || family == AF_INET6);
     41     net::SocketDescriptor result = net::kInvalidSocket;
     42     std::swap(result, socket_);
     43     return result;
     44   }
     45 
     46  private:
     47   net::SocketDescriptor socket_;
     48   DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactory);
     49 };
     50 
     51 struct SocketInfo {
     52   SocketInfo(net::SocketDescriptor socket,
     53              net::AddressFamily address_family,
     54              uint32 interface_index)
     55       : socket(socket),
     56         address_family(address_family),
     57         interface_index(interface_index) {
     58   }
     59   net::SocketDescriptor socket;
     60   net::AddressFamily address_family;
     61   uint32 interface_index;
     62 };
     63 
     64 // Returns list of sockets preallocated before.
     65 class PreCreatedMDnsSocketFactory : public net::MDnsSocketFactory {
     66  public:
     67   PreCreatedMDnsSocketFactory() {}
     68   virtual ~PreCreatedMDnsSocketFactory() {
     69     // Not empty if process exits too fast, before starting mDns code. If
     70     // happened, destructors may crash accessing destroyed global objects.
     71     sockets_.weak_clear();
     72   }
     73 
     74   // net::MDnsSocketFactory implementation:
     75   virtual void CreateSockets(
     76       ScopedVector<net::DatagramServerSocket>* sockets) OVERRIDE {
     77     sockets->swap(sockets_);
     78     Reset();
     79   }
     80 
     81   void AddSocket(const SocketInfo& socket_info) {
     82     // Takes ownership of socket_info.socket;
     83     ScopedSocketFactory platform_factory(socket_info.socket);
     84     scoped_ptr<net::DatagramServerSocket> socket(
     85         net::CreateAndBindMDnsSocket(socket_info.address_family,
     86                                      socket_info.interface_index));
     87     if (socket) {
     88       socket->DetachFromThread();
     89       sockets_.push_back(socket.release());
     90     }
     91   }
     92 
     93   void Reset() {
     94     sockets_.clear();
     95   }
     96 
     97  private:
     98   ScopedVector<net::DatagramServerSocket> sockets_;
     99 
    100   DISALLOW_COPY_AND_ASSIGN(PreCreatedMDnsSocketFactory);
    101 };
    102 
    103 base::LazyInstance<PreCreatedMDnsSocketFactory>
    104     g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER;
    105 
    106 #if defined(OS_WIN)
    107 
    108 void ClosePlatformSocket(net::SocketDescriptor socket) {
    109   ::closesocket(socket);
    110 }
    111 
    112 void StaticInitializeSocketFactory() {
    113   net::InterfaceIndexFamilyList interfaces(net::GetMDnsInterfacesToBind());
    114   for (size_t i = 0; i < interfaces.size(); ++i) {
    115     DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 ||
    116            interfaces[i].second == net::ADDRESS_FAMILY_IPV6);
    117     net::SocketDescriptor descriptor =
    118         net::CreatePlatformSocket(
    119             net::ConvertAddressFamily(interfaces[i].second), SOCK_DGRAM,
    120                                       IPPROTO_UDP);
    121     g_local_discovery_socket_factory.Get().AddSocket(
    122         SocketInfo(descriptor, interfaces[i].second, interfaces[i].first));
    123   }
    124 }
    125 
    126 #else  // OS_WIN
    127 
    128 void ClosePlatformSocket(net::SocketDescriptor socket) {
    129   ::close(socket);
    130 }
    131 
    132 void StaticInitializeSocketFactory() {
    133 }
    134 
    135 #endif  // OS_WIN
    136 
    137 void SendHostMessageOnUtilityThread(IPC::Message* msg) {
    138   content::UtilityThread::Get()->Send(msg);
    139 }
    140 
    141 std::string WatcherUpdateToString(ServiceWatcher::UpdateType update) {
    142   switch (update) {
    143     case ServiceWatcher::UPDATE_ADDED:
    144       return "UPDATE_ADDED";
    145     case ServiceWatcher::UPDATE_CHANGED:
    146       return "UPDATE_CHANGED";
    147     case ServiceWatcher::UPDATE_REMOVED:
    148       return "UPDATE_REMOVED";
    149     case ServiceWatcher::UPDATE_INVALIDATED:
    150       return "UPDATE_INVALIDATED";
    151   }
    152   return "Unknown Update";
    153 }
    154 
    155 std::string ResolverStatusToString(ServiceResolver::RequestStatus status) {
    156   switch (status) {
    157     case ServiceResolver::STATUS_SUCCESS:
    158       return "STATUS_SUCESS";
    159     case ServiceResolver::STATUS_REQUEST_TIMEOUT:
    160       return "STATUS_REQUEST_TIMEOUT";
    161     case ServiceResolver::STATUS_KNOWN_NONEXISTENT:
    162       return "STATUS_KNOWN_NONEXISTENT";
    163   }
    164   return "Unknown Status";
    165 }
    166 
    167 }  // namespace
    168 
    169 ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() {
    170 }
    171 
    172 ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() {
    173   DCHECK(!discovery_thread_);
    174 }
    175 
    176 void ServiceDiscoveryMessageHandler::PreSandboxStartup() {
    177   StaticInitializeSocketFactory();
    178 }
    179 
    180 void ServiceDiscoveryMessageHandler::InitializeMdns() {
    181   if (service_discovery_client_ || mdns_client_)
    182     return;
    183 
    184   mdns_client_ = net::MDnsClient::CreateDefault();
    185   bool result =
    186       mdns_client_->StartListening(g_local_discovery_socket_factory.Pointer());
    187   // Close unused sockets.
    188   g_local_discovery_socket_factory.Get().Reset();
    189   if (!result) {
    190     VLOG(1) << "Failed to start MDnsClient";
    191     Send(new LocalDiscoveryHostMsg_Error());
    192     return;
    193   }
    194 
    195   service_discovery_client_.reset(
    196       new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get()));
    197 }
    198 
    199 bool ServiceDiscoveryMessageHandler::InitializeThread() {
    200   if (discovery_task_runner_.get())
    201     return true;
    202   if (discovery_thread_)
    203     return false;
    204   utility_task_runner_ = base::MessageLoop::current()->message_loop_proxy();
    205   discovery_thread_.reset(new base::Thread("ServiceDiscoveryThread"));
    206   base::Thread::Options thread_options(base::MessageLoop::TYPE_IO, 0);
    207   if (discovery_thread_->StartWithOptions(thread_options)) {
    208     discovery_task_runner_ = discovery_thread_->message_loop_proxy();
    209     discovery_task_runner_->PostTask(FROM_HERE,
    210         base::Bind(&ServiceDiscoveryMessageHandler::InitializeMdns,
    211                     base::Unretained(this)));
    212   }
    213   return discovery_task_runner_.get() != NULL;
    214 }
    215 
    216 bool ServiceDiscoveryMessageHandler::OnMessageReceived(
    217     const IPC::Message& message) {
    218   bool handled = true;
    219   IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message)
    220 #if defined(OS_POSIX)
    221     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetSockets, OnSetSockets)
    222 #endif  // OS_POSIX
    223     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher)
    224     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices)
    225     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_SetActivelyRefreshServices,
    226                         OnSetActivelyRefreshServices)
    227     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher)
    228     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService)
    229     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver)
    230     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveLocalDomain,
    231                         OnResolveLocalDomain)
    232     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyLocalDomainResolver,
    233                         OnDestroyLocalDomainResolver)
    234     IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ShutdownLocalDiscovery,
    235                         ShutdownLocalDiscovery)
    236     IPC_MESSAGE_UNHANDLED(handled = false)
    237   IPC_END_MESSAGE_MAP()
    238   return handled;
    239 }
    240 
    241 void ServiceDiscoveryMessageHandler::PostTask(
    242     const tracked_objects::Location& from_here,
    243     const base::Closure& task) {
    244   if (!InitializeThread())
    245     return;
    246   discovery_task_runner_->PostTask(from_here, task);
    247 }
    248 
    249 #if defined(OS_POSIX)
    250 void ServiceDiscoveryMessageHandler::OnSetSockets(
    251     const std::vector<LocalDiscoveryMsg_SocketInfo>& sockets) {
    252   for (size_t i = 0; i < sockets.size(); ++i) {
    253     g_local_discovery_socket_factory.Get().AddSocket(
    254         SocketInfo(sockets[i].descriptor.fd, sockets[i].address_family,
    255                    sockets[i].interface_index));
    256   }
    257 }
    258 #endif  // OS_POSIX
    259 
    260 void ServiceDiscoveryMessageHandler::OnStartWatcher(
    261     uint64 id,
    262     const std::string& service_type) {
    263   PostTask(FROM_HERE,
    264            base::Bind(&ServiceDiscoveryMessageHandler::StartWatcher,
    265                       base::Unretained(this), id, service_type));
    266 }
    267 
    268 void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id,
    269                                                         bool force_update) {
    270   PostTask(FROM_HERE,
    271            base::Bind(&ServiceDiscoveryMessageHandler::DiscoverServices,
    272                       base::Unretained(this), id, force_update));
    273 }
    274 
    275 void ServiceDiscoveryMessageHandler::OnSetActivelyRefreshServices(
    276     uint64 id, bool actively_refresh_services) {
    277   PostTask(FROM_HERE,
    278            base::Bind(
    279                &ServiceDiscoveryMessageHandler::SetActivelyRefreshServices,
    280                base::Unretained(this), id, actively_refresh_services));
    281 }
    282 
    283 void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) {
    284   PostTask(FROM_HERE,
    285            base::Bind(&ServiceDiscoveryMessageHandler::DestroyWatcher,
    286                       base::Unretained(this), id));
    287 }
    288 
    289 void ServiceDiscoveryMessageHandler::OnResolveService(
    290     uint64 id,
    291     const std::string& service_name) {
    292   PostTask(FROM_HERE,
    293            base::Bind(&ServiceDiscoveryMessageHandler::ResolveService,
    294                       base::Unretained(this), id, service_name));
    295 }
    296 
    297 void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) {
    298   PostTask(FROM_HERE,
    299            base::Bind(&ServiceDiscoveryMessageHandler::DestroyResolver,
    300                       base::Unretained(this), id));
    301 }
    302 
    303 void ServiceDiscoveryMessageHandler::OnResolveLocalDomain(
    304     uint64 id, const std::string& domain,
    305     net::AddressFamily address_family) {
    306     PostTask(FROM_HERE,
    307            base::Bind(&ServiceDiscoveryMessageHandler::ResolveLocalDomain,
    308                       base::Unretained(this), id, domain, address_family));
    309 }
    310 
    311 void ServiceDiscoveryMessageHandler::OnDestroyLocalDomainResolver(uint64 id) {
    312   PostTask(FROM_HERE,
    313            base::Bind(
    314                &ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver,
    315                base::Unretained(this), id));
    316 }
    317 
    318 void ServiceDiscoveryMessageHandler::StartWatcher(
    319     uint64 id,
    320     const std::string& service_type) {
    321   VLOG(1) << "StartWatcher, id=" << id << ", type=" << service_type;
    322   if (!service_discovery_client_)
    323     return;
    324   DCHECK(!ContainsKey(service_watchers_, id));
    325   scoped_ptr<ServiceWatcher> watcher(
    326       service_discovery_client_->CreateServiceWatcher(
    327           service_type,
    328           base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated,
    329                      base::Unretained(this), id)));
    330   watcher->Start();
    331   service_watchers_[id].reset(watcher.release());
    332 }
    333 
    334 void ServiceDiscoveryMessageHandler::DiscoverServices(uint64 id,
    335                                                       bool force_update) {
    336   VLOG(1) << "DiscoverServices, id=" << id;
    337   if (!service_discovery_client_)
    338     return;
    339   DCHECK(ContainsKey(service_watchers_, id));
    340   service_watchers_[id]->DiscoverNewServices(force_update);
    341 }
    342 
    343 void ServiceDiscoveryMessageHandler::SetActivelyRefreshServices(
    344     uint64 id,
    345     bool actively_refresh_services) {
    346   VLOG(1) << "ActivelyRefreshServices, id=" << id;
    347   if (!service_discovery_client_)
    348     return;
    349   DCHECK(ContainsKey(service_watchers_, id));
    350   service_watchers_[id]->SetActivelyRefreshServices(actively_refresh_services);
    351 }
    352 
    353 void ServiceDiscoveryMessageHandler::DestroyWatcher(uint64 id) {
    354   VLOG(1) << "DestoryWatcher, id=" << id;
    355   if (!service_discovery_client_)
    356     return;
    357   service_watchers_.erase(id);
    358 }
    359 
    360 void ServiceDiscoveryMessageHandler::ResolveService(
    361     uint64 id,
    362     const std::string& service_name) {
    363   VLOG(1) << "ResolveService, id=" << id << ", name=" << service_name;
    364   if (!service_discovery_client_)
    365     return;
    366   DCHECK(!ContainsKey(service_resolvers_, id));
    367   scoped_ptr<ServiceResolver> resolver(
    368       service_discovery_client_->CreateServiceResolver(
    369           service_name,
    370           base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved,
    371                      base::Unretained(this), id)));
    372   resolver->StartResolving();
    373   service_resolvers_[id].reset(resolver.release());
    374 }
    375 
    376 void ServiceDiscoveryMessageHandler::DestroyResolver(uint64 id) {
    377   VLOG(1) << "DestroyResolver, id=" << id;
    378   if (!service_discovery_client_)
    379     return;
    380   service_resolvers_.erase(id);
    381 }
    382 
    383 void ServiceDiscoveryMessageHandler::ResolveLocalDomain(
    384     uint64 id,
    385     const std::string& domain,
    386     net::AddressFamily address_family) {
    387   VLOG(1) << "ResolveLocalDomain, id=" << id << ", domain=" << domain;
    388   if (!service_discovery_client_)
    389     return;
    390   DCHECK(!ContainsKey(local_domain_resolvers_, id));
    391   scoped_ptr<LocalDomainResolver> resolver(
    392       service_discovery_client_->CreateLocalDomainResolver(
    393           domain, address_family,
    394           base::Bind(&ServiceDiscoveryMessageHandler::OnLocalDomainResolved,
    395                      base::Unretained(this), id)));
    396   resolver->Start();
    397   local_domain_resolvers_[id].reset(resolver.release());
    398 }
    399 
    400 void ServiceDiscoveryMessageHandler::DestroyLocalDomainResolver(uint64 id) {
    401   VLOG(1) << "DestroyLocalDomainResolver, id=" << id;
    402   if (!service_discovery_client_)
    403     return;
    404   local_domain_resolvers_.erase(id);
    405 }
    406 
    407 void ServiceDiscoveryMessageHandler::ShutdownLocalDiscovery() {
    408   if (!discovery_task_runner_.get())
    409     return;
    410 
    411   discovery_task_runner_->PostTask(
    412       FROM_HERE,
    413       base::Bind(&ServiceDiscoveryMessageHandler::ShutdownOnIOThread,
    414                  base::Unretained(this)));
    415 
    416   // This will wait for message loop to drain, so ShutdownOnIOThread will
    417   // definitely be called.
    418   discovery_thread_.reset();
    419 }
    420 
    421 void ServiceDiscoveryMessageHandler::ShutdownOnIOThread() {
    422   VLOG(1) << "ShutdownLocalDiscovery";
    423   service_watchers_.clear();
    424   service_resolvers_.clear();
    425   local_domain_resolvers_.clear();
    426   service_discovery_client_.reset();
    427   mdns_client_.reset();
    428 }
    429 
    430 void ServiceDiscoveryMessageHandler::OnServiceUpdated(
    431     uint64 id,
    432     ServiceWatcher::UpdateType update,
    433     const std::string& name) {
    434   VLOG(1) << "OnServiceUpdated, id=" << id
    435           << ", status=" << WatcherUpdateToString(update) << ", name=" << name;
    436   DCHECK(service_discovery_client_);
    437 
    438   Send(new LocalDiscoveryHostMsg_WatcherCallback(id, update, name));
    439 }
    440 
    441 void ServiceDiscoveryMessageHandler::OnServiceResolved(
    442     uint64 id,
    443     ServiceResolver::RequestStatus status,
    444     const ServiceDescription& description) {
    445   VLOG(1) << "OnServiceResolved, id=" << id
    446           << ", status=" << ResolverStatusToString(status)
    447           << ", name=" << description.service_name;
    448 
    449   DCHECK(service_discovery_client_);
    450   Send(new LocalDiscoveryHostMsg_ResolverCallback(id, status, description));
    451 }
    452 
    453 void ServiceDiscoveryMessageHandler::OnLocalDomainResolved(
    454     uint64 id,
    455     bool success,
    456     const net::IPAddressNumber& address_ipv4,
    457     const net::IPAddressNumber& address_ipv6) {
    458   VLOG(1) << "OnLocalDomainResolved, id=" << id
    459           << ", IPv4=" << (address_ipv4.empty() ? "" :
    460                            net::IPAddressToString(address_ipv4))
    461           << ", IPv6=" << (address_ipv6.empty() ? "" :
    462                            net::IPAddressToString(address_ipv6));
    463 
    464   DCHECK(service_discovery_client_);
    465   Send(new LocalDiscoveryHostMsg_LocalDomainResolverCallback(
    466           id, success, address_ipv4, address_ipv6));
    467 }
    468 
    469 void ServiceDiscoveryMessageHandler::Send(IPC::Message* msg) {
    470   utility_task_runner_->PostTask(FROM_HERE,
    471                                  base::Bind(&SendHostMessageOnUtilityThread,
    472                                             msg));
    473 }
    474 
    475 }  // namespace local_discovery
    476