Home | History | Annotate | Download | only in net
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/browser/net/transport_security_persister.h"
      6 
      7 #include <map>
      8 #include <string>
      9 #include <vector>
     10 
     11 #include "base/file_util.h"
     12 #include "base/files/file_path.h"
     13 #include "base/files/scoped_temp_dir.h"
     14 #include "base/message_loop/message_loop.h"
     15 #include "content/public/test/test_browser_thread.h"
     16 #include "net/http/transport_security_state.h"
     17 #include "testing/gtest/include/gtest/gtest.h"
     18 
     19 using net::TransportSecurityState;
     20 
     21 class TransportSecurityPersisterTest : public testing::Test {
     22  public:
     23   TransportSecurityPersisterTest()
     24       : message_loop_(base::MessageLoop::TYPE_IO),
     25         test_file_thread_(content::BrowserThread::FILE, &message_loop_),
     26         test_io_thread_(content::BrowserThread::IO, &message_loop_) {
     27   }
     28 
     29   virtual ~TransportSecurityPersisterTest() {
     30     message_loop_.RunUntilIdle();
     31   }
     32 
     33   virtual void SetUp() OVERRIDE {
     34     ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
     35     persister_.reset(
     36         new TransportSecurityPersister(&state_, temp_dir_.path(), false));
     37   }
     38 
     39  protected:
     40   // Ordering is important here. If member variables are not destroyed in the
     41   // right order, then DCHECKs will fail all over the place.
     42   base::MessageLoop message_loop_;
     43 
     44   // Needed for ImportantFileWriter, which TransportSecurityPersister uses.
     45   content::TestBrowserThread test_file_thread_;
     46 
     47   // TransportSecurityPersister runs on the IO thread.
     48   content::TestBrowserThread test_io_thread_;
     49 
     50   base::ScopedTempDir temp_dir_;
     51   TransportSecurityState state_;
     52   scoped_ptr<TransportSecurityPersister> persister_;
     53 };
     54 
     55 TEST_F(TransportSecurityPersisterTest, SerializeData1) {
     56   std::string output;
     57   bool dirty;
     58 
     59   EXPECT_TRUE(persister_->SerializeData(&output));
     60   EXPECT_TRUE(persister_->LoadEntries(output, &dirty));
     61   EXPECT_FALSE(dirty);
     62 }
     63 
     64 TEST_F(TransportSecurityPersisterTest, SerializeData2) {
     65   TransportSecurityState::DomainState domain_state;
     66   const base::Time current_time(base::Time::Now());
     67   const base::Time expiry = current_time + base::TimeDelta::FromSeconds(1000);
     68   static const char kYahooDomain[] = "yahoo.com";
     69 
     70   EXPECT_FALSE(state_.GetDomainState(kYahooDomain, true, &domain_state));
     71 
     72   bool include_subdomains = true;
     73   state_.AddHSTS(kYahooDomain, expiry, include_subdomains);
     74 
     75   std::string output;
     76   bool dirty;
     77   EXPECT_TRUE(persister_->SerializeData(&output));
     78   EXPECT_TRUE(persister_->LoadEntries(output, &dirty));
     79 
     80   EXPECT_TRUE(state_.GetDomainState(kYahooDomain, true, &domain_state));
     81   EXPECT_EQ(domain_state.upgrade_mode,
     82             TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
     83   EXPECT_TRUE(state_.GetDomainState("foo.yahoo.com", true, &domain_state));
     84   EXPECT_EQ(domain_state.upgrade_mode,
     85             TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
     86   EXPECT_TRUE(state_.GetDomainState("foo.bar.yahoo.com", true, &domain_state));
     87   EXPECT_EQ(domain_state.upgrade_mode,
     88             TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
     89   EXPECT_TRUE(state_.GetDomainState("foo.bar.baz.yahoo.com", true,
     90                                    &domain_state));
     91   EXPECT_EQ(domain_state.upgrade_mode,
     92             TransportSecurityState::DomainState::MODE_FORCE_HTTPS);
     93   EXPECT_FALSE(state_.GetDomainState("com", true, &domain_state));
     94 }
     95 
     96 TEST_F(TransportSecurityPersisterTest, SerializeData3) {
     97   // Add an entry.
     98   net::HashValue fp1(net::HASH_VALUE_SHA1);
     99   memset(fp1.data(), 0, fp1.size());
    100   net::HashValue fp2(net::HASH_VALUE_SHA1);
    101   memset(fp2.data(), 1, fp2.size());
    102   base::Time expiry =
    103       base::Time::Now() + base::TimeDelta::FromSeconds(1000);
    104   net::HashValueVector dynamic_spki_hashes;
    105   dynamic_spki_hashes.push_back(fp1);
    106   dynamic_spki_hashes.push_back(fp2);
    107   bool include_subdomains = false;
    108   state_.AddHSTS("www.example.com", expiry, include_subdomains);
    109   state_.AddHPKP("www.example.com", expiry, include_subdomains,
    110                  dynamic_spki_hashes);
    111 
    112   // Add another entry.
    113   memset(fp1.data(), 2, fp1.size());
    114   memset(fp2.data(), 3, fp2.size());
    115   expiry =
    116       base::Time::Now() + base::TimeDelta::FromSeconds(3000);
    117   dynamic_spki_hashes.push_back(fp1);
    118   dynamic_spki_hashes.push_back(fp2);
    119   state_.AddHSTS("www.example.net", expiry, include_subdomains);
    120   state_.AddHPKP("www.example.net", expiry, include_subdomains,
    121                  dynamic_spki_hashes);
    122 
    123   // Save a copy of everything.
    124   std::map<std::string, TransportSecurityState::DomainState> saved;
    125   TransportSecurityState::Iterator i(state_);
    126   while (i.HasNext()) {
    127     saved[i.hostname()] = i.domain_state();
    128     i.Advance();
    129   }
    130 
    131   std::string serialized;
    132   EXPECT_TRUE(persister_->SerializeData(&serialized));
    133 
    134   // Persist the data to the file. For the test to be fast and not flaky, we
    135   // just do it directly rather than call persister_->StateIsDirty. (That uses
    136   // ImportantFileWriter, which has an asynchronous commit interval rather
    137   // than block.) Use a different basename just for cleanliness.
    138   base::FilePath path =
    139       temp_dir_.path().AppendASCII("TransportSecurityPersisterTest");
    140   EXPECT_TRUE(file_util::WriteFile(path, serialized.c_str(),
    141                                    serialized.size()));
    142 
    143   // Read the data back.
    144   std::string persisted;
    145   EXPECT_TRUE(file_util::ReadFileToString(path, &persisted));
    146   EXPECT_EQ(persisted, serialized);
    147   bool dirty;
    148   EXPECT_TRUE(persister_->LoadEntries(persisted, &dirty));
    149   EXPECT_FALSE(dirty);
    150 
    151   // Check that states are the same as saved.
    152   size_t count = 0;
    153   TransportSecurityState::Iterator j(state_);
    154   while (j.HasNext()) {
    155     count++;
    156     j.Advance();
    157   }
    158   EXPECT_EQ(count, saved.size());
    159 }
    160 
    161 TEST_F(TransportSecurityPersisterTest, SerializeDataOld) {
    162   // This is an old-style piece of transport state JSON, which has no creation
    163   // date.
    164   std::string output =
    165       "{ "
    166       "\"NiyD+3J1r6z1wjl2n1ALBu94Zj9OsEAMo0kCN8js0Uk=\": {"
    167       "\"expiry\": 1266815027.983453, "
    168       "\"include_subdomains\": false, "
    169       "\"mode\": \"strict\" "
    170       "}"
    171       "}";
    172   bool dirty;
    173   EXPECT_TRUE(persister_->LoadEntries(output, &dirty));
    174   EXPECT_TRUE(dirty);
    175 }
    176 
    177 TEST_F(TransportSecurityPersisterTest, PublicKeyHashes) {
    178   TransportSecurityState::DomainState domain_state;
    179   static const char kTestDomain[] = "example.com";
    180   EXPECT_FALSE(state_.GetDomainState(kTestDomain, false, &domain_state));
    181   net::HashValueVector hashes;
    182   EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes));
    183 
    184   net::HashValue sha1(net::HASH_VALUE_SHA1);
    185   memset(sha1.data(), '1', sha1.size());
    186   domain_state.dynamic_spki_hashes.push_back(sha1);
    187 
    188   EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes));
    189 
    190   hashes.push_back(sha1);
    191   EXPECT_TRUE(domain_state.CheckPublicKeyPins(hashes));
    192 
    193   hashes[0].data()[0] = '2';
    194   EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes));
    195 
    196   const base::Time current_time(base::Time::Now());
    197   const base::Time expiry = current_time + base::TimeDelta::FromSeconds(1000);
    198   bool include_subdomains = false;
    199   state_.AddHSTS(kTestDomain, expiry, include_subdomains);
    200   state_.AddHPKP(kTestDomain, expiry, include_subdomains,
    201                  domain_state.dynamic_spki_hashes);
    202   std::string ser;
    203   EXPECT_TRUE(persister_->SerializeData(&ser));
    204   bool dirty;
    205   EXPECT_TRUE(persister_->LoadEntries(ser, &dirty));
    206   EXPECT_TRUE(state_.GetDomainState(kTestDomain, false, &domain_state));
    207   EXPECT_EQ(1u, domain_state.dynamic_spki_hashes.size());
    208   EXPECT_EQ(sha1.tag, domain_state.dynamic_spki_hashes[0].tag);
    209   EXPECT_EQ(0, memcmp(domain_state.dynamic_spki_hashes[0].data(), sha1.data(),
    210                       sha1.size()));
    211 }
    212