Home | History | Annotate | Download | only in protocol
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/pairing_registry.h"
      6 
      7 #include <stdlib.h>
      8 
      9 #include <algorithm>
     10 
     11 #include "base/bind.h"
     12 #include "base/compiler_specific.h"
     13 #include "base/memory/scoped_ptr.h"
     14 #include "base/message_loop/message_loop.h"
     15 #include "base/run_loop.h"
     16 #include "base/thread_task_runner_handle.h"
     17 #include "base/values.h"
     18 #include "remoting/protocol/protocol_mock_objects.h"
     19 #include "testing/gmock/include/gmock/gmock.h"
     20 #include "testing/gtest/include/gtest/gtest.h"
     21 
     22 using testing::Sequence;
     23 
     24 namespace {
     25 
     26 using remoting::protocol::PairingRegistry;
     27 
     28 class MockPairingRegistryCallbacks {
     29  public:
     30   MockPairingRegistryCallbacks() {}
     31   virtual ~MockPairingRegistryCallbacks() {}
     32 
     33   MOCK_METHOD1(DoneCallback, void(bool));
     34   MOCK_METHOD1(GetAllPairingsCallbackPtr, void(base::ListValue*));
     35   MOCK_METHOD1(GetPairingCallback, void(PairingRegistry::Pairing));
     36 
     37   void GetAllPairingsCallback(scoped_ptr<base::ListValue> pairings) {
     38     GetAllPairingsCallbackPtr(pairings.get());
     39   }
     40 
     41  private:
     42   DISALLOW_COPY_AND_ASSIGN(MockPairingRegistryCallbacks);
     43 };
     44 
     45 // Verify that a pairing Dictionary has correct entries, but doesn't include
     46 // any shared secret.
     47 void VerifyPairing(PairingRegistry::Pairing expected,
     48                    const base::DictionaryValue& actual) {
     49   std::string value;
     50   EXPECT_TRUE(actual.GetString(PairingRegistry::kClientNameKey, &value));
     51   EXPECT_EQ(expected.client_name(), value);
     52   EXPECT_TRUE(actual.GetString(PairingRegistry::kClientIdKey, &value));
     53   EXPECT_EQ(expected.client_id(), value);
     54 
     55   EXPECT_FALSE(actual.HasKey(PairingRegistry::kSharedSecretKey));
     56 }
     57 
     58 }  // namespace
     59 
     60 namespace remoting {
     61 namespace protocol {
     62 
     63 class PairingRegistryTest : public testing::Test {
     64  public:
     65   virtual void SetUp() OVERRIDE {
     66     callback_count_ = 0;
     67   }
     68 
     69   void set_pairings(scoped_ptr<base::ListValue> pairings) {
     70     pairings_ = pairings.Pass();
     71   }
     72 
     73   void ExpectSecret(const std::string& expected,
     74                     PairingRegistry::Pairing actual) {
     75     EXPECT_EQ(expected, actual.shared_secret());
     76     ++callback_count_;
     77   }
     78 
     79   void ExpectSaveSuccess(bool success) {
     80     EXPECT_TRUE(success);
     81     ++callback_count_;
     82   }
     83 
     84  protected:
     85   base::MessageLoop message_loop_;
     86   base::RunLoop run_loop_;
     87 
     88   int callback_count_;
     89   scoped_ptr<base::ListValue> pairings_;
     90 };
     91 
     92 TEST_F(PairingRegistryTest, CreateAndGetPairings) {
     93   scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
     94       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
     95   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("my_client");
     96   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("my_client");
     97 
     98   EXPECT_NE(pairing_1.shared_secret(), pairing_2.shared_secret());
     99 
    100   registry->GetPairing(pairing_1.client_id(),
    101                        base::Bind(&PairingRegistryTest::ExpectSecret,
    102                                   base::Unretained(this),
    103                                   pairing_1.shared_secret()));
    104   EXPECT_EQ(1, callback_count_);
    105 
    106   // Check that the second client is paired with a different shared secret.
    107   registry->GetPairing(pairing_2.client_id(),
    108                        base::Bind(&PairingRegistryTest::ExpectSecret,
    109                                   base::Unretained(this),
    110                                   pairing_2.shared_secret()));
    111   EXPECT_EQ(2, callback_count_);
    112 }
    113 
    114 TEST_F(PairingRegistryTest, GetAllPairings) {
    115   scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
    116       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
    117   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
    118   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
    119 
    120   registry->GetAllPairings(
    121       base::Bind(&PairingRegistryTest::set_pairings,
    122                  base::Unretained(this)));
    123 
    124   ASSERT_EQ(2u, pairings_->GetSize());
    125   const base::DictionaryValue* actual_pairing_1;
    126   const base::DictionaryValue* actual_pairing_2;
    127   ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_1));
    128   ASSERT_TRUE(pairings_->GetDictionary(1, &actual_pairing_2));
    129 
    130   // Ordering is not guaranteed, so swap if necessary.
    131   std::string actual_client_id;
    132   ASSERT_TRUE(actual_pairing_1->GetString(PairingRegistry::kClientIdKey,
    133                                           &actual_client_id));
    134   if (actual_client_id != pairing_1.client_id()) {
    135     std::swap(actual_pairing_1, actual_pairing_2);
    136   }
    137 
    138   VerifyPairing(pairing_1, *actual_pairing_1);
    139   VerifyPairing(pairing_2, *actual_pairing_2);
    140 }
    141 
    142 TEST_F(PairingRegistryTest, DeletePairing) {
    143   scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
    144       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
    145   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
    146   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
    147 
    148   registry->DeletePairing(
    149       pairing_1.client_id(),
    150       base::Bind(&PairingRegistryTest::ExpectSaveSuccess,
    151                  base::Unretained(this)));
    152 
    153   // Re-read the list, and verify it only has the pairing_2 client.
    154   registry->GetAllPairings(
    155       base::Bind(&PairingRegistryTest::set_pairings,
    156                  base::Unretained(this)));
    157 
    158   ASSERT_EQ(1u, pairings_->GetSize());
    159   const base::DictionaryValue* actual_pairing_2;
    160   ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_2));
    161   std::string actual_client_id;
    162   ASSERT_TRUE(actual_pairing_2->GetString(PairingRegistry::kClientIdKey,
    163                                           &actual_client_id));
    164   EXPECT_EQ(pairing_2.client_id(), actual_client_id);
    165 }
    166 
    167 TEST_F(PairingRegistryTest, ClearAllPairings) {
    168   scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry(
    169       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
    170   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
    171   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
    172 
    173   registry->ClearAllPairings(
    174       base::Bind(&PairingRegistryTest::ExpectSaveSuccess,
    175                  base::Unretained(this)));
    176 
    177   // Re-read the list, and verify it is empty.
    178   registry->GetAllPairings(
    179       base::Bind(&PairingRegistryTest::set_pairings,
    180                  base::Unretained(this)));
    181 
    182   EXPECT_TRUE(pairings_->empty());
    183 }
    184 
    185 ACTION_P(QuitMessageLoop, callback) {
    186   callback.Run();
    187 }
    188 
    189 MATCHER_P(EqualsClientName, client_name, "") {
    190   return arg.client_name() == client_name;
    191 }
    192 
    193 MATCHER(NoPairings, "") {
    194   return arg->empty();
    195 }
    196 
    197 TEST_F(PairingRegistryTest, SerializedRequests) {
    198   MockPairingRegistryCallbacks callbacks;
    199   Sequence s;
    200   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1")))
    201       .InSequence(s);
    202   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client2")))
    203       .InSequence(s);
    204   EXPECT_CALL(callbacks, DoneCallback(true))
    205       .InSequence(s);
    206   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1")))
    207       .InSequence(s);
    208   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("")))
    209       .InSequence(s);
    210   EXPECT_CALL(callbacks, DoneCallback(true))
    211       .InSequence(s);
    212   EXPECT_CALL(callbacks, GetAllPairingsCallbackPtr(NoPairings()))
    213       .InSequence(s);
    214   EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client3")))
    215       .InSequence(s)
    216       .WillOnce(QuitMessageLoop(run_loop_.QuitClosure()));
    217 
    218   scoped_refptr<PairingRegistry> registry = new PairingRegistry(
    219       base::ThreadTaskRunnerHandle::Get(),
    220       scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate()));
    221   PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1");
    222   PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2");
    223   registry->GetPairing(
    224       pairing_1.client_id(),
    225       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
    226                  base::Unretained(&callbacks)));
    227   registry->GetPairing(
    228       pairing_2.client_id(),
    229       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
    230                  base::Unretained(&callbacks)));
    231   registry->DeletePairing(
    232       pairing_2.client_id(),
    233       base::Bind(&MockPairingRegistryCallbacks::DoneCallback,
    234                  base::Unretained(&callbacks)));
    235   registry->GetPairing(
    236       pairing_1.client_id(),
    237       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
    238                  base::Unretained(&callbacks)));
    239   registry->GetPairing(
    240       pairing_2.client_id(),
    241       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
    242                  base::Unretained(&callbacks)));
    243   registry->ClearAllPairings(
    244       base::Bind(&MockPairingRegistryCallbacks::DoneCallback,
    245                  base::Unretained(&callbacks)));
    246   registry->GetAllPairings(
    247       base::Bind(&MockPairingRegistryCallbacks::GetAllPairingsCallback,
    248                  base::Unretained(&callbacks)));
    249   PairingRegistry::Pairing pairing_3 = registry->CreatePairing("client3");
    250   registry->GetPairing(
    251       pairing_3.client_id(),
    252       base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback,
    253                  base::Unretained(&callbacks)));
    254 
    255   run_loop_.Run();
    256 }
    257 
    258 }  // namespace protocol
    259 }  // namespace remoting
    260