Home | History | Annotate | Download | only in local_discovery
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "base/memory/weak_ptr.h"
      6 #include "base/run_loop.h"
      7 #include "chrome/common/local_discovery/service_discovery_client_impl.h"
      8 #include "net/base/net_errors.h"
      9 #include "net/dns/dns_protocol.h"
     10 #include "net/dns/mdns_client_impl.h"
     11 #include "net/dns/mock_mdns_socket_factory.h"
     12 #include "testing/gmock/include/gmock/gmock.h"
     13 #include "testing/gtest/include/gtest/gtest.h"
     14 
     15 using ::testing::_;
     16 using ::testing::Invoke;
     17 using ::testing::StrictMock;
     18 using ::testing::NiceMock;
     19 using ::testing::Mock;
     20 using ::testing::SaveArg;
     21 using ::testing::SetArgPointee;
     22 using ::testing::Return;
     23 using ::testing::Exactly;
     24 
     25 namespace local_discovery {
     26 
     27 namespace {
     28 
     29 const uint8 kSamplePacketPTR[] = {
     30   // Header
     31   0x00, 0x00,               // ID is zeroed out
     32   0x81, 0x80,               // Standard query response, RA, no error
     33   0x00, 0x00,               // No questions (for simplicity)
     34   0x00, 0x01,               // 1 RR (answers)
     35   0x00, 0x00,               // 0 authority RRs
     36   0x00, 0x00,               // 0 additional RRs
     37 
     38   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
     39   0x04, '_', 't', 'c', 'p',
     40   0x05, 'l', 'o', 'c', 'a', 'l',
     41   0x00,
     42   0x00, 0x0c,        // TYPE is PTR.
     43   0x00, 0x01,        // CLASS is IN.
     44   0x00, 0x00,        // TTL (4 bytes) is 1 second.
     45   0x00, 0x01,
     46   0x00, 0x08,        // RDLENGTH is 8 bytes.
     47   0x05, 'h', 'e', 'l', 'l', 'o',
     48   0xc0, 0x0c
     49 };
     50 
     51 const uint8 kSamplePacketSRV[] = {
     52   // Header
     53   0x00, 0x00,               // ID is zeroed out
     54   0x81, 0x80,               // Standard query response, RA, no error
     55   0x00, 0x00,               // No questions (for simplicity)
     56   0x00, 0x01,               // 1 RR (answers)
     57   0x00, 0x00,               // 0 authority RRs
     58   0x00, 0x00,               // 0 additional RRs
     59 
     60   0x05, 'h', 'e', 'l', 'l', 'o',
     61   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
     62   0x04, '_', 't', 'c', 'p',
     63   0x05, 'l', 'o', 'c', 'a', 'l',
     64   0x00,
     65   0x00, 0x21,        // TYPE is SRV.
     66   0x00, 0x01,        // CLASS is IN.
     67   0x00, 0x00,        // TTL (4 bytes) is 1 second.
     68   0x00, 0x01,
     69   0x00, 0x15,        // RDLENGTH is 21 bytes.
     70   0x00, 0x00,
     71   0x00, 0x00,
     72   0x22, 0xb8,  // port 8888
     73   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
     74   0x05, 'l', 'o', 'c', 'a', 'l',
     75   0x00,
     76 };
     77 
     78 const uint8 kSamplePacketTXT[] = {
     79   // Header
     80   0x00, 0x00,               // ID is zeroed out
     81   0x81, 0x80,               // Standard query response, RA, no error
     82   0x00, 0x00,               // No questions (for simplicity)
     83   0x00, 0x01,               // 1 RR (answers)
     84   0x00, 0x00,               // 0 authority RRs
     85   0x00, 0x00,               // 0 additional RRs
     86 
     87   0x05, 'h', 'e', 'l', 'l', 'o',
     88   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
     89   0x04, '_', 't', 'c', 'p',
     90   0x05, 'l', 'o', 'c', 'a', 'l',
     91   0x00,
     92   0x00, 0x10,        // TYPE is PTR.
     93   0x00, 0x01,        // CLASS is IN.
     94   0x00, 0x00,        // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
     95   0x00, 0x01,
     96   0x00, 0x06,        // RDLENGTH is 21 bytes.
     97   0x05, 'h', 'e', 'l', 'l', 'o'
     98 };
     99 
    100 const uint8 kSamplePacketSRVA[] = {
    101   // Header
    102   0x00, 0x00,               // ID is zeroed out
    103   0x81, 0x80,               // Standard query response, RA, no error
    104   0x00, 0x00,               // No questions (for simplicity)
    105   0x00, 0x02,               // 2 RR (answers)
    106   0x00, 0x00,               // 0 authority RRs
    107   0x00, 0x00,               // 0 additional RRs
    108 
    109   0x05, 'h', 'e', 'l', 'l', 'o',
    110   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    111   0x04, '_', 't', 'c', 'p',
    112   0x05, 'l', 'o', 'c', 'a', 'l',
    113   0x00,
    114   0x00, 0x21,        // TYPE is SRV.
    115   0x00, 0x01,        // CLASS is IN.
    116   0x00, 0x00,        // TTL (4 bytes) is 16 seconds.
    117   0x00, 0x10,
    118   0x00, 0x15,        // RDLENGTH is 21 bytes.
    119   0x00, 0x00,
    120   0x00, 0x00,
    121   0x22, 0xb8,  // port 8888
    122   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
    123   0x05, 'l', 'o', 'c', 'a', 'l',
    124   0x00,
    125 
    126   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
    127   0x05, 'l', 'o', 'c', 'a', 'l',
    128   0x00,
    129   0x00, 0x01,        // TYPE is A.
    130   0x00, 0x01,        // CLASS is IN.
    131   0x00, 0x00,        // TTL (4 bytes) is 16 seconds.
    132   0x00, 0x10,
    133   0x00, 0x04,        // RDLENGTH is 4 bytes.
    134   0x01, 0x02,
    135   0x03, 0x04,
    136 };
    137 
    138 const uint8 kSamplePacketPTR2[] = {
    139   // Header
    140   0x00, 0x00,               // ID is zeroed out
    141   0x81, 0x80,               // Standard query response, RA, no error
    142   0x00, 0x00,               // No questions (for simplicity)
    143   0x00, 0x02,               // 2 RR (answers)
    144   0x00, 0x00,               // 0 authority RRs
    145   0x00, 0x00,               // 0 additional RRs
    146 
    147   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    148   0x04, '_', 't', 'c', 'p',
    149   0x05, 'l', 'o', 'c', 'a', 'l',
    150   0x00,
    151   0x00, 0x0c,        // TYPE is PTR.
    152   0x00, 0x01,        // CLASS is IN.
    153   0x02, 0x00,        // TTL (4 bytes) is 1 second.
    154   0x00, 0x01,
    155   0x00, 0x08,        // RDLENGTH is 8 bytes.
    156   0x05, 'g', 'd', 'b', 'y', 'e',
    157   0xc0, 0x0c,
    158 
    159   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    160   0x04, '_', 't', 'c', 'p',
    161   0x05, 'l', 'o', 'c', 'a', 'l',
    162   0x00,
    163   0x00, 0x0c,        // TYPE is PTR.
    164   0x00, 0x01,        // CLASS is IN.
    165   0x02, 0x00,        // TTL (4 bytes) is 1 second.
    166   0x00, 0x01,
    167   0x00, 0x08,        // RDLENGTH is 8 bytes.
    168   0x05, 'h', 'e', 'l', 'l', 'o',
    169   0xc0, 0x0c
    170 };
    171 
    172 const uint8 kSamplePacketQuerySRV[] = {
    173   // Header
    174   0x00, 0x00,               // ID is zeroed out
    175   0x00, 0x00,               // No flags.
    176   0x00, 0x01,               // One question.
    177   0x00, 0x00,               // 0 RRs (answers)
    178   0x00, 0x00,               // 0 authority RRs
    179   0x00, 0x00,               // 0 additional RRs
    180 
    181   // Question
    182   0x05, 'h', 'e', 'l', 'l', 'o',
    183   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    184   0x04, '_', 't', 'c', 'p',
    185   0x05, 'l', 'o', 'c', 'a', 'l',
    186   0x00,
    187   0x00, 0x21,        // TYPE is SRV.
    188   0x00, 0x01,        // CLASS is IN.
    189 };
    190 
    191 
    192 class MockServiceWatcherClient {
    193  public:
    194   MOCK_METHOD2(OnServiceUpdated,
    195                void(ServiceWatcher::UpdateType, const std::string&));
    196 
    197   ServiceWatcher::UpdatedCallback GetCallback() {
    198     return base::Bind(&MockServiceWatcherClient::OnServiceUpdated,
    199                       base::Unretained(this));
    200   }
    201 };
    202 
    203 class ServiceDiscoveryTest : public ::testing::Test {
    204  public:
    205   ServiceDiscoveryTest()
    206       : service_discovery_client_(&mdns_client_) {
    207     mdns_client_.StartListening(&socket_factory_);
    208   }
    209 
    210   virtual ~ServiceDiscoveryTest() {
    211   }
    212 
    213  protected:
    214   void RunFor(base::TimeDelta time_period) {
    215     base::CancelableCallback<void()> callback(base::Bind(
    216         &ServiceDiscoveryTest::Stop, base::Unretained(this)));
    217     base::MessageLoop::current()->PostDelayedTask(
    218         FROM_HERE, callback.callback(), time_period);
    219 
    220     base::MessageLoop::current()->Run();
    221     callback.Cancel();
    222   }
    223 
    224   void Stop() {
    225     base::MessageLoop::current()->Quit();
    226   }
    227 
    228   net::MockMDnsSocketFactory socket_factory_;
    229   net::MDnsClientImpl mdns_client_;
    230   ServiceDiscoveryClientImpl service_discovery_client_;
    231   base::MessageLoop loop_;
    232 };
    233 
    234 TEST_F(ServiceDiscoveryTest, AddRemoveService) {
    235   StrictMock<MockServiceWatcherClient> delegate;
    236 
    237   scoped_ptr<ServiceWatcher> watcher(
    238       service_discovery_client_.CreateServiceWatcher(
    239           "_privet._tcp.local", delegate.GetCallback()));
    240 
    241   watcher->Start();
    242 
    243   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    244                                          "hello._privet._tcp.local"))
    245       .Times(Exactly(1));
    246 
    247   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
    248 
    249   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED,
    250                                          "hello._privet._tcp.local"))
    251       .Times(Exactly(1));
    252 
    253   RunFor(base::TimeDelta::FromSeconds(2));
    254 };
    255 
    256 TEST_F(ServiceDiscoveryTest, DiscoverNewServices) {
    257   StrictMock<MockServiceWatcherClient> delegate;
    258 
    259   scoped_ptr<ServiceWatcher> watcher(
    260       service_discovery_client_.CreateServiceWatcher(
    261           "_privet._tcp.local", delegate.GetCallback()));
    262 
    263   watcher->Start();
    264 
    265   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
    266 
    267   watcher->DiscoverNewServices(false);
    268 
    269   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
    270 
    271   RunFor(base::TimeDelta::FromSeconds(2));
    272 };
    273 
    274 TEST_F(ServiceDiscoveryTest, ReadCachedServices) {
    275   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
    276 
    277   StrictMock<MockServiceWatcherClient> delegate;
    278 
    279   scoped_ptr<ServiceWatcher> watcher(
    280       service_discovery_client_.CreateServiceWatcher(
    281           "_privet._tcp.local", delegate.GetCallback()));
    282 
    283   watcher->Start();
    284 
    285   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    286                                          "hello._privet._tcp.local"))
    287       .Times(Exactly(1));
    288 
    289   base::MessageLoop::current()->RunUntilIdle();
    290 };
    291 
    292 
    293 TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) {
    294   socket_factory_.SimulateReceive(kSamplePacketPTR2, sizeof(kSamplePacketPTR2));
    295 
    296   StrictMock<MockServiceWatcherClient> delegate;
    297   scoped_ptr<ServiceWatcher> watcher =
    298       service_discovery_client_.CreateServiceWatcher(
    299           "_privet._tcp.local", delegate.GetCallback());
    300 
    301   watcher->Start();
    302 
    303   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    304                                          "hello._privet._tcp.local"))
    305       .Times(Exactly(1));
    306 
    307   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    308                                          "gdbye._privet._tcp.local"))
    309       .Times(Exactly(1));
    310 
    311   base::MessageLoop::current()->RunUntilIdle();
    312 };
    313 
    314 
    315 TEST_F(ServiceDiscoveryTest, OnServiceChanged) {
    316   StrictMock<MockServiceWatcherClient> delegate;
    317   scoped_ptr<ServiceWatcher> watcher(
    318       service_discovery_client_.CreateServiceWatcher(
    319           "_privet._tcp.local", delegate.GetCallback()));
    320 
    321   watcher->Start();
    322 
    323   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    324                                          "hello._privet._tcp.local"))
    325       .Times(Exactly(1));
    326 
    327   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
    328 
    329   base::MessageLoop::current()->RunUntilIdle();
    330 
    331   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
    332                                          "hello._privet._tcp.local"))
    333       .Times(Exactly(1));
    334 
    335   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
    336 
    337   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
    338 
    339   base::MessageLoop::current()->RunUntilIdle();
    340 };
    341 
    342 TEST_F(ServiceDiscoveryTest, SinglePacket) {
    343   StrictMock<MockServiceWatcherClient> delegate;
    344   scoped_ptr<ServiceWatcher> watcher(
    345       service_discovery_client_.CreateServiceWatcher(
    346           "_privet._tcp.local", delegate.GetCallback()));
    347 
    348   watcher->Start();
    349 
    350   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    351                                          "hello._privet._tcp.local"))
    352       .Times(Exactly(1));
    353 
    354   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
    355 
    356   // Reset the "already updated" flag.
    357   base::MessageLoop::current()->RunUntilIdle();
    358 
    359   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
    360                                          "hello._privet._tcp.local"))
    361       .Times(Exactly(1));
    362 
    363   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
    364 
    365   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
    366 
    367   base::MessageLoop::current()->RunUntilIdle();
    368 };
    369 
    370 TEST_F(ServiceDiscoveryTest, ActivelyRefreshServices) {
    371   StrictMock<MockServiceWatcherClient> delegate;
    372   scoped_ptr<ServiceWatcher> watcher(
    373       service_discovery_client_.CreateServiceWatcher(
    374           "_privet._tcp.local", delegate.GetCallback()));
    375 
    376   watcher->Start();
    377   watcher->SetActivelyRefreshServices(true);
    378 
    379   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    380                                          "hello._privet._tcp.local"))
    381       .Times(Exactly(1));
    382 
    383   std::string query_packet = std::string((const char*)(kSamplePacketQuerySRV),
    384                                          sizeof(kSamplePacketQuerySRV));
    385 
    386   EXPECT_CALL(socket_factory_, OnSendTo(query_packet))
    387       .Times(2);
    388 
    389   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
    390 
    391   base::MessageLoop::current()->RunUntilIdle();
    392 
    393   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
    394 
    395   EXPECT_CALL(socket_factory_, OnSendTo(query_packet))
    396       .Times(4);  // IPv4 and IPv6 at 85% and 95%
    397 
    398   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED,
    399                                          "hello._privet._tcp.local"))
    400       .Times(Exactly(1));
    401 
    402   RunFor(base::TimeDelta::FromSeconds(2));
    403 
    404   base::MessageLoop::current()->RunUntilIdle();
    405 };
    406 
    407 
    408 class ServiceResolverTest : public ServiceDiscoveryTest {
    409  public:
    410   ServiceResolverTest() {
    411     metadata_expected_.push_back("hello");
    412     address_expected_ = net::HostPortPair("myhello.local", 8888);
    413     ip_address_expected_.push_back(1);
    414     ip_address_expected_.push_back(2);
    415     ip_address_expected_.push_back(3);
    416     ip_address_expected_.push_back(4);
    417   }
    418 
    419   ~ServiceResolverTest() {
    420   }
    421 
    422   void SetUp()  {
    423     resolver_ = service_discovery_client_.CreateServiceResolver(
    424                     "hello._privet._tcp.local",
    425                      base::Bind(&ServiceResolverTest::OnFinishedResolving,
    426                                 base::Unretained(this)));
    427   }
    428 
    429   void OnFinishedResolving(ServiceResolver::RequestStatus request_status,
    430                            const ServiceDescription& service_description) {
    431     OnFinishedResolvingInternal(request_status,
    432                                 service_description.address.ToString(),
    433                                 service_description.metadata,
    434                                 service_description.ip_address);
    435   }
    436 
    437   MOCK_METHOD4(OnFinishedResolvingInternal,
    438                void(ServiceResolver::RequestStatus,
    439                     const std::string&,
    440                     const std::vector<std::string>&,
    441                     const net::IPAddressNumber&));
    442 
    443  protected:
    444   scoped_ptr<ServiceResolver> resolver_;
    445   net::IPAddressNumber ip_address_;
    446   net::HostPortPair address_expected_;
    447   std::vector<std::string> metadata_expected_;
    448   net::IPAddressNumber ip_address_expected_;
    449 };
    450 
    451 TEST_F(ServiceResolverTest, TxtAndSrvButNoA) {
    452   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
    453 
    454   resolver_->StartResolving();
    455 
    456   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
    457 
    458   base::MessageLoop::current()->RunUntilIdle();
    459 
    460   EXPECT_CALL(*this,
    461               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
    462                                           address_expected_.ToString(),
    463                                           metadata_expected_,
    464                                           net::IPAddressNumber()));
    465 
    466   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
    467 };
    468 
    469 TEST_F(ServiceResolverTest, TxtSrvAndA) {
    470   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
    471 
    472   resolver_->StartResolving();
    473 
    474   EXPECT_CALL(*this,
    475               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
    476                                           address_expected_.ToString(),
    477                                           metadata_expected_,
    478                                           ip_address_expected_));
    479 
    480   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
    481 
    482   socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
    483 };
    484 
    485 TEST_F(ServiceResolverTest, JustSrv) {
    486   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
    487 
    488   resolver_->StartResolving();
    489 
    490   EXPECT_CALL(*this,
    491               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
    492                                           address_expected_.ToString(),
    493                                           std::vector<std::string>(),
    494                                           ip_address_expected_));
    495 
    496   socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
    497 
    498   // TODO(noamsml): When NSEC record support is added, change this to use an
    499   // NSEC record.
    500   RunFor(base::TimeDelta::FromSeconds(4));
    501 };
    502 
    503 TEST_F(ServiceResolverTest, WithNothing) {
    504   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
    505 
    506   resolver_->StartResolving();
    507 
    508   EXPECT_CALL(*this, OnFinishedResolvingInternal(
    509                          ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _));
    510 
    511   // TODO(noamsml): When NSEC record support is added, change this to use an
    512   // NSEC record.
    513   RunFor(base::TimeDelta::FromSeconds(4));
    514 };
    515 
    516 }  // namespace
    517 
    518 }  // namespace local_discovery
    519