Home | History | Annotate | Download | only in local_discovery
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "base/memory/weak_ptr.h"
      6 #include "base/run_loop.h"
      7 #include "chrome/utility/local_discovery/service_discovery_client_impl.h"
      8 #include "net/base/net_errors.h"
      9 #include "net/dns/dns_protocol.h"
     10 #include "net/dns/mdns_client_impl.h"
     11 #include "net/dns/mock_mdns_socket_factory.h"
     12 #include "testing/gmock/include/gmock/gmock.h"
     13 #include "testing/gtest/include/gtest/gtest.h"
     14 
     15 using ::testing::_;
     16 using ::testing::Invoke;
     17 using ::testing::StrictMock;
     18 using ::testing::NiceMock;
     19 using ::testing::Mock;
     20 using ::testing::SaveArg;
     21 using ::testing::SetArgPointee;
     22 using ::testing::Return;
     23 using ::testing::Exactly;
     24 
     25 namespace local_discovery {
     26 
     27 namespace {
     28 
     29 const uint8 kSamplePacketPTR[] = {
     30   // Header
     31   0x00, 0x00,               // ID is zeroed out
     32   0x81, 0x80,               // Standard query response, RA, no error
     33   0x00, 0x00,               // No questions (for simplicity)
     34   0x00, 0x01,               // 1 RR (answers)
     35   0x00, 0x00,               // 0 authority RRs
     36   0x00, 0x00,               // 0 additional RRs
     37 
     38   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
     39   0x04, '_', 't', 'c', 'p',
     40   0x05, 'l', 'o', 'c', 'a', 'l',
     41   0x00,
     42   0x00, 0x0c,        // TYPE is PTR.
     43   0x00, 0x01,        // CLASS is IN.
     44   0x00, 0x00,        // TTL (4 bytes) is 1 second.
     45   0x00, 0x01,
     46   0x00, 0x08,        // RDLENGTH is 8 bytes.
     47   0x05, 'h', 'e', 'l', 'l', 'o',
     48   0xc0, 0x0c
     49 };
     50 
     51 const uint8 kSamplePacketSRV[] = {
     52   // Header
     53   0x00, 0x00,               // ID is zeroed out
     54   0x81, 0x80,               // Standard query response, RA, no error
     55   0x00, 0x00,               // No questions (for simplicity)
     56   0x00, 0x01,               // 1 RR (answers)
     57   0x00, 0x00,               // 0 authority RRs
     58   0x00, 0x00,               // 0 additional RRs
     59 
     60   0x05, 'h', 'e', 'l', 'l', 'o',
     61   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
     62   0x04, '_', 't', 'c', 'p',
     63   0x05, 'l', 'o', 'c', 'a', 'l',
     64   0x00,
     65   0x00, 0x21,        // TYPE is SRV.
     66   0x00, 0x01,        // CLASS is IN.
     67   0x00, 0x00,        // TTL (4 bytes) is 1 second.
     68   0x00, 0x01,
     69   0x00, 0x15,        // RDLENGTH is 21 bytes.
     70   0x00, 0x00,
     71   0x00, 0x00,
     72   0x22, 0xb8,  // port 8888
     73   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
     74   0x05, 'l', 'o', 'c', 'a', 'l',
     75   0x00,
     76 };
     77 
     78 const uint8 kSamplePacketTXT[] = {
     79   // Header
     80   0x00, 0x00,               // ID is zeroed out
     81   0x81, 0x80,               // Standard query response, RA, no error
     82   0x00, 0x00,               // No questions (for simplicity)
     83   0x00, 0x01,               // 1 RR (answers)
     84   0x00, 0x00,               // 0 authority RRs
     85   0x00, 0x00,               // 0 additional RRs
     86 
     87   0x05, 'h', 'e', 'l', 'l', 'o',
     88   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
     89   0x04, '_', 't', 'c', 'p',
     90   0x05, 'l', 'o', 'c', 'a', 'l',
     91   0x00,
     92   0x00, 0x10,        // TYPE is PTR.
     93   0x00, 0x01,        // CLASS is IN.
     94   0x00, 0x00,        // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
     95   0x00, 0x01,
     96   0x00, 0x06,        // RDLENGTH is 21 bytes.
     97   0x05, 'h', 'e', 'l', 'l', 'o'
     98 };
     99 
    100 const uint8 kSamplePacketSRVA[] = {
    101   // Header
    102   0x00, 0x00,               // ID is zeroed out
    103   0x81, 0x80,               // Standard query response, RA, no error
    104   0x00, 0x00,               // No questions (for simplicity)
    105   0x00, 0x02,               // 2 RR (answers)
    106   0x00, 0x00,               // 0 authority RRs
    107   0x00, 0x00,               // 0 additional RRs
    108 
    109   0x05, 'h', 'e', 'l', 'l', 'o',
    110   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    111   0x04, '_', 't', 'c', 'p',
    112   0x05, 'l', 'o', 'c', 'a', 'l',
    113   0x00,
    114   0x00, 0x21,        // TYPE is SRV.
    115   0x00, 0x01,        // CLASS is IN.
    116   0x00, 0x00,        // TTL (4 bytes) is 16 seconds.
    117   0x00, 0x10,
    118   0x00, 0x15,        // RDLENGTH is 21 bytes.
    119   0x00, 0x00,
    120   0x00, 0x00,
    121   0x22, 0xb8,  // port 8888
    122   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
    123   0x05, 'l', 'o', 'c', 'a', 'l',
    124   0x00,
    125 
    126   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
    127   0x05, 'l', 'o', 'c', 'a', 'l',
    128   0x00,
    129   0x00, 0x01,        // TYPE is A.
    130   0x00, 0x01,        // CLASS is IN.
    131   0x00, 0x00,        // TTL (4 bytes) is 16 seconds.
    132   0x00, 0x10,
    133   0x00, 0x04,        // RDLENGTH is 4 bytes.
    134   0x01, 0x02,
    135   0x03, 0x04,
    136 };
    137 
    138 const uint8 kSamplePacketPTR2[] = {
    139   // Header
    140   0x00, 0x00,               // ID is zeroed out
    141   0x81, 0x80,               // Standard query response, RA, no error
    142   0x00, 0x00,               // No questions (for simplicity)
    143   0x00, 0x02,               // 2 RR (answers)
    144   0x00, 0x00,               // 0 authority RRs
    145   0x00, 0x00,               // 0 additional RRs
    146 
    147   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    148   0x04, '_', 't', 'c', 'p',
    149   0x05, 'l', 'o', 'c', 'a', 'l',
    150   0x00,
    151   0x00, 0x0c,        // TYPE is PTR.
    152   0x00, 0x01,        // CLASS is IN.
    153   0x02, 0x00,        // TTL (4 bytes) is 1 second.
    154   0x00, 0x01,
    155   0x00, 0x08,        // RDLENGTH is 8 bytes.
    156   0x05, 'g', 'd', 'b', 'y', 'e',
    157   0xc0, 0x0c,
    158 
    159   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    160   0x04, '_', 't', 'c', 'p',
    161   0x05, 'l', 'o', 'c', 'a', 'l',
    162   0x00,
    163   0x00, 0x0c,        // TYPE is PTR.
    164   0x00, 0x01,        // CLASS is IN.
    165   0x02, 0x00,        // TTL (4 bytes) is 1 second.
    166   0x00, 0x01,
    167   0x00, 0x08,        // RDLENGTH is 8 bytes.
    168   0x05, 'h', 'e', 'l', 'l', 'o',
    169   0xc0, 0x0c
    170 };
    171 
    172 class MockServiceWatcherClient {
    173  public:
    174   MOCK_METHOD2(OnServiceUpdated,
    175                void(ServiceWatcher::UpdateType, const std::string&));
    176 
    177   ServiceWatcher::UpdatedCallback GetCallback() {
    178     return base::Bind(&MockServiceWatcherClient::OnServiceUpdated,
    179                       base::Unretained(this));
    180   }
    181 };
    182 
    183 class ServiceDiscoveryTest : public ::testing::Test {
    184  public:
    185   ServiceDiscoveryTest()
    186       : socket_factory_(new net::MockMDnsSocketFactory),
    187         mdns_client_(
    188             scoped_ptr<net::MDnsConnection::SocketFactory>(
    189                 socket_factory_)),
    190         service_discovery_client_(&mdns_client_) {
    191     mdns_client_.StartListening();
    192   }
    193 
    194   virtual ~ServiceDiscoveryTest() {
    195   }
    196 
    197  protected:
    198   void RunFor(base::TimeDelta time_period) {
    199     base::CancelableCallback<void()> callback(base::Bind(
    200         &ServiceDiscoveryTest::Stop, base::Unretained(this)));
    201     base::MessageLoop::current()->PostDelayedTask(
    202         FROM_HERE, callback.callback(), time_period);
    203 
    204     base::MessageLoop::current()->Run();
    205     callback.Cancel();
    206   }
    207 
    208   void Stop() {
    209     base::MessageLoop::current()->Quit();
    210   }
    211 
    212   net::MockMDnsSocketFactory* socket_factory_;
    213   net::MDnsClientImpl mdns_client_;
    214   ServiceDiscoveryClientImpl service_discovery_client_;
    215   base::MessageLoop loop_;
    216 };
    217 
    218 TEST_F(ServiceDiscoveryTest, AddRemoveService) {
    219   StrictMock<MockServiceWatcherClient> delegate;
    220 
    221   scoped_ptr<ServiceWatcher> watcher(
    222       service_discovery_client_.CreateServiceWatcher(
    223           "_privet._tcp.local", delegate.GetCallback()));
    224 
    225   watcher->Start();
    226 
    227   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    228                                          "hello._privet._tcp.local"))
    229       .Times(Exactly(1));
    230 
    231   socket_factory_->SimulateReceive(
    232       kSamplePacketPTR, sizeof(kSamplePacketPTR));
    233 
    234   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED,
    235                                          "hello._privet._tcp.local"))
    236       .Times(Exactly(1));
    237 
    238   RunFor(base::TimeDelta::FromSeconds(2));
    239 };
    240 
    241 TEST_F(ServiceDiscoveryTest, DiscoverNewServices) {
    242   StrictMock<MockServiceWatcherClient> delegate;
    243 
    244   scoped_ptr<ServiceWatcher> watcher(
    245       service_discovery_client_.CreateServiceWatcher(
    246           "_privet._tcp.local", delegate.GetCallback()));
    247 
    248   watcher->Start();
    249 
    250   EXPECT_CALL(*socket_factory_, OnSendTo(_))
    251       .Times(2);
    252 
    253   watcher->DiscoverNewServices(false);
    254 };
    255 
    256 TEST_F(ServiceDiscoveryTest, ReadCachedServices) {
    257   socket_factory_->SimulateReceive(
    258       kSamplePacketPTR, sizeof(kSamplePacketPTR));
    259 
    260   StrictMock<MockServiceWatcherClient> delegate;
    261 
    262   scoped_ptr<ServiceWatcher> watcher(
    263       service_discovery_client_.CreateServiceWatcher(
    264           "_privet._tcp.local", delegate.GetCallback()));
    265 
    266   watcher->Start();
    267 
    268   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    269                                          "hello._privet._tcp.local"))
    270       .Times(Exactly(1));
    271 
    272   base::MessageLoop::current()->RunUntilIdle();
    273 };
    274 
    275 
    276 TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) {
    277   socket_factory_->SimulateReceive(
    278       kSamplePacketPTR2, sizeof(kSamplePacketPTR2));
    279 
    280   StrictMock<MockServiceWatcherClient> delegate;
    281   scoped_ptr<ServiceWatcher> watcher =
    282       service_discovery_client_.CreateServiceWatcher(
    283           "_privet._tcp.local", delegate.GetCallback());
    284 
    285   watcher->Start();
    286 
    287   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    288                                          "hello._privet._tcp.local"))
    289       .Times(Exactly(1));
    290 
    291   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    292                                          "gdbye._privet._tcp.local"))
    293       .Times(Exactly(1));
    294 
    295   base::MessageLoop::current()->RunUntilIdle();
    296 };
    297 
    298 
    299 TEST_F(ServiceDiscoveryTest, OnServiceChanged) {
    300   StrictMock<MockServiceWatcherClient> delegate;
    301   scoped_ptr<ServiceWatcher> watcher(
    302       service_discovery_client_.CreateServiceWatcher(
    303           "_privet._tcp.local", delegate.GetCallback()));
    304 
    305   watcher->Start();
    306 
    307   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    308                                          "hello._privet._tcp.local"))
    309       .Times(Exactly(1));
    310 
    311   socket_factory_->SimulateReceive(
    312       kSamplePacketPTR, sizeof(kSamplePacketPTR));
    313 
    314   base::MessageLoop::current()->RunUntilIdle();
    315 
    316   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
    317                                          "hello._privet._tcp.local"))
    318       .Times(Exactly(1));
    319 
    320   socket_factory_->SimulateReceive(
    321       kSamplePacketSRV, sizeof(kSamplePacketSRV));
    322 
    323   socket_factory_->SimulateReceive(
    324       kSamplePacketTXT, sizeof(kSamplePacketTXT));
    325 
    326   base::MessageLoop::current()->RunUntilIdle();
    327 };
    328 
    329 TEST_F(ServiceDiscoveryTest, SinglePacket) {
    330   StrictMock<MockServiceWatcherClient> delegate;
    331   scoped_ptr<ServiceWatcher> watcher(
    332       service_discovery_client_.CreateServiceWatcher(
    333           "_privet._tcp.local", delegate.GetCallback()));
    334 
    335   watcher->Start();
    336 
    337   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    338                                          "hello._privet._tcp.local"))
    339       .Times(Exactly(1));
    340 
    341   socket_factory_->SimulateReceive(
    342       kSamplePacketPTR, sizeof(kSamplePacketPTR));
    343 
    344   // Reset the "already updated" flag.
    345   base::MessageLoop::current()->RunUntilIdle();
    346 
    347   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
    348                                          "hello._privet._tcp.local"))
    349       .Times(Exactly(1));
    350 
    351   socket_factory_->SimulateReceive(
    352       kSamplePacketSRV, sizeof(kSamplePacketSRV));
    353 
    354   socket_factory_->SimulateReceive(
    355       kSamplePacketTXT, sizeof(kSamplePacketTXT));
    356 
    357   base::MessageLoop::current()->RunUntilIdle();
    358 };
    359 
    360 class ServiceResolverTest : public ServiceDiscoveryTest {
    361  public:
    362   ServiceResolverTest() {
    363     metadata_expected_.push_back("hello");
    364     address_expected_ = net::HostPortPair("myhello.local", 8888);
    365     ip_address_expected_.push_back(1);
    366     ip_address_expected_.push_back(2);
    367     ip_address_expected_.push_back(3);
    368     ip_address_expected_.push_back(4);
    369   }
    370 
    371   ~ServiceResolverTest() {
    372   }
    373 
    374   void SetUp()  {
    375     resolver_ = service_discovery_client_.CreateServiceResolver(
    376                     "hello._privet._tcp.local",
    377                      base::Bind(&ServiceResolverTest::OnFinishedResolving,
    378                                 base::Unretained(this)));
    379   }
    380 
    381   void OnFinishedResolving(ServiceResolver::RequestStatus request_status,
    382                            const ServiceDescription& service_description) {
    383     OnFinishedResolvingInternal(request_status,
    384                                 service_description.address.ToString(),
    385                                 service_description.metadata,
    386                                 service_description.ip_address);
    387   }
    388 
    389   MOCK_METHOD4(OnFinishedResolvingInternal,
    390                void(ServiceResolver::RequestStatus,
    391                     const std::string&,
    392                     const std::vector<std::string>&,
    393                     const net::IPAddressNumber&));
    394 
    395  protected:
    396   scoped_ptr<ServiceResolver> resolver_;
    397   net::IPAddressNumber ip_address_;
    398   net::HostPortPair address_expected_;
    399   std::vector<std::string> metadata_expected_;
    400   net::IPAddressNumber ip_address_expected_;
    401 };
    402 
    403 TEST_F(ServiceResolverTest, TxtAndSrvButNoA) {
    404   EXPECT_CALL(*socket_factory_, OnSendTo(_))
    405       .Times(4);
    406 
    407   resolver_->StartResolving();
    408 
    409   socket_factory_->SimulateReceive(
    410       kSamplePacketSRV, sizeof(kSamplePacketSRV));
    411 
    412   base::MessageLoop::current()->RunUntilIdle();
    413 
    414   EXPECT_CALL(*this,
    415               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
    416                                           address_expected_.ToString(),
    417                                           metadata_expected_,
    418                                           net::IPAddressNumber()));
    419 
    420   socket_factory_->SimulateReceive(
    421       kSamplePacketTXT, sizeof(kSamplePacketTXT));
    422 };
    423 
    424 TEST_F(ServiceResolverTest, TxtSrvAndA) {
    425   EXPECT_CALL(*socket_factory_, OnSendTo(_))
    426       .Times(4);
    427 
    428   resolver_->StartResolving();
    429 
    430   EXPECT_CALL(*this,
    431               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
    432                                           address_expected_.ToString(),
    433                                           metadata_expected_,
    434                                           ip_address_expected_));
    435 
    436   socket_factory_->SimulateReceive(
    437       kSamplePacketTXT, sizeof(kSamplePacketTXT));
    438 
    439   socket_factory_->SimulateReceive(
    440       kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
    441 };
    442 
    443 TEST_F(ServiceResolverTest, JustSrv) {
    444   EXPECT_CALL(*socket_factory_, OnSendTo(_))
    445       .Times(4);
    446 
    447   resolver_->StartResolving();
    448 
    449   EXPECT_CALL(*this,
    450               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
    451                                           address_expected_.ToString(),
    452                                           std::vector<std::string>(),
    453                                           ip_address_expected_));
    454 
    455   socket_factory_->SimulateReceive(
    456       kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
    457 
    458   // TODO(noamsml): When NSEC record support is added, change this to use an
    459   // NSEC record.
    460   RunFor(base::TimeDelta::FromSeconds(4));
    461 };
    462 
    463 TEST_F(ServiceResolverTest, WithNothing) {
    464   EXPECT_CALL(*socket_factory_, OnSendTo(_))
    465       .Times(4);
    466 
    467   resolver_->StartResolving();
    468 
    469   EXPECT_CALL(*this, OnFinishedResolvingInternal(
    470                         ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _));
    471 
    472   // TODO(noamsml): When NSEC record support is added, change this to use an
    473   // NSEC record.
    474   RunFor(base::TimeDelta::FromSeconds(4));
    475 };
    476 
    477 }  // namespace
    478 
    479 }  // namespace local_discovery
    480