Home | History | Annotate | Download | only in ssl
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/ssl/default_channel_id_store.h"
      6 
      7 #include <map>
      8 #include <string>
      9 #include <vector>
     10 
     11 #include "base/bind.h"
     12 #include "base/compiler_specific.h"
     13 #include "base/logging.h"
     14 #include "base/memory/scoped_ptr.h"
     15 #include "base/message_loop/message_loop.h"
     16 #include "net/base/net_errors.h"
     17 #include "testing/gtest/include/gtest/gtest.h"
     18 
     19 namespace net {
     20 
     21 namespace {
     22 
     23 void CallCounter(int* counter) {
     24   (*counter)++;
     25 }
     26 
     27 void GetChannelIDCallbackNotCalled(int err,
     28                                    const std::string& server_identifier,
     29                                    base::Time expiration_time,
     30                                    const std::string& private_key_result,
     31                                    const std::string& cert_result) {
     32   ADD_FAILURE() << "Unexpected callback execution.";
     33 }
     34 
     35 class AsyncGetChannelIDHelper {
     36  public:
     37   AsyncGetChannelIDHelper() : called_(false) {}
     38 
     39   void Callback(int err,
     40                 const std::string& server_identifier,
     41                 base::Time expiration_time,
     42                 const std::string& private_key_result,
     43                 const std::string& cert_result) {
     44     err_ = err;
     45     server_identifier_ = server_identifier;
     46     expiration_time_ = expiration_time;
     47     private_key_ = private_key_result;
     48     cert_ = cert_result;
     49     called_ = true;
     50   }
     51 
     52   int err_;
     53   std::string server_identifier_;
     54   base::Time expiration_time_;
     55   std::string private_key_;
     56   std::string cert_;
     57   bool called_;
     58 };
     59 
     60 void GetAllCallback(
     61     ChannelIDStore::ChannelIDList* dest,
     62     const ChannelIDStore::ChannelIDList& result) {
     63   *dest = result;
     64 }
     65 
     66 class MockPersistentStore
     67     : public DefaultChannelIDStore::PersistentStore {
     68  public:
     69   MockPersistentStore();
     70 
     71   // DefaultChannelIDStore::PersistentStore implementation.
     72   virtual void Load(const LoadedCallback& loaded_callback) OVERRIDE;
     73   virtual void AddChannelID(
     74       const DefaultChannelIDStore::ChannelID& channel_id) OVERRIDE;
     75   virtual void DeleteChannelID(
     76       const DefaultChannelIDStore::ChannelID& channel_id) OVERRIDE;
     77   virtual void SetForceKeepSessionState() OVERRIDE;
     78 
     79  protected:
     80   virtual ~MockPersistentStore();
     81 
     82  private:
     83   typedef std::map<std::string, DefaultChannelIDStore::ChannelID>
     84       ChannelIDMap;
     85 
     86   ChannelIDMap channel_ids_;
     87 };
     88 
     89 MockPersistentStore::MockPersistentStore() {}
     90 
     91 void MockPersistentStore::Load(const LoadedCallback& loaded_callback) {
     92   scoped_ptr<ScopedVector<DefaultChannelIDStore::ChannelID> >
     93       channel_ids(new ScopedVector<DefaultChannelIDStore::ChannelID>());
     94   ChannelIDMap::iterator it;
     95 
     96   for (it = channel_ids_.begin(); it != channel_ids_.end(); ++it) {
     97     channel_ids->push_back(
     98         new DefaultChannelIDStore::ChannelID(it->second));
     99   }
    100 
    101   base::MessageLoop::current()->PostTask(
    102       FROM_HERE, base::Bind(loaded_callback, base::Passed(&channel_ids)));
    103 }
    104 
    105 void MockPersistentStore::AddChannelID(
    106     const DefaultChannelIDStore::ChannelID& channel_id) {
    107   channel_ids_[channel_id.server_identifier()] = channel_id;
    108 }
    109 
    110 void MockPersistentStore::DeleteChannelID(
    111     const DefaultChannelIDStore::ChannelID& channel_id) {
    112   channel_ids_.erase(channel_id.server_identifier());
    113 }
    114 
    115 void MockPersistentStore::SetForceKeepSessionState() {}
    116 
    117 MockPersistentStore::~MockPersistentStore() {}
    118 
    119 }  // namespace
    120 
    121 TEST(DefaultChannelIDStoreTest, TestLoading) {
    122   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    123 
    124   persistent_store->AddChannelID(
    125       DefaultChannelIDStore::ChannelID(
    126           "google.com",
    127           base::Time(),
    128           base::Time(),
    129           "a", "b"));
    130   persistent_store->AddChannelID(
    131       DefaultChannelIDStore::ChannelID(
    132           "verisign.com",
    133           base::Time(),
    134           base::Time(),
    135           "c", "d"));
    136 
    137   // Make sure channel_ids load properly.
    138   DefaultChannelIDStore store(persistent_store.get());
    139   // Load has not occurred yet.
    140   EXPECT_EQ(0, store.GetChannelIDCount());
    141   store.SetChannelID(
    142       "verisign.com",
    143       base::Time(),
    144       base::Time(),
    145       "e", "f");
    146   // Wait for load & queued set task.
    147   base::MessageLoop::current()->RunUntilIdle();
    148   EXPECT_EQ(2, store.GetChannelIDCount());
    149   store.SetChannelID(
    150       "twitter.com",
    151       base::Time(),
    152       base::Time(),
    153       "g", "h");
    154   // Set should be synchronous now that load is done.
    155   EXPECT_EQ(3, store.GetChannelIDCount());
    156 }
    157 
    158 //TODO(mattm): add more tests of without a persistent store?
    159 TEST(DefaultChannelIDStoreTest, TestSettingAndGetting) {
    160   // No persistent store, all calls will be synchronous.
    161   DefaultChannelIDStore store(NULL);
    162   base::Time expiration_time;
    163   std::string private_key, cert;
    164   EXPECT_EQ(0, store.GetChannelIDCount());
    165   EXPECT_EQ(ERR_FILE_NOT_FOUND,
    166             store.GetChannelID("verisign.com",
    167                                &expiration_time,
    168                                &private_key,
    169                                &cert,
    170                                base::Bind(&GetChannelIDCallbackNotCalled)));
    171   EXPECT_TRUE(private_key.empty());
    172   EXPECT_TRUE(cert.empty());
    173   store.SetChannelID(
    174       "verisign.com",
    175       base::Time::FromInternalValue(123),
    176       base::Time::FromInternalValue(456),
    177       "i", "j");
    178   EXPECT_EQ(OK,
    179             store.GetChannelID("verisign.com",
    180                                &expiration_time,
    181                                &private_key,
    182                                &cert,
    183                                base::Bind(&GetChannelIDCallbackNotCalled)));
    184   EXPECT_EQ(456, expiration_time.ToInternalValue());
    185   EXPECT_EQ("i", private_key);
    186   EXPECT_EQ("j", cert);
    187 }
    188 
    189 TEST(DefaultChannelIDStoreTest, TestDuplicateChannelIds) {
    190   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    191   DefaultChannelIDStore store(persistent_store.get());
    192 
    193   base::Time expiration_time;
    194   std::string private_key, cert;
    195   EXPECT_EQ(0, store.GetChannelIDCount());
    196   store.SetChannelID(
    197       "verisign.com",
    198       base::Time::FromInternalValue(123),
    199       base::Time::FromInternalValue(1234),
    200       "a", "b");
    201   store.SetChannelID(
    202       "verisign.com",
    203       base::Time::FromInternalValue(456),
    204       base::Time::FromInternalValue(4567),
    205       "c", "d");
    206 
    207   // Wait for load & queued set tasks.
    208   base::MessageLoop::current()->RunUntilIdle();
    209   EXPECT_EQ(1, store.GetChannelIDCount());
    210   EXPECT_EQ(OK,
    211             store.GetChannelID("verisign.com",
    212                                &expiration_time,
    213                                &private_key,
    214                                &cert,
    215                                base::Bind(&GetChannelIDCallbackNotCalled)));
    216   EXPECT_EQ(4567, expiration_time.ToInternalValue());
    217   EXPECT_EQ("c", private_key);
    218   EXPECT_EQ("d", cert);
    219 }
    220 
    221 TEST(DefaultChannelIDStoreTest, TestAsyncGet) {
    222   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    223   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
    224       "verisign.com",
    225       base::Time::FromInternalValue(123),
    226       base::Time::FromInternalValue(1234),
    227       "a", "b"));
    228 
    229   DefaultChannelIDStore store(persistent_store.get());
    230   AsyncGetChannelIDHelper helper;
    231   base::Time expiration_time;
    232   std::string private_key;
    233   std::string cert = "not set";
    234   EXPECT_EQ(0, store.GetChannelIDCount());
    235   EXPECT_EQ(ERR_IO_PENDING,
    236             store.GetChannelID("verisign.com",
    237                                &expiration_time,
    238                                &private_key,
    239                                &cert,
    240                                base::Bind(&AsyncGetChannelIDHelper::Callback,
    241                                           base::Unretained(&helper))));
    242 
    243   // Wait for load & queued get tasks.
    244   base::MessageLoop::current()->RunUntilIdle();
    245   EXPECT_EQ(1, store.GetChannelIDCount());
    246   EXPECT_EQ("not set", cert);
    247   EXPECT_TRUE(helper.called_);
    248   EXPECT_EQ(OK, helper.err_);
    249   EXPECT_EQ("verisign.com", helper.server_identifier_);
    250   EXPECT_EQ(1234, helper.expiration_time_.ToInternalValue());
    251   EXPECT_EQ("a", helper.private_key_);
    252   EXPECT_EQ("b", helper.cert_);
    253 }
    254 
    255 TEST(DefaultChannelIDStoreTest, TestDeleteAll) {
    256   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    257   DefaultChannelIDStore store(persistent_store.get());
    258 
    259   store.SetChannelID(
    260       "verisign.com",
    261       base::Time(),
    262       base::Time(),
    263       "a", "b");
    264   store.SetChannelID(
    265       "google.com",
    266       base::Time(),
    267       base::Time(),
    268       "c", "d");
    269   store.SetChannelID(
    270       "harvard.com",
    271       base::Time(),
    272       base::Time(),
    273       "e", "f");
    274   // Wait for load & queued set tasks.
    275   base::MessageLoop::current()->RunUntilIdle();
    276 
    277   EXPECT_EQ(3, store.GetChannelIDCount());
    278   int delete_finished = 0;
    279   store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
    280   ASSERT_EQ(1, delete_finished);
    281   EXPECT_EQ(0, store.GetChannelIDCount());
    282 }
    283 
    284 TEST(DefaultChannelIDStoreTest, TestAsyncGetAndDeleteAll) {
    285   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    286   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
    287       "verisign.com",
    288       base::Time(),
    289       base::Time(),
    290       "a", "b"));
    291   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
    292       "google.com",
    293       base::Time(),
    294       base::Time(),
    295       "c", "d"));
    296 
    297   ChannelIDStore::ChannelIDList pre_channel_ids;
    298   ChannelIDStore::ChannelIDList post_channel_ids;
    299   int delete_finished = 0;
    300   DefaultChannelIDStore store(persistent_store.get());
    301 
    302   store.GetAllChannelIDs(base::Bind(GetAllCallback, &pre_channel_ids));
    303   store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
    304   store.GetAllChannelIDs(base::Bind(GetAllCallback, &post_channel_ids));
    305   // Tasks have not run yet.
    306   EXPECT_EQ(0u, pre_channel_ids.size());
    307   // Wait for load & queued tasks.
    308   base::MessageLoop::current()->RunUntilIdle();
    309   EXPECT_EQ(0, store.GetChannelIDCount());
    310   EXPECT_EQ(2u, pre_channel_ids.size());
    311   EXPECT_EQ(0u, post_channel_ids.size());
    312 }
    313 
    314 TEST(DefaultChannelIDStoreTest, TestDelete) {
    315   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    316   DefaultChannelIDStore store(persistent_store.get());
    317 
    318   base::Time expiration_time;
    319   std::string private_key, cert;
    320   EXPECT_EQ(0, store.GetChannelIDCount());
    321   store.SetChannelID(
    322       "verisign.com",
    323       base::Time(),
    324       base::Time(),
    325       "a", "b");
    326   // Wait for load & queued set task.
    327   base::MessageLoop::current()->RunUntilIdle();
    328 
    329   store.SetChannelID(
    330       "google.com",
    331       base::Time(),
    332       base::Time(),
    333       "c", "d");
    334 
    335   EXPECT_EQ(2, store.GetChannelIDCount());
    336   int delete_finished = 0;
    337   store.DeleteChannelID("verisign.com",
    338                               base::Bind(&CallCounter, &delete_finished));
    339   ASSERT_EQ(1, delete_finished);
    340   EXPECT_EQ(1, store.GetChannelIDCount());
    341   EXPECT_EQ(ERR_FILE_NOT_FOUND,
    342             store.GetChannelID("verisign.com",
    343                                &expiration_time,
    344                                &private_key,
    345                                &cert,
    346                                base::Bind(&GetChannelIDCallbackNotCalled)));
    347   EXPECT_EQ(OK,
    348             store.GetChannelID("google.com",
    349                                &expiration_time,
    350                                &private_key,
    351                                &cert,
    352                                base::Bind(&GetChannelIDCallbackNotCalled)));
    353   int delete2_finished = 0;
    354   store.DeleteChannelID("google.com",
    355                         base::Bind(&CallCounter, &delete2_finished));
    356   ASSERT_EQ(1, delete2_finished);
    357   EXPECT_EQ(0, store.GetChannelIDCount());
    358   EXPECT_EQ(ERR_FILE_NOT_FOUND,
    359             store.GetChannelID("google.com",
    360                                &expiration_time,
    361                                &private_key,
    362                                &cert,
    363                                base::Bind(&GetChannelIDCallbackNotCalled)));
    364 }
    365 
    366 TEST(DefaultChannelIDStoreTest, TestAsyncDelete) {
    367   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    368   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
    369       "a.com",
    370       base::Time::FromInternalValue(1),
    371       base::Time::FromInternalValue(2),
    372       "a", "b"));
    373   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
    374       "b.com",
    375       base::Time::FromInternalValue(3),
    376       base::Time::FromInternalValue(4),
    377       "c", "d"));
    378   DefaultChannelIDStore store(persistent_store.get());
    379   int delete_finished = 0;
    380   store.DeleteChannelID("a.com",
    381                         base::Bind(&CallCounter, &delete_finished));
    382 
    383   AsyncGetChannelIDHelper a_helper;
    384   AsyncGetChannelIDHelper b_helper;
    385   base::Time expiration_time;
    386   std::string private_key;
    387   std::string cert = "not set";
    388   EXPECT_EQ(0, store.GetChannelIDCount());
    389   EXPECT_EQ(ERR_IO_PENDING,
    390       store.GetChannelID(
    391           "a.com", &expiration_time, &private_key, &cert,
    392           base::Bind(&AsyncGetChannelIDHelper::Callback,
    393                      base::Unretained(&a_helper))));
    394   EXPECT_EQ(ERR_IO_PENDING,
    395       store.GetChannelID(
    396           "b.com", &expiration_time, &private_key, &cert,
    397           base::Bind(&AsyncGetChannelIDHelper::Callback,
    398                      base::Unretained(&b_helper))));
    399 
    400   EXPECT_EQ(0, delete_finished);
    401   EXPECT_FALSE(a_helper.called_);
    402   EXPECT_FALSE(b_helper.called_);
    403   // Wait for load & queued tasks.
    404   base::MessageLoop::current()->RunUntilIdle();
    405   EXPECT_EQ(1, delete_finished);
    406   EXPECT_EQ(1, store.GetChannelIDCount());
    407   EXPECT_EQ("not set", cert);
    408   EXPECT_TRUE(a_helper.called_);
    409   EXPECT_EQ(ERR_FILE_NOT_FOUND, a_helper.err_);
    410   EXPECT_EQ("a.com", a_helper.server_identifier_);
    411   EXPECT_EQ(0, a_helper.expiration_time_.ToInternalValue());
    412   EXPECT_EQ("", a_helper.private_key_);
    413   EXPECT_EQ("", a_helper.cert_);
    414   EXPECT_TRUE(b_helper.called_);
    415   EXPECT_EQ(OK, b_helper.err_);
    416   EXPECT_EQ("b.com", b_helper.server_identifier_);
    417   EXPECT_EQ(4, b_helper.expiration_time_.ToInternalValue());
    418   EXPECT_EQ("c", b_helper.private_key_);
    419   EXPECT_EQ("d", b_helper.cert_);
    420 }
    421 
    422 TEST(DefaultChannelIDStoreTest, TestGetAll) {
    423   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    424   DefaultChannelIDStore store(persistent_store.get());
    425 
    426   EXPECT_EQ(0, store.GetChannelIDCount());
    427   store.SetChannelID(
    428       "verisign.com",
    429       base::Time(),
    430       base::Time(),
    431       "a", "b");
    432   store.SetChannelID(
    433       "google.com",
    434       base::Time(),
    435       base::Time(),
    436       "c", "d");
    437   store.SetChannelID(
    438       "harvard.com",
    439       base::Time(),
    440       base::Time(),
    441       "e", "f");
    442   store.SetChannelID(
    443       "mit.com",
    444       base::Time(),
    445       base::Time(),
    446       "g", "h");
    447   // Wait for load & queued set tasks.
    448   base::MessageLoop::current()->RunUntilIdle();
    449 
    450   EXPECT_EQ(4, store.GetChannelIDCount());
    451   ChannelIDStore::ChannelIDList channel_ids;
    452   store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
    453   EXPECT_EQ(4u, channel_ids.size());
    454 }
    455 
    456 TEST(DefaultChannelIDStoreTest, TestInitializeFrom) {
    457   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    458   DefaultChannelIDStore store(persistent_store.get());
    459 
    460   store.SetChannelID(
    461       "preexisting.com",
    462       base::Time(),
    463       base::Time(),
    464       "a", "b");
    465   store.SetChannelID(
    466       "both.com",
    467       base::Time(),
    468       base::Time(),
    469       "c", "d");
    470   // Wait for load & queued set tasks.
    471   base::MessageLoop::current()->RunUntilIdle();
    472   EXPECT_EQ(2, store.GetChannelIDCount());
    473 
    474   ChannelIDStore::ChannelIDList source_channel_ids;
    475   source_channel_ids.push_back(ChannelIDStore::ChannelID(
    476       "both.com",
    477       base::Time(),
    478       base::Time(),
    479       // Key differs from above to test that existing entries are overwritten.
    480       "e", "f"));
    481   source_channel_ids.push_back(ChannelIDStore::ChannelID(
    482       "copied.com",
    483       base::Time(),
    484       base::Time(),
    485       "g", "h"));
    486   store.InitializeFrom(source_channel_ids);
    487   EXPECT_EQ(3, store.GetChannelIDCount());
    488 
    489   ChannelIDStore::ChannelIDList channel_ids;
    490   store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
    491   ASSERT_EQ(3u, channel_ids.size());
    492 
    493   ChannelIDStore::ChannelIDList::iterator channel_id = channel_ids.begin();
    494   EXPECT_EQ("both.com", channel_id->server_identifier());
    495   EXPECT_EQ("e", channel_id->private_key());
    496 
    497   ++channel_id;
    498   EXPECT_EQ("copied.com", channel_id->server_identifier());
    499   EXPECT_EQ("g", channel_id->private_key());
    500 
    501   ++channel_id;
    502   EXPECT_EQ("preexisting.com", channel_id->server_identifier());
    503   EXPECT_EQ("a", channel_id->private_key());
    504 }
    505 
    506 TEST(DefaultChannelIDStoreTest, TestAsyncInitializeFrom) {
    507   scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
    508   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
    509       "preexisting.com",
    510       base::Time(),
    511       base::Time(),
    512       "a", "b"));
    513   persistent_store->AddChannelID(ChannelIDStore::ChannelID(
    514       "both.com",
    515       base::Time(),
    516       base::Time(),
    517       "c", "d"));
    518 
    519   DefaultChannelIDStore store(persistent_store.get());
    520   ChannelIDStore::ChannelIDList source_channel_ids;
    521   source_channel_ids.push_back(ChannelIDStore::ChannelID(
    522       "both.com",
    523       base::Time(),
    524       base::Time(),
    525       // Key differs from above to test that existing entries are overwritten.
    526       "e", "f"));
    527   source_channel_ids.push_back(ChannelIDStore::ChannelID(
    528       "copied.com",
    529       base::Time(),
    530       base::Time(),
    531       "g", "h"));
    532   store.InitializeFrom(source_channel_ids);
    533   EXPECT_EQ(0, store.GetChannelIDCount());
    534   // Wait for load & queued tasks.
    535   base::MessageLoop::current()->RunUntilIdle();
    536   EXPECT_EQ(3, store.GetChannelIDCount());
    537 
    538   ChannelIDStore::ChannelIDList channel_ids;
    539   store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
    540   ASSERT_EQ(3u, channel_ids.size());
    541 
    542   ChannelIDStore::ChannelIDList::iterator channel_id = channel_ids.begin();
    543   EXPECT_EQ("both.com", channel_id->server_identifier());
    544   EXPECT_EQ("e", channel_id->private_key());
    545 
    546   ++channel_id;
    547   EXPECT_EQ("copied.com", channel_id->server_identifier());
    548   EXPECT_EQ("g", channel_id->private_key());
    549 
    550   ++channel_id;
    551   EXPECT_EQ("preexisting.com", channel_id->server_identifier());
    552   EXPECT_EQ("a", channel_id->private_key());
    553 }
    554 
    555 }  // namespace net
    556