1 // 2 // Copyright (C) 2012 The Android Open Source Project 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 17 #include "shill/dns_client.h" 18 19 #include <arpa/inet.h> 20 #include <netdb.h> 21 #include <netinet/in.h> 22 #include <sys/socket.h> 23 24 #include <map> 25 #include <memory> 26 #include <set> 27 #include <string> 28 #include <vector> 29 30 #include <base/bind.h> 31 #include <base/bind_helpers.h> 32 #include <base/stl_util.h> 33 #include <base/strings/string_number_conversions.h> 34 35 #include "shill/logging.h" 36 #include "shill/net/shill_time.h" 37 #include "shill/shill_ares.h" 38 39 using base::Bind; 40 using base::Unretained; 41 using std::map; 42 using std::set; 43 using std::string; 44 using std::vector; 45 46 namespace shill { 47 48 namespace Logging { 49 static auto kModuleLogScope = ScopeLogger::kDNS; 50 static string ObjectID(DNSClient* d) { return d->interface_name(); } 51 } 52 53 const char DNSClient::kErrorNoData[] = "The query response contains no answers"; 54 const char DNSClient::kErrorFormErr[] = "The server says the query is bad"; 55 const char DNSClient::kErrorServerFail[] = "The server says it had a failure"; 56 const char DNSClient::kErrorNotFound[] = "The queried-for domain was not found"; 57 const char DNSClient::kErrorNotImp[] = "The server doesn't implement operation"; 58 const char DNSClient::kErrorRefused[] = "The server replied, refused the query"; 59 const char DNSClient::kErrorBadQuery[] = "Locally we could not format a query"; 60 const char DNSClient::kErrorNetRefused[] = "The network connection was refused"; 61 const char DNSClient::kErrorTimedOut[] = "The network connection was timed out"; 62 const char DNSClient::kErrorUnknown[] = "DNS Resolver unknown internal error"; 63 64 const int DNSClient::kDefaultDNSPort = 53; 65 66 // Private to the implementation of resolver so callers don't include ares.h 67 struct DNSClientState { 68 DNSClientState() : channel(nullptr), start_time{} {} 69 70 ares_channel channel; 71 map<ares_socket_t, std::shared_ptr<IOHandler>> read_handlers; 72 map<ares_socket_t, std::shared_ptr<IOHandler>> write_handlers; 73 struct timeval start_time; 74 }; 75 76 DNSClient::DNSClient(IPAddress::Family family, 77 const string& interface_name, 78 const vector<string>& dns_servers, 79 int timeout_ms, 80 EventDispatcher* dispatcher, 81 const ClientCallback& callback) 82 : address_(IPAddress(family)), 83 interface_name_(interface_name), 84 dns_servers_(dns_servers), 85 dispatcher_(dispatcher), 86 callback_(callback), 87 timeout_ms_(timeout_ms), 88 running_(false), 89 weak_ptr_factory_(this), 90 ares_(Ares::GetInstance()), 91 time_(Time::GetInstance()) {} 92 93 DNSClient::~DNSClient() { 94 Stop(); 95 } 96 97 bool DNSClient::Start(const string& hostname, Error* error) { 98 if (running_) { 99 Error::PopulateAndLog(FROM_HERE, error, Error::kInProgress, 100 "Only one DNS request is allowed at a time"); 101 return false; 102 } 103 104 if (!resolver_state_.get()) { 105 struct ares_options options; 106 memset(&options, 0, sizeof(options)); 107 options.timeout = timeout_ms_; 108 109 if (dns_servers_.empty()) { 110 Error::PopulateAndLog(FROM_HERE, error, Error::kInvalidArguments, 111 "No valid DNS server addresses"); 112 return false; 113 } 114 115 resolver_state_.reset(new DNSClientState); 116 int status = ares_->InitOptions(&resolver_state_->channel, 117 &options, 118 ARES_OPT_TIMEOUTMS); 119 if (status != ARES_SUCCESS) { 120 Error::PopulateAndLog(FROM_HERE, error, Error::kOperationFailed, 121 "ARES initialization returns error code: " + 122 base::IntToString(status)); 123 resolver_state_.reset(); 124 return false; 125 } 126 127 // Format DNS server addresses string as "host:port[,host:port...]" to be 128 // used in call to ares_set_servers_csv for setting DNS server addresses. 129 // There is a bug in ares library when parsing IPv6 addresses, where it 130 // always assumes the port number are specified when address contains ":". 131 // So when IPv6 address are given without port number as "xx:xx:xx::yy",the 132 // parser would parse the address as "xx:xx:xx:" and port number as "yy". 133 // To work around this bug, port number are added to each address. 134 // 135 // Alternatively, we can use ares_set_servers instead, where we would 136 // explicitly construct a link list of ares_addr_node. 137 string server_addresses; 138 bool first = true; 139 for (const auto& ip : dns_servers_) { 140 if (!first) { 141 server_addresses += ","; 142 } else { 143 first = false; 144 } 145 server_addresses += (ip + ":" + base::IntToString(kDefaultDNSPort)); 146 } 147 status = ares_->SetServersCsv(resolver_state_->channel, 148 server_addresses.c_str()); 149 if (status != ARES_SUCCESS) { 150 Error::PopulateAndLog(FROM_HERE, error, Error::kOperationFailed, 151 "ARES set DNS servers error code: " + 152 base::IntToString(status)); 153 resolver_state_.reset(); 154 return false; 155 } 156 157 ares_->SetLocalDev(resolver_state_->channel, interface_name_.c_str()); 158 } 159 160 running_ = true; 161 time_->GetTimeMonotonic(&resolver_state_->start_time); 162 ares_->GetHostByName(resolver_state_->channel, hostname.c_str(), 163 address_.family(), ReceiveDNSReplyCB, this); 164 165 if (!RefreshHandles()) { 166 LOG(ERROR) << "Impossibly short timeout."; 167 error->CopyFrom(error_); 168 Stop(); 169 return false; 170 } 171 172 return true; 173 } 174 175 void DNSClient::Stop() { 176 SLOG(this, 3) << "In " << __func__; 177 if (!resolver_state_.get()) { 178 return; 179 } 180 181 running_ = false; 182 weak_ptr_factory_.InvalidateWeakPtrs(); 183 error_.Reset(); 184 address_.SetAddressToDefault(); 185 ares_->Destroy(resolver_state_->channel); 186 resolver_state_.reset(); 187 } 188 189 bool DNSClient::IsActive() const { 190 return running_; 191 } 192 193 // We delay our call to completion so that we exit all IOHandlers, and 194 // can clean up all of our local state before calling the callback, or 195 // during the process of the execution of the callee (which is free to 196 // call our destructor safely). 197 void DNSClient::HandleCompletion() { 198 SLOG(this, 3) << "In " << __func__; 199 Error error; 200 error.CopyFrom(error_); 201 IPAddress address(address_); 202 if (!error.IsSuccess()) { 203 // If the DNS request did not succeed, do not trust it for future 204 // attempts. 205 Stop(); 206 } else { 207 // Prepare our state for the next request without destroying the 208 // current ARES state. 209 error_.Reset(); 210 address_.SetAddressToDefault(); 211 } 212 callback_.Run(error, address); 213 } 214 215 void DNSClient::HandleDNSRead(int fd) { 216 ares_->ProcessFd(resolver_state_->channel, fd, ARES_SOCKET_BAD); 217 RefreshHandles(); 218 } 219 220 void DNSClient::HandleDNSWrite(int fd) { 221 ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, fd); 222 RefreshHandles(); 223 } 224 225 void DNSClient::HandleTimeout() { 226 ares_->ProcessFd(resolver_state_->channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); 227 RefreshHandles(); 228 } 229 230 void DNSClient::ReceiveDNSReply(int status, struct hostent* hostent) { 231 if (!running_) { 232 // We can be called during ARES shutdown -- ignore these events. 233 return; 234 } 235 SLOG(this, 3) << "In " << __func__; 236 running_ = false; 237 timeout_closure_.Cancel(); 238 dispatcher_->PostTask(Bind(&DNSClient::HandleCompletion, 239 weak_ptr_factory_.GetWeakPtr())); 240 241 if (status == ARES_SUCCESS && 242 hostent != nullptr && 243 hostent->h_addrtype == address_.family() && 244 static_cast<size_t>(hostent->h_length) == 245 IPAddress::GetAddressLength(address_.family()) && 246 hostent->h_addr_list != nullptr && 247 hostent->h_addr_list[0] != nullptr) { 248 address_ = IPAddress(address_.family(), 249 ByteString(reinterpret_cast<unsigned char*>( 250 hostent->h_addr_list[0]), hostent->h_length)); 251 } else { 252 switch (status) { 253 case ARES_ENODATA: 254 error_.Populate(Error::kOperationFailed, kErrorNoData); 255 break; 256 case ARES_EFORMERR: 257 error_.Populate(Error::kOperationFailed, kErrorFormErr); 258 break; 259 case ARES_ESERVFAIL: 260 error_.Populate(Error::kOperationFailed, kErrorServerFail); 261 break; 262 case ARES_ENOTFOUND: 263 error_.Populate(Error::kOperationFailed, kErrorNotFound); 264 break; 265 case ARES_ENOTIMP: 266 error_.Populate(Error::kOperationFailed, kErrorNotImp); 267 break; 268 case ARES_EREFUSED: 269 error_.Populate(Error::kOperationFailed, kErrorRefused); 270 break; 271 case ARES_EBADQUERY: 272 case ARES_EBADNAME: 273 case ARES_EBADFAMILY: 274 case ARES_EBADRESP: 275 error_.Populate(Error::kOperationFailed, kErrorBadQuery); 276 break; 277 case ARES_ECONNREFUSED: 278 error_.Populate(Error::kOperationFailed, kErrorNetRefused); 279 break; 280 case ARES_ETIMEOUT: 281 error_.Populate(Error::kOperationTimeout, kErrorTimedOut); 282 break; 283 default: 284 error_.Populate(Error::kOperationFailed, kErrorUnknown); 285 if (status == ARES_SUCCESS) { 286 LOG(ERROR) << "ARES returned success but hostent was invalid!"; 287 } else { 288 LOG(ERROR) << "ARES returned unhandled error status " << status; 289 } 290 break; 291 } 292 } 293 } 294 295 void DNSClient::ReceiveDNSReplyCB(void* arg, int status, 296 int /*timeouts*/, 297 struct hostent* hostent) { 298 DNSClient* res = static_cast<DNSClient*>(arg); 299 res->ReceiveDNSReply(status, hostent); 300 } 301 302 bool DNSClient::RefreshHandles() { 303 map<ares_socket_t, std::shared_ptr<IOHandler>> old_read = 304 resolver_state_->read_handlers; 305 map<ares_socket_t, std::shared_ptr<IOHandler>> old_write = 306 resolver_state_->write_handlers; 307 308 resolver_state_->read_handlers.clear(); 309 resolver_state_->write_handlers.clear(); 310 311 ares_socket_t sockets[ARES_GETSOCK_MAXNUM]; 312 int action_bits = ares_->GetSock(resolver_state_->channel, sockets, 313 ARES_GETSOCK_MAXNUM); 314 315 base::Callback<void(int)> read_callback( 316 Bind(&DNSClient::HandleDNSRead, weak_ptr_factory_.GetWeakPtr())); 317 base::Callback<void(int)> write_callback( 318 Bind(&DNSClient::HandleDNSWrite, weak_ptr_factory_.GetWeakPtr())); 319 for (int i = 0; i < ARES_GETSOCK_MAXNUM; i++) { 320 if (ARES_GETSOCK_READABLE(action_bits, i)) { 321 if (ContainsKey(old_read, sockets[i])) { 322 resolver_state_->read_handlers[sockets[i]] = old_read[sockets[i]]; 323 } else { 324 resolver_state_->read_handlers[sockets[i]] = 325 std::shared_ptr<IOHandler> ( 326 dispatcher_->CreateReadyHandler(sockets[i], 327 IOHandler::kModeInput, 328 read_callback)); 329 } 330 } 331 if (ARES_GETSOCK_WRITABLE(action_bits, i)) { 332 if (ContainsKey(old_write, sockets[i])) { 333 resolver_state_->write_handlers[sockets[i]] = old_write[sockets[i]]; 334 } else { 335 resolver_state_->write_handlers[sockets[i]] = 336 std::shared_ptr<IOHandler> ( 337 dispatcher_->CreateReadyHandler(sockets[i], 338 IOHandler::kModeOutput, 339 write_callback)); 340 } 341 } 342 } 343 344 if (!running_) { 345 // We are here just to clean up socket handles, and the ARES state was 346 // cleaned up during the last call to ares_->ProcessFd(). 347 return false; 348 } 349 350 // Schedule timer event for the earlier of our timeout or one requested by 351 // the resolver library. 352 struct timeval now, elapsed_time, timeout_tv; 353 time_->GetTimeMonotonic(&now); 354 timersub(&now, &resolver_state_->start_time, &elapsed_time); 355 timeout_tv.tv_sec = timeout_ms_ / 1000; 356 timeout_tv.tv_usec = (timeout_ms_ % 1000) * 1000; 357 timeout_closure_.Cancel(); 358 359 if (timercmp(&elapsed_time, &timeout_tv, >=)) { 360 // There are 3 cases of interest: 361 // - If we got here from Start(), when we return, Stop() will be 362 // called, so our cleanup task will not run, so we will not have the 363 // side-effect of both invoking the callback and returning False 364 // in Start(). 365 // - If we got here from the tail of an IO event, we can't call 366 // Stop() since that will blow away the IOHandler we are running 367 // in. We will perform the cleanup in the posted task below. 368 // - If we got here from a timeout handler, we will perform cleanup 369 // in the posted task. 370 running_ = false; 371 error_.Populate(Error::kOperationTimeout, kErrorTimedOut); 372 dispatcher_->PostTask(Bind(&DNSClient::HandleCompletion, 373 weak_ptr_factory_.GetWeakPtr())); 374 return false; 375 } else { 376 struct timeval max, ret_tv; 377 timersub(&timeout_tv, &elapsed_time, &max); 378 struct timeval* tv = ares_->Timeout(resolver_state_->channel, 379 &max, &ret_tv); 380 timeout_closure_.Reset( 381 Bind(&DNSClient::HandleTimeout, weak_ptr_factory_.GetWeakPtr())); 382 dispatcher_->PostDelayedTask(timeout_closure_.callback(), 383 tv->tv_sec * 1000 + tv->tv_usec / 1000); 384 } 385 386 return true; 387 } 388 389 } // namespace shill 390