1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/dns/mdns_client_impl.h" 6 7 #include "base/bind.h" 8 #include "base/message_loop/message_loop_proxy.h" 9 #include "base/stl_util.h" 10 #include "base/time/default_clock.h" 11 #include "net/base/dns_util.h" 12 #include "net/base/net_errors.h" 13 #include "net/base/net_log.h" 14 #include "net/base/rand_callback.h" 15 #include "net/dns/dns_protocol.h" 16 #include "net/dns/record_rdata.h" 17 #include "net/udp/datagram_socket.h" 18 19 // TODO(gene): Remove this temporary method of disabling NSEC support once it 20 // becomes clear whether this feature should be 21 // supported. http://crbug.com/255232 22 #define ENABLE_NSEC 23 24 namespace net { 25 26 namespace { 27 28 const unsigned MDnsTransactionTimeoutSeconds = 3; 29 30 } // namespace 31 32 void MDnsSocketFactoryImpl::CreateSockets( 33 ScopedVector<DatagramServerSocket>* sockets) { 34 InterfaceIndexFamilyList interfaces(GetMDnsInterfacesToBind()); 35 for (size_t i = 0; i < interfaces.size(); ++i) { 36 DCHECK(interfaces[i].second == net::ADDRESS_FAMILY_IPV4 || 37 interfaces[i].second == net::ADDRESS_FAMILY_IPV6); 38 scoped_ptr<DatagramServerSocket> socket( 39 CreateAndBindMDnsSocket(interfaces[i].second, interfaces[i].first)); 40 if (socket) 41 sockets->push_back(socket.release()); 42 } 43 } 44 45 MDnsConnection::SocketHandler::SocketHandler( 46 scoped_ptr<DatagramServerSocket> socket, 47 MDnsConnection* connection) 48 : socket_(socket.Pass()), 49 connection_(connection), 50 response_(dns_protocol::kMaxMulticastSize) { 51 } 52 53 MDnsConnection::SocketHandler::~SocketHandler() { 54 } 55 56 int MDnsConnection::SocketHandler::Start() { 57 IPEndPoint end_point; 58 int rv = socket_->GetLocalAddress(&end_point); 59 if (rv != OK) 60 return rv; 61 DCHECK(end_point.GetFamily() == ADDRESS_FAMILY_IPV4 || 62 end_point.GetFamily() == ADDRESS_FAMILY_IPV6); 63 multicast_addr_ = GetMDnsIPEndPoint(end_point.GetFamily()); 64 return DoLoop(0); 65 } 66 67 int MDnsConnection::SocketHandler::DoLoop(int rv) { 68 do { 69 if (rv > 0) 70 connection_->OnDatagramReceived(&response_, recv_addr_, rv); 71 72 rv = socket_->RecvFrom( 73 response_.io_buffer(), 74 response_.io_buffer()->size(), 75 &recv_addr_, 76 base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived, 77 base::Unretained(this))); 78 } while (rv > 0); 79 80 if (rv != ERR_IO_PENDING) 81 return rv; 82 83 return OK; 84 } 85 86 void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) { 87 if (rv >= OK) 88 rv = DoLoop(rv); 89 90 if (rv != OK) 91 connection_->OnError(this, rv); 92 } 93 94 int MDnsConnection::SocketHandler::Send(IOBuffer* buffer, unsigned size) { 95 return socket_->SendTo(buffer, size, multicast_addr_, 96 base::Bind(&MDnsConnection::SocketHandler::SendDone, 97 base::Unretained(this) )); 98 } 99 100 void MDnsConnection::SocketHandler::SendDone(int rv) { 101 // TODO(noamsml): Retry logic. 102 } 103 104 MDnsConnection::MDnsConnection(MDnsConnection::Delegate* delegate) : 105 delegate_(delegate) { 106 } 107 108 MDnsConnection::~MDnsConnection() { 109 } 110 111 bool MDnsConnection::Init(MDnsSocketFactory* socket_factory) { 112 ScopedVector<DatagramServerSocket> sockets; 113 socket_factory->CreateSockets(&sockets); 114 115 for (size_t i = 0; i < sockets.size(); ++i) { 116 socket_handlers_.push_back( 117 new MDnsConnection::SocketHandler(make_scoped_ptr(sockets[i]), this)); 118 } 119 sockets.weak_clear(); 120 121 // All unbound sockets need to be bound before processing untrusted input. 122 // This is done for security reasons, so that an attacker can't get an unbound 123 // socket. 124 for (size_t i = 0; i < socket_handlers_.size();) { 125 int rv = socket_handlers_[i]->Start(); 126 if (rv != OK) { 127 socket_handlers_.erase(socket_handlers_.begin() + i); 128 VLOG(1) << "Start failed, socket=" << i << ", error=" << rv; 129 } else { 130 ++i; 131 } 132 } 133 VLOG(1) << "Sockets ready:" << socket_handlers_.size(); 134 return !socket_handlers_.empty(); 135 } 136 137 bool MDnsConnection::Send(IOBuffer* buffer, unsigned size) { 138 bool success = false; 139 for (size_t i = 0; i < socket_handlers_.size(); ++i) { 140 int rv = socket_handlers_[i]->Send(buffer, size); 141 if (rv >= OK || rv == ERR_IO_PENDING) { 142 success = true; 143 } else { 144 VLOG(1) << "Send failed, socket=" << i << ", error=" << rv; 145 } 146 } 147 return success; 148 } 149 150 void MDnsConnection::OnError(SocketHandler* loop, 151 int error) { 152 // TODO(noamsml): Specific handling of intermittent errors that can be handled 153 // in the connection. 154 delegate_->OnConnectionError(error); 155 } 156 157 void MDnsConnection::OnDatagramReceived( 158 DnsResponse* response, 159 const IPEndPoint& recv_addr, 160 int bytes_read) { 161 // TODO(noamsml): More sophisticated error handling. 162 DCHECK_GT(bytes_read, 0); 163 delegate_->HandlePacket(response, bytes_read); 164 } 165 166 MDnsClientImpl::Core::Core(MDnsClientImpl* client) 167 : client_(client), connection_(new MDnsConnection(this)) { 168 } 169 170 MDnsClientImpl::Core::~Core() { 171 STLDeleteValues(&listeners_); 172 } 173 174 bool MDnsClientImpl::Core::Init(MDnsSocketFactory* socket_factory) { 175 return connection_->Init(socket_factory); 176 } 177 178 bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) { 179 std::string name_dns; 180 if (!DNSDomainFromDot(name, &name_dns)) 181 return false; 182 183 DnsQuery query(0, name_dns, rrtype); 184 query.set_flags(0); // Remove the RD flag from the query. It is unneeded. 185 186 return connection_->Send(query.io_buffer(), query.io_buffer()->size()); 187 } 188 189 void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, 190 int bytes_read) { 191 unsigned offset; 192 // Note: We store cache keys rather than record pointers to avoid 193 // erroneous behavior in case a packet contains multiple exclusive 194 // records with the same type and name. 195 std::map<MDnsCache::Key, MDnsListener::UpdateType> update_keys; 196 197 if (!response->InitParseWithoutQuery(bytes_read)) { 198 LOG(WARNING) << "Could not understand an mDNS packet."; 199 return; // Message is unreadable. 200 } 201 202 // TODO(noamsml): duplicate query suppression. 203 if (!(response->flags() & dns_protocol::kFlagResponse)) 204 return; // Message is a query. ignore it. 205 206 DnsRecordParser parser = response->Parser(); 207 unsigned answer_count = response->answer_count() + 208 response->additional_answer_count(); 209 210 for (unsigned i = 0; i < answer_count; i++) { 211 offset = parser.GetOffset(); 212 scoped_ptr<const RecordParsed> record = RecordParsed::CreateFrom( 213 &parser, base::Time::Now()); 214 215 if (!record) { 216 LOG(WARNING) << "Could not understand an mDNS record."; 217 218 if (offset == parser.GetOffset()) { 219 LOG(WARNING) << "Abandoned parsing the rest of the packet."; 220 return; // The parser did not advance, abort reading the packet. 221 } else { 222 continue; // We may be able to extract other records from the packet. 223 } 224 } 225 226 if ((record->klass() & dns_protocol::kMDnsClassMask) != 227 dns_protocol::kClassIN) { 228 LOG(WARNING) << "Received an mDNS record with non-IN class. Ignoring."; 229 continue; // Ignore all records not in the IN class. 230 } 231 232 MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get()); 233 MDnsCache::UpdateType update = cache_.UpdateDnsRecord(record.Pass()); 234 235 // Cleanup time may have changed. 236 ScheduleCleanup(cache_.next_expiration()); 237 238 if (update != MDnsCache::NoChange) { 239 MDnsListener::UpdateType update_external; 240 241 switch (update) { 242 case MDnsCache::RecordAdded: 243 update_external = MDnsListener::RECORD_ADDED; 244 break; 245 case MDnsCache::RecordChanged: 246 update_external = MDnsListener::RECORD_CHANGED; 247 break; 248 case MDnsCache::NoChange: 249 default: 250 NOTREACHED(); 251 // Dummy assignment to suppress compiler warning. 252 update_external = MDnsListener::RECORD_CHANGED; 253 break; 254 } 255 256 update_keys.insert(std::make_pair(update_key, update_external)); 257 } 258 } 259 260 for (std::map<MDnsCache::Key, MDnsListener::UpdateType>::iterator i = 261 update_keys.begin(); i != update_keys.end(); i++) { 262 const RecordParsed* record = cache_.LookupKey(i->first); 263 if (!record) 264 continue; 265 266 if (record->type() == dns_protocol::kTypeNSEC) { 267 #if defined(ENABLE_NSEC) 268 NotifyNsecRecord(record); 269 #endif 270 } else { 271 AlertListeners(i->second, ListenerKey(record->name(), record->type()), 272 record); 273 } 274 } 275 } 276 277 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) { 278 DCHECK_EQ(dns_protocol::kTypeNSEC, record->type()); 279 const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>(); 280 DCHECK(rdata); 281 282 // Remove all cached records matching the nonexistent RR types. 283 std::vector<const RecordParsed*> records_to_remove; 284 285 cache_.FindDnsRecords(0, record->name(), &records_to_remove, 286 base::Time::Now()); 287 288 for (std::vector<const RecordParsed*>::iterator i = records_to_remove.begin(); 289 i != records_to_remove.end(); i++) { 290 if ((*i)->type() == dns_protocol::kTypeNSEC) 291 continue; 292 if (!rdata->GetBit((*i)->type())) { 293 scoped_ptr<const RecordParsed> record_removed = cache_.RemoveRecord((*i)); 294 DCHECK(record_removed); 295 OnRecordRemoved(record_removed.get()); 296 } 297 } 298 299 // Alert all listeners waiting for the nonexistent RR types. 300 ListenerMap::iterator i = 301 listeners_.upper_bound(ListenerKey(record->name(), 0)); 302 for (; i != listeners_.end() && i->first.first == record->name(); i++) { 303 if (!rdata->GetBit(i->first.second)) { 304 FOR_EACH_OBSERVER(MDnsListenerImpl, *i->second, AlertNsecRecord()); 305 } 306 } 307 } 308 309 void MDnsClientImpl::Core::OnConnectionError(int error) { 310 // TODO(noamsml): On connection error, recreate connection and flush cache. 311 } 312 313 void MDnsClientImpl::Core::AlertListeners( 314 MDnsListener::UpdateType update_type, 315 const ListenerKey& key, 316 const RecordParsed* record) { 317 ListenerMap::iterator listener_map_iterator = listeners_.find(key); 318 if (listener_map_iterator == listeners_.end()) return; 319 320 FOR_EACH_OBSERVER(MDnsListenerImpl, *listener_map_iterator->second, 321 AlertDelegate(update_type, record)); 322 } 323 324 void MDnsClientImpl::Core::AddListener( 325 MDnsListenerImpl* listener) { 326 ListenerKey key(listener->GetName(), listener->GetType()); 327 std::pair<ListenerMap::iterator, bool> observer_insert_result = 328 listeners_.insert( 329 make_pair(key, static_cast<ObserverList<MDnsListenerImpl>*>(NULL))); 330 331 // If an equivalent key does not exist, actually create the observer list. 332 if (observer_insert_result.second) 333 observer_insert_result.first->second = new ObserverList<MDnsListenerImpl>(); 334 335 ObserverList<MDnsListenerImpl>* observer_list = 336 observer_insert_result.first->second; 337 338 observer_list->AddObserver(listener); 339 } 340 341 void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) { 342 ListenerKey key(listener->GetName(), listener->GetType()); 343 ListenerMap::iterator observer_list_iterator = listeners_.find(key); 344 345 DCHECK(observer_list_iterator != listeners_.end()); 346 DCHECK(observer_list_iterator->second->HasObserver(listener)); 347 348 observer_list_iterator->second->RemoveObserver(listener); 349 350 // Remove the observer list from the map if it is empty 351 if (!observer_list_iterator->second->might_have_observers()) { 352 // Schedule the actual removal for later in case the listener removal 353 // happens while iterating over the observer list. 354 base::MessageLoop::current()->PostTask( 355 FROM_HERE, base::Bind( 356 &MDnsClientImpl::Core::CleanupObserverList, AsWeakPtr(), key)); 357 } 358 } 359 360 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) { 361 ListenerMap::iterator found = listeners_.find(key); 362 if (found != listeners_.end() && !found->second->might_have_observers()) { 363 delete found->second; 364 listeners_.erase(found); 365 } 366 } 367 368 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) { 369 // Cleanup is already scheduled, no need to do anything. 370 if (cleanup == scheduled_cleanup_) return; 371 scheduled_cleanup_ = cleanup; 372 373 // This cancels the previously scheduled cleanup. 374 cleanup_callback_.Reset(base::Bind( 375 &MDnsClientImpl::Core::DoCleanup, base::Unretained(this))); 376 377 // If |cleanup| is empty, then no cleanup necessary. 378 if (cleanup != base::Time()) { 379 base::MessageLoop::current()->PostDelayedTask( 380 FROM_HERE, 381 cleanup_callback_.callback(), 382 cleanup - base::Time::Now()); 383 } 384 } 385 386 void MDnsClientImpl::Core::DoCleanup() { 387 cache_.CleanupRecords(base::Time::Now(), base::Bind( 388 &MDnsClientImpl::Core::OnRecordRemoved, base::Unretained(this))); 389 390 ScheduleCleanup(cache_.next_expiration()); 391 } 392 393 void MDnsClientImpl::Core::OnRecordRemoved( 394 const RecordParsed* record) { 395 AlertListeners(MDnsListener::RECORD_REMOVED, 396 ListenerKey(record->name(), record->type()), record); 397 } 398 399 void MDnsClientImpl::Core::QueryCache( 400 uint16 rrtype, const std::string& name, 401 std::vector<const RecordParsed*>* records) const { 402 cache_.FindDnsRecords(rrtype, name, records, base::Time::Now()); 403 } 404 405 MDnsClientImpl::MDnsClientImpl() { 406 } 407 408 MDnsClientImpl::~MDnsClientImpl() { 409 } 410 411 bool MDnsClientImpl::StartListening(MDnsSocketFactory* socket_factory) { 412 DCHECK(!core_.get()); 413 core_.reset(new Core(this)); 414 if (!core_->Init(socket_factory)) { 415 core_.reset(); 416 return false; 417 } 418 return true; 419 } 420 421 void MDnsClientImpl::StopListening() { 422 core_.reset(); 423 } 424 425 bool MDnsClientImpl::IsListening() const { 426 return core_.get() != NULL; 427 } 428 429 scoped_ptr<MDnsListener> MDnsClientImpl::CreateListener( 430 uint16 rrtype, 431 const std::string& name, 432 MDnsListener::Delegate* delegate) { 433 return scoped_ptr<net::MDnsListener>( 434 new MDnsListenerImpl(rrtype, name, delegate, this)); 435 } 436 437 scoped_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction( 438 uint16 rrtype, 439 const std::string& name, 440 int flags, 441 const MDnsTransaction::ResultCallback& callback) { 442 return scoped_ptr<MDnsTransaction>( 443 new MDnsTransactionImpl(rrtype, name, flags, callback, this)); 444 } 445 446 MDnsListenerImpl::MDnsListenerImpl( 447 uint16 rrtype, 448 const std::string& name, 449 MDnsListener::Delegate* delegate, 450 MDnsClientImpl* client) 451 : rrtype_(rrtype), name_(name), client_(client), delegate_(delegate), 452 started_(false) { 453 } 454 455 bool MDnsListenerImpl::Start() { 456 DCHECK(!started_); 457 458 started_ = true; 459 460 DCHECK(client_->core()); 461 client_->core()->AddListener(this); 462 463 return true; 464 } 465 466 MDnsListenerImpl::~MDnsListenerImpl() { 467 if (started_) { 468 DCHECK(client_->core()); 469 client_->core()->RemoveListener(this); 470 } 471 } 472 473 const std::string& MDnsListenerImpl::GetName() const { 474 return name_; 475 } 476 477 uint16 MDnsListenerImpl::GetType() const { 478 return rrtype_; 479 } 480 481 void MDnsListenerImpl::AlertDelegate(MDnsListener::UpdateType update_type, 482 const RecordParsed* record) { 483 DCHECK(started_); 484 delegate_->OnRecordUpdate(update_type, record); 485 } 486 487 void MDnsListenerImpl::AlertNsecRecord() { 488 DCHECK(started_); 489 delegate_->OnNsecRecord(name_, rrtype_); 490 } 491 492 MDnsTransactionImpl::MDnsTransactionImpl( 493 uint16 rrtype, 494 const std::string& name, 495 int flags, 496 const MDnsTransaction::ResultCallback& callback, 497 MDnsClientImpl* client) 498 : rrtype_(rrtype), name_(name), callback_(callback), client_(client), 499 started_(false), flags_(flags) { 500 DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_); 501 DCHECK(flags_ & MDnsTransaction::QUERY_CACHE || 502 flags_ & MDnsTransaction::QUERY_NETWORK); 503 } 504 505 MDnsTransactionImpl::~MDnsTransactionImpl() { 506 timeout_.Cancel(); 507 } 508 509 bool MDnsTransactionImpl::Start() { 510 DCHECK(!started_); 511 started_ = true; 512 513 base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr(); 514 if (flags_ & MDnsTransaction::QUERY_CACHE) { 515 ServeRecordsFromCache(); 516 517 if (!weak_this || !is_active()) return true; 518 } 519 520 if (flags_ & MDnsTransaction::QUERY_NETWORK) { 521 return QueryAndListen(); 522 } 523 524 // If this is a cache only query, signal that the transaction is over 525 // immediately. 526 SignalTransactionOver(); 527 return true; 528 } 529 530 const std::string& MDnsTransactionImpl::GetName() const { 531 return name_; 532 } 533 534 uint16 MDnsTransactionImpl::GetType() const { 535 return rrtype_; 536 } 537 538 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) { 539 DCHECK(started_); 540 OnRecordUpdate(MDnsListener::RECORD_ADDED, record); 541 } 542 543 void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result, 544 const RecordParsed* record) { 545 DCHECK(started_); 546 if (!is_active()) return; 547 548 // Ensure callback is run after touching all class state, so that 549 // the callback can delete the transaction. 550 MDnsTransaction::ResultCallback callback = callback_; 551 552 // Reset the transaction if it expects a single result, or if the result 553 // is a final one (everything except for a record). 554 if (flags_ & MDnsTransaction::SINGLE_RESULT || 555 result != MDnsTransaction::RESULT_RECORD) { 556 Reset(); 557 } 558 559 callback.Run(result, record); 560 } 561 562 void MDnsTransactionImpl::Reset() { 563 callback_.Reset(); 564 listener_.reset(); 565 timeout_.Cancel(); 566 } 567 568 void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update, 569 const RecordParsed* record) { 570 DCHECK(started_); 571 if (update == MDnsListener::RECORD_ADDED || 572 update == MDnsListener::RECORD_CHANGED) 573 TriggerCallback(MDnsTransaction::RESULT_RECORD, record); 574 } 575 576 void MDnsTransactionImpl::SignalTransactionOver() { 577 DCHECK(started_); 578 if (flags_ & MDnsTransaction::SINGLE_RESULT) { 579 TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, NULL); 580 } else { 581 TriggerCallback(MDnsTransaction::RESULT_DONE, NULL); 582 } 583 } 584 585 void MDnsTransactionImpl::ServeRecordsFromCache() { 586 std::vector<const RecordParsed*> records; 587 base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr(); 588 589 if (client_->core()) { 590 client_->core()->QueryCache(rrtype_, name_, &records); 591 for (std::vector<const RecordParsed*>::iterator i = records.begin(); 592 i != records.end() && weak_this; ++i) { 593 weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i); 594 } 595 596 #if defined(ENABLE_NSEC) 597 if (records.empty()) { 598 DCHECK(weak_this); 599 client_->core()->QueryCache(dns_protocol::kTypeNSEC, name_, &records); 600 if (!records.empty()) { 601 const NsecRecordRdata* rdata = 602 records.front()->rdata<NsecRecordRdata>(); 603 DCHECK(rdata); 604 if (!rdata->GetBit(rrtype_)) 605 weak_this->TriggerCallback(MDnsTransaction::RESULT_NSEC, NULL); 606 } 607 } 608 #endif 609 } 610 } 611 612 bool MDnsTransactionImpl::QueryAndListen() { 613 listener_ = client_->CreateListener(rrtype_, name_, this); 614 if (!listener_->Start()) 615 return false; 616 617 DCHECK(client_->core()); 618 if (!client_->core()->SendQuery(rrtype_, name_)) 619 return false; 620 621 timeout_.Reset(base::Bind(&MDnsTransactionImpl::SignalTransactionOver, 622 AsWeakPtr())); 623 base::MessageLoop::current()->PostDelayedTask( 624 FROM_HERE, 625 timeout_.callback(), 626 base::TimeDelta::FromSeconds(MDnsTransactionTimeoutSeconds)); 627 628 return true; 629 } 630 631 void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) { 632 TriggerCallback(RESULT_NSEC, NULL); 633 } 634 635 void MDnsTransactionImpl::OnCachePurged() { 636 // TODO(noamsml): Cache purge situations not yet implemented 637 } 638 639 } // namespace net 640