Home | History | Annotate | Download | only in prototype
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "cloud_print/gcp20/prototype/dns_sd_server.h"
      6 
      7 #include <string.h>
      8 
      9 #include "base/basictypes.h"
     10 #include "base/bind.h"
     11 #include "base/command_line.h"
     12 #include "base/message_loop/message_loop.h"
     13 #include "base/strings/stringprintf.h"
     14 #include "cloud_print/gcp20/prototype/dns_packet_parser.h"
     15 #include "cloud_print/gcp20/prototype/dns_response_builder.h"
     16 #include "cloud_print/gcp20/prototype/gcp20_switches.h"
     17 #include "net/base/dns_util.h"
     18 #include "net/base/net_errors.h"
     19 #include "net/base/net_util.h"
     20 #include "net/dns/dns_protocol.h"
     21 
     22 namespace {
     23 
     24 const char kDefaultIpAddressMulticast[] = "224.0.0.251";
     25 const uint16 kDefaultPortMulticast = 5353;
     26 
     27 const double kTimeToNextAnnouncement = 0.8;  // relatively to TTL
     28 const int kDnsBufSize = 65537;
     29 
     30 const uint16 kSrvPriority = 0;
     31 const uint16 kSrvWeight = 0;
     32 
     33 void DoNothingAfterSendToSocket(int /*val*/) {
     34   NOTREACHED();
     35   // TODO(maksymb): Delete this function once empty callback for SendTo() method
     36   // will be allowed.
     37 }
     38 
     39 }  // namespace
     40 
     41 DnsSdServer::DnsSdServer()
     42     : recv_buf_(new net::IOBufferWithSize(kDnsBufSize)),
     43       full_ttl_(0) {
     44 }
     45 
     46 DnsSdServer::~DnsSdServer() {
     47   Shutdown();
     48 }
     49 
     50 bool DnsSdServer::Start(const ServiceParameters& serv_params, uint32 full_ttl,
     51                         const std::vector<std::string>& metadata) {
     52   if (IsOnline())
     53     return true;
     54 
     55   if (!CreateSocket())
     56     return false;
     57 
     58   // Initializing server with parameters from arguments.
     59   serv_params_ = serv_params;
     60   full_ttl_ = full_ttl;
     61   metadata_ = metadata;
     62 
     63   VLOG(0) << "DNS server started";
     64   LOG(WARNING) << "DNS server does not support probing";
     65 
     66   SendAnnouncement(full_ttl_);
     67   base::MessageLoop::current()->PostTask(
     68       FROM_HERE,
     69       base::Bind(&DnsSdServer::OnDatagramReceived, AsWeakPtr()));
     70 
     71   return true;
     72 }
     73 
     74 void DnsSdServer::Update() {
     75   if (!IsOnline())
     76     return;
     77 
     78   SendAnnouncement(full_ttl_);
     79 }
     80 
     81 void DnsSdServer::Shutdown() {
     82   if (!IsOnline())
     83     return;
     84 
     85   SendAnnouncement(0);  // TTL is 0
     86   socket_->Close();
     87   socket_.reset(NULL);
     88   VLOG(0) << "DNS server stopped";
     89 }
     90 
     91 void DnsSdServer::UpdateMetadata(const std::vector<std::string>& metadata) {
     92   if (!IsOnline())
     93     return;
     94 
     95   metadata_ = metadata;
     96 
     97   // TODO(maksymb): If less than 20% of full TTL left before next announcement
     98   // then send it now.
     99 
    100   uint32 current_ttl = GetCurrentTLL();
    101   if (!CommandLine::ForCurrentProcess()->HasSwitch(switches::kNoAnnouncement)) {
    102     DnsResponseBuilder builder(current_ttl);
    103 
    104     builder.AppendTxt(serv_params_.service_name_, current_ttl, metadata_, true);
    105     scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
    106 
    107     DCHECK(buffer.get() != NULL);
    108 
    109     socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_,
    110                     base::Bind(&DoNothingAfterSendToSocket));
    111   }
    112 }
    113 
    114 bool DnsSdServer::CreateSocket() {
    115   net::IPAddressNumber local_ip_any;
    116   bool success = net::ParseIPLiteralToNumber("0.0.0.0", &local_ip_any);
    117   DCHECK(success);
    118 
    119   net::IPAddressNumber multicast_dns_ip_address;
    120   success = net::ParseIPLiteralToNumber(kDefaultIpAddressMulticast,
    121                                         &multicast_dns_ip_address);
    122   DCHECK(success);
    123 
    124 
    125   socket_.reset(new net::UDPSocket(net::DatagramSocket::DEFAULT_BIND,
    126                                    net::RandIntCallback(), NULL,
    127                                    net::NetLog::Source()));
    128 
    129   net::IPEndPoint local_address = net::IPEndPoint(local_ip_any,
    130                                                   kDefaultPortMulticast);
    131   multicast_address_ = net::IPEndPoint(multicast_dns_ip_address,
    132                                        kDefaultPortMulticast);
    133 
    134   socket_->AllowAddressReuse();
    135 
    136   int status = socket_->Bind(local_address);
    137   if (status < 0)
    138     return false;
    139 
    140   socket_->SetMulticastLoopbackMode(false);
    141   status = socket_->JoinGroup(multicast_dns_ip_address);
    142 
    143   if (status < 0)
    144     return false;
    145 
    146   DCHECK(socket_->is_connected());
    147 
    148   return true;
    149 }
    150 
    151 void DnsSdServer::ProcessMessage(int len, net::IOBufferWithSize* buf) {
    152   VLOG(1) << "Received new message with length: " << len;
    153 
    154   // Parse the message.
    155   DnsPacketParser parser(buf->data(), len);
    156 
    157   if (!parser.IsValid())  // Was unable to parse header.
    158     return;
    159 
    160   // TODO(maksymb): Handle truncated messages.
    161   if (parser.header().flags & net::dns_protocol::kFlagResponse)  // Not a query.
    162     return;
    163 
    164   DnsResponseBuilder builder(parser.header().id);
    165 
    166   uint32 current_ttl = GetCurrentTLL();
    167 
    168   DnsQueryRecord query;
    169   // TODO(maksymb): Check known answers.
    170   for (int query_idx = 0; query_idx < parser.header().qdcount; ++query_idx) {
    171     bool success = parser.ReadRecord(&query);
    172     if (success) {
    173       ProccessQuery(current_ttl, query, &builder);
    174     } else {  // if (success)
    175       VLOG(0) << "Broken package";
    176       break;
    177     }
    178   }
    179 
    180   scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
    181   if (buffer.get() == NULL)
    182     return;  // No answers.
    183 
    184   VLOG(1) << "Current TTL for respond: " << current_ttl;
    185 
    186   bool unicast_respond =
    187       CommandLine::ForCurrentProcess()->HasSwitch(switches::kUnicastRespond);
    188   socket_->SendTo(buffer.get(), buffer.get()->size(),
    189                   unicast_respond ? recv_address_ : multicast_address_,
    190                   base::Bind(&DoNothingAfterSendToSocket));
    191   VLOG(1) << "Responded to "
    192       << (unicast_respond ? recv_address_ : multicast_address_).ToString();
    193 }
    194 
    195 void DnsSdServer::ProccessQuery(uint32 current_ttl, const DnsQueryRecord& query,
    196                                 DnsResponseBuilder* builder) const {
    197   std::string log;
    198   bool responded = false;
    199   switch (query.qtype) {
    200     // TODO(maksymb): Add IPv6 support.
    201     case net::dns_protocol::kTypePTR:
    202       log = "Processing PTR query";
    203       if (query.qname == serv_params_.service_type_ ||
    204           query.qname == serv_params_.secondary_service_type_) {
    205         builder->AppendPtr(query.qname, current_ttl,
    206                            serv_params_.service_name_, true);
    207 
    208         if (CommandLine::ForCurrentProcess()->HasSwitch(
    209                 switches::kExtendedResponce)) {
    210           builder->AppendSrv(serv_params_.service_name_, current_ttl,
    211                              kSrvPriority, kSrvWeight, serv_params_.http_port_,
    212                              serv_params_.service_domain_name_, false);
    213           builder->AppendA(serv_params_.service_domain_name_, current_ttl,
    214                            serv_params_.http_ipv4_, false);
    215           builder->AppendAAAA(serv_params_.service_domain_name_, current_ttl,
    216                               serv_params_.http_ipv6_, false);
    217           builder->AppendTxt(serv_params_.service_name_, current_ttl, metadata_,
    218                              false);
    219         }
    220 
    221         responded = true;
    222       }
    223 
    224       break;
    225     case net::dns_protocol::kTypeSRV:
    226       log = "Processing SRV query";
    227       if (query.qname == serv_params_.service_name_) {
    228         builder->AppendSrv(serv_params_.service_name_, current_ttl,
    229                            kSrvPriority, kSrvWeight, serv_params_.http_port_,
    230                            serv_params_.service_domain_name_, true);
    231         responded = true;
    232       }
    233       break;
    234     case net::dns_protocol::kTypeA:
    235       log = "Processing A query";
    236       if (query.qname == serv_params_.service_domain_name_) {
    237         builder->AppendA(serv_params_.service_domain_name_, current_ttl,
    238                          serv_params_.http_ipv4_, true);
    239         responded = true;
    240       }
    241       break;
    242     case net::dns_protocol::kTypeAAAA:
    243       log = "Processing AAAA query";
    244       if (query.qname == serv_params_.service_domain_name_) {
    245         builder->AppendAAAA(serv_params_.service_domain_name_, current_ttl,
    246                             serv_params_.http_ipv6_, true);
    247         responded = true;
    248       }
    249       break;
    250     case net::dns_protocol::kTypeTXT:
    251       log = "Processing TXT query";
    252       if (query.qname == serv_params_.service_name_) {
    253         builder->AppendTxt(serv_params_.service_name_, current_ttl, metadata_,
    254                            true);
    255         responded = true;
    256       }
    257       break;
    258     default:
    259       base::SStringPrintf(&log, "Unknown query type (%d)", query.qtype);
    260   }
    261   log += responded ? ": responded" : ": ignored";
    262   VLOG(1) << log;
    263 }
    264 
    265 void DnsSdServer::DoLoop(int rv) {
    266   // TODO(maksymb): Check what happened if buffer will be overflowed
    267   do {
    268     if (rv > 0)
    269       ProcessMessage(rv, recv_buf_.get());
    270     rv = socket_->RecvFrom(recv_buf_.get(), recv_buf_->size(), &recv_address_,
    271                            base::Bind(&DnsSdServer::DoLoop, AsWeakPtr()));
    272   } while (rv > 0);
    273 
    274   // TODO(maksymb): Add handler for errors
    275   DCHECK(rv == net::ERR_IO_PENDING);
    276 }
    277 
    278 void DnsSdServer::OnDatagramReceived() {
    279   DoLoop(0);
    280 }
    281 
    282 void DnsSdServer::SendAnnouncement(uint32 ttl) {
    283   if (!CommandLine::ForCurrentProcess()->HasSwitch(switches::kNoAnnouncement)) {
    284     DnsResponseBuilder builder(ttl);
    285 
    286     builder.AppendPtr(serv_params_.service_type_, ttl,
    287                      serv_params_.service_name_, true);
    288     builder.AppendPtr(serv_params_.secondary_service_type_, ttl,
    289                       serv_params_.service_name_, true);
    290     builder.AppendSrv(serv_params_.service_name_, ttl, kSrvPriority,
    291                       kSrvWeight, serv_params_.http_port_,
    292                       serv_params_.service_domain_name_, true);
    293     builder.AppendA(serv_params_.service_domain_name_, ttl,
    294                     serv_params_.http_ipv4_, true);
    295     builder.AppendAAAA(serv_params_.service_domain_name_, ttl,
    296                        serv_params_.http_ipv6_, true);
    297     builder.AppendTxt(serv_params_.service_name_, ttl, metadata_, true);
    298 
    299     scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
    300 
    301     DCHECK(buffer.get() != NULL);
    302 
    303     socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_,
    304                     base::Bind(&DoNothingAfterSendToSocket));
    305 
    306     VLOG(1) << "Announcement was sent with TTL: " << ttl;
    307   }
    308 
    309   time_until_live_ = base::Time::Now() +
    310       base::TimeDelta::FromSeconds(full_ttl_);
    311 
    312   // Schedule next announcement.
    313   base::MessageLoop::current()->PostDelayedTask(
    314       FROM_HERE,
    315       base::Bind(&DnsSdServer::Update, AsWeakPtr()),
    316       base::TimeDelta::FromSeconds(static_cast<int64>(
    317           kTimeToNextAnnouncement*full_ttl_)));
    318 }
    319 
    320 uint32 DnsSdServer::GetCurrentTLL() const {
    321   uint32 current_ttl = (time_until_live_ - base::Time::Now()).InSeconds();
    322   if (time_until_live_ < base::Time::Now() || current_ttl == 0) {
    323     // This should not be reachable. But still we don't need to fail.
    324     current_ttl = 1;  // Service is still alive.
    325     LOG(ERROR) << "|current_ttl| was equal to zero.";
    326   }
    327   return current_ttl;
    328 }
    329