Home | History | Annotate | Download | only in local_discovery
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "base/memory/weak_ptr.h"
      6 #include "base/run_loop.h"
      7 #include "chrome/utility/local_discovery/service_discovery_client_impl.h"
      8 #include "net/base/net_errors.h"
      9 #include "net/dns/dns_protocol.h"
     10 #include "net/dns/mdns_client_impl.h"
     11 #include "net/dns/mock_mdns_socket_factory.h"
     12 #include "testing/gmock/include/gmock/gmock.h"
     13 #include "testing/gtest/include/gtest/gtest.h"
     14 
     15 using ::testing::_;
     16 using ::testing::Invoke;
     17 using ::testing::StrictMock;
     18 using ::testing::NiceMock;
     19 using ::testing::Mock;
     20 using ::testing::SaveArg;
     21 using ::testing::SetArgPointee;
     22 using ::testing::Return;
     23 using ::testing::Exactly;
     24 
     25 namespace local_discovery {
     26 
     27 namespace {
     28 
     29 const uint8 kSamplePacketPTR[] = {
     30   // Header
     31   0x00, 0x00,               // ID is zeroed out
     32   0x81, 0x80,               // Standard query response, RA, no error
     33   0x00, 0x00,               // No questions (for simplicity)
     34   0x00, 0x01,               // 1 RR (answers)
     35   0x00, 0x00,               // 0 authority RRs
     36   0x00, 0x00,               // 0 additional RRs
     37 
     38   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
     39   0x04, '_', 't', 'c', 'p',
     40   0x05, 'l', 'o', 'c', 'a', 'l',
     41   0x00,
     42   0x00, 0x0c,        // TYPE is PTR.
     43   0x00, 0x01,        // CLASS is IN.
     44   0x00, 0x00,        // TTL (4 bytes) is 1 second.
     45   0x00, 0x01,
     46   0x00, 0x08,        // RDLENGTH is 8 bytes.
     47   0x05, 'h', 'e', 'l', 'l', 'o',
     48   0xc0, 0x0c
     49 };
     50 
     51 const uint8 kSamplePacketSRV[] = {
     52   // Header
     53   0x00, 0x00,               // ID is zeroed out
     54   0x81, 0x80,               // Standard query response, RA, no error
     55   0x00, 0x00,               // No questions (for simplicity)
     56   0x00, 0x01,               // 1 RR (answers)
     57   0x00, 0x00,               // 0 authority RRs
     58   0x00, 0x00,               // 0 additional RRs
     59 
     60   0x05, 'h', 'e', 'l', 'l', 'o',
     61   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
     62   0x04, '_', 't', 'c', 'p',
     63   0x05, 'l', 'o', 'c', 'a', 'l',
     64   0x00,
     65   0x00, 0x21,        // TYPE is SRV.
     66   0x00, 0x01,        // CLASS is IN.
     67   0x00, 0x00,        // TTL (4 bytes) is 1 second.
     68   0x00, 0x01,
     69   0x00, 0x15,        // RDLENGTH is 21 bytes.
     70   0x00, 0x00,
     71   0x00, 0x00,
     72   0x22, 0xb8,  // port 8888
     73   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
     74   0x05, 'l', 'o', 'c', 'a', 'l',
     75   0x00,
     76 };
     77 
     78 const uint8 kSamplePacketTXT[] = {
     79   // Header
     80   0x00, 0x00,               // ID is zeroed out
     81   0x81, 0x80,               // Standard query response, RA, no error
     82   0x00, 0x00,               // No questions (for simplicity)
     83   0x00, 0x01,               // 1 RR (answers)
     84   0x00, 0x00,               // 0 authority RRs
     85   0x00, 0x00,               // 0 additional RRs
     86 
     87   0x05, 'h', 'e', 'l', 'l', 'o',
     88   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
     89   0x04, '_', 't', 'c', 'p',
     90   0x05, 'l', 'o', 'c', 'a', 'l',
     91   0x00,
     92   0x00, 0x10,        // TYPE is PTR.
     93   0x00, 0x01,        // CLASS is IN.
     94   0x00, 0x00,        // TTL (4 bytes) is 20 hours, 47 minutes, 48 seconds.
     95   0x00, 0x01,
     96   0x00, 0x06,        // RDLENGTH is 21 bytes.
     97   0x05, 'h', 'e', 'l', 'l', 'o'
     98 };
     99 
    100 const uint8 kSamplePacketSRVA[] = {
    101   // Header
    102   0x00, 0x00,               // ID is zeroed out
    103   0x81, 0x80,               // Standard query response, RA, no error
    104   0x00, 0x00,               // No questions (for simplicity)
    105   0x00, 0x02,               // 2 RR (answers)
    106   0x00, 0x00,               // 0 authority RRs
    107   0x00, 0x00,               // 0 additional RRs
    108 
    109   0x05, 'h', 'e', 'l', 'l', 'o',
    110   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    111   0x04, '_', 't', 'c', 'p',
    112   0x05, 'l', 'o', 'c', 'a', 'l',
    113   0x00,
    114   0x00, 0x21,        // TYPE is SRV.
    115   0x00, 0x01,        // CLASS is IN.
    116   0x00, 0x00,        // TTL (4 bytes) is 16 seconds.
    117   0x00, 0x10,
    118   0x00, 0x15,        // RDLENGTH is 21 bytes.
    119   0x00, 0x00,
    120   0x00, 0x00,
    121   0x22, 0xb8,  // port 8888
    122   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
    123   0x05, 'l', 'o', 'c', 'a', 'l',
    124   0x00,
    125 
    126   0x07, 'm', 'y', 'h', 'e', 'l', 'l', 'o',
    127   0x05, 'l', 'o', 'c', 'a', 'l',
    128   0x00,
    129   0x00, 0x01,        // TYPE is A.
    130   0x00, 0x01,        // CLASS is IN.
    131   0x00, 0x00,        // TTL (4 bytes) is 16 seconds.
    132   0x00, 0x10,
    133   0x00, 0x04,        // RDLENGTH is 4 bytes.
    134   0x01, 0x02,
    135   0x03, 0x04,
    136 };
    137 
    138 const uint8 kSamplePacketPTR2[] = {
    139   // Header
    140   0x00, 0x00,               // ID is zeroed out
    141   0x81, 0x80,               // Standard query response, RA, no error
    142   0x00, 0x00,               // No questions (for simplicity)
    143   0x00, 0x02,               // 2 RR (answers)
    144   0x00, 0x00,               // 0 authority RRs
    145   0x00, 0x00,               // 0 additional RRs
    146 
    147   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    148   0x04, '_', 't', 'c', 'p',
    149   0x05, 'l', 'o', 'c', 'a', 'l',
    150   0x00,
    151   0x00, 0x0c,        // TYPE is PTR.
    152   0x00, 0x01,        // CLASS is IN.
    153   0x02, 0x00,        // TTL (4 bytes) is 1 second.
    154   0x00, 0x01,
    155   0x00, 0x08,        // RDLENGTH is 8 bytes.
    156   0x05, 'g', 'd', 'b', 'y', 'e',
    157   0xc0, 0x0c,
    158 
    159   0x07, '_', 'p', 'r', 'i', 'v', 'e', 't',
    160   0x04, '_', 't', 'c', 'p',
    161   0x05, 'l', 'o', 'c', 'a', 'l',
    162   0x00,
    163   0x00, 0x0c,        // TYPE is PTR.
    164   0x00, 0x01,        // CLASS is IN.
    165   0x02, 0x00,        // TTL (4 bytes) is 1 second.
    166   0x00, 0x01,
    167   0x00, 0x08,        // RDLENGTH is 8 bytes.
    168   0x05, 'h', 'e', 'l', 'l', 'o',
    169   0xc0, 0x0c
    170 };
    171 
    172 class MockServiceWatcherClient {
    173  public:
    174   MOCK_METHOD2(OnServiceUpdated,
    175                void(ServiceWatcher::UpdateType, const std::string&));
    176 
    177   ServiceWatcher::UpdatedCallback GetCallback() {
    178     return base::Bind(&MockServiceWatcherClient::OnServiceUpdated,
    179                       base::Unretained(this));
    180   }
    181 };
    182 
    183 class ServiceDiscoveryTest : public ::testing::Test {
    184  public:
    185   ServiceDiscoveryTest()
    186       : service_discovery_client_(&mdns_client_) {
    187     mdns_client_.StartListening(&socket_factory_);
    188   }
    189 
    190   virtual ~ServiceDiscoveryTest() {
    191   }
    192 
    193  protected:
    194   void RunFor(base::TimeDelta time_period) {
    195     base::CancelableCallback<void()> callback(base::Bind(
    196         &ServiceDiscoveryTest::Stop, base::Unretained(this)));
    197     base::MessageLoop::current()->PostDelayedTask(
    198         FROM_HERE, callback.callback(), time_period);
    199 
    200     base::MessageLoop::current()->Run();
    201     callback.Cancel();
    202   }
    203 
    204   void Stop() {
    205     base::MessageLoop::current()->Quit();
    206   }
    207 
    208   net::MockMDnsSocketFactory socket_factory_;
    209   net::MDnsClientImpl mdns_client_;
    210   ServiceDiscoveryClientImpl service_discovery_client_;
    211   base::MessageLoop loop_;
    212 };
    213 
    214 TEST_F(ServiceDiscoveryTest, AddRemoveService) {
    215   StrictMock<MockServiceWatcherClient> delegate;
    216 
    217   scoped_ptr<ServiceWatcher> watcher(
    218       service_discovery_client_.CreateServiceWatcher(
    219           "_privet._tcp.local", delegate.GetCallback()));
    220 
    221   watcher->Start();
    222 
    223   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    224                                          "hello._privet._tcp.local"))
    225       .Times(Exactly(1));
    226 
    227   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
    228 
    229   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_REMOVED,
    230                                          "hello._privet._tcp.local"))
    231       .Times(Exactly(1));
    232 
    233   RunFor(base::TimeDelta::FromSeconds(2));
    234 };
    235 
    236 TEST_F(ServiceDiscoveryTest, DiscoverNewServices) {
    237   StrictMock<MockServiceWatcherClient> delegate;
    238 
    239   scoped_ptr<ServiceWatcher> watcher(
    240       service_discovery_client_.CreateServiceWatcher(
    241           "_privet._tcp.local", delegate.GetCallback()));
    242 
    243   watcher->Start();
    244 
    245   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
    246 
    247   watcher->DiscoverNewServices(false);
    248 
    249   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(2);
    250 
    251   RunFor(base::TimeDelta::FromSeconds(2));
    252 };
    253 
    254 TEST_F(ServiceDiscoveryTest, ReadCachedServices) {
    255   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
    256 
    257   StrictMock<MockServiceWatcherClient> delegate;
    258 
    259   scoped_ptr<ServiceWatcher> watcher(
    260       service_discovery_client_.CreateServiceWatcher(
    261           "_privet._tcp.local", delegate.GetCallback()));
    262 
    263   watcher->Start();
    264 
    265   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    266                                          "hello._privet._tcp.local"))
    267       .Times(Exactly(1));
    268 
    269   base::MessageLoop::current()->RunUntilIdle();
    270 };
    271 
    272 
    273 TEST_F(ServiceDiscoveryTest, ReadCachedServicesMultiple) {
    274   socket_factory_.SimulateReceive(kSamplePacketPTR2, sizeof(kSamplePacketPTR2));
    275 
    276   StrictMock<MockServiceWatcherClient> delegate;
    277   scoped_ptr<ServiceWatcher> watcher =
    278       service_discovery_client_.CreateServiceWatcher(
    279           "_privet._tcp.local", delegate.GetCallback());
    280 
    281   watcher->Start();
    282 
    283   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    284                                          "hello._privet._tcp.local"))
    285       .Times(Exactly(1));
    286 
    287   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    288                                          "gdbye._privet._tcp.local"))
    289       .Times(Exactly(1));
    290 
    291   base::MessageLoop::current()->RunUntilIdle();
    292 };
    293 
    294 
    295 TEST_F(ServiceDiscoveryTest, OnServiceChanged) {
    296   StrictMock<MockServiceWatcherClient> delegate;
    297   scoped_ptr<ServiceWatcher> watcher(
    298       service_discovery_client_.CreateServiceWatcher(
    299           "_privet._tcp.local", delegate.GetCallback()));
    300 
    301   watcher->Start();
    302 
    303   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    304                                          "hello._privet._tcp.local"))
    305       .Times(Exactly(1));
    306 
    307   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
    308 
    309   base::MessageLoop::current()->RunUntilIdle();
    310 
    311   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
    312                                          "hello._privet._tcp.local"))
    313       .Times(Exactly(1));
    314 
    315   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
    316 
    317   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
    318 
    319   base::MessageLoop::current()->RunUntilIdle();
    320 };
    321 
    322 TEST_F(ServiceDiscoveryTest, SinglePacket) {
    323   StrictMock<MockServiceWatcherClient> delegate;
    324   scoped_ptr<ServiceWatcher> watcher(
    325       service_discovery_client_.CreateServiceWatcher(
    326           "_privet._tcp.local", delegate.GetCallback()));
    327 
    328   watcher->Start();
    329 
    330   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_ADDED,
    331                                          "hello._privet._tcp.local"))
    332       .Times(Exactly(1));
    333 
    334   socket_factory_.SimulateReceive(kSamplePacketPTR, sizeof(kSamplePacketPTR));
    335 
    336   // Reset the "already updated" flag.
    337   base::MessageLoop::current()->RunUntilIdle();
    338 
    339   EXPECT_CALL(delegate, OnServiceUpdated(ServiceWatcher::UPDATE_CHANGED,
    340                                          "hello._privet._tcp.local"))
    341       .Times(Exactly(1));
    342 
    343   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
    344 
    345   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
    346 
    347   base::MessageLoop::current()->RunUntilIdle();
    348 };
    349 
    350 class ServiceResolverTest : public ServiceDiscoveryTest {
    351  public:
    352   ServiceResolverTest() {
    353     metadata_expected_.push_back("hello");
    354     address_expected_ = net::HostPortPair("myhello.local", 8888);
    355     ip_address_expected_.push_back(1);
    356     ip_address_expected_.push_back(2);
    357     ip_address_expected_.push_back(3);
    358     ip_address_expected_.push_back(4);
    359   }
    360 
    361   ~ServiceResolverTest() {
    362   }
    363 
    364   void SetUp()  {
    365     resolver_ = service_discovery_client_.CreateServiceResolver(
    366                     "hello._privet._tcp.local",
    367                      base::Bind(&ServiceResolverTest::OnFinishedResolving,
    368                                 base::Unretained(this)));
    369   }
    370 
    371   void OnFinishedResolving(ServiceResolver::RequestStatus request_status,
    372                            const ServiceDescription& service_description) {
    373     OnFinishedResolvingInternal(request_status,
    374                                 service_description.address.ToString(),
    375                                 service_description.metadata,
    376                                 service_description.ip_address);
    377   }
    378 
    379   MOCK_METHOD4(OnFinishedResolvingInternal,
    380                void(ServiceResolver::RequestStatus,
    381                     const std::string&,
    382                     const std::vector<std::string>&,
    383                     const net::IPAddressNumber&));
    384 
    385  protected:
    386   scoped_ptr<ServiceResolver> resolver_;
    387   net::IPAddressNumber ip_address_;
    388   net::HostPortPair address_expected_;
    389   std::vector<std::string> metadata_expected_;
    390   net::IPAddressNumber ip_address_expected_;
    391 };
    392 
    393 TEST_F(ServiceResolverTest, TxtAndSrvButNoA) {
    394   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
    395 
    396   resolver_->StartResolving();
    397 
    398   socket_factory_.SimulateReceive(kSamplePacketSRV, sizeof(kSamplePacketSRV));
    399 
    400   base::MessageLoop::current()->RunUntilIdle();
    401 
    402   EXPECT_CALL(*this,
    403               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
    404                                           address_expected_.ToString(),
    405                                           metadata_expected_,
    406                                           net::IPAddressNumber()));
    407 
    408   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
    409 };
    410 
    411 TEST_F(ServiceResolverTest, TxtSrvAndA) {
    412   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
    413 
    414   resolver_->StartResolving();
    415 
    416   EXPECT_CALL(*this,
    417               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
    418                                           address_expected_.ToString(),
    419                                           metadata_expected_,
    420                                           ip_address_expected_));
    421 
    422   socket_factory_.SimulateReceive(kSamplePacketTXT, sizeof(kSamplePacketTXT));
    423 
    424   socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
    425 };
    426 
    427 TEST_F(ServiceResolverTest, JustSrv) {
    428   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
    429 
    430   resolver_->StartResolving();
    431 
    432   EXPECT_CALL(*this,
    433               OnFinishedResolvingInternal(ServiceResolver::STATUS_SUCCESS,
    434                                           address_expected_.ToString(),
    435                                           std::vector<std::string>(),
    436                                           ip_address_expected_));
    437 
    438   socket_factory_.SimulateReceive(kSamplePacketSRVA, sizeof(kSamplePacketSRVA));
    439 
    440   // TODO(noamsml): When NSEC record support is added, change this to use an
    441   // NSEC record.
    442   RunFor(base::TimeDelta::FromSeconds(4));
    443 };
    444 
    445 TEST_F(ServiceResolverTest, WithNothing) {
    446   EXPECT_CALL(socket_factory_, OnSendTo(_)).Times(4);
    447 
    448   resolver_->StartResolving();
    449 
    450   EXPECT_CALL(*this, OnFinishedResolvingInternal(
    451                          ServiceResolver::STATUS_REQUEST_TIMEOUT, _, _, _));
    452 
    453   // TODO(noamsml): When NSEC record support is added, change this to use an
    454   // NSEC record.
    455   RunFor(base::TimeDelta::FromSeconds(4));
    456 };
    457 
    458 }  // namespace
    459 
    460 }  // namespace local_discovery
    461