1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "base/memory/weak_ptr.h" 6 #include "base/run_loop.h" 7 #include "chrome/utility/local_discovery/service_discovery_client_impl.h" 8 #include "net/base/net_errors.h" 9 #include "net/dns/dns_protocol.h" 10 #include "net/dns/mdns_client_impl.h" 11 #include "net/dns/mock_mdns_socket_factory.h" 12 #include "testing/gmock/include/gmock/gmock.h" 13 #include "testing/gtest/include/gtest/gtest.h" 14 15 using ::testing::_; 16 using ::testing::Invoke; 17 using ::testing::StrictMock; 18 using ::testing::NiceMock; 19 using ::testing::Mock; 20 using ::testing::SaveArg; 21 using ::testing::SetArgPointee; 22 using ::testing::Return; 23 using ::testing::Exactly; 24 25 namespace local_discovery { 26 27 namespace { 28 29 const uint8 kSamplePacketPTR[] = { 30 // Header 31 0x00, 0x00, // ID is zeroed out 32 0x81, 0x80, // Standard query response, RA, no error 33 0x00, 0x00, // No questions (for simplicity) 34 0x00, 0x01, // 1 RR (answers) 35 0x00, 0x00, // 0 authority RRs 36 0x00, 0x00, // 0 additional RRs 37 38 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 39 0x04, '_', 't', 'c', 'p', 40 0x05, 'l', 'o', 'c', 'a', 'l', 41 0x00, 42 0x00, 0x0c, // TYPE is PTR. 43 0x00, 0x01, // CLASS is IN. 44 0x00, 0x00, // TTL (4 bytes) is 1 second. 45 0x00, 0x01, 46 0x00, 0x08, // RDLENGTH is 8 bytes. 47 0x05, 'h', 'e', 'l', 'l', 'o', 48 0xc0, 0x0c 49 }; 50 51 const uint8 kSamplePacketSRV[] = { 52 // Header 53 0x00, 0x00, // ID is zeroed out 54 0x81, 0x80, // Standard query response, RA, no error 55 0x00, 0x00, // No questions (for simplicity) 56 0x00, 0x01, // 1 RR (answers) 57 0x00, 0x00, // 0 authority RRs 58 0x00, 0x00, // 0 additional RRs 59 60 0x05, 'h', 'e', 'l', 'l', 'o', 61 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 62 0x04, '_', 't', 'c', 'p', 63 0x05, 'l', 'o', 'c', 'a', 'l', 64 0x00, 65 0x00, 0x21, // TYPE is SRV. 66 0x00, 0x01, // CLASS is IN. 67 0x00, 0x00, // TTL (4 bytes) is 1 second. 68 0x00, 0x01, 69 0x00, 0x15, // RDLENGTH is 21 bytes. 70 0x00, 0x00, 71 0x00, 0x00, 72 0x22, 0xb8, // port 8888 73 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 74 0x05, 'l', 'o', 'c', 'a', 'l', 75 0x00, 76 }; 77 78 const uint8 kSamplePacketTXT[] = { 79 // Header 80 0x00, 0x00, // ID is zeroed out 81 0x81, 0x80, // Standard query response, RA, no error 82 0x00, 0x00, // No questions (for simplicity) 83 0x00, 0x01, // 1 RR (answers) 84 0x00, 0x00, // 0 authority RRs 85 0x00, 0x00, // 0 additional RRs 86 87 0x05, 'h', 'e', 'l', 'l', 'o', 88 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 89 0x04, '_', 't', 'c', 'p', 90 0x05, 'l', 'o', 'c', 'a', 'l', 91 0x00, 92 0x00, 0x10, // TYPE is PTR. 93 0x00, 0x01, // CLASS is IN. 94 0x00, 0x00, // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds. 95 0x00, 0x01, 96 0x00, 0x06, // RDLENGTH is 21 bytes. 97 0x05, 'h', 'e', 'l', 'l', 'o' 98 }; 99 100 const uint8 kSamplePacketSRVA[] = { 101 // Header 102 0x00, 0x00, // ID is zeroed out 103 0x81, 0x80, // Standard query response, RA, no error 104 0x00, 0x00, // No questions (for simplicity) 105 0x00, 0x02, // 2 RR (answers) 106 0x00, 0x00, // 0 authority RRs 107 0x00, 0x00, // 0 additional RRs 108 109 0x05, 'h', 'e', 'l', 'l', 'o', 110 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 111 0x04, '_', 't', 'c', 'p', 112 0x05, 'l', 'o', 'c', 'a', 'l', 113 0x00, 114 0x00, 0x21, // TYPE is SRV. 115 0x00, 0x01, // CLASS is IN. 116 0x00, 0x00, // TTL (4 bytes) is 16 seconds. 117 0x00, 0x10, 118 0x00, 0x15, // RDLENGTH is 21 bytes. 119 0x00, 0x00, 120 0x00, 0x00, 121 0x22, 0xb8, // port 8888 122 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 123 0x05, 'l', 'o', 'c', 'a', 'l', 124 0x00, 125 126 0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o', 127 0x05, 'l', 'o', 'c', 'a', 'l', 128 0x00, 129 0x00, 0x01, // TYPE is A. 130 0x00, 0x01, // CLASS is IN. 131 0x00, 0x00, // TTL (4 bytes) is 16 seconds. 132 0x00, 0x10, 133 0x00, 0x04, // RDLENGTH is 4 bytes. 134 0x01, 0x02, 135 0x03, 0x04, 136 }; 137 138 const uint8 kSamplePacketPTR2[] = { 139 // Header 140 0x00, 0x00, // ID is zeroed out 141 0x81, 0x80, // Standard query response, RA, no error 142 0x00, 0x00, // No questions (for simplicity) 143 0x00, 0x02, // 2 RR (answers) 144 0x00, 0x00, // 0 authority RRs 145 0x00, 0x00, // 0 additional RRs 146 147 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 148 0x04, '_', 't', 'c', 'p', 149 0x05, 'l', 'o', 'c', 'a', 'l', 150 0x00, 151 0x00, 0x0c, // TYPE is PTR. 152 0x00, 0x01, // CLASS is IN. 153 0x02, 0x00, // TTL (4 bytes) is 1 second. 154 0x00, 0x01, 155 0x00, 0x08, // RDLENGTH is 8 bytes. 156 0x05, 'g', 'd', 'b', 'y', 'e', 157 0xc0, 0x0c, 158 159 0x07, '_', 'p', 'r', 'i', 'v', 'e', 't', 160 0x04, '_', 't', 'c', 'p', 161 0x05, 'l', 'o', 'c', 'a', 'l', 162 0x00, 163 0x00, 0x0c, // TYPE is PTR. 164 0x00, 0x01, // CLASS is IN. 165 0x02, 0x00, // TTL (4 bytes) is 1 second. 166 0x00, 0x01, 167 0x00, 0x08, // RDLENGTH is 8 bytes. 168 0x05, 'h', 'e', 'l', 'l', 'o', 169 0xc0, 0x0c 170 }; 171 172 class MockServiceWatcherClient { 173 public: 174 MOCK_METHOD2(OnServiceUpdated, 175 void(ServiceWatcher::UpdateType, const std::string&)); 176 177 ServiceWatcher::UpdatedCallback GetCallback() { 178 return base::Bind(&MockServiceWatcherClient::OnServiceUpdated, 179 base::Unretained(this)); 180 } 181 }; 182 183 class ServiceDiscoveryTest : public ::testing::Test { 184 public: 185 ServiceDiscoveryTest() 186 : service_discovery_client_(&mdns_client_) { 187 mdns_client_.StartListening(&socket_factory_); 188 } 189 190 virtual ~ServiceDiscoveryTest() { 191 } 192 193 protected: 194 void RunFor(base::TimeDelta time_period) { 195 base::CancelableCallback<void()> callback(base::Bind( 196 &ServiceDiscoveryTest::Stop, base::Unretained(this))); 197 base::MessageLoop::current()->PostDelayedTask( 198 FROM_HERE, callback.callback(), time_period); 199 200 base::MessageLoop::current()->Run(); 201 callback.Cancel(); 202 } 203 204 void Stop() { 205 base::MessageLoop::current()->Quit(); 206 } 207 208 net::MockMDnsSocketFactory socket_factory_; 209 net::MDnsClientImpl mdns_client_; 210 ServiceDiscoveryClientImpl service_discovery_client_; 211 base::MessageLoop loop_; 212 }; 213 214 TEST_F(ServiceDiscoveryTest, AddRemoveService) { 215 StrictMock<MockServiceWatcherClient> delegate; 216 217 scoped_ptr<ServiceWatcher> watcher( 218 service_discovery_client_.CreateServiceWatcher( 219 "_privet._tcp.local", delegate.GetCallback())); 220 221 watcher->Start(); 222 223 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 224 "hello._privet._tcp.local")) 225 .Times(Exactly(1)); 226 227 socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); 228 229 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED, 230 "hello._privet._tcp.local")) 231 .Times(Exactly(1)); 232 233 RunFor(base::TimeDelta::FromSeconds(2)); 234 }; 235 236 TEST_F(ServiceDiscoveryTest, DiscoverNewServices) { 237 StrictMock<MockServiceWatcherClient> delegate; 238 239 scoped_ptr<ServiceWatcher> watcher( 240 service_discovery_client_.CreateServiceWatcher( 241 "_privet._tcp.local", delegate.GetCallback())); 242 243 watcher->Start(); 244 245 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); 246 247 watcher->DiscoverNewServices(false); 248 249 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2); 250 251 RunFor(base::TimeDelta::FromSeconds(2)); 252 }; 253 254 TEST_F(ServiceDiscoveryTest, ReadCachedServices) { 255 socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); 256 257 StrictMock<MockServiceWatcherClient> delegate; 258 259 scoped_ptr<ServiceWatcher> watcher( 260 service_discovery_client_.CreateServiceWatcher( 261 "_privet._tcp.local", delegate.GetCallback())); 262 263 watcher->Start(); 264 265 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 266 "hello._privet._tcp.local")) 267 .Times(Exactly(1)); 268 269 base::MessageLoop::current()->RunUntilIdle(); 270 }; 271 272 273 TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) { 274 socket_factory_.SimulateReceive(kSamplePacketPTR2, sizeof(kSamplePacketPTR2)); 275 276 StrictMock<MockServiceWatcherClient> delegate; 277 scoped_ptr<ServiceWatcher> watcher = 278 service_discovery_client_.CreateServiceWatcher( 279 "_privet._tcp.local", delegate.GetCallback()); 280 281 watcher->Start(); 282 283 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 284 "hello._privet._tcp.local")) 285 .Times(Exactly(1)); 286 287 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 288 "gdbye._privet._tcp.local")) 289 .Times(Exactly(1)); 290 291 base::MessageLoop::current()->RunUntilIdle(); 292 }; 293 294 295 TEST_F(ServiceDiscoveryTest, OnServiceChanged) { 296 StrictMock<MockServiceWatcherClient> delegate; 297 scoped_ptr<ServiceWatcher> watcher( 298 service_discovery_client_.CreateServiceWatcher( 299 "_privet._tcp.local", delegate.GetCallback())); 300 301 watcher->Start(); 302 303 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 304 "hello._privet._tcp.local")) 305 .Times(Exactly(1)); 306 307 socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); 308 309 base::MessageLoop::current()->RunUntilIdle(); 310 311 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED, 312 "hello._privet._tcp.local")) 313 .Times(Exactly(1)); 314 315 socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); 316 317 socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); 318 319 base::MessageLoop::current()->RunUntilIdle(); 320 }; 321 322 TEST_F(ServiceDiscoveryTest, SinglePacket) { 323 StrictMock<MockServiceWatcherClient> delegate; 324 scoped_ptr<ServiceWatcher> watcher( 325 service_discovery_client_.CreateServiceWatcher( 326 "_privet._tcp.local", delegate.GetCallback())); 327 328 watcher->Start(); 329 330 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED, 331 "hello._privet._tcp.local")) 332 .Times(Exactly(1)); 333 334 socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR)); 335 336 // Reset the "already updated" flag. 337 base::MessageLoop::current()->RunUntilIdle(); 338 339 EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED, 340 "hello._privet._tcp.local")) 341 .Times(Exactly(1)); 342 343 socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); 344 345 socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); 346 347 base::MessageLoop::current()->RunUntilIdle(); 348 }; 349 350 class ServiceResolverTest : public ServiceDiscoveryTest { 351 public: 352 ServiceResolverTest() { 353 metadata_expected_.push_back("hello"); 354 address_expected_ = net::HostPortPair("myhello.local", 8888); 355 ip_address_expected_.push_back(1); 356 ip_address_expected_.push_back(2); 357 ip_address_expected_.push_back(3); 358 ip_address_expected_.push_back(4); 359 } 360 361 ~ServiceResolverTest() { 362 } 363 364 void SetUp() { 365 resolver_ = service_discovery_client_.CreateServiceResolver( 366 "hello._privet._tcp.local", 367 base::Bind(&ServiceResolverTest::OnFinishedResolving, 368 base::Unretained(this))); 369 } 370 371 void OnFinishedResolving(ServiceResolver::RequestStatus request_status, 372 const ServiceDescription& service_description) { 373 OnFinishedResolvingInternal(request_status, 374 service_description.address.ToString(), 375 service_description.metadata, 376 service_description.ip_address); 377 } 378 379 MOCK_METHOD4(OnFinishedResolvingInternal, 380 void(ServiceResolver::RequestStatus, 381 const std::string&, 382 const std::vector<std::string>&, 383 const net::IPAddressNumber&)); 384 385 protected: 386 scoped_ptr<ServiceResolver> resolver_; 387 net::IPAddressNumber ip_address_; 388 net::HostPortPair address_expected_; 389 std::vector<std::string> metadata_expected_; 390 net::IPAddressNumber ip_address_expected_; 391 }; 392 393 TEST_F(ServiceResolverTest, TxtAndSrvButNoA) { 394 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); 395 396 resolver_->StartResolving(); 397 398 socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV)); 399 400 base::MessageLoop::current()->RunUntilIdle(); 401 402 EXPECT_CALL(*this, 403 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, 404 address_expected_.ToString(), 405 metadata_expected_, 406 net::IPAddressNumber())); 407 408 socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); 409 }; 410 411 TEST_F(ServiceResolverTest, TxtSrvAndA) { 412 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); 413 414 resolver_->StartResolving(); 415 416 EXPECT_CALL(*this, 417 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, 418 address_expected_.ToString(), 419 metadata_expected_, 420 ip_address_expected_)); 421 422 socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT)); 423 424 socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA)); 425 }; 426 427 TEST_F(ServiceResolverTest, JustSrv) { 428 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); 429 430 resolver_->StartResolving(); 431 432 EXPECT_CALL(*this, 433 OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS, 434 address_expected_.ToString(), 435 std::vector<std::string>(), 436 ip_address_expected_)); 437 438 socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA)); 439 440 // TODO(noamsml): When NSEC record support is added, change this to use an 441 // NSEC record. 442 RunFor(base::TimeDelta::FromSeconds(4)); 443 }; 444 445 TEST_F(ServiceResolverTest, WithNothing) { 446 EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4); 447 448 resolver_->StartResolving(); 449 450 EXPECT_CALL(*this, OnFinishedResolvingInternal( 451 ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _)); 452 453 // TODO(noamsml): When NSEC record support is added, change this to use an 454 // NSEC record. 455 RunFor(base::TimeDelta::FromSeconds(4)); 456 }; 457 458 } // namespace 459 460 } // namespace local_discovery 461