1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/http/transport_security_persister.h" 6 7 #include <map> 8 #include <string> 9 #include <vector> 10 11 #include "base/file_util.h" 12 #include "base/files/file_path.h" 13 #include "base/files/scoped_temp_dir.h" 14 #include "base/message_loop/message_loop.h" 15 #include "net/http/transport_security_state.h" 16 #include "testing/gtest/include/gtest/gtest.h" 17 18 using net::TransportSecurityPersister; 19 using net::TransportSecurityState; 20 21 class TransportSecurityPersisterTest : public testing::Test { 22 public: 23 TransportSecurityPersisterTest() { 24 } 25 26 virtual ~TransportSecurityPersisterTest() { 27 base::MessageLoopForIO::current()->RunUntilIdle(); 28 } 29 30 virtual void SetUp() OVERRIDE { 31 ASSERT_TRUE(temp_dir_.CreateUniqueTempDir()); 32 persister_.reset(new TransportSecurityPersister( 33 &state_, 34 temp_dir_.path(), 35 base::MessageLoopForIO::current()->message_loop_proxy(), 36 false)); 37 } 38 39 protected: 40 base::ScopedTempDir temp_dir_; 41 TransportSecurityState state_; 42 scoped_ptr<TransportSecurityPersister> persister_; 43 }; 44 45 TEST_F(TransportSecurityPersisterTest, SerializeData1) { 46 std::string output; 47 bool dirty; 48 49 EXPECT_TRUE(persister_->SerializeData(&output)); 50 EXPECT_TRUE(persister_->LoadEntries(output, &dirty)); 51 EXPECT_FALSE(dirty); 52 } 53 54 TEST_F(TransportSecurityPersisterTest, SerializeData2) { 55 TransportSecurityState::DomainState domain_state; 56 const base::Time current_time(base::Time::Now()); 57 const base::Time expiry = current_time + base::TimeDelta::FromSeconds(1000); 58 static const char kYahooDomain[] = "yahoo.com"; 59 60 EXPECT_FALSE(state_.GetStaticDomainState(kYahooDomain, true, &domain_state)); 61 EXPECT_FALSE(state_.GetDynamicDomainState(kYahooDomain, &domain_state)); 62 63 bool include_subdomains = true; 64 state_.AddHSTS(kYahooDomain, expiry, include_subdomains); 65 66 std::string output; 67 bool dirty; 68 EXPECT_TRUE(persister_->SerializeData(&output)); 69 EXPECT_TRUE(persister_->LoadEntries(output, &dirty)); 70 71 EXPECT_TRUE(state_.GetDynamicDomainState(kYahooDomain, &domain_state)); 72 EXPECT_EQ(domain_state.sts.upgrade_mode, 73 TransportSecurityState::DomainState::MODE_FORCE_HTTPS); 74 EXPECT_TRUE(state_.GetDynamicDomainState("foo.yahoo.com", &domain_state)); 75 EXPECT_EQ(domain_state.sts.upgrade_mode, 76 TransportSecurityState::DomainState::MODE_FORCE_HTTPS); 77 EXPECT_TRUE(state_.GetDynamicDomainState("foo.bar.yahoo.com", &domain_state)); 78 EXPECT_EQ(domain_state.sts.upgrade_mode, 79 TransportSecurityState::DomainState::MODE_FORCE_HTTPS); 80 EXPECT_TRUE( 81 state_.GetDynamicDomainState("foo.bar.baz.yahoo.com", &domain_state)); 82 EXPECT_EQ(domain_state.sts.upgrade_mode, 83 TransportSecurityState::DomainState::MODE_FORCE_HTTPS); 84 EXPECT_FALSE(state_.GetStaticDomainState("com", true, &domain_state)); 85 } 86 87 TEST_F(TransportSecurityPersisterTest, SerializeData3) { 88 // Add an entry. 89 net::HashValue fp1(net::HASH_VALUE_SHA1); 90 memset(fp1.data(), 0, fp1.size()); 91 net::HashValue fp2(net::HASH_VALUE_SHA1); 92 memset(fp2.data(), 1, fp2.size()); 93 base::Time expiry = 94 base::Time::Now() + base::TimeDelta::FromSeconds(1000); 95 net::HashValueVector dynamic_spki_hashes; 96 dynamic_spki_hashes.push_back(fp1); 97 dynamic_spki_hashes.push_back(fp2); 98 bool include_subdomains = false; 99 state_.AddHSTS("www.example.com", expiry, include_subdomains); 100 state_.AddHPKP("www.example.com", expiry, include_subdomains, 101 dynamic_spki_hashes); 102 103 // Add another entry. 104 memset(fp1.data(), 2, fp1.size()); 105 memset(fp2.data(), 3, fp2.size()); 106 expiry = 107 base::Time::Now() + base::TimeDelta::FromSeconds(3000); 108 dynamic_spki_hashes.push_back(fp1); 109 dynamic_spki_hashes.push_back(fp2); 110 state_.AddHSTS("www.example.net", expiry, include_subdomains); 111 state_.AddHPKP("www.example.net", expiry, include_subdomains, 112 dynamic_spki_hashes); 113 114 // Save a copy of everything. 115 std::map<std::string, TransportSecurityState::DomainState> saved; 116 TransportSecurityState::Iterator i(state_); 117 while (i.HasNext()) { 118 saved[i.hostname()] = i.domain_state(); 119 i.Advance(); 120 } 121 122 std::string serialized; 123 EXPECT_TRUE(persister_->SerializeData(&serialized)); 124 125 // Persist the data to the file. For the test to be fast and not flaky, we 126 // just do it directly rather than call persister_->StateIsDirty. (That uses 127 // ImportantFileWriter, which has an asynchronous commit interval rather 128 // than block.) Use a different basename just for cleanliness. 129 base::FilePath path = 130 temp_dir_.path().AppendASCII("TransportSecurityPersisterTest"); 131 EXPECT_TRUE(base::WriteFile(path, serialized.c_str(), serialized.size())); 132 133 // Read the data back. 134 std::string persisted; 135 EXPECT_TRUE(base::ReadFileToString(path, &persisted)); 136 EXPECT_EQ(persisted, serialized); 137 bool dirty; 138 EXPECT_TRUE(persister_->LoadEntries(persisted, &dirty)); 139 EXPECT_FALSE(dirty); 140 141 // Check that states are the same as saved. 142 size_t count = 0; 143 TransportSecurityState::Iterator j(state_); 144 while (j.HasNext()) { 145 count++; 146 j.Advance(); 147 } 148 EXPECT_EQ(count, saved.size()); 149 } 150 151 TEST_F(TransportSecurityPersisterTest, SerializeDataOld) { 152 // This is an old-style piece of transport state JSON, which has no creation 153 // date. 154 std::string output = 155 "{ " 156 "\"NiyD+3J1r6z1wjl2n1ALBu94Zj9OsEAMo0kCN8js0Uk=\": {" 157 "\"expiry\": 1266815027.983453, " 158 "\"include_subdomains\": false, " 159 "\"mode\": \"strict\" " 160 "}" 161 "}"; 162 bool dirty; 163 EXPECT_TRUE(persister_->LoadEntries(output, &dirty)); 164 EXPECT_TRUE(dirty); 165 } 166 167 TEST_F(TransportSecurityPersisterTest, PublicKeyHashes) { 168 TransportSecurityState::DomainState domain_state; 169 static const char kTestDomain[] = "example.com"; 170 EXPECT_FALSE(state_.GetDynamicDomainState(kTestDomain, &domain_state)); 171 net::HashValueVector hashes; 172 std::string failure_log; 173 EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes, &failure_log)); 174 175 net::HashValue sha1(net::HASH_VALUE_SHA1); 176 memset(sha1.data(), '1', sha1.size()); 177 domain_state.pkp.spki_hashes.push_back(sha1); 178 179 EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes, &failure_log)); 180 181 hashes.push_back(sha1); 182 EXPECT_TRUE(domain_state.CheckPublicKeyPins(hashes, &failure_log)); 183 184 hashes[0].data()[0] = '2'; 185 EXPECT_FALSE(domain_state.CheckPublicKeyPins(hashes, &failure_log)); 186 187 const base::Time current_time(base::Time::Now()); 188 const base::Time expiry = current_time + base::TimeDelta::FromSeconds(1000); 189 bool include_subdomains = false; 190 state_.AddHSTS(kTestDomain, expiry, include_subdomains); 191 state_.AddHPKP( 192 kTestDomain, expiry, include_subdomains, domain_state.pkp.spki_hashes); 193 std::string serialized; 194 EXPECT_TRUE(persister_->SerializeData(&serialized)); 195 bool dirty; 196 EXPECT_TRUE(persister_->LoadEntries(serialized, &dirty)); 197 198 TransportSecurityState::DomainState new_domain_state; 199 EXPECT_TRUE(state_.GetDynamicDomainState(kTestDomain, &new_domain_state)); 200 EXPECT_EQ(1u, new_domain_state.pkp.spki_hashes.size()); 201 EXPECT_EQ(sha1.tag, new_domain_state.pkp.spki_hashes[0].tag); 202 EXPECT_EQ(0, 203 memcmp(new_domain_state.pkp.spki_hashes[0].data(), 204 sha1.data(), 205 sha1.size())); 206 } 207