Home | History | Annotate | Download | only in prototype
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "cloud_print/gcp20/prototype/dns_sd_server.h"
      6 
      7 #include <string.h>
      8 
      9 #include "base/basictypes.h"
     10 #include "base/bind.h"
     11 #include "base/command_line.h"
     12 #include "base/message_loop/message_loop.h"
     13 #include "base/strings/stringprintf.h"
     14 #include "cloud_print/gcp20/prototype/dns_packet_parser.h"
     15 #include "cloud_print/gcp20/prototype/dns_response_builder.h"
     16 #include "net/base/big_endian.h"
     17 #include "net/base/dns_util.h"
     18 #include "net/base/net_errors.h"
     19 #include "net/base/net_util.h"
     20 #include "net/dns/dns_protocol.h"
     21 
     22 namespace {
     23 
     24 const char kDefaultIpAddressMulticast[] = "224.0.0.251";
     25 const uint16 kDefaultPortMulticast = 5353;
     26 
     27 const double kTimeToNextAnnouncement = 0.8;  // relatively to TTL
     28 const int kDnsBufSize = 65537;
     29 
     30 const uint16 kSrvPriority = 0;
     31 const uint16 kSrvWeight = 0;
     32 
     33 void DoNothingAfterSendToSocket(int /*val*/) {
     34   NOTREACHED();
     35   // TODO(maksymb): Delete this function once empty callback for SendTo() method
     36   // will be allowed.
     37 }
     38 
     39 }  // namespace
     40 
     41 DnsSdServer::DnsSdServer()
     42     : recv_buf_(new net::IOBufferWithSize(kDnsBufSize)),
     43       full_ttl_(0) {
     44 }
     45 
     46 DnsSdServer::~DnsSdServer() {
     47   Shutdown();
     48 }
     49 
     50 bool DnsSdServer::Start(const ServiceParameters& serv_params, uint32 full_ttl,
     51                         const std::vector<std::string>& metadata) {
     52   if (IsOnline())
     53     return true;
     54 
     55   if (!CreateSocket())
     56     return false;
     57 
     58   // Initializing server with parameters from arguments.
     59   serv_params_ = serv_params;
     60   full_ttl_ = full_ttl;
     61   metadata_ = metadata;
     62 
     63   LOG(INFO) << "DNS server started";
     64   LOG(WARNING) << "DNS server does not support probing";
     65 
     66   SendAnnouncement(full_ttl_);
     67   base::MessageLoop::current()->PostTask(
     68       FROM_HERE,
     69       base::Bind(&DnsSdServer::OnDatagramReceived, AsWeakPtr()));
     70 
     71   return true;
     72 }
     73 
     74 void DnsSdServer::Update() {
     75   if (!IsOnline())
     76     return;
     77 
     78   SendAnnouncement(full_ttl_);
     79 }
     80 
     81 void DnsSdServer::Shutdown() {
     82   if (!IsOnline())
     83     return;
     84 
     85   SendAnnouncement(0);  // TTL is 0
     86   socket_->Close();
     87   socket_.reset(NULL);
     88   LOG(INFO) << "DNS server stopped";
     89 }
     90 
     91 void DnsSdServer::UpdateMetadata(const std::vector<std::string>& metadata) {
     92   if (!IsOnline())
     93     return;
     94 
     95   metadata_ = metadata;
     96 
     97   // TODO(maksymb): If less than 20% of full TTL left before next announcement
     98   // then send it now.
     99 
    100   uint32 current_ttl = GetCurrentTLL();
    101   if (!CommandLine::ForCurrentProcess()->HasSwitch("no-announcement")) {
    102     DnsResponseBuilder builder(current_ttl);
    103 
    104     builder.AppendTxt(serv_params_.service_name_, current_ttl, metadata_);
    105     scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
    106 
    107     DCHECK(buffer.get() != NULL);
    108 
    109     socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_,
    110                     base::Bind(&DoNothingAfterSendToSocket));
    111   }
    112 }
    113 
    114 bool DnsSdServer::CreateSocket() {
    115   net::IPAddressNumber local_ip_any;
    116   bool success = net::ParseIPLiteralToNumber("0.0.0.0", &local_ip_any);
    117   DCHECK(success);
    118 
    119   net::IPAddressNumber multicast_dns_ip_address;
    120   success = net::ParseIPLiteralToNumber(kDefaultIpAddressMulticast,
    121                                         &multicast_dns_ip_address);
    122   DCHECK(success);
    123 
    124 
    125   socket_.reset(new net::UDPSocket(net::DatagramSocket::DEFAULT_BIND,
    126                                    net::RandIntCallback(), NULL,
    127                                    net::NetLog::Source()));
    128 
    129   net::IPEndPoint local_address = net::IPEndPoint(local_ip_any,
    130                                                   kDefaultPortMulticast);
    131   multicast_address_ = net::IPEndPoint(multicast_dns_ip_address,
    132                                        kDefaultPortMulticast);
    133 
    134   socket_->AllowAddressReuse();
    135 
    136   int status = socket_->Bind(local_address);
    137   if (status < 0)
    138     return false;
    139 
    140   socket_->SetMulticastLoopbackMode(false);
    141   status = socket_->JoinGroup(multicast_dns_ip_address);
    142 
    143   if (status < 0)
    144     return false;
    145 
    146   DCHECK(socket_->is_connected());
    147 
    148   return true;
    149 }
    150 
    151 void DnsSdServer::ProcessMessage(int len, net::IOBufferWithSize* buf) {
    152   VLOG(1) << "Received new message with length: " << len;
    153 
    154   // Parse the message.
    155   DnsPacketParser parser(buf->data(), len);
    156 
    157   if (!parser.IsValid())  // Was unable to parse header.
    158     return;
    159 
    160   // TODO(maksymb): Handle truncated messages.
    161   if (parser.header().flags & net::dns_protocol::kFlagResponse)  // Not a query.
    162     return;
    163 
    164   DnsResponseBuilder builder(parser.header().id);
    165 
    166   uint32 current_ttl = GetCurrentTLL();
    167 
    168   DnsQueryRecord query;
    169   // TODO(maksymb): Check known answers.
    170   for (int query_idx = 0; query_idx < parser.header().qdcount; ++query_idx) {
    171     bool success = parser.ReadRecord(&query);
    172     if (success) {
    173       ProccessQuery(current_ttl, query, &builder);
    174     } else {  // if (success)
    175       LOG(INFO) << "Broken package";
    176       break;
    177     }
    178   }
    179 
    180   scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
    181   if (buffer.get() == NULL)
    182     return;  // No answers.
    183 
    184   VLOG(1) << "Current TTL for respond: " << current_ttl;
    185 
    186   bool unicast_respond =
    187       CommandLine::ForCurrentProcess()->HasSwitch("unicast-respond");
    188   socket_->SendTo(buffer.get(), buffer.get()->size(),
    189                   unicast_respond ? recv_address_ : multicast_address_,
    190                   base::Bind(&DoNothingAfterSendToSocket));
    191   VLOG(1) << "Responded to "
    192       << (unicast_respond ? recv_address_ : multicast_address_).ToString();
    193 }
    194 
    195 void DnsSdServer::ProccessQuery(uint32 current_ttl, const DnsQueryRecord& query,
    196                                 DnsResponseBuilder* builder) const {
    197   std::string log;
    198   bool responded = false;
    199   switch (query.qtype) {
    200     // TODO(maksymb): Add IPv6 support.
    201     case net::dns_protocol::kTypePTR:
    202       log = "Processing PTR query";
    203       if (query.qname == serv_params_.service_type_) {
    204         builder->AppendPtr(serv_params_.service_type_, current_ttl,
    205                            serv_params_.service_name_);
    206         responded = true;
    207       }
    208       break;
    209     case net::dns_protocol::kTypeSRV:
    210       log = "Processing SRV query";
    211       if (query.qname == serv_params_.service_name_) {
    212         builder->AppendSrv(serv_params_.service_name_, current_ttl,
    213                            kSrvPriority, kSrvWeight, serv_params_.http_port_,
    214                            serv_params_.service_domain_name_);
    215         responded = true;
    216       }
    217       break;
    218     case net::dns_protocol::kTypeA:
    219       log = "Processing A query";
    220       if (query.qname == serv_params_.service_domain_name_) {
    221         builder->AppendA(serv_params_.service_domain_name_, current_ttl,
    222                          serv_params_.http_ipv4_);
    223         responded = true;
    224       }
    225       break;
    226     case net::dns_protocol::kTypeTXT:
    227       log = "Processing TXT query";
    228       if (query.qname == serv_params_.service_name_) {
    229         builder->AppendTxt(serv_params_.service_name_, current_ttl, metadata_);
    230         responded = true;
    231       }
    232       break;
    233     default:
    234       base::SStringPrintf(&log, "Unknown query type (%d)", query.qtype);
    235   }
    236   log += responded ? ": responded" : ": ignored";
    237   VLOG(1) << log;
    238 }
    239 
    240 void DnsSdServer::DoLoop(int rv) {
    241   // TODO(maksymb): Check what happened if buffer will be overflowed
    242   do {
    243     if (rv > 0)
    244       ProcessMessage(rv, recv_buf_.get());
    245     rv = socket_->RecvFrom(recv_buf_.get(), recv_buf_->size(), &recv_address_,
    246                            base::Bind(&DnsSdServer::DoLoop, AsWeakPtr()));
    247   } while (rv > 0);
    248 
    249   // TODO(maksymb): Add handler for errors
    250   DCHECK(rv == net::ERR_IO_PENDING);
    251 }
    252 
    253 void DnsSdServer::OnDatagramReceived() {
    254   DoLoop(0);
    255 }
    256 
    257 void DnsSdServer::SendAnnouncement(uint32 ttl) {
    258   if (!CommandLine::ForCurrentProcess()->HasSwitch("no-announcement")) {
    259     DnsResponseBuilder builder(ttl);
    260 
    261     builder.AppendPtr(serv_params_.service_type_, ttl,
    262                       serv_params_.service_name_);
    263     builder.AppendSrv(serv_params_.service_name_, ttl, kSrvPriority, kSrvWeight,
    264                       serv_params_.http_port_,
    265                       serv_params_.service_domain_name_);
    266     builder.AppendA(serv_params_.service_domain_name_, ttl,
    267                     serv_params_.http_ipv4_);
    268     builder.AppendTxt(serv_params_.service_name_, ttl, metadata_);
    269     scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
    270 
    271     DCHECK(buffer.get() != NULL);
    272 
    273     socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_,
    274                     base::Bind(&DoNothingAfterSendToSocket));
    275 
    276     VLOG(1) << "Announcement was sent with TTL: " << ttl;
    277   }
    278 
    279   time_until_live_ = base::Time::Now() +
    280       base::TimeDelta::FromSeconds(full_ttl_);
    281 
    282   // Schedule next announcement.
    283   base::MessageLoop::current()->PostDelayedTask(
    284       FROM_HERE,
    285       base::Bind(&DnsSdServer::Update, AsWeakPtr()),
    286       base::TimeDelta::FromSeconds(static_cast<int64>(
    287           kTimeToNextAnnouncement*full_ttl_)));
    288 }
    289 
    290 uint32 DnsSdServer::GetCurrentTLL() const {
    291   uint32 current_ttl = (time_until_live_ - base::Time::Now()).InSeconds();
    292   if (time_until_live_ < base::Time::Now() || current_ttl == 0) {
    293     // This should not be reachable. But still we don't need to fail.
    294     current_ttl = 1;  // Service is still alive.
    295     LOG(ERROR) << "|current_ttl| was equal to zero.";
    296   }
    297   return current_ttl;
    298 }
    299 
    300