1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <utility> 6 7 #include "base/logging.h" 8 #include "base/memory/singleton.h" 9 #include "base/message_loop/message_loop_proxy.h" 10 #include "base/stl_util.h" 11 #include "chrome/utility/local_discovery/service_discovery_client_impl.h" 12 #include "net/dns/dns_protocol.h" 13 #include "net/dns/record_rdata.h" 14 15 namespace local_discovery { 16 17 namespace { 18 // TODO(noamsml): Make this configurable through the LocalDomainResolver 19 // interface. 20 const int kLocalDomainSecondAddressTimeoutMs = 100; 21 22 const int kInitialRequeryTimeSeconds = 1; 23 const int kMaxRequeryTimeSeconds = 2; // Time for last requery 24 } 25 26 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl( 27 net::MDnsClient* mdns_client) : mdns_client_(mdns_client) { 28 } 29 30 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() { 31 } 32 33 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher( 34 const std::string& service_type, 35 const ServiceWatcher::UpdatedCallback& callback) { 36 return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl( 37 service_type, callback, mdns_client_)); 38 } 39 40 scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver( 41 const std::string& service_name, 42 const ServiceResolver::ResolveCompleteCallback& callback) { 43 return scoped_ptr<ServiceResolver>(new ServiceResolverImpl( 44 service_name, callback, mdns_client_)); 45 } 46 47 scoped_ptr<LocalDomainResolver> 48 ServiceDiscoveryClientImpl::CreateLocalDomainResolver( 49 const std::string& domain, 50 net::AddressFamily address_family, 51 const LocalDomainResolver::IPAddressCallback& callback) { 52 return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl( 53 domain, address_family, callback, mdns_client_)); 54 } 55 56 ServiceWatcherImpl::ServiceWatcherImpl( 57 const std::string& service_type, 58 const ServiceWatcher::UpdatedCallback& callback, 59 net::MDnsClient* mdns_client) 60 : service_type_(service_type), callback_(callback), started_(false), 61 mdns_client_(mdns_client) { 62 } 63 64 void ServiceWatcherImpl::Start() { 65 DCHECK(!started_); 66 listener_ = mdns_client_->CreateListener( 67 net::dns_protocol::kTypePTR, service_type_, this); 68 started_ = listener_->Start(); 69 if (started_) 70 ReadCachedServices(); 71 } 72 73 ServiceWatcherImpl::~ServiceWatcherImpl() { 74 } 75 76 void ServiceWatcherImpl::DiscoverNewServices(bool force_update) { 77 DCHECK(started_); 78 if (force_update) 79 services_.clear(); 80 SendQuery(kInitialRequeryTimeSeconds, force_update); 81 } 82 83 void ServiceWatcherImpl::ReadCachedServices() { 84 DCHECK(started_); 85 CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/, 86 &transaction_cache_); 87 } 88 89 bool ServiceWatcherImpl::CreateTransaction( 90 bool network, bool cache, bool force_refresh, 91 scoped_ptr<net::MDnsTransaction>* transaction) { 92 int transaction_flags = 0; 93 if (network) 94 transaction_flags |= net::MDnsTransaction::QUERY_NETWORK; 95 96 if (cache) 97 transaction_flags |= net::MDnsTransaction::QUERY_CACHE; 98 99 // TODO(noamsml): Add flag for force_refresh when supported. 100 101 if (transaction_flags) { 102 *transaction = mdns_client_->CreateTransaction( 103 net::dns_protocol::kTypePTR, service_type_, transaction_flags, 104 base::Bind(&ServiceWatcherImpl::OnTransactionResponse, 105 base::Unretained(this), transaction)); 106 return (*transaction)->Start(); 107 } 108 109 return true; 110 } 111 112 std::string ServiceWatcherImpl::GetServiceType() const { 113 return listener_->GetName(); 114 } 115 116 void ServiceWatcherImpl::OnRecordUpdate( 117 net::MDnsListener::UpdateType update, 118 const net::RecordParsed* record) { 119 DCHECK(started_); 120 if (record->type() == net::dns_protocol::kTypePTR) { 121 DCHECK(record->name() == GetServiceType()); 122 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>(); 123 124 switch (update) { 125 case net::MDnsListener::RECORD_ADDED: 126 AddService(rdata->ptrdomain()); 127 break; 128 case net::MDnsListener::RECORD_CHANGED: 129 NOTREACHED(); 130 break; 131 case net::MDnsListener::RECORD_REMOVED: 132 RemoveService(rdata->ptrdomain()); 133 break; 134 } 135 } else { 136 DCHECK(record->type() == net::dns_protocol::kTypeSRV || 137 record->type() == net::dns_protocol::kTypeTXT); 138 DCHECK(services_.find(record->name()) != services_.end()); 139 140 DeferUpdate(UPDATE_CHANGED, record->name()); 141 } 142 } 143 144 void ServiceWatcherImpl::OnCachePurged() { 145 // Not yet implemented. 146 } 147 148 void ServiceWatcherImpl::OnTransactionResponse( 149 scoped_ptr<net::MDnsTransaction>* transaction, 150 net::MDnsTransaction::Result result, 151 const net::RecordParsed* record) { 152 DCHECK(started_); 153 if (result == net::MDnsTransaction::RESULT_RECORD) { 154 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>(); 155 DCHECK(rdata); 156 AddService(rdata->ptrdomain()); 157 } else if (result == net::MDnsTransaction::RESULT_DONE) { 158 transaction->reset(); 159 } 160 161 // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC 162 // record for PTR records on any name. 163 } 164 165 ServiceWatcherImpl::ServiceListeners::ServiceListeners( 166 const std::string& service_name, 167 ServiceWatcherImpl* watcher, 168 net::MDnsClient* mdns_client) : update_pending_(false) { 169 srv_listener_ = mdns_client->CreateListener( 170 net::dns_protocol::kTypeSRV, service_name, watcher); 171 txt_listener_ = mdns_client->CreateListener( 172 net::dns_protocol::kTypeTXT, service_name, watcher); 173 } 174 175 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() { 176 } 177 178 bool ServiceWatcherImpl::ServiceListeners::Start() { 179 if (!srv_listener_->Start()) 180 return false; 181 return txt_listener_->Start(); 182 } 183 184 void ServiceWatcherImpl::AddService(const std::string& service) { 185 DCHECK(started_); 186 std::pair<ServiceListenersMap::iterator, bool> found = services_.insert( 187 make_pair(service, linked_ptr<ServiceListeners>(NULL))); 188 if (found.second) { // Newly inserted. 189 found.first->second = linked_ptr<ServiceListeners>( 190 new ServiceListeners(service, this, mdns_client_)); 191 bool success = found.first->second->Start(); 192 193 DeferUpdate(UPDATE_ADDED, service); 194 195 DCHECK(success); 196 } 197 } 198 199 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type, 200 const std::string& service_name) { 201 ServiceListenersMap::iterator found = services_.find(service_name); 202 203 if (found != services_.end() && !found->second->update_pending()) { 204 found->second->set_update_pending(true); 205 base::MessageLoop::current()->PostTask( 206 FROM_HERE, 207 base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(), 208 update_type, service_name)); 209 } 210 } 211 212 void ServiceWatcherImpl::DeliverDeferredUpdate( 213 ServiceWatcher::UpdateType update_type, const std::string& service_name) { 214 ServiceListenersMap::iterator found = services_.find(service_name); 215 216 if (found != services_.end()) { 217 found->second->set_update_pending(false); 218 if (!callback_.is_null()) 219 callback_.Run(update_type, service_name); 220 } 221 } 222 223 void ServiceWatcherImpl::RemoveService(const std::string& service) { 224 DCHECK(started_); 225 ServiceListenersMap::iterator found = services_.find(service); 226 if (found != services_.end()) { 227 services_.erase(found); 228 if (!callback_.is_null()) 229 callback_.Run(UPDATE_REMOVED, service); 230 } 231 } 232 233 void ServiceWatcherImpl::OnNsecRecord(const std::string& name, 234 unsigned rrtype) { 235 // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR 236 // on any name. 237 } 238 239 void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) { 240 if (timeout_seconds <= kMaxRequeryTimeSeconds) { 241 base::MessageLoop::current()->PostDelayedTask( 242 FROM_HERE, 243 base::Bind(&ServiceWatcherImpl::SendQuery, 244 AsWeakPtr(), 245 timeout_seconds * 2 /*next_timeout_seconds*/, 246 false /*force_update*/), 247 base::TimeDelta::FromSeconds(timeout_seconds)); 248 } 249 } 250 251 void ServiceWatcherImpl::SendQuery(int next_timeout_seconds, 252 bool force_update) { 253 CreateTransaction(true /*network*/, false /*cache*/, force_update, 254 &transaction_network_); 255 ScheduleQuery(next_timeout_seconds); 256 } 257 258 ServiceResolverImpl::ServiceResolverImpl( 259 const std::string& service_name, 260 const ResolveCompleteCallback& callback, 261 net::MDnsClient* mdns_client) 262 : service_name_(service_name), callback_(callback), 263 metadata_resolved_(false), address_resolved_(false), 264 mdns_client_(mdns_client) { 265 } 266 267 void ServiceResolverImpl::StartResolving() { 268 address_resolved_ = false; 269 metadata_resolved_ = false; 270 service_staging_ = ServiceDescription(); 271 service_staging_.service_name = service_name_; 272 273 if (!CreateTxtTransaction() || !CreateSrvTransaction()) { 274 ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT); 275 } 276 } 277 278 ServiceResolverImpl::~ServiceResolverImpl() { 279 } 280 281 bool ServiceResolverImpl::CreateTxtTransaction() { 282 txt_transaction_ = mdns_client_->CreateTransaction( 283 net::dns_protocol::kTypeTXT, service_name_, 284 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE | 285 net::MDnsTransaction::QUERY_NETWORK, 286 base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse, 287 AsWeakPtr())); 288 return txt_transaction_->Start(); 289 } 290 291 // TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in 292 void ServiceResolverImpl::CreateATransaction() { 293 a_transaction_ = mdns_client_->CreateTransaction( 294 net::dns_protocol::kTypeA, 295 service_staging_.address.host(), 296 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE, 297 base::Bind(&ServiceResolverImpl::ARecordTransactionResponse, 298 AsWeakPtr())); 299 a_transaction_->Start(); 300 } 301 302 bool ServiceResolverImpl::CreateSrvTransaction() { 303 srv_transaction_ = mdns_client_->CreateTransaction( 304 net::dns_protocol::kTypeSRV, service_name_, 305 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE | 306 net::MDnsTransaction::QUERY_NETWORK, 307 base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse, 308 AsWeakPtr())); 309 return srv_transaction_->Start(); 310 } 311 312 std::string ServiceResolverImpl::GetName() const { 313 return service_name_; 314 } 315 316 void ServiceResolverImpl::SrvRecordTransactionResponse( 317 net::MDnsTransaction::Result status, const net::RecordParsed* record) { 318 srv_transaction_.reset(); 319 if (status == net::MDnsTransaction::RESULT_RECORD) { 320 DCHECK(record); 321 service_staging_.address = RecordToAddress(record); 322 service_staging_.last_seen = record->time_created(); 323 CreateATransaction(); 324 } else { 325 ServiceNotFound(MDnsStatusToRequestStatus(status)); 326 } 327 } 328 329 void ServiceResolverImpl::TxtRecordTransactionResponse( 330 net::MDnsTransaction::Result status, const net::RecordParsed* record) { 331 txt_transaction_.reset(); 332 if (status == net::MDnsTransaction::RESULT_RECORD) { 333 DCHECK(record); 334 service_staging_.metadata = RecordToMetadata(record); 335 } else { 336 service_staging_.metadata = std::vector<std::string>(); 337 } 338 339 metadata_resolved_ = true; 340 AlertCallbackIfReady(); 341 } 342 343 void ServiceResolverImpl::ARecordTransactionResponse( 344 net::MDnsTransaction::Result status, const net::RecordParsed* record) { 345 a_transaction_.reset(); 346 347 if (status == net::MDnsTransaction::RESULT_RECORD) { 348 DCHECK(record); 349 service_staging_.ip_address = RecordToIPAddress(record); 350 } else { 351 service_staging_.ip_address = net::IPAddressNumber(); 352 } 353 354 address_resolved_ = true; 355 AlertCallbackIfReady(); 356 } 357 358 void ServiceResolverImpl::AlertCallbackIfReady() { 359 if (metadata_resolved_ && address_resolved_) { 360 txt_transaction_.reset(); 361 srv_transaction_.reset(); 362 a_transaction_.reset(); 363 if (!callback_.is_null()) 364 callback_.Run(STATUS_SUCCESS, service_staging_); 365 } 366 } 367 368 void ServiceResolverImpl::ServiceNotFound( 369 ServiceResolver::RequestStatus status) { 370 txt_transaction_.reset(); 371 srv_transaction_.reset(); 372 a_transaction_.reset(); 373 if (!callback_.is_null()) 374 callback_.Run(status, ServiceDescription()); 375 } 376 377 ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus( 378 net::MDnsTransaction::Result status) const { 379 switch (status) { 380 case net::MDnsTransaction::RESULT_RECORD: 381 return ServiceResolver::STATUS_SUCCESS; 382 case net::MDnsTransaction::RESULT_NO_RESULTS: 383 return ServiceResolver::STATUS_REQUEST_TIMEOUT; 384 case net::MDnsTransaction::RESULT_NSEC: 385 return ServiceResolver::STATUS_KNOWN_NONEXISTENT; 386 case net::MDnsTransaction::RESULT_DONE: // Pass through. 387 default: 388 NOTREACHED(); 389 return ServiceResolver::STATUS_REQUEST_TIMEOUT; 390 } 391 } 392 393 const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata( 394 const net::RecordParsed* record) const { 395 DCHECK(record->type() == net::dns_protocol::kTypeTXT); 396 const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>(); 397 DCHECK(txt_rdata); 398 return txt_rdata->texts(); 399 } 400 401 net::HostPortPair ServiceResolverImpl::RecordToAddress( 402 const net::RecordParsed* record) const { 403 DCHECK(record->type() == net::dns_protocol::kTypeSRV); 404 const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>(); 405 DCHECK(srv_rdata); 406 return net::HostPortPair(srv_rdata->target(), srv_rdata->port()); 407 } 408 409 const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress( 410 const net::RecordParsed* record) const { 411 DCHECK(record->type() == net::dns_protocol::kTypeA); 412 const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>(); 413 DCHECK(a_rdata); 414 return a_rdata->address(); 415 } 416 417 LocalDomainResolverImpl::LocalDomainResolverImpl( 418 const std::string& domain, 419 net::AddressFamily address_family, 420 const IPAddressCallback& callback, 421 net::MDnsClient* mdns_client) 422 : domain_(domain), address_family_(address_family), callback_(callback), 423 transactions_finished_(0), mdns_client_(mdns_client) { 424 } 425 426 LocalDomainResolverImpl::~LocalDomainResolverImpl() { 427 timeout_callback_.Cancel(); 428 } 429 430 void LocalDomainResolverImpl::Start() { 431 if (address_family_ == net::ADDRESS_FAMILY_IPV4 || 432 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { 433 transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA); 434 transaction_a_->Start(); 435 } 436 437 if (address_family_ == net::ADDRESS_FAMILY_IPV6 || 438 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { 439 transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA); 440 transaction_aaaa_->Start(); 441 } 442 } 443 444 scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction( 445 uint16 type) { 446 return mdns_client_->CreateTransaction( 447 type, domain_, net::MDnsTransaction::SINGLE_RESULT | 448 net::MDnsTransaction::QUERY_CACHE | 449 net::MDnsTransaction::QUERY_NETWORK, 450 base::Bind(&LocalDomainResolverImpl::OnTransactionComplete, 451 base::Unretained(this))); 452 } 453 454 void LocalDomainResolverImpl::OnTransactionComplete( 455 net::MDnsTransaction::Result result, const net::RecordParsed* record) { 456 transactions_finished_++; 457 458 if (result == net::MDnsTransaction::RESULT_RECORD) { 459 if (record->type() == net::dns_protocol::kTypeA) { 460 const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>(); 461 address_ipv4_ = rdata->address(); 462 } else { 463 DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type()); 464 const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>(); 465 address_ipv6_ = rdata->address(); 466 } 467 } 468 469 if (transactions_finished_ == 1 && 470 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { 471 timeout_callback_.Reset(base::Bind( 472 &LocalDomainResolverImpl::SendResolvedAddresses, 473 base::Unretained(this))); 474 475 base::MessageLoop::current()->PostDelayedTask( 476 FROM_HERE, 477 timeout_callback_.callback(), 478 base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs)); 479 } else if (transactions_finished_ == 2 480 || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) { 481 SendResolvedAddresses(); 482 } 483 } 484 485 bool LocalDomainResolverImpl::IsSuccess() { 486 return !address_ipv4_.empty() || !address_ipv6_.empty(); 487 } 488 489 void LocalDomainResolverImpl::SendResolvedAddresses() { 490 transaction_a_.reset(); 491 transaction_aaaa_.reset(); 492 timeout_callback_.Cancel(); 493 callback_.Run(IsSuccess(), address_ipv4_, address_ipv6_); 494 } 495 496 } // namespace local_discovery 497