Home | History | Annotate | Download | only in shill
      1 //
      2 // Copyright (C) 2012 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "shill/dns_client.h"
     18 
     19 #include <arpa/inet.h>
     20 #include <netdb.h>
     21 #include <netinet/in.h>
     22 #include <sys/socket.h>
     23 
     24 #include <map>
     25 #include <memory>
     26 #include <set>
     27 #include <string>
     28 #include <vector>
     29 
     30 #include <base/bind.h>
     31 #include <base/bind_helpers.h>
     32 #include <base/stl_util.h>
     33 #include <base/strings/string_number_conversions.h>
     34 
     35 #include "shill/logging.h"
     36 #include "shill/net/shill_time.h"
     37 #include "shill/shill_ares.h"
     38 
     39 using base::Bind;
     40 using base::Unretained;
     41 using std::map;
     42 using std::set;
     43 using std::string;
     44 using std::vector;
     45 
     46 namespace shill {
     47 
     48 namespace Logging {
     49 static auto kModuleLogScope = ScopeLogger::kDNS;
     50 static string ObjectID(DNSClient* d) { return d->interface_name(); }
     51 }
     52 
     53 const char DNSClient::kErrorNoData[] = "The query response contains no answers";
     54 const char DNSClient::kErrorFormErr[] = "The server says the query is bad";
     55 const char DNSClient::kErrorServerFail[] = "The server says it had a failure";
     56 const char DNSClient::kErrorNotFound[] = "The queried-for domain was not found";
     57 const char DNSClient::kErrorNotImp[] = "The server doesn't implement operation";
     58 const char DNSClient::kErrorRefused[] = "The server replied, refused the query";
     59 const char DNSClient::kErrorBadQuery[] = "Locally we could not format a query";
     60 const char DNSClient::kErrorNetRefused[] = "The network connection was refused";
     61 const char DNSClient::kErrorTimedOut[] = "The network connection was timed out";
     62 const char DNSClient::kErrorUnknown[] = "DNS Resolver unknown internal error";
     63 
     64 const int DNSClient::kDefaultDNSPort = 53;
     65 
     66 // Private to the implementation of resolver so callers don't include ares.h
     67 struct DNSClientState {
     68   DNSClientState() : channel(nullptr), start_time{} {}
     69 
     70   ares_channel channel;
     71   map<ares_socket_t, std::shared_ptr<IOHandler>> read_handlers;
     72   map<ares_socket_t, std::shared_ptr<IOHandler>> write_handlers;
     73   struct timeval start_time;
     74 };
     75 
     76 DNSClient::DNSClient(IPAddress::Family family,
     77                      const string& interface_name,
     78                      const vector<string>& dns_servers,
     79                      int timeout_ms,
     80                      EventDispatcher* dispatcher,
     81                      const ClientCallback& callback)
     82     : address_(IPAddress(family)),
     83       interface_name_(interface_name),
     84       dns_servers_(dns_servers),
     85       dispatcher_(dispatcher),
     86       callback_(callback),
     87       timeout_ms_(timeout_ms),
     88       running_(false),
     89       weak_ptr_factory_(this),
     90       ares_(Ares::GetInstance()),
     91       time_(Time::GetInstance()) {}
     92 
     93 DNSClient::~DNSClient() {
     94   Stop();
     95 }
     96 
     97 bool DNSClient::Start(const string& hostname, Error* error) {
     98   if (running_) {
     99     Error::PopulateAndLog(FROM_HERE, error, Error::kInProgress,
    100                           "Only one DNS request is allowed at a time");
    101     return false;
    102   }
    103 
    104   if (!resolver_state_.get()) {
    105     struct ares_options options;
    106     memset(&options, 0, sizeof(options));
    107     options.timeout = timeout_ms_;
    108 
    109     if (dns_servers_.empty()) {
    110       Error::PopulateAndLog(FROM_HERE, error, Error::kInvalidArguments,
    111                             "No valid DNS server addresses");
    112       return false;
    113     }
    114 
    115     resolver_state_.reset(new DNSClientState);
    116     int status = ares_->InitOptions(&resolver_state_->channel,
    117                                    &options,
    118                                    ARES_OPT_TIMEOUTMS);
    119     if (status != ARES_SUCCESS) {
    120       Error::PopulateAndLog(FROM_HERE, error, Error::kOperationFailed,
    121                             "ARES initialization returns error code: " +
    122                             base::IntToString(status));
    123       resolver_state_.reset();
    124       return false;
    125     }
    126 
    127     // Format DNS server addresses string as "host:port[,host:port...]" to be
    128     // used in call to ares_set_servers_csv for setting DNS server addresses.
    129     // There is a bug in ares library when parsing IPv6 addresses, where it
    130     // always assumes the port number are specified when address contains ":".
    131     // So when IPv6 address are given without port number as "xx:xx:xx::yy",the
    132     // parser would parse the address as "xx:xx:xx:" and port number as "yy".
    133     // To work around this bug, port number are added to each address.
    134     //
    135     // Alternatively, we can use ares_set_servers instead, where we would
    136     // explicitly construct a link list of ares_addr_node.
    137     string server_addresses;
    138     bool first = true;
    139     for (const auto& ip : dns_servers_) {
    140       if (!first) {
    141         server_addresses += ",";
    142       } else {
    143         first = false;
    144       }
    145       server_addresses += (ip + ":" + base::IntToString(kDefaultDNSPort));
    146     }
    147     status = ares_->SetServersCsv(resolver_state_->channel,
    148                                   server_addresses.c_str());
    149     if (status != ARES_SUCCESS) {
    150       Error::PopulateAndLog(FROM_HERE, error, Error::kOperationFailed,
    151                             "ARES set DNS servers error code: " +
    152                             base::IntToString(status));
    153       resolver_state_.reset();
    154       return false;
    155     }
    156 
    157     ares_->SetLocalDev(resolver_state_->channel, interface_name_.c_str());
    158   }
    159 
    160   running_ = true;
    161   time_->GetTimeMonotonic(&resolver_state_->start_time);
    162   ares_->GetHostByName(resolver_state_->channel, hostname.c_str(),
    163                        address_.family(), ReceiveDNSReplyCB, this);
    164 
    165   if (!RefreshHandles()) {
    166     LOG(ERROR) << "Impossibly short timeout.";
    167     error->CopyFrom(error_);
    168     Stop();
    169     return false;
    170   }
    171 
    172   return true;
    173 }
    174 
    175 void DNSClient::Stop() {
    176   SLOG(this, 3) << "In " << __func__;
    177   if (!resolver_state_.get()) {
    178     return;
    179   }
    180 
    181   running_ = false;
    182   weak_ptr_factory_.InvalidateWeakPtrs();
    183   error_.Reset();
    184   address_.SetAddressToDefault();
    185   ares_->Destroy(resolver_state_->channel);
    186   resolver_state_.reset();
    187 }
    188 
    189 bool DNSClient::IsActive() const {
    190   return running_;
    191 }
    192 
    193 // We delay our call to completion so that we exit all IOHandlers, and
    194 // can clean up all of our local state before calling the callback, or
    195 // during the process of the execution of the callee (which is free to
    196 // call our destructor safely).
    197 void DNSClient::HandleCompletion() {
    198   SLOG(this, 3) << "In " << __func__;
    199   Error error;
    200   error.CopyFrom(error_);
    201   IPAddress address(address_);
    202   if (!error.IsSuccess()) {
    203     // If the DNS request did not succeed, do not trust it for future
    204     // attempts.
    205     Stop();
    206   } else {
    207     // Prepare our state for the next request without destroying the
    208     // current ARES state.
    209     error_.Reset();
    210     address_.SetAddressToDefault();
    211   }
    212   callback_.Run(error, address);
    213 }
    214 
    215 void DNSClient::HandleDNSRead(int fd) {
    216   ares_->ProcessFd(resolver_state_->channel, fd, ARES_SOCKET_BAD);
    217   RefreshHandles();
    218 }
    219 
    220 void DNSClient::HandleDNSWrite(int fd) {
    221   ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, fd);
    222   RefreshHandles();
    223 }
    224 
    225 void DNSClient::HandleTimeout() {
    226   ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
    227   RefreshHandles();
    228 }
    229 
    230 void DNSClient::ReceiveDNSReply(int status, struct hostent* hostent) {
    231   if (!running_) {
    232     // We can be called during ARES shutdown -- ignore these events.
    233     return;
    234   }
    235   SLOG(this, 3) << "In " << __func__;
    236   running_ = false;
    237   timeout_closure_.Cancel();
    238   dispatcher_->PostTask(Bind(&DNSClient::HandleCompletion,
    239                              weak_ptr_factory_.GetWeakPtr()));
    240 
    241   if (status == ARES_SUCCESS &&
    242       hostent != nullptr &&
    243       hostent->h_addrtype == address_.family() &&
    244       static_cast<size_t>(hostent->h_length) ==
    245       IPAddress::GetAddressLength(address_.family()) &&
    246       hostent->h_addr_list != nullptr &&
    247       hostent->h_addr_list[0] != nullptr) {
    248     address_ = IPAddress(address_.family(),
    249                          ByteString(reinterpret_cast<unsigned char*>(
    250                              hostent->h_addr_list[0]), hostent->h_length));
    251   } else {
    252     switch (status) {
    253       case ARES_ENODATA:
    254         error_.Populate(Error::kOperationFailed, kErrorNoData);
    255         break;
    256       case ARES_EFORMERR:
    257         error_.Populate(Error::kOperationFailed, kErrorFormErr);
    258         break;
    259       case ARES_ESERVFAIL:
    260         error_.Populate(Error::kOperationFailed, kErrorServerFail);
    261         break;
    262       case ARES_ENOTFOUND:
    263         error_.Populate(Error::kOperationFailed, kErrorNotFound);
    264         break;
    265       case ARES_ENOTIMP:
    266         error_.Populate(Error::kOperationFailed, kErrorNotImp);
    267         break;
    268       case ARES_EREFUSED:
    269         error_.Populate(Error::kOperationFailed, kErrorRefused);
    270         break;
    271       case ARES_EBADQUERY:
    272       case ARES_EBADNAME:
    273       case ARES_EBADFAMILY:
    274       case ARES_EBADRESP:
    275         error_.Populate(Error::kOperationFailed, kErrorBadQuery);
    276         break;
    277       case ARES_ECONNREFUSED:
    278         error_.Populate(Error::kOperationFailed, kErrorNetRefused);
    279         break;
    280       case ARES_ETIMEOUT:
    281         error_.Populate(Error::kOperationTimeout, kErrorTimedOut);
    282         break;
    283       default:
    284         error_.Populate(Error::kOperationFailed, kErrorUnknown);
    285         if (status == ARES_SUCCESS) {
    286           LOG(ERROR) << "ARES returned success but hostent was invalid!";
    287         } else {
    288           LOG(ERROR) << "ARES returned unhandled error status " << status;
    289         }
    290         break;
    291     }
    292   }
    293 }
    294 
    295 void DNSClient::ReceiveDNSReplyCB(void* arg, int status,
    296                                   int /*timeouts*/,
    297                                   struct hostent* hostent) {
    298   DNSClient* res = static_cast<DNSClient*>(arg);
    299   res->ReceiveDNSReply(status, hostent);
    300 }
    301 
    302 bool DNSClient::RefreshHandles() {
    303   map<ares_socket_t, std::shared_ptr<IOHandler>> old_read =
    304       resolver_state_->read_handlers;
    305   map<ares_socket_t, std::shared_ptr<IOHandler>> old_write =
    306       resolver_state_->write_handlers;
    307 
    308   resolver_state_->read_handlers.clear();
    309   resolver_state_->write_handlers.clear();
    310 
    311   ares_socket_t sockets[ARES_GETSOCK_MAXNUM];
    312   int action_bits = ares_->GetSock(resolver_state_->channel, sockets,
    313                                    ARES_GETSOCK_MAXNUM);
    314 
    315   base::Callback<void(int)> read_callback(
    316       Bind(&DNSClient::HandleDNSRead, weak_ptr_factory_.GetWeakPtr()));
    317   base::Callback<void(int)> write_callback(
    318       Bind(&DNSClient::HandleDNSWrite, weak_ptr_factory_.GetWeakPtr()));
    319   for (int i = 0; i < ARES_GETSOCK_MAXNUM; i++) {
    320     if (ARES_GETSOCK_READABLE(action_bits, i)) {
    321       if (ContainsKey(old_read, sockets[i])) {
    322         resolver_state_->read_handlers[sockets[i]] = old_read[sockets[i]];
    323       } else {
    324         resolver_state_->read_handlers[sockets[i]] =
    325             std::shared_ptr<IOHandler> (
    326                 dispatcher_->CreateReadyHandler(sockets[i],
    327                                                 IOHandler::kModeInput,
    328                                                 read_callback));
    329       }
    330     }
    331     if (ARES_GETSOCK_WRITABLE(action_bits, i)) {
    332       if (ContainsKey(old_write, sockets[i])) {
    333         resolver_state_->write_handlers[sockets[i]] = old_write[sockets[i]];
    334       } else {
    335         resolver_state_->write_handlers[sockets[i]] =
    336             std::shared_ptr<IOHandler> (
    337                 dispatcher_->CreateReadyHandler(sockets[i],
    338                                                 IOHandler::kModeOutput,
    339                                                 write_callback));
    340       }
    341     }
    342   }
    343 
    344   if (!running_) {
    345     // We are here just to clean up socket handles, and the ARES state was
    346     // cleaned up during the last call to ares_->ProcessFd().
    347     return false;
    348   }
    349 
    350   // Schedule timer event for the earlier of our timeout or one requested by
    351   // the resolver library.
    352   struct timeval now, elapsed_time, timeout_tv;
    353   time_->GetTimeMonotonic(&now);
    354   timersub(&now, &resolver_state_->start_time, &elapsed_time);
    355   timeout_tv.tv_sec = timeout_ms_ / 1000;
    356   timeout_tv.tv_usec = (timeout_ms_ % 1000) * 1000;
    357   timeout_closure_.Cancel();
    358 
    359   if (timercmp(&elapsed_time, &timeout_tv, >=)) {
    360     // There are 3 cases of interest:
    361     //  - If we got here from Start(), when we return, Stop() will be
    362     //    called, so our cleanup task will not run, so we will not have the
    363     //    side-effect of both invoking the callback and returning False
    364     //    in Start().
    365     //  - If we got here from the tail of an IO event, we can't call
    366     //    Stop() since that will blow away the IOHandler we are running
    367     //    in.  We will perform the cleanup in the posted task below.
    368     //  - If we got here from a timeout handler, we will perform cleanup
    369     //    in the posted task.
    370     running_ = false;
    371     error_.Populate(Error::kOperationTimeout, kErrorTimedOut);
    372     dispatcher_->PostTask(Bind(&DNSClient::HandleCompletion,
    373                                weak_ptr_factory_.GetWeakPtr()));
    374     return false;
    375   } else {
    376     struct timeval max, ret_tv;
    377     timersub(&timeout_tv, &elapsed_time, &max);
    378     struct timeval* tv = ares_->Timeout(resolver_state_->channel,
    379                                         &max, &ret_tv);
    380     timeout_closure_.Reset(
    381         Bind(&DNSClient::HandleTimeout, weak_ptr_factory_.GetWeakPtr()));
    382     dispatcher_->PostDelayedTask(timeout_closure_.callback(),
    383                                  tv->tv_sec * 1000 + tv->tv_usec / 1000);
    384   }
    385 
    386   return true;
    387 }
    388 
    389 }  // namespace shill
    390