1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "base/memory/weak_ptr.h" 6 #include "base/run_loop.h" 7 #include "chrome/utility/local_discovery/service_discovery_client_impl.h" 8 #include "net/base/net_errors.h" 9 #include "net/dns/dns_protocol.h" 10 #include "net/dns/mdns_client_impl.h" 11 #include "net/dns/mock_mdns_socket_factory.h" 12 #include "testing/gmock/include/gmock/gmock.h" 13 #include "testing/gtest/include/gtest/gtest.h" 14 15 using ::testing::_; 16 using ::testing::Invoke; 17 using ::testing::StrictMock; 18 using ::testing::NiceMock; 19 using ::testing::Mock; 20 using ::testing::SaveArg; 21 using ::testing::SetArgPointee; 22 using ::testing::Return; 23 using ::testing::Exactly; 24 25 namespace local_discovery { 26 27 namespace { 28 29 const uint8 kSamplePacketPTR[] = { 30 // Header 31 0x00, 0x00, // ID is zeroed out 32 0x81, 0x80, // Standard query response, RA, no error 33 0x00, 0x00, // No questions (for simplicity) 34 0x00, 0x01, // 1 RR (answers) 35 0x00, 0x00, // 0 authority RRs 36 0x00, 0x00, // 0 additional RRs 37 38 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 39 0x04, '_', 't', 'c', 'p', 40 0x05, 'l', 'o', 'c', 'a', 'l', 41 0x00, 42 0x00, 0x0c, // TYPE is PTR. 43 0x00, 0x01, // CLASS is IN. 44 0x00, 0x00, // TTL (4 bytes) is 1 second. 45 0x00, 0x01, 46 0x00, 0x08, // RDLENGTH is 8 bytes. 47 0x05, 'h', 'e', 'l', 'l', 'o', 48 0xc0, 0x0c 49 }; 50 51 const uint8 kSamplePacketSRV[] = { 52 // Header 53 0x00, 0x00, // ID is zeroed out 54 0x81, 0x80, // Standard query response, RA, no error 55 0x00, 0x00, // No questions (for simplicity) 56 0x00, 0x01, // 1 RR (answers) 57 0x00, 0x00, // 0 authority RRs 58 0x00, 0x00, // 0 additional RRs 59 60 0x05, 'h', 'e', 'l', 'l', 'o', 61 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 62 0x04, '_', 't', 'c', 'p', 63 0x05, 'l', 'o', 'c', 'a', 'l', 64 0x00, 65 0x00, 0x21, // TYPE is SRV. 66 0x00, 0x01, // CLASS is IN. 67 0x00, 0x00, // TTL (4 bytes) is 1 second. 68 0x00, 0x01, 69 0x00, 0x15, // RDLENGTH is 21 bytes. 70 0x00, 0x00, 71 0x00, 0x00, 72 0x22, 0xb8, // port 8888 73 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 74 0x05, 'l', 'o', 'c', 'a', 'l', 75 0x00, 76 }; 77 78 const uint8 kSamplePacketTXT[] = { 79 // Header 80 0x00, 0x00, // ID is zeroed out 81 0x81, 0x80, // Standard query response, RA, no error 82 0x00, 0x00, // No questions (for simplicity) 83 0x00, 0x01, // 1 RR (answers) 84 0x00, 0x00, // 0 authority RRs 85 0x00, 0x00, // 0 additional RRs 86 87 0x05, 'h', 'e', 'l', 'l', 'o', 88 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 89 0x04, '_', 't', 'c', 'p', 90 0x05, 'l', 'o', 'c', 'a', 'l', 91 0x00, 92 0x00, 0x10, // TYPE is PTR. 93 0x00, 0x01, // CLASS is IN. 94 0x00, 0x00, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. 95 0x00, 0x01, 96 0x00, 0x06, // RDLENGTH is 21 bytes. 97 0x05, 'h', 'e', 'l', 'l', 'o' 98 }; 99 100 const uint8 kSamplePacketSRVA[] = { 101 // Header 102 0x00, 0x00, // ID is zeroed out 103 0x81, 0x80, // Standard query response, RA, no error 104 0x00, 0x00, // No questions (for simplicity) 105 0x00, 0x02, // 2 RR (answers) 106 0x00, 0x00, // 0 authority RRs 107 0x00, 0x00, // 0 additional RRs 108 109 0x05, 'h', 'e', 'l', 'l', 'o', 110 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 111 0x04, '_', 't', 'c', 'p', 112 0x05, 'l', 'o', 'c', 'a', 'l', 113 0x00, 114 0x00, 0x21, // TYPE is SRV. 115 0x00, 0x01, // CLASS is IN. 116 0x00, 0x00, // TTL (4 bytes) is 16 seconds. 117 0x00, 0x10, 118 0x00, 0x15, // RDLENGTH is 21 bytes. 119 0x00, 0x00, 120 0x00, 0x00, 121 0x22, 0xb8, // port 8888 122 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 123 0x05, 'l', 'o', 'c', 'a', 'l', 124 0x00, 125 126 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 127 0x05, 'l', 'o', 'c', 'a', 'l', 128 0x00, 129 0x00, 0x01, // TYPE is A. 130 0x00, 0x01, // CLASS is IN. 131 0x00, 0x00, // TTL (4 bytes) is 16 seconds. 132 0x00, 0x10, 133 0x00, 0x04, // RDLENGTH is 4 bytes. 134 0x01, 0x02, 135 0x03, 0x04, 136 }; 137 138 const uint8 kSamplePacketPTR2[] = { 139 // Header 140 0x00, 0x00, // ID is zeroed out 141 0x81, 0x80, // Standard query response, RA, no error 142 0x00, 0x00, // No questions (for simplicity) 143 0x00, 0x02, // 2 RR (answers) 144 0x00, 0x00, // 0 authority RRs 145 0x00, 0x00, // 0 additional RRs 146 147 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 148 0x04, '_', 't', 'c', 'p', 149 0x05, 'l', 'o', 'c', 'a', 'l', 150 0x00, 151 0x00, 0x0c, // TYPE is PTR. 152 0x00, 0x01, // CLASS is IN. 153 0x02, 0x00, // TTL (4 bytes) is 1 second. 154 0x00, 0x01, 155 0x00, 0x08, // RDLENGTH is 8 bytes. 156 0x05, 'g', 'd', 'b', 'y', 'e', 157 0xc0, 0x0c, 158 159 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 160 0x04, '_', 't', 'c', 'p', 161 0x05, 'l', 'o', 'c', 'a', 'l', 162 0x00, 163 0x00, 0x0c, // TYPE is PTR. 164 0x00, 0x01, // CLASS is IN. 165 0x02, 0x00, // TTL (4 bytes) is 1 second. 166 0x00, 0x01, 167 0x00, 0x08, // RDLENGTH is 8 bytes. 168 0x05, 'h', 'e', 'l', 'l', 'o', 169 0xc0, 0x0c 170 }; 171 172 class MockServiceWatcherClient { 173 public: 174 MOCK_METHOD2(OnServiceUpdated, 175 void(ServiceWatcher::UpdateType, const std::string&)); 176 177 ServiceWatcher::UpdatedCallback GetCallback() { 178 return base::Bind(&MockServiceWatcherClient::OnServiceUpdated, 179 base::Unretained(this)); 180 } 181 }; 182 183 class ServiceDiscoveryTest : public ::testing::Test { 184 public: 185 ServiceDiscoveryTest() 186 : socket_factory_(new net::MockMDnsSocketFactory), 187 mdns_client_( 188 scoped_ptr<net::MDnsConnection::SocketFactory>( 189 socket_factory_)), 190 service_discovery_client_(&mdns_client_) { 191 mdns_client_.StartListening(); 192 } 193 194 virtual ~ServiceDiscoveryTest() { 195 } 196 197 protected: 198 void RunFor(base::TimeDelta time_period) { 199 base::CancelableCallback<void()> callback(base::Bind( 200 &ServiceDiscoveryTest::Stop, base::Unretained(this))); 201 base::MessageLoop::current()->PostDelayedTask( 202 FROM_HERE, callback.callback(), time_period); 203 204 base::MessageLoop::current()->Run(); 205 callback.Cancel(); 206 } 207 208 void Stop() { 209 base::MessageLoop::current()->Quit(); 210 } 211 212 net::MockMDnsSocketFactory* socket_factory_; 213 net::MDnsClientImpl mdns_client_; 214 ServiceDiscoveryClientImpl service_discovery_client_; 215 base::MessageLoop loop_; 216 }; 217 218 TEST_F(ServiceDiscoveryTest, AddRemoveService) { 219 StrictMock<MockServiceWatcherClient> delegate; 220 221 scoped_ptr<ServiceWatcher> watcher( 222 service_discovery_client_.CreateServiceWatcher( 223 "_privet._tcp.local", delegate.GetCallback())); 224 225 watcher->Start(); 226 227 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 228 "hello._privet._tcp.local")) 229 .Times(Exactly(1)); 230 231 socket_factory_->SimulateReceive( 232 kSamplePacketPTR, sizeof(kSamplePacketPTR)); 233 234 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED, 235 "hello._privet._tcp.local")) 236 .Times(Exactly(1)); 237 238 RunFor(base::TimeDelta::FromSeconds(2)); 239 }; 240 241 TEST_F(ServiceDiscoveryTest, DiscoverNewServices) { 242 StrictMock<MockServiceWatcherClient> delegate; 243 244 scoped_ptr<ServiceWatcher> watcher( 245 service_discovery_client_.CreateServiceWatcher( 246 "_privet._tcp.local", delegate.GetCallback())); 247 248 watcher->Start(); 249 250 EXPECT_CALL(*socket_factory_, OnSendTo(_)) 251 .Times(2); 252 253 watcher->DiscoverNewServices(false); 254 }; 255 256 TEST_F(ServiceDiscoveryTest, ReadCachedServices) { 257 socket_factory_->SimulateReceive( 258 kSamplePacketPTR, sizeof(kSamplePacketPTR)); 259 260 StrictMock<MockServiceWatcherClient> delegate; 261 262 scoped_ptr<ServiceWatcher> watcher( 263 service_discovery_client_.CreateServiceWatcher( 264 "_privet._tcp.local", delegate.GetCallback())); 265 266 watcher->Start(); 267 268 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 269 "hello._privet._tcp.local")) 270 .Times(Exactly(1)); 271 272 base::MessageLoop::current()->RunUntilIdle(); 273 }; 274 275 276 TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) { 277 socket_factory_->SimulateReceive( 278 kSamplePacketPTR2, sizeof(kSamplePacketPTR2)); 279 280 StrictMock<MockServiceWatcherClient> delegate; 281 scoped_ptr<ServiceWatcher> watcher = 282 service_discovery_client_.CreateServiceWatcher( 283 "_privet._tcp.local", delegate.GetCallback()); 284 285 watcher->Start(); 286 287 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 288 "hello._privet._tcp.local")) 289 .Times(Exactly(1)); 290 291 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 292 "gdbye._privet._tcp.local")) 293 .Times(Exactly(1)); 294 295 base::MessageLoop::current()->RunUntilIdle(); 296 }; 297 298 299 TEST_F(ServiceDiscoveryTest, OnServiceChanged) { 300 StrictMock<MockServiceWatcherClient> delegate; 301 scoped_ptr<ServiceWatcher> watcher( 302 service_discovery_client_.CreateServiceWatcher( 303 "_privet._tcp.local", delegate.GetCallback())); 304 305 watcher->Start(); 306 307 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 308 "hello._privet._tcp.local")) 309 .Times(Exactly(1)); 310 311 socket_factory_->SimulateReceive( 312 kSamplePacketPTR, sizeof(kSamplePacketPTR)); 313 314 base::MessageLoop::current()->RunUntilIdle(); 315 316 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED, 317 "hello._privet._tcp.local")) 318 .Times(Exactly(1)); 319 320 socket_factory_->SimulateReceive( 321 kSamplePacketSRV, sizeof(kSamplePacketSRV)); 322 323 socket_factory_->SimulateReceive( 324 kSamplePacketTXT, sizeof(kSamplePacketTXT)); 325 326 base::MessageLoop::current()->RunUntilIdle(); 327 }; 328 329 TEST_F(ServiceDiscoveryTest, SinglePacket) { 330 StrictMock<MockServiceWatcherClient> delegate; 331 scoped_ptr<ServiceWatcher> watcher( 332 service_discovery_client_.CreateServiceWatcher( 333 "_privet._tcp.local", delegate.GetCallback())); 334 335 watcher->Start(); 336 337 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 338 "hello._privet._tcp.local")) 339 .Times(Exactly(1)); 340 341 socket_factory_->SimulateReceive( 342 kSamplePacketPTR, sizeof(kSamplePacketPTR)); 343 344 // Reset the "already updated" flag. 345 base::MessageLoop::current()->RunUntilIdle(); 346 347 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED, 348 "hello._privet._tcp.local")) 349 .Times(Exactly(1)); 350 351 socket_factory_->SimulateReceive( 352 kSamplePacketSRV, sizeof(kSamplePacketSRV)); 353 354 socket_factory_->SimulateReceive( 355 kSamplePacketTXT, sizeof(kSamplePacketTXT)); 356 357 base::MessageLoop::current()->RunUntilIdle(); 358 }; 359 360 class ServiceResolverTest : public ServiceDiscoveryTest { 361 public: 362 ServiceResolverTest() { 363 metadata_expected_.push_back("hello"); 364 address_expected_ = net::HostPortPair("myhello.local", 8888); 365 ip_address_expected_.push_back(1); 366 ip_address_expected_.push_back(2); 367 ip_address_expected_.push_back(3); 368 ip_address_expected_.push_back(4); 369 } 370 371 ~ServiceResolverTest() { 372 } 373 374 void SetUp() { 375 resolver_ = service_discovery_client_.CreateServiceResolver( 376 "hello._privet._tcp.local", 377 base::Bind(&ServiceResolverTest::OnFinishedResolving, 378 base::Unretained(this))); 379 } 380 381 void OnFinishedResolving(ServiceResolver::RequestStatus request_status, 382 const ServiceDescription& service_description) { 383 OnFinishedResolvingInternal(request_status, 384 service_description.address.ToString(), 385 service_description.metadata, 386 service_description.ip_address); 387 } 388 389 MOCK_METHOD4(OnFinishedResolvingInternal, 390 void(ServiceResolver::RequestStatus, 391 const std::string&, 392 const std::vector<std::string>&, 393 const net::IPAddressNumber&)); 394 395 protected: 396 scoped_ptr<ServiceResolver> resolver_; 397 net::IPAddressNumber ip_address_; 398 net::HostPortPair address_expected_; 399 std::vector<std::string> metadata_expected_; 400 net::IPAddressNumber ip_address_expected_; 401 }; 402 403 TEST_F(ServiceResolverTest, TxtAndSrvButNoA) { 404 EXPECT_CALL(*socket_factory_, OnSendTo(_)) 405 .Times(4); 406 407 resolver_->StartResolving(); 408 409 socket_factory_->SimulateReceive( 410 kSamplePacketSRV, sizeof(kSamplePacketSRV)); 411 412 base::MessageLoop::current()->RunUntilIdle(); 413 414 EXPECT_CALL(*this, 415 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, 416 address_expected_.ToString(), 417 metadata_expected_, 418 net::IPAddressNumber())); 419 420 socket_factory_->SimulateReceive( 421 kSamplePacketTXT, sizeof(kSamplePacketTXT)); 422 }; 423 424 TEST_F(ServiceResolverTest, TxtSrvAndA) { 425 EXPECT_CALL(*socket_factory_, OnSendTo(_)) 426 .Times(4); 427 428 resolver_->StartResolving(); 429 430 EXPECT_CALL(*this, 431 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, 432 address_expected_.ToString(), 433 metadata_expected_, 434 ip_address_expected_)); 435 436 socket_factory_->SimulateReceive( 437 kSamplePacketTXT, sizeof(kSamplePacketTXT)); 438 439 socket_factory_->SimulateReceive( 440 kSamplePacketSRVA, sizeof(kSamplePacketSRVA)); 441 }; 442 443 TEST_F(ServiceResolverTest, JustSrv) { 444 EXPECT_CALL(*socket_factory_, OnSendTo(_)) 445 .Times(4); 446 447 resolver_->StartResolving(); 448 449 EXPECT_CALL(*this, 450 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, 451 address_expected_.ToString(), 452 std::vector<std::string>(), 453 ip_address_expected_)); 454 455 socket_factory_->SimulateReceive( 456 kSamplePacketSRVA, sizeof(kSamplePacketSRVA)); 457 458 // TODO(noamsml): When NSEC record support is added, change this to use an 459 // NSEC record. 460 RunFor(base::TimeDelta::FromSeconds(4)); 461 }; 462 463 TEST_F(ServiceResolverTest, WithNothing) { 464 EXPECT_CALL(*socket_factory_, OnSendTo(_)) 465 .Times(4); 466 467 resolver_->StartResolving(); 468 469 EXPECT_CALL(*this, OnFinishedResolvingInternal( 470 ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _)); 471 472 // TODO(noamsml): When NSEC record support is added, change this to use an 473 // NSEC record. 474 RunFor(base::TimeDelta::FromSeconds(4)); 475 }; 476 477 } // namespace 478 479 } // namespace local_discovery 480