1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/pairing_registry.h" 6 7 #include <stdlib.h> 8 9 #include <algorithm> 10 11 #include "base/bind.h" 12 #include "base/compiler_specific.h" 13 #include "base/memory/scoped_ptr.h" 14 #include "base/message_loop/message_loop.h" 15 #include "base/run_loop.h" 16 #include "base/thread_task_runner_handle.h" 17 #include "base/values.h" 18 #include "remoting/protocol/protocol_mock_objects.h" 19 #include "testing/gmock/include/gmock/gmock.h" 20 #include "testing/gtest/include/gtest/gtest.h" 21 22 using testing::Sequence; 23 24 namespace { 25 26 using remoting::protocol::PairingRegistry; 27 28 class MockPairingRegistryCallbacks { 29 public: 30 MockPairingRegistryCallbacks() {} 31 virtual ~MockPairingRegistryCallbacks() {} 32 33 MOCK_METHOD1(DoneCallback, void(bool)); 34 MOCK_METHOD1(GetAllPairingsCallbackPtr, void(base::ListValue*)); 35 MOCK_METHOD1(GetPairingCallback, void(PairingRegistry::Pairing)); 36 37 void GetAllPairingsCallback(scoped_ptr<base::ListValue> pairings) { 38 GetAllPairingsCallbackPtr(pairings.get()); 39 } 40 41 private: 42 DISALLOW_COPY_AND_ASSIGN(MockPairingRegistryCallbacks); 43 }; 44 45 // Verify that a pairing Dictionary has correct entries, but doesn't include 46 // any shared secret. 47 void VerifyPairing(PairingRegistry::Pairing expected, 48 const base::DictionaryValue& actual) { 49 std::string value; 50 EXPECT_TRUE(actual.GetString(PairingRegistry::kClientNameKey, &value)); 51 EXPECT_EQ(expected.client_name(), value); 52 EXPECT_TRUE(actual.GetString(PairingRegistry::kClientIdKey, &value)); 53 EXPECT_EQ(expected.client_id(), value); 54 55 EXPECT_FALSE(actual.HasKey(PairingRegistry::kSharedSecretKey)); 56 } 57 58 } // namespace 59 60 namespace remoting { 61 namespace protocol { 62 63 class PairingRegistryTest : public testing::Test { 64 public: 65 virtual void SetUp() OVERRIDE { 66 callback_count_ = 0; 67 } 68 69 void set_pairings(scoped_ptr<base::ListValue> pairings) { 70 pairings_ = pairings.Pass(); 71 } 72 73 void ExpectSecret(const std::string& expected, 74 PairingRegistry::Pairing actual) { 75 EXPECT_EQ(expected, actual.shared_secret()); 76 ++callback_count_; 77 } 78 79 void ExpectSaveSuccess(bool success) { 80 EXPECT_TRUE(success); 81 ++callback_count_; 82 } 83 84 protected: 85 base::MessageLoop message_loop_; 86 base::RunLoop run_loop_; 87 88 int callback_count_; 89 scoped_ptr<base::ListValue> pairings_; 90 }; 91 92 TEST_F(PairingRegistryTest, CreateAndGetPairings) { 93 scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( 94 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); 95 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("my_client"); 96 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("my_client"); 97 98 EXPECT_NE(pairing_1.shared_secret(), pairing_2.shared_secret()); 99 100 registry->GetPairing(pairing_1.client_id(), 101 base::Bind(&PairingRegistryTest::ExpectSecret, 102 base::Unretained(this), 103 pairing_1.shared_secret())); 104 EXPECT_EQ(1, callback_count_); 105 106 // Check that the second client is paired with a different shared secret. 107 registry->GetPairing(pairing_2.client_id(), 108 base::Bind(&PairingRegistryTest::ExpectSecret, 109 base::Unretained(this), 110 pairing_2.shared_secret())); 111 EXPECT_EQ(2, callback_count_); 112 } 113 114 TEST_F(PairingRegistryTest, GetAllPairings) { 115 scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( 116 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); 117 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); 118 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); 119 120 registry->GetAllPairings( 121 base::Bind(&PairingRegistryTest::set_pairings, 122 base::Unretained(this))); 123 124 ASSERT_EQ(2u, pairings_->GetSize()); 125 const base::DictionaryValue* actual_pairing_1; 126 const base::DictionaryValue* actual_pairing_2; 127 ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_1)); 128 ASSERT_TRUE(pairings_->GetDictionary(1, &actual_pairing_2)); 129 130 // Ordering is not guaranteed, so swap if necessary. 131 std::string actual_client_id; 132 ASSERT_TRUE(actual_pairing_1->GetString(PairingRegistry::kClientIdKey, 133 &actual_client_id)); 134 if (actual_client_id != pairing_1.client_id()) { 135 std::swap(actual_pairing_1, actual_pairing_2); 136 } 137 138 VerifyPairing(pairing_1, *actual_pairing_1); 139 VerifyPairing(pairing_2, *actual_pairing_2); 140 } 141 142 TEST_F(PairingRegistryTest, DeletePairing) { 143 scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( 144 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); 145 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); 146 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); 147 148 registry->DeletePairing( 149 pairing_1.client_id(), 150 base::Bind(&PairingRegistryTest::ExpectSaveSuccess, 151 base::Unretained(this))); 152 153 // Re-read the list, and verify it only has the pairing_2 client. 154 registry->GetAllPairings( 155 base::Bind(&PairingRegistryTest::set_pairings, 156 base::Unretained(this))); 157 158 ASSERT_EQ(1u, pairings_->GetSize()); 159 const base::DictionaryValue* actual_pairing_2; 160 ASSERT_TRUE(pairings_->GetDictionary(0, &actual_pairing_2)); 161 std::string actual_client_id; 162 ASSERT_TRUE(actual_pairing_2->GetString(PairingRegistry::kClientIdKey, 163 &actual_client_id)); 164 EXPECT_EQ(pairing_2.client_id(), actual_client_id); 165 } 166 167 TEST_F(PairingRegistryTest, ClearAllPairings) { 168 scoped_refptr<PairingRegistry> registry = new SynchronousPairingRegistry( 169 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); 170 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); 171 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); 172 173 registry->ClearAllPairings( 174 base::Bind(&PairingRegistryTest::ExpectSaveSuccess, 175 base::Unretained(this))); 176 177 // Re-read the list, and verify it is empty. 178 registry->GetAllPairings( 179 base::Bind(&PairingRegistryTest::set_pairings, 180 base::Unretained(this))); 181 182 EXPECT_TRUE(pairings_->empty()); 183 } 184 185 ACTION_P(QuitMessageLoop, callback) { 186 callback.Run(); 187 } 188 189 MATCHER_P(EqualsClientName, client_name, "") { 190 return arg.client_name() == client_name; 191 } 192 193 MATCHER(NoPairings, "") { 194 return arg->empty(); 195 } 196 197 TEST_F(PairingRegistryTest, SerializedRequests) { 198 MockPairingRegistryCallbacks callbacks; 199 Sequence s; 200 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1"))) 201 .InSequence(s); 202 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client2"))) 203 .InSequence(s); 204 EXPECT_CALL(callbacks, DoneCallback(true)) 205 .InSequence(s); 206 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client1"))) 207 .InSequence(s); 208 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName(""))) 209 .InSequence(s); 210 EXPECT_CALL(callbacks, DoneCallback(true)) 211 .InSequence(s); 212 EXPECT_CALL(callbacks, GetAllPairingsCallbackPtr(NoPairings())) 213 .InSequence(s); 214 EXPECT_CALL(callbacks, GetPairingCallback(EqualsClientName("client3"))) 215 .InSequence(s) 216 .WillOnce(QuitMessageLoop(run_loop_.QuitClosure())); 217 218 scoped_refptr<PairingRegistry> registry = new PairingRegistry( 219 base::ThreadTaskRunnerHandle::Get(), 220 scoped_ptr<PairingRegistry::Delegate>(new MockPairingRegistryDelegate())); 221 PairingRegistry::Pairing pairing_1 = registry->CreatePairing("client1"); 222 PairingRegistry::Pairing pairing_2 = registry->CreatePairing("client2"); 223 registry->GetPairing( 224 pairing_1.client_id(), 225 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, 226 base::Unretained(&callbacks))); 227 registry->GetPairing( 228 pairing_2.client_id(), 229 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, 230 base::Unretained(&callbacks))); 231 registry->DeletePairing( 232 pairing_2.client_id(), 233 base::Bind(&MockPairingRegistryCallbacks::DoneCallback, 234 base::Unretained(&callbacks))); 235 registry->GetPairing( 236 pairing_1.client_id(), 237 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, 238 base::Unretained(&callbacks))); 239 registry->GetPairing( 240 pairing_2.client_id(), 241 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, 242 base::Unretained(&callbacks))); 243 registry->ClearAllPairings( 244 base::Bind(&MockPairingRegistryCallbacks::DoneCallback, 245 base::Unretained(&callbacks))); 246 registry->GetAllPairings( 247 base::Bind(&MockPairingRegistryCallbacks::GetAllPairingsCallback, 248 base::Unretained(&callbacks))); 249 PairingRegistry::Pairing pairing_3 = registry->CreatePairing("client3"); 250 registry->GetPairing( 251 pairing_3.client_id(), 252 base::Bind(&MockPairingRegistryCallbacks::GetPairingCallback, 253 base::Unretained(&callbacks))); 254 255 run_loop_.Run(); 256 } 257 258 } // namespace protocol 259 } // namespace remoting 260