1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "base/memory/weak_ptr.h" 6 #include "base/run_loop.h" 7 #include "chrome/common/local_discovery/service_discovery_client_impl.h" 8 #include "net/base/net_errors.h" 9 #include "net/dns/dns_protocol.h" 10 #include "net/dns/mdns_client_impl.h" 11 #include "net/dns/mock_mdns_socket_factory.h" 12 #include "testing/gmock/include/gmock/gmock.h" 13 #include "testing/gtest/include/gtest/gtest.h" 14 15 using ::testing::_; 16 using ::testing::Invoke; 17 using ::testing::StrictMock; 18 using ::testing::NiceMock; 19 using ::testing::Mock; 20 using ::testing::SaveArg; 21 using ::testing::SetArgPointee; 22 using ::testing::Return; 23 using ::testing::Exactly; 24 25 namespace local_discovery { 26 27 namespace { 28 29 const uint8 kSamplePacketPTR[] = { 30 // Header 31 0x00, 0x00, // ID is zeroed out 32 0x81, 0x80, // Standard query response, RA, no error 33 0x00, 0x00, // No questions (for simplicity) 34 0x00, 0x01, // 1 RR (answers) 35 0x00, 0x00, // 0 authority RRs 36 0x00, 0x00, // 0 additional RRs 37 38 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 39 0x04, '_', 't', 'c', 'p', 40 0x05, 'l', 'o', 'c', 'a', 'l', 41 0x00, 42 0x00, 0x0c, // TYPE is PTR. 43 0x00, 0x01, // CLASS is IN. 44 0x00, 0x00, // TTL (4 bytes) is 1 second. 45 0x00, 0x01, 46 0x00, 0x08, // RDLENGTH is 8 bytes. 47 0x05, 'h', 'e', 'l', 'l', 'o', 48 0xc0, 0x0c 49 }; 50 51 const uint8 kSamplePacketSRV[] = { 52 // Header 53 0x00, 0x00, // ID is zeroed out 54 0x81, 0x80, // Standard query response, RA, no error 55 0x00, 0x00, // No questions (for simplicity) 56 0x00, 0x01, // 1 RR (answers) 57 0x00, 0x00, // 0 authority RRs 58 0x00, 0x00, // 0 additional RRs 59 60 0x05, 'h', 'e', 'l', 'l', 'o', 61 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 62 0x04, '_', 't', 'c', 'p', 63 0x05, 'l', 'o', 'c', 'a', 'l', 64 0x00, 65 0x00, 0x21, // TYPE is SRV. 66 0x00, 0x01, // CLASS is IN. 67 0x00, 0x00, // TTL (4 bytes) is 1 second. 68 0x00, 0x01, 69 0x00, 0x15, // RDLENGTH is 21 bytes. 70 0x00, 0x00, 71 0x00, 0x00, 72 0x22, 0xb8, // port 8888 73 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 74 0x05, 'l', 'o', 'c', 'a', 'l', 75 0x00, 76 }; 77 78 const uint8 kSamplePacketTXT[] = { 79 // Header 80 0x00, 0x00, // ID is zeroed out 81 0x81, 0x80, // Standard query response, RA, no error 82 0x00, 0x00, // No questions (for simplicity) 83 0x00, 0x01, // 1 RR (answers) 84 0x00, 0x00, // 0 authority RRs 85 0x00, 0x00, // 0 additional RRs 86 87 0x05, 'h', 'e', 'l', 'l', 'o', 88 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 89 0x04, '_', 't', 'c', 'p', 90 0x05, 'l', 'o', 'c', 'a', 'l', 91 0x00, 92 0x00, 0x10, // TYPE is PTR. 93 0x00, 0x01, // CLASS is IN. 94 0x00, 0x00, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. 95 0x00, 0x01, 96 0x00, 0x06, // RDLENGTH is 21 bytes. 97 0x05, 'h', 'e', 'l', 'l', 'o' 98 }; 99 100 const uint8 kSamplePacketSRVA[] = { 101 // Header 102 0x00, 0x00, // ID is zeroed out 103 0x81, 0x80, // Standard query response, RA, no error 104 0x00, 0x00, // No questions (for simplicity) 105 0x00, 0x02, // 2 RR (answers) 106 0x00, 0x00, // 0 authority RRs 107 0x00, 0x00, // 0 additional RRs 108 109 0x05, 'h', 'e', 'l', 'l', 'o', 110 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 111 0x04, '_', 't', 'c', 'p', 112 0x05, 'l', 'o', 'c', 'a', 'l', 113 0x00, 114 0x00, 0x21, // TYPE is SRV. 115 0x00, 0x01, // CLASS is IN. 116 0x00, 0x00, // TTL (4 bytes) is 16 seconds. 117 0x00, 0x10, 118 0x00, 0x15, // RDLENGTH is 21 bytes. 119 0x00, 0x00, 120 0x00, 0x00, 121 0x22, 0xb8, // port 8888 122 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 123 0x05, 'l', 'o', 'c', 'a', 'l', 124 0x00, 125 126 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 127 0x05, 'l', 'o', 'c', 'a', 'l', 128 0x00, 129 0x00, 0x01, // TYPE is A. 130 0x00, 0x01, // CLASS is IN. 131 0x00, 0x00, // TTL (4 bytes) is 16 seconds. 132 0x00, 0x10, 133 0x00, 0x04, // RDLENGTH is 4 bytes. 134 0x01, 0x02, 135 0x03, 0x04, 136 }; 137 138 const uint8 kSamplePacketPTR2[] = { 139 // Header 140 0x00, 0x00, // ID is zeroed out 141 0x81, 0x80, // Standard query response, RA, no error 142 0x00, 0x00, // No questions (for simplicity) 143 0x00, 0x02, // 2 RR (answers) 144 0x00, 0x00, // 0 authority RRs 145 0x00, 0x00, // 0 additional RRs 146 147 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 148 0x04, '_', 't', 'c', 'p', 149 0x05, 'l', 'o', 'c', 'a', 'l', 150 0x00, 151 0x00, 0x0c, // TYPE is PTR. 152 0x00, 0x01, // CLASS is IN. 153 0x02, 0x00, // TTL (4 bytes) is 1 second. 154 0x00, 0x01, 155 0x00, 0x08, // RDLENGTH is 8 bytes. 156 0x05, 'g', 'd', 'b', 'y', 'e', 157 0xc0, 0x0c, 158 159 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 160 0x04, '_', 't', 'c', 'p', 161 0x05, 'l', 'o', 'c', 'a', 'l', 162 0x00, 163 0x00, 0x0c, // TYPE is PTR. 164 0x00, 0x01, // CLASS is IN. 165 0x02, 0x00, // TTL (4 bytes) is 1 second. 166 0x00, 0x01, 167 0x00, 0x08, // RDLENGTH is 8 bytes. 168 0x05, 'h', 'e', 'l', 'l', 'o', 169 0xc0, 0x0c 170 }; 171 172 const uint8 kSamplePacketQuerySRV[] = { 173 // Header 174 0x00, 0x00, // ID is zeroed out 175 0x00, 0x00, // No flags. 176 0x00, 0x01, // One question. 177 0x00, 0x00, // 0 RRs (answers) 178 0x00, 0x00, // 0 authority RRs 179 0x00, 0x00, // 0 additional RRs 180 181 // Question 182 0x05, 'h', 'e', 'l', 'l', 'o', 183 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 184 0x04, '_', 't', 'c', 'p', 185 0x05, 'l', 'o', 'c', 'a', 'l', 186 0x00, 187 0x00, 0x21, // TYPE is SRV. 188 0x00, 0x01, // CLASS is IN. 189 }; 190 191 192 class MockServiceWatcherClient { 193 public: 194 MOCK_METHOD2(OnServiceUpdated, 195 void(ServiceWatcher::UpdateType, const std::string&)); 196 197 ServiceWatcher::UpdatedCallback GetCallback() { 198 return base::Bind(&MockServiceWatcherClient::OnServiceUpdated, 199 base::Unretained(this)); 200 } 201 }; 202 203 class ServiceDiscoveryTest : public ::testing::Test { 204 public: 205 ServiceDiscoveryTest() 206 : service_discovery_client_(&mdns_client_) { 207 mdns_client_.StartListening(&socket_factory_); 208 } 209 210 virtual ~ServiceDiscoveryTest() { 211 } 212 213 protected: 214 void RunFor(base::TimeDelta time_period) { 215 base::CancelableCallback<void()> callback(base::Bind( 216 &ServiceDiscoveryTest::Stop, base::Unretained(this))); 217 base::MessageLoop::current()->PostDelayedTask( 218 FROM_HERE, callback.callback(), time_period); 219 220 base::MessageLoop::current()->Run(); 221 callback.Cancel(); 222 } 223 224 void Stop() { 225 base::MessageLoop::current()->Quit(); 226 } 227 228 net::MockMDnsSocketFactory socket_factory_; 229 net::MDnsClientImpl mdns_client_; 230 ServiceDiscoveryClientImpl service_discovery_client_; 231 base::MessageLoop loop_; 232 }; 233 234 TEST_F(ServiceDiscoveryTest, AddRemoveService) { 235 StrictMock<MockServiceWatcherClient> delegate; 236 237 scoped_ptr<ServiceWatcher> watcher( 238 service_discovery_client_.CreateServiceWatcher( 239 "_privet._tcp.local", delegate.GetCallback())); 240 241 watcher->Start(); 242 243 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 244 "hello._privet._tcp.local")) 245 .Times(Exactly(1)); 246 247 socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); 248 249 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED, 250 "hello._privet._tcp.local")) 251 .Times(Exactly(1)); 252 253 RunFor(base::TimeDelta::FromSeconds(2)); 254 }; 255 256 TEST_F(ServiceDiscoveryTest, DiscoverNewServices) { 257 StrictMock<MockServiceWatcherClient> delegate; 258 259 scoped_ptr<ServiceWatcher> watcher( 260 service_discovery_client_.CreateServiceWatcher( 261 "_privet._tcp.local", delegate.GetCallback())); 262 263 watcher->Start(); 264 265 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); 266 267 watcher->DiscoverNewServices(false); 268 269 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); 270 271 RunFor(base::TimeDelta::FromSeconds(2)); 272 }; 273 274 TEST_F(ServiceDiscoveryTest, ReadCachedServices) { 275 socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); 276 277 StrictMock<MockServiceWatcherClient> delegate; 278 279 scoped_ptr<ServiceWatcher> watcher( 280 service_discovery_client_.CreateServiceWatcher( 281 "_privet._tcp.local", delegate.GetCallback())); 282 283 watcher->Start(); 284 285 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 286 "hello._privet._tcp.local")) 287 .Times(Exactly(1)); 288 289 base::MessageLoop::current()->RunUntilIdle(); 290 }; 291 292 293 TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) { 294 socket_factory_.SimulateReceive(kSamplePacketPTR2, sizeof(kSamplePacketPTR2)); 295 296 StrictMock<MockServiceWatcherClient> delegate; 297 scoped_ptr<ServiceWatcher> watcher = 298 service_discovery_client_.CreateServiceWatcher( 299 "_privet._tcp.local", delegate.GetCallback()); 300 301 watcher->Start(); 302 303 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 304 "hello._privet._tcp.local")) 305 .Times(Exactly(1)); 306 307 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 308 "gdbye._privet._tcp.local")) 309 .Times(Exactly(1)); 310 311 base::MessageLoop::current()->RunUntilIdle(); 312 }; 313 314 315 TEST_F(ServiceDiscoveryTest, OnServiceChanged) { 316 StrictMock<MockServiceWatcherClient> delegate; 317 scoped_ptr<ServiceWatcher> watcher( 318 service_discovery_client_.CreateServiceWatcher( 319 "_privet._tcp.local", delegate.GetCallback())); 320 321 watcher->Start(); 322 323 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 324 "hello._privet._tcp.local")) 325 .Times(Exactly(1)); 326 327 socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); 328 329 base::MessageLoop::current()->RunUntilIdle(); 330 331 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED, 332 "hello._privet._tcp.local")) 333 .Times(Exactly(1)); 334 335 socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); 336 337 socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); 338 339 base::MessageLoop::current()->RunUntilIdle(); 340 }; 341 342 TEST_F(ServiceDiscoveryTest, SinglePacket) { 343 StrictMock<MockServiceWatcherClient> delegate; 344 scoped_ptr<ServiceWatcher> watcher( 345 service_discovery_client_.CreateServiceWatcher( 346 "_privet._tcp.local", delegate.GetCallback())); 347 348 watcher->Start(); 349 350 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 351 "hello._privet._tcp.local")) 352 .Times(Exactly(1)); 353 354 socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); 355 356 // Reset the "already updated" flag. 357 base::MessageLoop::current()->RunUntilIdle(); 358 359 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED, 360 "hello._privet._tcp.local")) 361 .Times(Exactly(1)); 362 363 socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); 364 365 socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); 366 367 base::MessageLoop::current()->RunUntilIdle(); 368 }; 369 370 TEST_F(ServiceDiscoveryTest, ActivelyRefreshServices) { 371 StrictMock<MockServiceWatcherClient> delegate; 372 scoped_ptr<ServiceWatcher> watcher( 373 service_discovery_client_.CreateServiceWatcher( 374 "_privet._tcp.local", delegate.GetCallback())); 375 376 watcher->Start(); 377 watcher->SetActivelyRefreshServices(true); 378 379 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 380 "hello._privet._tcp.local")) 381 .Times(Exactly(1)); 382 383 std::string query_packet = std::string((const char*)(kSamplePacketQuerySRV), 384 sizeof(kSamplePacketQuerySRV)); 385 386 EXPECT_CALL(socket_factory_, OnSendTo(query_packet)) 387 .Times(2); 388 389 socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); 390 391 base::MessageLoop::current()->RunUntilIdle(); 392 393 socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); 394 395 EXPECT_CALL(socket_factory_, OnSendTo(query_packet)) 396 .Times(4); // IPv4 and IPv6 at 85% and 95% 397 398 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED, 399 "hello._privet._tcp.local")) 400 .Times(Exactly(1)); 401 402 RunFor(base::TimeDelta::FromSeconds(2)); 403 404 base::MessageLoop::current()->RunUntilIdle(); 405 }; 406 407 408 class ServiceResolverTest : public ServiceDiscoveryTest { 409 public: 410 ServiceResolverTest() { 411 metadata_expected_.push_back("hello"); 412 address_expected_ = net::HostPortPair("myhello.local", 8888); 413 ip_address_expected_.push_back(1); 414 ip_address_expected_.push_back(2); 415 ip_address_expected_.push_back(3); 416 ip_address_expected_.push_back(4); 417 } 418 419 ~ServiceResolverTest() { 420 } 421 422 void SetUp() { 423 resolver_ = service_discovery_client_.CreateServiceResolver( 424 "hello._privet._tcp.local", 425 base::Bind(&ServiceResolverTest::OnFinishedResolving, 426 base::Unretained(this))); 427 } 428 429 void OnFinishedResolving(ServiceResolver::RequestStatus request_status, 430 const ServiceDescription& service_description) { 431 OnFinishedResolvingInternal(request_status, 432 service_description.address.ToString(), 433 service_description.metadata, 434 service_description.ip_address); 435 } 436 437 MOCK_METHOD4(OnFinishedResolvingInternal, 438 void(ServiceResolver::RequestStatus, 439 const std::string&, 440 const std::vector<std::string>&, 441 const net::IPAddressNumber&)); 442 443 protected: 444 scoped_ptr<ServiceResolver> resolver_; 445 net::IPAddressNumber ip_address_; 446 net::HostPortPair address_expected_; 447 std::vector<std::string> metadata_expected_; 448 net::IPAddressNumber ip_address_expected_; 449 }; 450 451 TEST_F(ServiceResolverTest, TxtAndSrvButNoA) { 452 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); 453 454 resolver_->StartResolving(); 455 456 socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); 457 458 base::MessageLoop::current()->RunUntilIdle(); 459 460 EXPECT_CALL(*this, 461 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, 462 address_expected_.ToString(), 463 metadata_expected_, 464 net::IPAddressNumber())); 465 466 socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); 467 }; 468 469 TEST_F(ServiceResolverTest, TxtSrvAndA) { 470 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); 471 472 resolver_->StartResolving(); 473 474 EXPECT_CALL(*this, 475 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, 476 address_expected_.ToString(), 477 metadata_expected_, 478 ip_address_expected_)); 479 480 socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); 481 482 socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA)); 483 }; 484 485 TEST_F(ServiceResolverTest, JustSrv) { 486 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); 487 488 resolver_->StartResolving(); 489 490 EXPECT_CALL(*this, 491 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, 492 address_expected_.ToString(), 493 std::vector<std::string>(), 494 ip_address_expected_)); 495 496 socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA)); 497 498 // TODO(noamsml): When NSEC record support is added, change this to use an 499 // NSEC record. 500 RunFor(base::TimeDelta::FromSeconds(4)); 501 }; 502 503 TEST_F(ServiceResolverTest, WithNothing) { 504 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); 505 506 resolver_->StartResolving(); 507 508 EXPECT_CALL(*this, OnFinishedResolvingInternal( 509 ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _)); 510 511 // TODO(noamsml): When NSEC record support is added, change this to use an 512 // NSEC record. 513 RunFor(base::TimeDelta::FromSeconds(4)); 514 }; 515 516 } // namespace 517 518 } // namespace local_discovery 519