1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/ssl/default_server_bound_cert_store.h" 6 7 #include <map> 8 #include <string> 9 #include <vector> 10 11 #include "base/bind.h" 12 #include "base/compiler_specific.h" 13 #include "base/logging.h" 14 #include "base/memory/scoped_ptr.h" 15 #include "base/message_loop/message_loop.h" 16 #include "net/base/net_errors.h" 17 #include "testing/gtest/include/gtest/gtest.h" 18 19 namespace net { 20 21 namespace { 22 23 void CallCounter(int* counter) { 24 (*counter)++; 25 } 26 27 void NotCalled() { 28 ADD_FAILURE() << "Unexpected callback execution."; 29 } 30 31 void GetCertCallbackNotCalled(int err, 32 const std::string& server_identifier, 33 base::Time expiration_time, 34 const std::string& private_key_result, 35 const std::string& cert_result) { 36 ADD_FAILURE() << "Unexpected callback execution."; 37 } 38 39 class AsyncGetCertHelper { 40 public: 41 AsyncGetCertHelper() : called_(false) {} 42 43 void Callback(int err, 44 const std::string& server_identifier, 45 base::Time expiration_time, 46 const std::string& private_key_result, 47 const std::string& cert_result) { 48 err_ = err; 49 server_identifier_ = server_identifier; 50 expiration_time_ = expiration_time; 51 private_key_ = private_key_result; 52 cert_ = cert_result; 53 called_ = true; 54 } 55 56 int err_; 57 std::string server_identifier_; 58 base::Time expiration_time_; 59 std::string private_key_; 60 std::string cert_; 61 bool called_; 62 }; 63 64 void GetAllCallback( 65 ServerBoundCertStore::ServerBoundCertList* dest, 66 const ServerBoundCertStore::ServerBoundCertList& result) { 67 *dest = result; 68 } 69 70 class MockPersistentStore 71 : public DefaultServerBoundCertStore::PersistentStore { 72 public: 73 MockPersistentStore(); 74 75 // DefaultServerBoundCertStore::PersistentStore implementation. 76 virtual void Load(const LoadedCallback& loaded_callback) OVERRIDE; 77 virtual void AddServerBoundCert( 78 const DefaultServerBoundCertStore::ServerBoundCert& cert) OVERRIDE; 79 virtual void DeleteServerBoundCert( 80 const DefaultServerBoundCertStore::ServerBoundCert& cert) OVERRIDE; 81 virtual void SetForceKeepSessionState() OVERRIDE; 82 83 protected: 84 virtual ~MockPersistentStore(); 85 86 private: 87 typedef std::map<std::string, DefaultServerBoundCertStore::ServerBoundCert> 88 ServerBoundCertMap; 89 90 ServerBoundCertMap origin_certs_; 91 }; 92 93 MockPersistentStore::MockPersistentStore() {} 94 95 void MockPersistentStore::Load(const LoadedCallback& loaded_callback) { 96 scoped_ptr<ScopedVector<DefaultServerBoundCertStore::ServerBoundCert> > 97 certs(new ScopedVector<DefaultServerBoundCertStore::ServerBoundCert>()); 98 ServerBoundCertMap::iterator it; 99 100 for (it = origin_certs_.begin(); it != origin_certs_.end(); ++it) { 101 certs->push_back( 102 new DefaultServerBoundCertStore::ServerBoundCert(it->second)); 103 } 104 105 base::MessageLoop::current()->PostTask( 106 FROM_HERE, base::Bind(loaded_callback, base::Passed(&certs))); 107 } 108 109 void MockPersistentStore::AddServerBoundCert( 110 const DefaultServerBoundCertStore::ServerBoundCert& cert) { 111 origin_certs_[cert.server_identifier()] = cert; 112 } 113 114 void MockPersistentStore::DeleteServerBoundCert( 115 const DefaultServerBoundCertStore::ServerBoundCert& cert) { 116 origin_certs_.erase(cert.server_identifier()); 117 } 118 119 void MockPersistentStore::SetForceKeepSessionState() {} 120 121 MockPersistentStore::~MockPersistentStore() {} 122 123 } // namespace 124 125 TEST(DefaultServerBoundCertStoreTest, TestLoading) { 126 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 127 128 persistent_store->AddServerBoundCert( 129 DefaultServerBoundCertStore::ServerBoundCert( 130 "google.com", 131 base::Time(), 132 base::Time(), 133 "a", "b")); 134 persistent_store->AddServerBoundCert( 135 DefaultServerBoundCertStore::ServerBoundCert( 136 "verisign.com", 137 base::Time(), 138 base::Time(), 139 "c", "d")); 140 141 // Make sure certs load properly. 142 DefaultServerBoundCertStore store(persistent_store.get()); 143 // Load has not occurred yet. 144 EXPECT_EQ(0, store.GetCertCount()); 145 store.SetServerBoundCert( 146 "verisign.com", 147 base::Time(), 148 base::Time(), 149 "e", "f"); 150 // Wait for load & queued set task. 151 base::MessageLoop::current()->RunUntilIdle(); 152 EXPECT_EQ(2, store.GetCertCount()); 153 store.SetServerBoundCert( 154 "twitter.com", 155 base::Time(), 156 base::Time(), 157 "g", "h"); 158 // Set should be synchronous now that load is done. 159 EXPECT_EQ(3, store.GetCertCount()); 160 } 161 162 //TODO(mattm): add more tests of without a persistent store? 163 TEST(DefaultServerBoundCertStoreTest, TestSettingAndGetting) { 164 // No persistent store, all calls will be synchronous. 165 DefaultServerBoundCertStore store(NULL); 166 base::Time expiration_time; 167 std::string private_key, cert; 168 EXPECT_EQ(0, store.GetCertCount()); 169 EXPECT_EQ(ERR_FILE_NOT_FOUND, 170 store.GetServerBoundCert("verisign.com", 171 &expiration_time, 172 &private_key, 173 &cert, 174 base::Bind(&GetCertCallbackNotCalled))); 175 EXPECT_TRUE(private_key.empty()); 176 EXPECT_TRUE(cert.empty()); 177 store.SetServerBoundCert( 178 "verisign.com", 179 base::Time::FromInternalValue(123), 180 base::Time::FromInternalValue(456), 181 "i", "j"); 182 EXPECT_EQ(OK, 183 store.GetServerBoundCert("verisign.com", 184 &expiration_time, 185 &private_key, 186 &cert, 187 base::Bind(&GetCertCallbackNotCalled))); 188 EXPECT_EQ(456, expiration_time.ToInternalValue()); 189 EXPECT_EQ("i", private_key); 190 EXPECT_EQ("j", cert); 191 } 192 193 TEST(DefaultServerBoundCertStoreTest, TestDuplicateCerts) { 194 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 195 DefaultServerBoundCertStore store(persistent_store.get()); 196 197 base::Time expiration_time; 198 std::string private_key, cert; 199 EXPECT_EQ(0, store.GetCertCount()); 200 store.SetServerBoundCert( 201 "verisign.com", 202 base::Time::FromInternalValue(123), 203 base::Time::FromInternalValue(1234), 204 "a", "b"); 205 store.SetServerBoundCert( 206 "verisign.com", 207 base::Time::FromInternalValue(456), 208 base::Time::FromInternalValue(4567), 209 "c", "d"); 210 211 // Wait for load & queued set tasks. 212 base::MessageLoop::current()->RunUntilIdle(); 213 EXPECT_EQ(1, store.GetCertCount()); 214 EXPECT_EQ(OK, 215 store.GetServerBoundCert("verisign.com", 216 &expiration_time, 217 &private_key, 218 &cert, 219 base::Bind(&GetCertCallbackNotCalled))); 220 EXPECT_EQ(4567, expiration_time.ToInternalValue()); 221 EXPECT_EQ("c", private_key); 222 EXPECT_EQ("d", cert); 223 } 224 225 TEST(DefaultServerBoundCertStoreTest, TestAsyncGet) { 226 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 227 persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( 228 "verisign.com", 229 base::Time::FromInternalValue(123), 230 base::Time::FromInternalValue(1234), 231 "a", "b")); 232 233 DefaultServerBoundCertStore store(persistent_store.get()); 234 AsyncGetCertHelper helper; 235 base::Time expiration_time; 236 std::string private_key; 237 std::string cert = "not set"; 238 EXPECT_EQ(0, store.GetCertCount()); 239 EXPECT_EQ(ERR_IO_PENDING, 240 store.GetServerBoundCert("verisign.com", 241 &expiration_time, 242 &private_key, 243 &cert, 244 base::Bind(&AsyncGetCertHelper::Callback, 245 base::Unretained(&helper)))); 246 247 // Wait for load & queued get tasks. 248 base::MessageLoop::current()->RunUntilIdle(); 249 EXPECT_EQ(1, store.GetCertCount()); 250 EXPECT_EQ("not set", cert); 251 EXPECT_TRUE(helper.called_); 252 EXPECT_EQ(OK, helper.err_); 253 EXPECT_EQ("verisign.com", helper.server_identifier_); 254 EXPECT_EQ(1234, helper.expiration_time_.ToInternalValue()); 255 EXPECT_EQ("a", helper.private_key_); 256 EXPECT_EQ("b", helper.cert_); 257 } 258 259 TEST(DefaultServerBoundCertStoreTest, TestDeleteAll) { 260 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 261 DefaultServerBoundCertStore store(persistent_store.get()); 262 263 store.SetServerBoundCert( 264 "verisign.com", 265 base::Time(), 266 base::Time(), 267 "a", "b"); 268 store.SetServerBoundCert( 269 "google.com", 270 base::Time(), 271 base::Time(), 272 "c", "d"); 273 store.SetServerBoundCert( 274 "harvard.com", 275 base::Time(), 276 base::Time(), 277 "e", "f"); 278 // Wait for load & queued set tasks. 279 base::MessageLoop::current()->RunUntilIdle(); 280 281 EXPECT_EQ(3, store.GetCertCount()); 282 int delete_finished = 0; 283 store.DeleteAll(base::Bind(&CallCounter, &delete_finished)); 284 ASSERT_EQ(1, delete_finished); 285 EXPECT_EQ(0, store.GetCertCount()); 286 } 287 288 TEST(DefaultServerBoundCertStoreTest, TestAsyncGetAndDeleteAll) { 289 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 290 persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( 291 "verisign.com", 292 base::Time(), 293 base::Time(), 294 "a", "b")); 295 persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( 296 "google.com", 297 base::Time(), 298 base::Time(), 299 "c", "d")); 300 301 ServerBoundCertStore::ServerBoundCertList pre_certs; 302 ServerBoundCertStore::ServerBoundCertList post_certs; 303 int delete_finished = 0; 304 DefaultServerBoundCertStore store(persistent_store.get()); 305 306 store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &pre_certs)); 307 store.DeleteAll(base::Bind(&CallCounter, &delete_finished)); 308 store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &post_certs)); 309 // Tasks have not run yet. 310 EXPECT_EQ(0u, pre_certs.size()); 311 // Wait for load & queued tasks. 312 base::MessageLoop::current()->RunUntilIdle(); 313 EXPECT_EQ(0, store.GetCertCount()); 314 EXPECT_EQ(2u, pre_certs.size()); 315 EXPECT_EQ(0u, post_certs.size()); 316 } 317 318 TEST(DefaultServerBoundCertStoreTest, TestDelete) { 319 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 320 DefaultServerBoundCertStore store(persistent_store.get()); 321 322 base::Time expiration_time; 323 std::string private_key, cert; 324 EXPECT_EQ(0, store.GetCertCount()); 325 store.SetServerBoundCert( 326 "verisign.com", 327 base::Time(), 328 base::Time(), 329 "a", "b"); 330 // Wait for load & queued set task. 331 base::MessageLoop::current()->RunUntilIdle(); 332 333 store.SetServerBoundCert( 334 "google.com", 335 base::Time(), 336 base::Time(), 337 "c", "d"); 338 339 EXPECT_EQ(2, store.GetCertCount()); 340 int delete_finished = 0; 341 store.DeleteServerBoundCert("verisign.com", 342 base::Bind(&CallCounter, &delete_finished)); 343 ASSERT_EQ(1, delete_finished); 344 EXPECT_EQ(1, store.GetCertCount()); 345 EXPECT_EQ(ERR_FILE_NOT_FOUND, 346 store.GetServerBoundCert("verisign.com", 347 &expiration_time, 348 &private_key, 349 &cert, 350 base::Bind(&GetCertCallbackNotCalled))); 351 EXPECT_EQ(OK, 352 store.GetServerBoundCert("google.com", 353 &expiration_time, 354 &private_key, 355 &cert, 356 base::Bind(&GetCertCallbackNotCalled))); 357 int delete2_finished = 0; 358 store.DeleteServerBoundCert("google.com", 359 base::Bind(&CallCounter, &delete2_finished)); 360 ASSERT_EQ(1, delete2_finished); 361 EXPECT_EQ(0, store.GetCertCount()); 362 EXPECT_EQ(ERR_FILE_NOT_FOUND, 363 store.GetServerBoundCert("google.com", 364 &expiration_time, 365 &private_key, 366 &cert, 367 base::Bind(&GetCertCallbackNotCalled))); 368 } 369 370 TEST(DefaultServerBoundCertStoreTest, TestAsyncDelete) { 371 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 372 persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( 373 "a.com", 374 base::Time::FromInternalValue(1), 375 base::Time::FromInternalValue(2), 376 "a", "b")); 377 persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( 378 "b.com", 379 base::Time::FromInternalValue(3), 380 base::Time::FromInternalValue(4), 381 "c", "d")); 382 DefaultServerBoundCertStore store(persistent_store.get()); 383 int delete_finished = 0; 384 store.DeleteServerBoundCert("a.com", 385 base::Bind(&CallCounter, &delete_finished)); 386 387 AsyncGetCertHelper a_helper; 388 AsyncGetCertHelper b_helper; 389 base::Time expiration_time; 390 std::string private_key; 391 std::string cert = "not set"; 392 EXPECT_EQ(0, store.GetCertCount()); 393 EXPECT_EQ(ERR_IO_PENDING, 394 store.GetServerBoundCert( 395 "a.com", &expiration_time, &private_key, &cert, 396 base::Bind(&AsyncGetCertHelper::Callback, 397 base::Unretained(&a_helper)))); 398 EXPECT_EQ(ERR_IO_PENDING, 399 store.GetServerBoundCert( 400 "b.com", &expiration_time, &private_key, &cert, 401 base::Bind(&AsyncGetCertHelper::Callback, 402 base::Unretained(&b_helper)))); 403 404 EXPECT_EQ(0, delete_finished); 405 EXPECT_FALSE(a_helper.called_); 406 EXPECT_FALSE(b_helper.called_); 407 // Wait for load & queued tasks. 408 base::MessageLoop::current()->RunUntilIdle(); 409 EXPECT_EQ(1, delete_finished); 410 EXPECT_EQ(1, store.GetCertCount()); 411 EXPECT_EQ("not set", cert); 412 EXPECT_TRUE(a_helper.called_); 413 EXPECT_EQ(ERR_FILE_NOT_FOUND, a_helper.err_); 414 EXPECT_EQ("a.com", a_helper.server_identifier_); 415 EXPECT_EQ(0, a_helper.expiration_time_.ToInternalValue()); 416 EXPECT_EQ("", a_helper.private_key_); 417 EXPECT_EQ("", a_helper.cert_); 418 EXPECT_TRUE(b_helper.called_); 419 EXPECT_EQ(OK, b_helper.err_); 420 EXPECT_EQ("b.com", b_helper.server_identifier_); 421 EXPECT_EQ(4, b_helper.expiration_time_.ToInternalValue()); 422 EXPECT_EQ("c", b_helper.private_key_); 423 EXPECT_EQ("d", b_helper.cert_); 424 } 425 426 TEST(DefaultServerBoundCertStoreTest, TestGetAll) { 427 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 428 DefaultServerBoundCertStore store(persistent_store.get()); 429 430 EXPECT_EQ(0, store.GetCertCount()); 431 store.SetServerBoundCert( 432 "verisign.com", 433 base::Time(), 434 base::Time(), 435 "a", "b"); 436 store.SetServerBoundCert( 437 "google.com", 438 base::Time(), 439 base::Time(), 440 "c", "d"); 441 store.SetServerBoundCert( 442 "harvard.com", 443 base::Time(), 444 base::Time(), 445 "e", "f"); 446 store.SetServerBoundCert( 447 "mit.com", 448 base::Time(), 449 base::Time(), 450 "g", "h"); 451 // Wait for load & queued set tasks. 452 base::MessageLoop::current()->RunUntilIdle(); 453 454 EXPECT_EQ(4, store.GetCertCount()); 455 ServerBoundCertStore::ServerBoundCertList certs; 456 store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &certs)); 457 EXPECT_EQ(4u, certs.size()); 458 } 459 460 TEST(DefaultServerBoundCertStoreTest, TestInitializeFrom) { 461 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 462 DefaultServerBoundCertStore store(persistent_store.get()); 463 464 store.SetServerBoundCert( 465 "preexisting.com", 466 base::Time(), 467 base::Time(), 468 "a", "b"); 469 store.SetServerBoundCert( 470 "both.com", 471 base::Time(), 472 base::Time(), 473 "c", "d"); 474 // Wait for load & queued set tasks. 475 base::MessageLoop::current()->RunUntilIdle(); 476 EXPECT_EQ(2, store.GetCertCount()); 477 478 ServerBoundCertStore::ServerBoundCertList source_certs; 479 source_certs.push_back(ServerBoundCertStore::ServerBoundCert( 480 "both.com", 481 base::Time(), 482 base::Time(), 483 // Key differs from above to test that existing entries are overwritten. 484 "e", "f")); 485 source_certs.push_back(ServerBoundCertStore::ServerBoundCert( 486 "copied.com", 487 base::Time(), 488 base::Time(), 489 "g", "h")); 490 store.InitializeFrom(source_certs); 491 EXPECT_EQ(3, store.GetCertCount()); 492 493 ServerBoundCertStore::ServerBoundCertList certs; 494 store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &certs)); 495 ASSERT_EQ(3u, certs.size()); 496 497 ServerBoundCertStore::ServerBoundCertList::iterator cert = certs.begin(); 498 EXPECT_EQ("both.com", cert->server_identifier()); 499 EXPECT_EQ("e", cert->private_key()); 500 501 ++cert; 502 EXPECT_EQ("copied.com", cert->server_identifier()); 503 EXPECT_EQ("g", cert->private_key()); 504 505 ++cert; 506 EXPECT_EQ("preexisting.com", cert->server_identifier()); 507 EXPECT_EQ("a", cert->private_key()); 508 } 509 510 TEST(DefaultServerBoundCertStoreTest, TestAsyncInitializeFrom) { 511 scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore); 512 persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( 513 "preexisting.com", 514 base::Time(), 515 base::Time(), 516 "a", "b")); 517 persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert( 518 "both.com", 519 base::Time(), 520 base::Time(), 521 "c", "d")); 522 523 DefaultServerBoundCertStore store(persistent_store.get()); 524 ServerBoundCertStore::ServerBoundCertList source_certs; 525 source_certs.push_back(ServerBoundCertStore::ServerBoundCert( 526 "both.com", 527 base::Time(), 528 base::Time(), 529 // Key differs from above to test that existing entries are overwritten. 530 "e", "f")); 531 source_certs.push_back(ServerBoundCertStore::ServerBoundCert( 532 "copied.com", 533 base::Time(), 534 base::Time(), 535 "g", "h")); 536 store.InitializeFrom(source_certs); 537 EXPECT_EQ(0, store.GetCertCount()); 538 // Wait for load & queued tasks. 539 base::MessageLoop::current()->RunUntilIdle(); 540 EXPECT_EQ(3, store.GetCertCount()); 541 542 ServerBoundCertStore::ServerBoundCertList certs; 543 store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &certs)); 544 ASSERT_EQ(3u, certs.size()); 545 546 ServerBoundCertStore::ServerBoundCertList::iterator cert = certs.begin(); 547 EXPECT_EQ("both.com", cert->server_identifier()); 548 EXPECT_EQ("e", cert->private_key()); 549 550 ++cert; 551 EXPECT_EQ("copied.com", cert->server_identifier()); 552 EXPECT_EQ("g", cert->private_key()); 553 554 ++cert; 555 EXPECT_EQ("preexisting.com", cert->server_identifier()); 556 EXPECT_EQ("a", cert->private_key()); 557 } 558 559 } // namespace net 560