1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <utility> 6 7 #include "base/logging.h" 8 #include "base/memory/singleton.h" 9 #include "base/message_loop/message_loop_proxy.h" 10 #include "base/stl_util.h" 11 #include "chrome/utility/local_discovery/service_discovery_client_impl.h" 12 #include "net/dns/dns_protocol.h" 13 #include "net/dns/record_rdata.h" 14 15 namespace local_discovery { 16 17 ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl( 18 net::MDnsClient* mdns_client) : mdns_client_(mdns_client) { 19 } 20 21 ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() { 22 } 23 24 scoped_ptr<ServiceWatcher> ServiceDiscoveryClientImpl::CreateServiceWatcher( 25 const std::string& service_type, 26 const ServiceWatcher::UpdatedCallback& callback) { 27 return scoped_ptr<ServiceWatcher>(new ServiceWatcherImpl( 28 service_type, callback, mdns_client_)); 29 } 30 31 scoped_ptr<ServiceResolver> ServiceDiscoveryClientImpl::CreateServiceResolver( 32 const std::string& service_name, 33 const ServiceResolver::ResolveCompleteCallback& callback) { 34 return scoped_ptr<ServiceResolver>(new ServiceResolverImpl( 35 service_name, callback, mdns_client_)); 36 } 37 38 scoped_ptr<LocalDomainResolver> 39 ServiceDiscoveryClientImpl::CreateLocalDomainResolver( 40 const std::string& domain, 41 net::AddressFamily address_family, 42 const LocalDomainResolver::IPAddressCallback& callback) { 43 return scoped_ptr<LocalDomainResolver>(new LocalDomainResolverImpl( 44 domain, address_family, callback, mdns_client_)); 45 } 46 47 ServiceWatcherImpl::ServiceWatcherImpl( 48 const std::string& service_type, 49 const ServiceWatcher::UpdatedCallback& callback, 50 net::MDnsClient* mdns_client) 51 : service_type_(service_type), callback_(callback), started_(false), 52 mdns_client_(mdns_client) { 53 } 54 55 void ServiceWatcherImpl::Start() { 56 DCHECK(!started_); 57 listener_ = mdns_client_->CreateListener( 58 net::dns_protocol::kTypePTR, service_type_, this); 59 started_ = listener_->Start(); 60 if (started_) 61 ReadCachedServices(); 62 } 63 64 ServiceWatcherImpl::~ServiceWatcherImpl() { 65 } 66 67 void ServiceWatcherImpl::DiscoverNewServices(bool force_update) { 68 DCHECK(started_); 69 if (force_update) 70 services_.clear(); 71 CreateTransaction(true /*network*/, false /*cache*/, force_update, 72 &transaction_network_); 73 } 74 75 void ServiceWatcherImpl::ReadCachedServices() { 76 DCHECK(started_); 77 CreateTransaction(false /*network*/, true /*cache*/, false /*force refresh*/, 78 &transaction_cache_); 79 } 80 81 bool ServiceWatcherImpl::CreateTransaction( 82 bool network, bool cache, bool force_refresh, 83 scoped_ptr<net::MDnsTransaction>* transaction) { 84 int transaction_flags = 0; 85 if (network) 86 transaction_flags |= net::MDnsTransaction::QUERY_NETWORK; 87 88 if (cache) 89 transaction_flags |= net::MDnsTransaction::QUERY_CACHE; 90 91 // TODO(noamsml): Add flag for force_refresh when supported. 92 93 if (transaction_flags) { 94 *transaction = mdns_client_->CreateTransaction( 95 net::dns_protocol::kTypePTR, service_type_, transaction_flags, 96 base::Bind(&ServiceWatcherImpl::OnTransactionResponse, 97 base::Unretained(this), transaction)); 98 return (*transaction)->Start(); 99 } 100 101 return true; 102 } 103 104 std::string ServiceWatcherImpl::GetServiceType() const { 105 return listener_->GetName(); 106 } 107 108 void ServiceWatcherImpl::OnRecordUpdate( 109 net::MDnsListener::UpdateType update, 110 const net::RecordParsed* record) { 111 DCHECK(started_); 112 if (record->type() == net::dns_protocol::kTypePTR) { 113 DCHECK(record->name() == GetServiceType()); 114 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>(); 115 116 switch (update) { 117 case net::MDnsListener::RECORD_ADDED: 118 AddService(rdata->ptrdomain()); 119 break; 120 case net::MDnsListener::RECORD_CHANGED: 121 NOTREACHED(); 122 break; 123 case net::MDnsListener::RECORD_REMOVED: 124 RemoveService(rdata->ptrdomain()); 125 break; 126 } 127 } else { 128 DCHECK(record->type() == net::dns_protocol::kTypeSRV || 129 record->type() == net::dns_protocol::kTypeTXT); 130 DCHECK(services_.find(record->name()) != services_.end()); 131 132 DeferUpdate(UPDATE_CHANGED, record->name()); 133 } 134 } 135 136 void ServiceWatcherImpl::OnCachePurged() { 137 // Not yet implemented. 138 } 139 140 void ServiceWatcherImpl::OnTransactionResponse( 141 scoped_ptr<net::MDnsTransaction>* transaction, 142 net::MDnsTransaction::Result result, 143 const net::RecordParsed* record) { 144 DCHECK(started_); 145 if (result == net::MDnsTransaction::RESULT_RECORD) { 146 const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>(); 147 DCHECK(rdata); 148 AddService(rdata->ptrdomain()); 149 } else if (result == net::MDnsTransaction::RESULT_DONE) { 150 transaction->reset(); 151 } 152 153 // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC 154 // record for PTR records on any name. 155 } 156 157 ServiceWatcherImpl::ServiceListeners::ServiceListeners( 158 const std::string& service_name, 159 ServiceWatcherImpl* watcher, 160 net::MDnsClient* mdns_client) : update_pending_(false) { 161 srv_listener_ = mdns_client->CreateListener( 162 net::dns_protocol::kTypeSRV, service_name, watcher); 163 txt_listener_ = mdns_client->CreateListener( 164 net::dns_protocol::kTypeTXT, service_name, watcher); 165 } 166 167 ServiceWatcherImpl::ServiceListeners::~ServiceListeners() { 168 } 169 170 bool ServiceWatcherImpl::ServiceListeners::Start() { 171 if (!srv_listener_->Start()) 172 return false; 173 return txt_listener_->Start(); 174 } 175 176 void ServiceWatcherImpl::AddService(const std::string& service) { 177 DCHECK(started_); 178 std::pair<ServiceListenersMap::iterator, bool> found = services_.insert( 179 make_pair(service, linked_ptr<ServiceListeners>(NULL))); 180 if (found.second) { // Newly inserted. 181 found.first->second = linked_ptr<ServiceListeners>( 182 new ServiceListeners(service, this, mdns_client_)); 183 bool success = found.first->second->Start(); 184 185 DeferUpdate(UPDATE_ADDED, service); 186 187 DCHECK(success); 188 } 189 } 190 191 void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type, 192 const std::string& service_name) { 193 ServiceListenersMap::iterator found = services_.find(service_name); 194 195 if (found != services_.end() && !found->second->update_pending()) { 196 found->second->set_update_pending(true); 197 base::MessageLoop::current()->PostTask( 198 FROM_HERE, 199 base::Bind(&ServiceWatcherImpl::DeliverDeferredUpdate, AsWeakPtr(), 200 update_type, service_name)); 201 } 202 } 203 204 void ServiceWatcherImpl::DeliverDeferredUpdate( 205 ServiceWatcher::UpdateType update_type, const std::string& service_name) { 206 ServiceListenersMap::iterator found = services_.find(service_name); 207 208 if (found != services_.end()) { 209 found->second->set_update_pending(false); 210 if (!callback_.is_null()) 211 callback_.Run(update_type, service_name); 212 } 213 } 214 215 void ServiceWatcherImpl::RemoveService(const std::string& service) { 216 DCHECK(started_); 217 ServiceListenersMap::iterator found = services_.find(service); 218 if (found != services_.end()) { 219 services_.erase(found); 220 if (!callback_.is_null()) 221 callback_.Run(UPDATE_REMOVED, service); 222 } 223 } 224 225 void ServiceWatcherImpl::OnNsecRecord(const std::string& name, 226 unsigned rrtype) { 227 // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR 228 // on any name. 229 } 230 231 ServiceResolverImpl::ServiceResolverImpl( 232 const std::string& service_name, 233 const ResolveCompleteCallback& callback, 234 net::MDnsClient* mdns_client) 235 : service_name_(service_name), callback_(callback), 236 metadata_resolved_(false), address_resolved_(false), 237 mdns_client_(mdns_client) { 238 service_staging_.service_name = service_name_; 239 } 240 241 void ServiceResolverImpl::StartResolving() { 242 address_resolved_ = false; 243 metadata_resolved_ = false; 244 245 if (!CreateTxtTransaction() || !CreateSrvTransaction()) { 246 ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT); 247 } 248 } 249 250 ServiceResolverImpl::~ServiceResolverImpl() { 251 } 252 253 bool ServiceResolverImpl::CreateTxtTransaction() { 254 txt_transaction_ = mdns_client_->CreateTransaction( 255 net::dns_protocol::kTypeTXT, service_name_, 256 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE | 257 net::MDnsTransaction::QUERY_NETWORK, 258 base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse, 259 AsWeakPtr())); 260 return txt_transaction_->Start(); 261 } 262 263 // TODO(noamsml): quick-resolve for AAAA records. Since A records tend to be in 264 void ServiceResolverImpl::CreateATransaction() { 265 a_transaction_ = mdns_client_->CreateTransaction( 266 net::dns_protocol::kTypeA, 267 service_staging_.address.host(), 268 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE, 269 base::Bind(&ServiceResolverImpl::ARecordTransactionResponse, 270 AsWeakPtr())); 271 a_transaction_->Start(); 272 } 273 274 bool ServiceResolverImpl::CreateSrvTransaction() { 275 srv_transaction_ = mdns_client_->CreateTransaction( 276 net::dns_protocol::kTypeSRV, service_name_, 277 net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE | 278 net::MDnsTransaction::QUERY_NETWORK, 279 base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse, 280 AsWeakPtr())); 281 return srv_transaction_->Start(); 282 } 283 284 std::string ServiceResolverImpl::GetName() const { 285 return service_name_; 286 } 287 288 void ServiceResolverImpl::SrvRecordTransactionResponse( 289 net::MDnsTransaction::Result status, const net::RecordParsed* record) { 290 srv_transaction_.reset(); 291 if (status == net::MDnsTransaction::RESULT_RECORD) { 292 DCHECK(record); 293 service_staging_.address = RecordToAddress(record); 294 service_staging_.last_seen = record->time_created(); 295 CreateATransaction(); 296 } else { 297 ServiceNotFound(MDnsStatusToRequestStatus(status)); 298 } 299 } 300 301 void ServiceResolverImpl::TxtRecordTransactionResponse( 302 net::MDnsTransaction::Result status, const net::RecordParsed* record) { 303 txt_transaction_.reset(); 304 if (status == net::MDnsTransaction::RESULT_RECORD) { 305 DCHECK(record); 306 service_staging_.metadata = RecordToMetadata(record); 307 } else { 308 service_staging_.metadata = std::vector<std::string>(); 309 } 310 311 metadata_resolved_ = true; 312 AlertCallbackIfReady(); 313 } 314 315 void ServiceResolverImpl::ARecordTransactionResponse( 316 net::MDnsTransaction::Result status, const net::RecordParsed* record) { 317 a_transaction_.reset(); 318 319 if (status == net::MDnsTransaction::RESULT_RECORD) { 320 DCHECK(record); 321 service_staging_.ip_address = RecordToIPAddress(record); 322 } else { 323 service_staging_.ip_address = net::IPAddressNumber(); 324 } 325 326 address_resolved_ = true; 327 AlertCallbackIfReady(); 328 } 329 330 void ServiceResolverImpl::AlertCallbackIfReady() { 331 if (metadata_resolved_ && address_resolved_) { 332 txt_transaction_.reset(); 333 srv_transaction_.reset(); 334 a_transaction_.reset(); 335 if (!callback_.is_null()) 336 callback_.Run(STATUS_SUCCESS, service_staging_); 337 service_staging_ = ServiceDescription(); 338 } 339 } 340 341 void ServiceResolverImpl::ServiceNotFound( 342 ServiceResolver::RequestStatus status) { 343 txt_transaction_.reset(); 344 srv_transaction_.reset(); 345 a_transaction_.reset(); 346 if (!callback_.is_null()) 347 callback_.Run(status, ServiceDescription()); 348 } 349 350 ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus( 351 net::MDnsTransaction::Result status) const { 352 switch (status) { 353 case net::MDnsTransaction::RESULT_RECORD: 354 return ServiceResolver::STATUS_SUCCESS; 355 case net::MDnsTransaction::RESULT_NO_RESULTS: 356 return ServiceResolver::STATUS_REQUEST_TIMEOUT; 357 case net::MDnsTransaction::RESULT_NSEC: 358 return ServiceResolver::STATUS_KNOWN_NONEXISTENT; 359 case net::MDnsTransaction::RESULT_DONE: // Pass through. 360 default: 361 NOTREACHED(); 362 return ServiceResolver::STATUS_REQUEST_TIMEOUT; 363 } 364 } 365 366 const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata( 367 const net::RecordParsed* record) const { 368 DCHECK(record->type() == net::dns_protocol::kTypeTXT); 369 const net::TxtRecordRdata* txt_rdata = record->rdata<net::TxtRecordRdata>(); 370 DCHECK(txt_rdata); 371 return txt_rdata->texts(); 372 } 373 374 net::HostPortPair ServiceResolverImpl::RecordToAddress( 375 const net::RecordParsed* record) const { 376 DCHECK(record->type() == net::dns_protocol::kTypeSRV); 377 const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>(); 378 DCHECK(srv_rdata); 379 return net::HostPortPair(srv_rdata->target(), srv_rdata->port()); 380 } 381 382 const net::IPAddressNumber& ServiceResolverImpl::RecordToIPAddress( 383 const net::RecordParsed* record) const { 384 DCHECK(record->type() == net::dns_protocol::kTypeA); 385 const net::ARecordRdata* a_rdata = record->rdata<net::ARecordRdata>(); 386 DCHECK(a_rdata); 387 return a_rdata->address(); 388 } 389 390 LocalDomainResolverImpl::LocalDomainResolverImpl( 391 const std::string& domain, 392 net::AddressFamily address_family, 393 const IPAddressCallback& callback, 394 net::MDnsClient* mdns_client) 395 : domain_(domain), address_family_(address_family), callback_(callback), 396 transaction_failures_(0), mdns_client_(mdns_client) { 397 } 398 399 LocalDomainResolverImpl::~LocalDomainResolverImpl() { 400 } 401 402 void LocalDomainResolverImpl::Start() { 403 if (address_family_ == net::ADDRESS_FAMILY_IPV4 || 404 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { 405 transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA); 406 transaction_a_->Start(); 407 } 408 409 if (address_family_ == net::ADDRESS_FAMILY_IPV6 || 410 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { 411 transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA); 412 transaction_aaaa_->Start(); 413 } 414 } 415 416 scoped_ptr<net::MDnsTransaction> LocalDomainResolverImpl::CreateTransaction( 417 uint16 type) { 418 return mdns_client_->CreateTransaction( 419 type, domain_, net::MDnsTransaction::SINGLE_RESULT | 420 net::MDnsTransaction::QUERY_CACHE | 421 net::MDnsTransaction::QUERY_NETWORK, 422 base::Bind(&LocalDomainResolverImpl::OnTransactionComplete, 423 base::Unretained(this))); 424 } 425 426 void LocalDomainResolverImpl::OnTransactionComplete( 427 net::MDnsTransaction::Result result, const net::RecordParsed* record) { 428 if (result != net::MDnsTransaction::RESULT_RECORD && 429 address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) { 430 transaction_failures_++; 431 432 if (transaction_failures_ < 2) { 433 return; 434 } 435 } 436 437 transaction_a_.reset(); 438 transaction_aaaa_.reset(); 439 440 net::IPAddressNumber address; 441 if (result == net::MDnsTransaction::RESULT_RECORD) { 442 if (record->type() == net::dns_protocol::kTypeA) { 443 const net::ARecordRdata* rdata = record->rdata<net::ARecordRdata>(); 444 address = rdata->address(); 445 } else { 446 DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type()); 447 const net::AAAARecordRdata* rdata = record->rdata<net::AAAARecordRdata>(); 448 address = rdata->address(); 449 } 450 } 451 452 callback_.Run(result == net::MDnsTransaction::RESULT_RECORD, address); 453 } 454 455 } // namespace local_discovery 456