1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <utility> 6 7 #include "base/logging.h" 8 #include "base/memory/singleton.h" 9 #include "base/message_loop/message_loop_proxy.h" 10 #include "base/stl_util.h" 11 #include "chrome/common/local_discovery/service_discovery_client_impl.h" 12 #include "net/dns/dns_protocol.h" 13 #include "net/dns/record_rdata.h" 14 15 namespace local_discovery { 16 17 namespace { 18 // TODO(noamsml): Make this configurable through the LocalDomainResolver 19 // interface. 20 const int kLocalDomainSecondAddressTimeoutMs = 100; 21 22 const int kInitialRequeryTimeSeconds = 1; 23 const int kMaxRequeryTimeSeconds = 2; // Time for last requery 24 } 25 26 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl( 27 net::MDnsClient* mdns_client) : mdns_client_(mdns_client) { 28 } 29 30 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() { 31 } 32 33 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher( 34 const std::string& service_type, 35 const ServiceWatcher::UpdatedCallback& callback) { 36 return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl( 37 service_type, callback, mdns_client_)); 38 } 39 40 scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver( 41 const std::string& service_name, 42 const ServiceResolver::ResolveCompleteCallback& callback) { 43 return scoped_ptr<ServiceResolver>(new ServiceResolverImpl( 44 service_name, callback, mdns_client_)); 45 } 46 47 scoped_ptr<LocalDomainResolver> 48 ServiceDiscoveryClientImpl::CreateLocalDomainResolver( 49 const std::string& domain, 50 net::AddressFamily address_family, 51 const LocalDomainResolver::IPAddressCallback& callback) { 52 return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl( 53 domain, address_family, callback, mdns_client_)); 54 } 55 56 ServiceWatcherImpl::ServiceWatcherImpl( 57 const std::string& service_type, 58 const ServiceWatcher::UpdatedCallback& callback, 59 net::MDnsClient* mdns_client) 60 : service_type_(service_type), callback_(callback), started_(false), 61 actively_refresh_services_(false), mdns_client_(mdns_client) { 62 } 63 64 void ServiceWatcherImpl::Start() { 65 DCHECK(!started_); 66 listener_ = mdns_client_->CreateListener( 67 net::dns_protocol::kTypePTR, service_type_, this); 68 started_ = listener_->Start(); 69 if (started_) 70 ReadCachedServices(); 71 } 72 73 ServiceWatcherImpl::~ServiceWatcherImpl() { 74 } 75 76 void ServiceWatcherImpl::DiscoverNewServices(bool force_update) { 77 DCHECK(started_); 78 if (force_update) 79 services_.clear(); 80 SendQuery(kInitialRequeryTimeSeconds, force_update); 81 } 82 83 void ServiceWatcherImpl::SetActivelyRefreshServices( 84 bool actively_refresh_services) { 85 DCHECK(started_); 86 actively_refresh_services_ = actively_refresh_services; 87 88 for (ServiceListenersMap::iterator i = services_.begin(); 89 i != services_.end(); i++) { 90 i->second->SetActiveRefresh(actively_refresh_services); 91 } 92 } 93 94 void ServiceWatcherImpl::ReadCachedServices() { 95 DCHECK(started_); 96 CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/, 97 &transaction_cache_); 98 } 99 100 bool ServiceWatcherImpl::CreateTransaction( 101 bool network, bool cache, bool force_refresh, 102 scoped_ptr<net::MDnsTransaction>* transaction) { 103 int transaction_flags = 0; 104 if (network) 105 transaction_flags |= net::MDnsTransaction::QUERY_NETWORK; 106 107 if (cache) 108 transaction_flags |= net::MDnsTransaction::QUERY_CACHE; 109 110 // TODO(noamsml): Add flag for force_refresh when supported. 111 112 if (transaction_flags) { 113 *transaction = mdns_client_->CreateTransaction( 114 net::dns_protocol::kTypePTR, service_type_, transaction_flags, 115 base::Bind(&ServiceWatcherImpl::OnTransactionResponse, 116 base::Unretained(this), transaction)); 117 return (*transaction)->Start(); 118 } 119 120 return true; 121 } 122 123 std::string ServiceWatcherImpl::GetServiceType() const { 124 return listener_->GetName(); 125 } 126 127 void ServiceWatcherImpl::OnRecordUpdate( 128 net::MDnsListener::UpdateType update, 129 const net::RecordParsed* record) { 130 DCHECK(started_); 131 if (record->type() == net::dns_protocol::kTypePTR) { 132 DCHECK(record->name() == GetServiceType()); 133 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>(); 134 135 switch (update) { 136 case net::MDnsListener::RECORD_ADDED: 137 AddService(rdata->ptrdomain()); 138 break; 139 case net::MDnsListener::RECORD_CHANGED: 140 NOTREACHED(); 141 break; 142 case net::MDnsListener::RECORD_REMOVED: 143 RemovePTR(rdata->ptrdomain()); 144 break; 145 } 146 } else { 147 DCHECK(record->type() == net::dns_protocol::kTypeSRV || 148 record->type() == net::dns_protocol::kTypeTXT); 149 DCHECK(services_.find(record->name()) != services_.end()); 150 151 if (record->type() == net::dns_protocol::kTypeSRV) { 152 if (update == net::MDnsListener::RECORD_REMOVED) { 153 RemoveSRV(record->name()); 154 } else if (update == net::MDnsListener::RECORD_ADDED) { 155 AddSRV(record->name()); 156 } 157 } 158 159 // If this is the first time we see an SRV record, do not send 160 // an UPDATE_CHANGED. 161 if (record->type() != net::dns_protocol::kTypeSRV || 162 update != net::MDnsListener::RECORD_ADDED) { 163 DeferUpdate(UPDATE_CHANGED, record->name()); 164 } 165 } 166 } 167 168 void ServiceWatcherImpl::OnCachePurged() { 169 // Not yet implemented. 170 } 171 172 void ServiceWatcherImpl::OnTransactionResponse( 173 scoped_ptr<net::MDnsTransaction>* transaction, 174 net::MDnsTransaction::Result result, 175 const net::RecordParsed* record) { 176 DCHECK(started_); 177 if (result == net::MDnsTransaction::RESULT_RECORD) { 178 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>(); 179 DCHECK(rdata); 180 AddService(rdata->ptrdomain()); 181 } else if (result == net::MDnsTransaction::RESULT_DONE) { 182 transaction->reset(); 183 } 184 185 // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC 186 // record for PTR records on any name. 187 } 188 189 ServiceWatcherImpl::ServiceListeners::ServiceListeners( 190 const std::string& service_name, 191 ServiceWatcherImpl* watcher, 192 net::MDnsClient* mdns_client) 193 : service_name_(service_name), mdns_client_(mdns_client), 194 update_pending_(false), has_ptr_(true), has_srv_(false) { 195 srv_listener_ = mdns_client->CreateListener( 196 net::dns_protocol::kTypeSRV, service_name, watcher); 197 txt_listener_ = mdns_client->CreateListener( 198 net::dns_protocol::kTypeTXT, service_name, watcher); 199 } 200 201 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() { 202 } 203 204 bool ServiceWatcherImpl::ServiceListeners::Start() { 205 if (!srv_listener_->Start()) 206 return false; 207 return txt_listener_->Start(); 208 } 209 210 void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh( 211 bool active_refresh) { 212 srv_listener_->SetActiveRefresh(active_refresh); 213 214 if (active_refresh && !has_srv_) { 215 DCHECK(has_ptr_); 216 srv_transaction_ = mdns_client_->CreateTransaction( 217 net::dns_protocol::kTypeSRV, service_name_, 218 net::MDnsTransaction::SINGLE_RESULT | 219 net::MDnsTransaction::QUERY_CACHE | net::MDnsTransaction::QUERY_NETWORK, 220 base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord, 221 base::Unretained(this))); 222 srv_transaction_->Start(); 223 } else if (!active_refresh) { 224 srv_transaction_.reset(); 225 } 226 } 227 228 void ServiceWatcherImpl::ServiceListeners::OnSRVRecord( 229 net::MDnsTransaction::Result result, 230 const net::RecordParsed* record) { 231 set_has_srv(record != NULL); 232 } 233 234 void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv) { 235 has_srv_ = has_srv; 236 237 srv_transaction_.reset(); 238 } 239 240 void ServiceWatcherImpl::AddService(const std::string& service) { 241 DCHECK(started_); 242 std::pair<ServiceListenersMap::iterator, bool> found = services_.insert( 243 make_pair(service, linked_ptr<ServiceListeners>(NULL))); 244 245 if (found.second) { // Newly inserted. 246 found.first->second = linked_ptr<ServiceListeners>( 247 new ServiceListeners(service, this, mdns_client_)); 248 bool success = found.first->second->Start(); 249 found.first->second->SetActiveRefresh(actively_refresh_services_); 250 DeferUpdate(UPDATE_ADDED, service); 251 252 DCHECK(success); 253 } 254 255 found.first->second->set_has_ptr(true); 256 } 257 258 void ServiceWatcherImpl::AddSRV(const std::string& service) { 259 DCHECK(started_); 260 261 ServiceListenersMap::iterator found = services_.find(service); 262 if (found != services_.end()) { 263 found->second->set_has_srv(true); 264 } 265 } 266 267 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type, 268 const std::string& service_name) { 269 ServiceListenersMap::iterator found = services_.find(service_name); 270 271 if (found != services_.end() && !found->second->update_pending()) { 272 found->second->set_update_pending(true); 273 base::MessageLoop::current()->PostTask( 274 FROM_HERE, 275 base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(), 276 update_type, service_name)); 277 } 278 } 279 280 void ServiceWatcherImpl::DeliverDeferredUpdate( 281 ServiceWatcher::UpdateType update_type, const std::string& service_name) { 282 ServiceListenersMap::iterator found = services_.find(service_name); 283 284 if (found != services_.end()) { 285 found->second->set_update_pending(false); 286 if (!callback_.is_null()) 287 callback_.Run(update_type, service_name); 288 } 289 } 290 291 void ServiceWatcherImpl::RemovePTR(const std::string& service) { 292 DCHECK(started_); 293 294 ServiceListenersMap::iterator found = services_.find(service); 295 if (found != services_.end()) { 296 found->second->set_has_ptr(false); 297 298 if (!found->second->has_ptr_or_srv()) { 299 services_.erase(found); 300 if (!callback_.is_null()) 301 callback_.Run(UPDATE_REMOVED, service); 302 } 303 } 304 } 305 306 void ServiceWatcherImpl::RemoveSRV(const std::string& service) { 307 DCHECK(started_); 308 309 ServiceListenersMap::iterator found = services_.find(service); 310 if (found != services_.end()) { 311 found->second->set_has_srv(false); 312 313 if (!found->second->has_ptr_or_srv()) { 314 services_.erase(found); 315 if (!callback_.is_null()) 316 callback_.Run(UPDATE_REMOVED, service); 317 } 318 } 319 } 320 321 void ServiceWatcherImpl::OnNsecRecord(const std::string& name, 322 unsigned rrtype) { 323 // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR 324 // on any name. 325 } 326 327 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) { 328 if (timeout_seconds <= kMaxRequeryTimeSeconds) { 329 base::MessageLoop::current()->PostDelayedTask( 330 FROM_HERE, 331 base::Bind(&ServiceWatcherImpl::SendQuery, 332 AsWeakPtr(), 333 timeout_seconds * 2 /*next_timeout_seconds*/, 334 false /*force_update*/), 335 base::TimeDelta::FromSeconds(timeout_seconds)); 336 } 337 } 338 339 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds, 340 bool force_update) { 341 CreateTransaction(true /*network*/, false /*cache*/, force_update, 342 &transaction_network_); 343 ScheduleQuery(next_timeout_seconds); 344 } 345 346 ServiceResolverImpl::ServiceResolverImpl( 347 const std::string& service_name, 348 const ResolveCompleteCallback& callback, 349 net::MDnsClient* mdns_client) 350 : service_name_(service_name), callback_(callback), 351 metadata_resolved_(false), address_resolved_(false), 352 mdns_client_(mdns_client) { 353 } 354 355 void ServiceResolverImpl::StartResolving() { 356 address_resolved_ = false; 357 metadata_resolved_ = false; 358 service_staging_ = ServiceDescription(); 359 service_staging_.service_name = service_name_; 360 361 if (!CreateTxtTransaction() || !CreateSrvTransaction()) { 362 ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT); 363 } 364 } 365 366 ServiceResolverImpl::~ServiceResolverImpl() { 367 } 368 369 bool ServiceResolverImpl::CreateTxtTransaction() { 370 txt_transaction_ = mdns_client_->CreateTransaction( 371 net::dns_protocol::kTypeTXT, service_name_, 372 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE | 373 net::MDnsTransaction::QUERY_NETWORK, 374 base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse, 375 AsWeakPtr())); 376 return txt_transaction_->Start(); 377 } 378 379 // TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in 380 void ServiceResolverImpl::CreateATransaction() { 381 a_transaction_ = mdns_client_->CreateTransaction( 382 net::dns_protocol::kTypeA, 383 service_staging_.address.host(), 384 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE, 385 base::Bind(&ServiceResolverImpl::ARecordTransactionResponse, 386 AsWeakPtr())); 387 a_transaction_->Start(); 388 } 389 390 bool ServiceResolverImpl::CreateSrvTransaction() { 391 srv_transaction_ = mdns_client_->CreateTransaction( 392 net::dns_protocol::kTypeSRV, service_name_, 393 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE | 394 net::MDnsTransaction::QUERY_NETWORK, 395 base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse, 396 AsWeakPtr())); 397 return srv_transaction_->Start(); 398 } 399 400 std::string ServiceResolverImpl::GetName() const { 401 return service_name_; 402 } 403 404 void ServiceResolverImpl::SrvRecordTransactionResponse( 405 net::MDnsTransaction::Result status, const net::RecordParsed* record) { 406 srv_transaction_.reset(); 407 if (status == net::MDnsTransaction::RESULT_RECORD) { 408 DCHECK(record); 409 service_staging_.address = RecordToAddress(record); 410 service_staging_.last_seen = record->time_created(); 411 CreateATransaction(); 412 } else { 413 ServiceNotFound(MDnsStatusToRequestStatus(status)); 414 } 415 } 416 417 void ServiceResolverImpl::TxtRecordTransactionResponse( 418 net::MDnsTransaction::Result status, const net::RecordParsed* record) { 419 txt_transaction_.reset(); 420 if (status == net::MDnsTransaction::RESULT_RECORD) { 421 DCHECK(record); 422 service_staging_.metadata = RecordToMetadata(record); 423 } else { 424 service_staging_.metadata = std::vector<std::string>(); 425 } 426 427 metadata_resolved_ = true; 428 AlertCallbackIfReady(); 429 } 430 431 void ServiceResolverImpl::ARecordTransactionResponse( 432 net::MDnsTransaction::Result status, const net::RecordParsed* record) { 433 a_transaction_.reset(); 434 435 if (status == net::MDnsTransaction::RESULT_RECORD) { 436 DCHECK(record); 437 service_staging_.ip_address = RecordToIPAddress(record); 438 } else { 439 service_staging_.ip_address = net::IPAddressNumber(); 440 } 441 442 address_resolved_ = true; 443 AlertCallbackIfReady(); 444 } 445 446 void ServiceResolverImpl::AlertCallbackIfReady() { 447 if (metadata_resolved_ && address_resolved_) { 448 txt_transaction_.reset(); 449 srv_transaction_.reset(); 450 a_transaction_.reset(); 451 if (!callback_.is_null()) 452 callback_.Run(STATUS_SUCCESS, service_staging_); 453 } 454 } 455 456 void ServiceResolverImpl::ServiceNotFound( 457 ServiceResolver::RequestStatus status) { 458 txt_transaction_.reset(); 459 srv_transaction_.reset(); 460 a_transaction_.reset(); 461 if (!callback_.is_null()) 462 callback_.Run(status, ServiceDescription()); 463 } 464 465 ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus( 466 net::MDnsTransaction::Result status) const { 467 switch (status) { 468 case net::MDnsTransaction::RESULT_RECORD: 469 return ServiceResolver::STATUS_SUCCESS; 470 case net::MDnsTransaction::RESULT_NO_RESULTS: 471 return ServiceResolver::STATUS_REQUEST_TIMEOUT; 472 case net::MDnsTransaction::RESULT_NSEC: 473 return ServiceResolver::STATUS_KNOWN_NONEXISTENT; 474 case net::MDnsTransaction::RESULT_DONE: // Pass through. 475 default: 476 NOTREACHED(); 477 return ServiceResolver::STATUS_REQUEST_TIMEOUT; 478 } 479 } 480 481 const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata( 482 const net::RecordParsed* record) const { 483 DCHECK(record->type() == net::dns_protocol::kTypeTXT); 484 const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>(); 485 DCHECK(txt_rdata); 486 return txt_rdata->texts(); 487 } 488 489 net::HostPortPair ServiceResolverImpl::RecordToAddress( 490 const net::RecordParsed* record) const { 491 DCHECK(record->type() == net::dns_protocol::kTypeSRV); 492 const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>(); 493 DCHECK(srv_rdata); 494 return net::HostPortPair(srv_rdata->target(), srv_rdata->port()); 495 } 496 497 const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress( 498 const net::RecordParsed* record) const { 499 DCHECK(record->type() == net::dns_protocol::kTypeA); 500 const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>(); 501 DCHECK(a_rdata); 502 return a_rdata->address(); 503 } 504 505 LocalDomainResolverImpl::LocalDomainResolverImpl( 506 const std::string& domain, 507 net::AddressFamily address_family, 508 const IPAddressCallback& callback, 509 net::MDnsClient* mdns_client) 510 : domain_(domain), address_family_(address_family), callback_(callback), 511 transactions_finished_(0), mdns_client_(mdns_client) { 512 } 513 514 LocalDomainResolverImpl::~LocalDomainResolverImpl() { 515 timeout_callback_.Cancel(); 516 } 517 518 void LocalDomainResolverImpl::Start() { 519 if (address_family_ == net::ADDRESS_FAMILY_IPV4 || 520 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { 521 transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA); 522 transaction_a_->Start(); 523 } 524 525 if (address_family_ == net::ADDRESS_FAMILY_IPV6 || 526 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { 527 transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA); 528 transaction_aaaa_->Start(); 529 } 530 } 531 532 scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction( 533 uint16 type) { 534 return mdns_client_->CreateTransaction( 535 type, domain_, net::MDnsTransaction::SINGLE_RESULT | 536 net::MDnsTransaction::QUERY_CACHE | 537 net::MDnsTransaction::QUERY_NETWORK, 538 base::Bind(&LocalDomainResolverImpl::OnTransactionComplete, 539 base::Unretained(this))); 540 } 541 542 void LocalDomainResolverImpl::OnTransactionComplete( 543 net::MDnsTransaction::Result result, const net::RecordParsed* record) { 544 transactions_finished_++; 545 546 if (result == net::MDnsTransaction::RESULT_RECORD) { 547 if (record->type() == net::dns_protocol::kTypeA) { 548 const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>(); 549 address_ipv4_ = rdata->address(); 550 } else { 551 DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type()); 552 const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>(); 553 address_ipv6_ = rdata->address(); 554 } 555 } 556 557 if (transactions_finished_ == 1 && 558 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { 559 timeout_callback_.Reset(base::Bind( 560 &LocalDomainResolverImpl::SendResolvedAddresses, 561 base::Unretained(this))); 562 563 base::MessageLoop::current()->PostDelayedTask( 564 FROM_HERE, 565 timeout_callback_.callback(), 566 base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs)); 567 } else if (transactions_finished_ == 2 568 || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) { 569 SendResolvedAddresses(); 570 } 571 } 572 573 bool LocalDomainResolverImpl::IsSuccess() { 574 return !address_ipv4_.empty() || !address_ipv6_.empty(); 575 } 576 577 void LocalDomainResolverImpl::SendResolvedAddresses() { 578 transaction_a_.reset(); 579 transaction_aaaa_.reset(); 580 timeout_callback_.Cancel(); 581 callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_); 582 } 583 584 } // namespace local_discovery 585