Home | History | Annotate | Download | only in ssl
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/ssl/default_server_bound_cert_store.h"
      6 
      7 #include <map>
      8 #include <string>
      9 #include <vector>
     10 
     11 #include "base/bind.h"
     12 #include "base/compiler_specific.h"
     13 #include "base/logging.h"
     14 #include "base/memory/scoped_ptr.h"
     15 #include "base/message_loop/message_loop.h"
     16 #include "net/base/net_errors.h"
     17 #include "testing/gtest/include/gtest/gtest.h"
     18 
     19 namespace net {
     20 
     21 namespace {
     22 
     23 void CallCounter(int* counter) {
     24   (*counter)++;
     25 }
     26 
     27 void NotCalled() {
     28   ADD_FAILURE() << "Unexpected callback execution.";
     29 }
     30 
     31 void GetCertCallbackNotCalled(int err,
     32                               const std::string& server_identifier,
     33                               base::Time expiration_time,
     34                               const std::string& private_key_result,
     35                               const std::string& cert_result) {
     36   ADD_FAILURE() << "Unexpected callback execution.";
     37 }
     38 
     39 class AsyncGetCertHelper {
     40  public:
     41   AsyncGetCertHelper() : called_(false) {}
     42 
     43   void Callback(int err,
     44                 const std::string& server_identifier,
     45                 base::Time expiration_time,
     46                 const std::string& private_key_result,
     47                 const std::string& cert_result) {
     48     err_ = err;
     49     server_identifier_ = server_identifier;
     50     expiration_time_ = expiration_time;
     51     private_key_ = private_key_result;
     52     cert_ = cert_result;
     53     called_ = true;
     54   }
     55 
     56   int err_;
     57   std::string server_identifier_;
     58   base::Time expiration_time_;
     59   std::string private_key_;
     60   std::string cert_;
     61   bool called_;
     62 };
     63 
     64 void GetAllCallback(
     65     ServerBoundCertStore::ServerBoundCertList* dest,
     66     const ServerBoundCertStore::ServerBoundCertList& result) {
     67   *dest = result;
     68 }
     69 
     70 class MockPersistentStore
     71     : public DefaultServerBoundCertStore::PersistentStore {
     72  public:
     73   MockPersistentStore();
     74 
     75   // DefaultServerBoundCertStore::PersistentStore implementation.
     76   virtual void Load(const LoadedCallback& loaded_callback) OVERRIDE;
     77   virtual void AddServerBoundCert(
     78       const DefaultServerBoundCertStore::ServerBoundCert& cert) OVERRIDE;
     79   virtual void DeleteServerBoundCert(
     80       const DefaultServerBoundCertStore::ServerBoundCert& cert) OVERRIDE;
     81   virtual void SetForceKeepSessionState() OVERRIDE;
     82 
     83  protected:
     84   virtual ~MockPersistentStore();
     85 
     86  private:
     87   typedef std::map<std::string, DefaultServerBoundCertStore::ServerBoundCert>
     88       ServerBoundCertMap;
     89 
     90   ServerBoundCertMap origin_certs_;
     91 };
     92 
     93 MockPersistentStore::MockPersistentStore() {}
     94 
     95 void MockPersistentStore::Load(const LoadedCallback& loaded_callback) {
     96   scoped_ptr<ScopedVector<DefaultServerBoundCertStore::ServerBoundCert> >
     97       certs(new ScopedVector<DefaultServerBoundCertStore::ServerBoundCert>());
     98   ServerBoundCertMap::iterator it;
     99 
    100   for (it = origin_certs_.begin(); it != origin_certs_.end(); ++it) {
    101     certs->push_back(
    102         new DefaultServerBoundCertStore::ServerBoundCert(it->second));
    103   }
    104 
    105   base::MessageLoop::current()->PostTask(
    106       FROM_HERE, base::Bind(loaded_callback, base::Passed(&certs)));
    107 }
    108 
    109 void MockPersistentStore::AddServerBoundCert(
    110     const DefaultServerBoundCertStore::ServerBoundCert& cert) {
    111   origin_certs_[cert.server_identifier()] = cert;
    112 }
    113 
    114 void MockPersistentStore::DeleteServerBoundCert(
    115     const DefaultServerBoundCertStore::ServerBoundCert& cert) {
    116   origin_certs_.erase(cert.server_identifier());
    117 }
    118 
    119 void MockPersistentStore::SetForceKeepSessionState() {}
    120 
    121 MockPersistentStore::~MockPersistentStore() {}
    122 
    123 }  // namespace
    124 
    125 TEST(DefaultServerBoundCertStoreTest, TestLoading) {
    126   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    127 
    128   persistent_store->AddServerBoundCert(
    129       DefaultServerBoundCertStore::ServerBoundCert(
    130           "google.com",
    131           base::Time(),
    132           base::Time(),
    133           "a", "b"));
    134   persistent_store->AddServerBoundCert(
    135       DefaultServerBoundCertStore::ServerBoundCert(
    136           "verisign.com",
    137           base::Time(),
    138           base::Time(),
    139           "c", "d"));
    140 
    141   // Make sure certs load properly.
    142   DefaultServerBoundCertStore store(persistent_store.get());
    143   // Load has not occurred yet.
    144   EXPECT_EQ(0, store.GetCertCount());
    145   store.SetServerBoundCert(
    146       "verisign.com",
    147       base::Time(),
    148       base::Time(),
    149       "e", "f");
    150   // Wait for load & queued set task.
    151   base::MessageLoop::current()->RunUntilIdle();
    152   EXPECT_EQ(2, store.GetCertCount());
    153   store.SetServerBoundCert(
    154       "twitter.com",
    155       base::Time(),
    156       base::Time(),
    157       "g", "h");
    158   // Set should be synchronous now that load is done.
    159   EXPECT_EQ(3, store.GetCertCount());
    160 }
    161 
    162 //TODO(mattm): add more tests of without a persistent store?
    163 TEST(DefaultServerBoundCertStoreTest, TestSettingAndGetting) {
    164   // No persistent store, all calls will be synchronous.
    165   DefaultServerBoundCertStore store(NULL);
    166   base::Time expiration_time;
    167   std::string private_key, cert;
    168   EXPECT_EQ(0, store.GetCertCount());
    169   EXPECT_EQ(ERR_FILE_NOT_FOUND,
    170             store.GetServerBoundCert("verisign.com",
    171                                      &expiration_time,
    172                                      &private_key,
    173                                      &cert,
    174                                      base::Bind(&GetCertCallbackNotCalled)));
    175   EXPECT_TRUE(private_key.empty());
    176   EXPECT_TRUE(cert.empty());
    177   store.SetServerBoundCert(
    178       "verisign.com",
    179       base::Time::FromInternalValue(123),
    180       base::Time::FromInternalValue(456),
    181       "i", "j");
    182   EXPECT_EQ(OK,
    183             store.GetServerBoundCert("verisign.com",
    184                                      &expiration_time,
    185                                      &private_key,
    186                                      &cert,
    187                                      base::Bind(&GetCertCallbackNotCalled)));
    188   EXPECT_EQ(456, expiration_time.ToInternalValue());
    189   EXPECT_EQ("i", private_key);
    190   EXPECT_EQ("j", cert);
    191 }
    192 
    193 TEST(DefaultServerBoundCertStoreTest, TestDuplicateCerts) {
    194   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    195   DefaultServerBoundCertStore store(persistent_store.get());
    196 
    197   base::Time expiration_time;
    198   std::string private_key, cert;
    199   EXPECT_EQ(0, store.GetCertCount());
    200   store.SetServerBoundCert(
    201       "verisign.com",
    202       base::Time::FromInternalValue(123),
    203       base::Time::FromInternalValue(1234),
    204       "a", "b");
    205   store.SetServerBoundCert(
    206       "verisign.com",
    207       base::Time::FromInternalValue(456),
    208       base::Time::FromInternalValue(4567),
    209       "c", "d");
    210 
    211   // Wait for load & queued set tasks.
    212   base::MessageLoop::current()->RunUntilIdle();
    213   EXPECT_EQ(1, store.GetCertCount());
    214   EXPECT_EQ(OK,
    215             store.GetServerBoundCert("verisign.com",
    216                                      &expiration_time,
    217                                      &private_key,
    218                                      &cert,
    219                                      base::Bind(&GetCertCallbackNotCalled)));
    220   EXPECT_EQ(4567, expiration_time.ToInternalValue());
    221   EXPECT_EQ("c", private_key);
    222   EXPECT_EQ("d", cert);
    223 }
    224 
    225 TEST(DefaultServerBoundCertStoreTest, TestAsyncGet) {
    226   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    227   persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert(
    228       "verisign.com",
    229       base::Time::FromInternalValue(123),
    230       base::Time::FromInternalValue(1234),
    231       "a", "b"));
    232 
    233   DefaultServerBoundCertStore store(persistent_store.get());
    234   AsyncGetCertHelper helper;
    235   base::Time expiration_time;
    236   std::string private_key;
    237   std::string cert = "not set";
    238   EXPECT_EQ(0, store.GetCertCount());
    239   EXPECT_EQ(ERR_IO_PENDING,
    240             store.GetServerBoundCert("verisign.com",
    241                                      &expiration_time,
    242                                      &private_key,
    243                                      &cert,
    244                                      base::Bind(&AsyncGetCertHelper::Callback,
    245                                                 base::Unretained(&helper))));
    246 
    247   // Wait for load & queued get tasks.
    248   base::MessageLoop::current()->RunUntilIdle();
    249   EXPECT_EQ(1, store.GetCertCount());
    250   EXPECT_EQ("not set", cert);
    251   EXPECT_TRUE(helper.called_);
    252   EXPECT_EQ(OK, helper.err_);
    253   EXPECT_EQ("verisign.com", helper.server_identifier_);
    254   EXPECT_EQ(1234, helper.expiration_time_.ToInternalValue());
    255   EXPECT_EQ("a", helper.private_key_);
    256   EXPECT_EQ("b", helper.cert_);
    257 }
    258 
    259 TEST(DefaultServerBoundCertStoreTest, TestDeleteAll) {
    260   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    261   DefaultServerBoundCertStore store(persistent_store.get());
    262 
    263   store.SetServerBoundCert(
    264       "verisign.com",
    265       base::Time(),
    266       base::Time(),
    267       "a", "b");
    268   store.SetServerBoundCert(
    269       "google.com",
    270       base::Time(),
    271       base::Time(),
    272       "c", "d");
    273   store.SetServerBoundCert(
    274       "harvard.com",
    275       base::Time(),
    276       base::Time(),
    277       "e", "f");
    278   // Wait for load & queued set tasks.
    279   base::MessageLoop::current()->RunUntilIdle();
    280 
    281   EXPECT_EQ(3, store.GetCertCount());
    282   int delete_finished = 0;
    283   store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
    284   ASSERT_EQ(1, delete_finished);
    285   EXPECT_EQ(0, store.GetCertCount());
    286 }
    287 
    288 TEST(DefaultServerBoundCertStoreTest, TestAsyncGetAndDeleteAll) {
    289   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    290   persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert(
    291       "verisign.com",
    292       base::Time(),
    293       base::Time(),
    294       "a", "b"));
    295   persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert(
    296       "google.com",
    297       base::Time(),
    298       base::Time(),
    299       "c", "d"));
    300 
    301   ServerBoundCertStore::ServerBoundCertList pre_certs;
    302   ServerBoundCertStore::ServerBoundCertList post_certs;
    303   int delete_finished = 0;
    304   DefaultServerBoundCertStore store(persistent_store.get());
    305 
    306   store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &pre_certs));
    307   store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
    308   store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &post_certs));
    309   // Tasks have not run yet.
    310   EXPECT_EQ(0u, pre_certs.size());
    311   // Wait for load & queued tasks.
    312   base::MessageLoop::current()->RunUntilIdle();
    313   EXPECT_EQ(0, store.GetCertCount());
    314   EXPECT_EQ(2u, pre_certs.size());
    315   EXPECT_EQ(0u, post_certs.size());
    316 }
    317 
    318 TEST(DefaultServerBoundCertStoreTest, TestDelete) {
    319   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    320   DefaultServerBoundCertStore store(persistent_store.get());
    321 
    322   base::Time expiration_time;
    323   std::string private_key, cert;
    324   EXPECT_EQ(0, store.GetCertCount());
    325   store.SetServerBoundCert(
    326       "verisign.com",
    327       base::Time(),
    328       base::Time(),
    329       "a", "b");
    330   // Wait for load & queued set task.
    331   base::MessageLoop::current()->RunUntilIdle();
    332 
    333   store.SetServerBoundCert(
    334       "google.com",
    335       base::Time(),
    336       base::Time(),
    337       "c", "d");
    338 
    339   EXPECT_EQ(2, store.GetCertCount());
    340   int delete_finished = 0;
    341   store.DeleteServerBoundCert("verisign.com",
    342                               base::Bind(&CallCounter, &delete_finished));
    343   ASSERT_EQ(1, delete_finished);
    344   EXPECT_EQ(1, store.GetCertCount());
    345   EXPECT_EQ(ERR_FILE_NOT_FOUND,
    346             store.GetServerBoundCert("verisign.com",
    347                                      &expiration_time,
    348                                      &private_key,
    349                                      &cert,
    350                                      base::Bind(&GetCertCallbackNotCalled)));
    351   EXPECT_EQ(OK,
    352             store.GetServerBoundCert("google.com",
    353                                      &expiration_time,
    354                                      &private_key,
    355                                      &cert,
    356                                      base::Bind(&GetCertCallbackNotCalled)));
    357   int delete2_finished = 0;
    358   store.DeleteServerBoundCert("google.com",
    359                               base::Bind(&CallCounter, &delete2_finished));
    360   ASSERT_EQ(1, delete2_finished);
    361   EXPECT_EQ(0, store.GetCertCount());
    362   EXPECT_EQ(ERR_FILE_NOT_FOUND,
    363             store.GetServerBoundCert("google.com",
    364                                      &expiration_time,
    365                                      &private_key,
    366                                      &cert,
    367                                      base::Bind(&GetCertCallbackNotCalled)));
    368 }
    369 
    370 TEST(DefaultServerBoundCertStoreTest, TestAsyncDelete) {
    371   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    372   persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert(
    373       "a.com",
    374       base::Time::FromInternalValue(1),
    375       base::Time::FromInternalValue(2),
    376       "a", "b"));
    377   persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert(
    378       "b.com",
    379       base::Time::FromInternalValue(3),
    380       base::Time::FromInternalValue(4),
    381       "c", "d"));
    382   DefaultServerBoundCertStore store(persistent_store.get());
    383   int delete_finished = 0;
    384   store.DeleteServerBoundCert("a.com",
    385                               base::Bind(&CallCounter, &delete_finished));
    386 
    387   AsyncGetCertHelper a_helper;
    388   AsyncGetCertHelper b_helper;
    389   base::Time expiration_time;
    390   std::string private_key;
    391   std::string cert = "not set";
    392   EXPECT_EQ(0, store.GetCertCount());
    393   EXPECT_EQ(ERR_IO_PENDING,
    394       store.GetServerBoundCert(
    395           "a.com", &expiration_time, &private_key, &cert,
    396           base::Bind(&AsyncGetCertHelper::Callback,
    397                      base::Unretained(&a_helper))));
    398   EXPECT_EQ(ERR_IO_PENDING,
    399       store.GetServerBoundCert(
    400           "b.com", &expiration_time, &private_key, &cert,
    401           base::Bind(&AsyncGetCertHelper::Callback,
    402                      base::Unretained(&b_helper))));
    403 
    404   EXPECT_EQ(0, delete_finished);
    405   EXPECT_FALSE(a_helper.called_);
    406   EXPECT_FALSE(b_helper.called_);
    407   // Wait for load & queued tasks.
    408   base::MessageLoop::current()->RunUntilIdle();
    409   EXPECT_EQ(1, delete_finished);
    410   EXPECT_EQ(1, store.GetCertCount());
    411   EXPECT_EQ("not set", cert);
    412   EXPECT_TRUE(a_helper.called_);
    413   EXPECT_EQ(ERR_FILE_NOT_FOUND, a_helper.err_);
    414   EXPECT_EQ("a.com", a_helper.server_identifier_);
    415   EXPECT_EQ(0, a_helper.expiration_time_.ToInternalValue());
    416   EXPECT_EQ("", a_helper.private_key_);
    417   EXPECT_EQ("", a_helper.cert_);
    418   EXPECT_TRUE(b_helper.called_);
    419   EXPECT_EQ(OK, b_helper.err_);
    420   EXPECT_EQ("b.com", b_helper.server_identifier_);
    421   EXPECT_EQ(4, b_helper.expiration_time_.ToInternalValue());
    422   EXPECT_EQ("c", b_helper.private_key_);
    423   EXPECT_EQ("d", b_helper.cert_);
    424 }
    425 
    426 TEST(DefaultServerBoundCertStoreTest, TestGetAll) {
    427   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    428   DefaultServerBoundCertStore store(persistent_store.get());
    429 
    430   EXPECT_EQ(0, store.GetCertCount());
    431   store.SetServerBoundCert(
    432       "verisign.com",
    433       base::Time(),
    434       base::Time(),
    435       "a", "b");
    436   store.SetServerBoundCert(
    437       "google.com",
    438       base::Time(),
    439       base::Time(),
    440       "c", "d");
    441   store.SetServerBoundCert(
    442       "harvard.com",
    443       base::Time(),
    444       base::Time(),
    445       "e", "f");
    446   store.SetServerBoundCert(
    447       "mit.com",
    448       base::Time(),
    449       base::Time(),
    450       "g", "h");
    451   // Wait for load & queued set tasks.
    452   base::MessageLoop::current()->RunUntilIdle();
    453 
    454   EXPECT_EQ(4, store.GetCertCount());
    455   ServerBoundCertStore::ServerBoundCertList certs;
    456   store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &certs));
    457   EXPECT_EQ(4u, certs.size());
    458 }
    459 
    460 TEST(DefaultServerBoundCertStoreTest, TestInitializeFrom) {
    461   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    462   DefaultServerBoundCertStore store(persistent_store.get());
    463 
    464   store.SetServerBoundCert(
    465       "preexisting.com",
    466       base::Time(),
    467       base::Time(),
    468       "a", "b");
    469   store.SetServerBoundCert(
    470       "both.com",
    471       base::Time(),
    472       base::Time(),
    473       "c", "d");
    474   // Wait for load & queued set tasks.
    475   base::MessageLoop::current()->RunUntilIdle();
    476   EXPECT_EQ(2, store.GetCertCount());
    477 
    478   ServerBoundCertStore::ServerBoundCertList source_certs;
    479   source_certs.push_back(ServerBoundCertStore::ServerBoundCert(
    480       "both.com",
    481       base::Time(),
    482       base::Time(),
    483       // Key differs from above to test that existing entries are overwritten.
    484       "e", "f"));
    485   source_certs.push_back(ServerBoundCertStore::ServerBoundCert(
    486       "copied.com",
    487       base::Time(),
    488       base::Time(),
    489       "g", "h"));
    490   store.InitializeFrom(source_certs);
    491   EXPECT_EQ(3, store.GetCertCount());
    492 
    493   ServerBoundCertStore::ServerBoundCertList certs;
    494   store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &certs));
    495   ASSERT_EQ(3u, certs.size());
    496 
    497   ServerBoundCertStore::ServerBoundCertList::iterator cert = certs.begin();
    498   EXPECT_EQ("both.com", cert->server_identifier());
    499   EXPECT_EQ("e", cert->private_key());
    500 
    501   ++cert;
    502   EXPECT_EQ("copied.com", cert->server_identifier());
    503   EXPECT_EQ("g", cert->private_key());
    504 
    505   ++cert;
    506   EXPECT_EQ("preexisting.com", cert->server_identifier());
    507   EXPECT_EQ("a", cert->private_key());
    508 }
    509 
    510 TEST(DefaultServerBoundCertStoreTest, TestAsyncInitializeFrom) {
    511   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    512   persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert(
    513       "preexisting.com",
    514       base::Time(),
    515       base::Time(),
    516       "a", "b"));
    517   persistent_store->AddServerBoundCert(ServerBoundCertStore::ServerBoundCert(
    518       "both.com",
    519       base::Time(),
    520       base::Time(),
    521       "c", "d"));
    522 
    523   DefaultServerBoundCertStore store(persistent_store.get());
    524   ServerBoundCertStore::ServerBoundCertList source_certs;
    525   source_certs.push_back(ServerBoundCertStore::ServerBoundCert(
    526       "both.com",
    527       base::Time(),
    528       base::Time(),
    529       // Key differs from above to test that existing entries are overwritten.
    530       "e", "f"));
    531   source_certs.push_back(ServerBoundCertStore::ServerBoundCert(
    532       "copied.com",
    533       base::Time(),
    534       base::Time(),
    535       "g", "h"));
    536   store.InitializeFrom(source_certs);
    537   EXPECT_EQ(0, store.GetCertCount());
    538   // Wait for load & queued tasks.
    539   base::MessageLoop::current()->RunUntilIdle();
    540   EXPECT_EQ(3, store.GetCertCount());
    541 
    542   ServerBoundCertStore::ServerBoundCertList certs;
    543   store.GetAllServerBoundCerts(base::Bind(GetAllCallback, &certs));
    544   ASSERT_EQ(3u, certs.size());
    545 
    546   ServerBoundCertStore::ServerBoundCertList::iterator cert = certs.begin();
    547   EXPECT_EQ("both.com", cert->server_identifier());
    548   EXPECT_EQ("e", cert->private_key());
    549 
    550   ++cert;
    551   EXPECT_EQ("copied.com", cert->server_identifier());
    552   EXPECT_EQ("g", cert->private_key());
    553 
    554   ++cert;
    555   EXPECT_EQ("preexisting.com", cert->server_identifier());
    556   EXPECT_EQ("a", cert->private_key());
    557 }
    558 
    559 }  // namespace net
    560