Home | History | Annotate | Download | only in http
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/http/transport_security_persister.h"
      6 
      7 #include "base/base64.h"
      8 #include "base/bind.h"
      9 #include "base/files/file_path.h"
     10 #include "base/files/file_util.h"
     11 #include "base/json/json_reader.h"
     12 #include "base/json/json_writer.h"
     13 #include "base/message_loop/message_loop.h"
     14 #include "base/message_loop/message_loop_proxy.h"
     15 #include "base/sequenced_task_runner.h"
     16 #include "base/task_runner_util.h"
     17 #include "base/values.h"
     18 #include "crypto/sha2.h"
     19 #include "net/cert/x509_certificate.h"
     20 #include "net/http/transport_security_state.h"
     21 
     22 using net::HashValue;
     23 using net::HashValueTag;
     24 using net::HashValueVector;
     25 using net::TransportSecurityState;
     26 
     27 namespace {
     28 
     29 base::ListValue* SPKIHashesToListValue(const HashValueVector& hashes) {
     30   base::ListValue* pins = new base::ListValue;
     31   for (size_t i = 0; i != hashes.size(); i++)
     32     pins->Append(new base::StringValue(hashes[i].ToString()));
     33   return pins;
     34 }
     35 
     36 void SPKIHashesFromListValue(const base::ListValue& pins,
     37                              HashValueVector* hashes) {
     38   size_t num_pins = pins.GetSize();
     39   for (size_t i = 0; i < num_pins; ++i) {
     40     std::string type_and_base64;
     41     HashValue fingerprint;
     42     if (pins.GetString(i, &type_and_base64) &&
     43         fingerprint.FromString(type_and_base64)) {
     44       hashes->push_back(fingerprint);
     45     }
     46   }
     47 }
     48 
     49 // This function converts the binary hashes to a base64 string which we can
     50 // include in a JSON file.
     51 std::string HashedDomainToExternalString(const std::string& hashed) {
     52   std::string out;
     53   base::Base64Encode(hashed, &out);
     54   return out;
     55 }
     56 
     57 // This inverts |HashedDomainToExternalString|, above. It turns an external
     58 // string (from a JSON file) into an internal (binary) string.
     59 std::string ExternalStringToHashedDomain(const std::string& external) {
     60   std::string out;
     61   if (!base::Base64Decode(external, &out) ||
     62       out.size() != crypto::kSHA256Length) {
     63     return std::string();
     64   }
     65 
     66   return out;
     67 }
     68 
     69 const char kIncludeSubdomains[] = "include_subdomains";
     70 const char kStsIncludeSubdomains[] = "sts_include_subdomains";
     71 const char kPkpIncludeSubdomains[] = "pkp_include_subdomains";
     72 const char kMode[] = "mode";
     73 const char kExpiry[] = "expiry";
     74 const char kDynamicSPKIHashesExpiry[] = "dynamic_spki_hashes_expiry";
     75 const char kDynamicSPKIHashes[] = "dynamic_spki_hashes";
     76 const char kForceHTTPS[] = "force-https";
     77 const char kStrict[] = "strict";
     78 const char kDefault[] = "default";
     79 const char kPinningOnly[] = "pinning-only";
     80 const char kCreated[] = "created";
     81 const char kStsObserved[] = "sts_observed";
     82 const char kPkpObserved[] = "pkp_observed";
     83 
     84 std::string LoadState(const base::FilePath& path) {
     85   std::string result;
     86   if (!base::ReadFileToString(path, &result)) {
     87     return "";
     88   }
     89   return result;
     90 }
     91 
     92 }  // namespace
     93 
     94 
     95 namespace net {
     96 
     97 TransportSecurityPersister::TransportSecurityPersister(
     98     TransportSecurityState* state,
     99     const base::FilePath& profile_path,
    100     const scoped_refptr<base::SequencedTaskRunner>& background_runner,
    101     bool readonly)
    102     : transport_security_state_(state),
    103       writer_(profile_path.AppendASCII("TransportSecurity"), background_runner),
    104       foreground_runner_(base::MessageLoop::current()->message_loop_proxy()),
    105       background_runner_(background_runner),
    106       readonly_(readonly),
    107       weak_ptr_factory_(this) {
    108   transport_security_state_->SetDelegate(this);
    109 
    110   base::PostTaskAndReplyWithResult(
    111       background_runner_.get(),
    112       FROM_HERE,
    113       base::Bind(&::LoadState, writer_.path()),
    114       base::Bind(&TransportSecurityPersister::CompleteLoad,
    115                  weak_ptr_factory_.GetWeakPtr()));
    116 }
    117 
    118 TransportSecurityPersister::~TransportSecurityPersister() {
    119   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
    120 
    121   if (writer_.HasPendingWrite())
    122     writer_.DoScheduledWrite();
    123 
    124   transport_security_state_->SetDelegate(NULL);
    125 }
    126 
    127 void TransportSecurityPersister::StateIsDirty(
    128     TransportSecurityState* state) {
    129   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
    130   DCHECK_EQ(transport_security_state_, state);
    131 
    132   if (!readonly_)
    133     writer_.ScheduleWrite(this);
    134 }
    135 
    136 bool TransportSecurityPersister::SerializeData(std::string* output) {
    137   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
    138 
    139   base::DictionaryValue toplevel;
    140   base::Time now = base::Time::Now();
    141   TransportSecurityState::Iterator state(*transport_security_state_);
    142   for (; state.HasNext(); state.Advance()) {
    143     const std::string& hostname = state.hostname();
    144     const TransportSecurityState::DomainState& domain_state =
    145         state.domain_state();
    146 
    147     base::DictionaryValue* serialized = new base::DictionaryValue;
    148     serialized->SetBoolean(kStsIncludeSubdomains,
    149                            domain_state.sts.include_subdomains);
    150     serialized->SetBoolean(kPkpIncludeSubdomains,
    151                            domain_state.pkp.include_subdomains);
    152     serialized->SetDouble(kStsObserved,
    153                           domain_state.sts.last_observed.ToDoubleT());
    154     serialized->SetDouble(kPkpObserved,
    155                           domain_state.pkp.last_observed.ToDoubleT());
    156     serialized->SetDouble(kExpiry, domain_state.sts.expiry.ToDoubleT());
    157     serialized->SetDouble(kDynamicSPKIHashesExpiry,
    158                           domain_state.pkp.expiry.ToDoubleT());
    159 
    160     switch (domain_state.sts.upgrade_mode) {
    161       case TransportSecurityState::DomainState::MODE_FORCE_HTTPS:
    162         serialized->SetString(kMode, kForceHTTPS);
    163         break;
    164       case TransportSecurityState::DomainState::MODE_DEFAULT:
    165         serialized->SetString(kMode, kDefault);
    166         break;
    167       default:
    168         NOTREACHED() << "DomainState with unknown mode";
    169         delete serialized;
    170         continue;
    171     }
    172 
    173     if (now < domain_state.pkp.expiry) {
    174       serialized->Set(kDynamicSPKIHashes,
    175                       SPKIHashesToListValue(domain_state.pkp.spki_hashes));
    176     }
    177 
    178     toplevel.Set(HashedDomainToExternalString(hostname), serialized);
    179   }
    180 
    181   base::JSONWriter::WriteWithOptions(&toplevel,
    182                                      base::JSONWriter::OPTIONS_PRETTY_PRINT,
    183                                      output);
    184   return true;
    185 }
    186 
    187 bool TransportSecurityPersister::LoadEntries(const std::string& serialized,
    188                                              bool* dirty) {
    189   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
    190 
    191   transport_security_state_->ClearDynamicData();
    192   return Deserialize(serialized, dirty, transport_security_state_);
    193 }
    194 
    195 // static
    196 bool TransportSecurityPersister::Deserialize(const std::string& serialized,
    197                                              bool* dirty,
    198                                              TransportSecurityState* state) {
    199   scoped_ptr<base::Value> value(base::JSONReader::Read(serialized));
    200   base::DictionaryValue* dict_value = NULL;
    201   if (!value.get() || !value->GetAsDictionary(&dict_value))
    202     return false;
    203 
    204   const base::Time current_time(base::Time::Now());
    205   bool dirtied = false;
    206 
    207   for (base::DictionaryValue::Iterator i(*dict_value);
    208        !i.IsAtEnd(); i.Advance()) {
    209     const base::DictionaryValue* parsed = NULL;
    210     if (!i.value().GetAsDictionary(&parsed)) {
    211       LOG(WARNING) << "Could not parse entry " << i.key() << "; skipping entry";
    212       continue;
    213     }
    214 
    215     TransportSecurityState::DomainState domain_state;
    216 
    217     // kIncludeSubdomains is a legacy synonym for kStsIncludeSubdomains and
    218     // kPkpIncludeSubdomains. Parse at least one of these properties,
    219     // preferably the new ones.
    220     bool include_subdomains = false;
    221     bool parsed_include_subdomains = parsed->GetBoolean(kIncludeSubdomains,
    222                                                         &include_subdomains);
    223     domain_state.sts.include_subdomains = include_subdomains;
    224     domain_state.pkp.include_subdomains = include_subdomains;
    225     if (parsed->GetBoolean(kStsIncludeSubdomains, &include_subdomains)) {
    226       domain_state.sts.include_subdomains = include_subdomains;
    227       parsed_include_subdomains = true;
    228     }
    229     if (parsed->GetBoolean(kPkpIncludeSubdomains, &include_subdomains)) {
    230       domain_state.pkp.include_subdomains = include_subdomains;
    231       parsed_include_subdomains = true;
    232     }
    233 
    234     std::string mode_string;
    235     double expiry = 0;
    236     if (!parsed_include_subdomains ||
    237         !parsed->GetString(kMode, &mode_string) ||
    238         !parsed->GetDouble(kExpiry, &expiry)) {
    239       LOG(WARNING) << "Could not parse some elements of entry " << i.key()
    240                    << "; skipping entry";
    241       continue;
    242     }
    243 
    244     // Don't fail if this key is not present.
    245     double dynamic_spki_hashes_expiry = 0;
    246     parsed->GetDouble(kDynamicSPKIHashesExpiry,
    247                       &dynamic_spki_hashes_expiry);
    248 
    249     const base::ListValue* pins_list = NULL;
    250     if (parsed->GetList(kDynamicSPKIHashes, &pins_list)) {
    251       SPKIHashesFromListValue(*pins_list, &domain_state.pkp.spki_hashes);
    252     }
    253 
    254     if (mode_string == kForceHTTPS || mode_string == kStrict) {
    255       domain_state.sts.upgrade_mode =
    256           TransportSecurityState::DomainState::MODE_FORCE_HTTPS;
    257     } else if (mode_string == kDefault || mode_string == kPinningOnly) {
    258       domain_state.sts.upgrade_mode =
    259           TransportSecurityState::DomainState::MODE_DEFAULT;
    260     } else {
    261       LOG(WARNING) << "Unknown TransportSecurityState mode string "
    262                    << mode_string << " found for entry " << i.key()
    263                    << "; skipping entry";
    264       continue;
    265     }
    266 
    267     domain_state.sts.expiry = base::Time::FromDoubleT(expiry);
    268     domain_state.pkp.expiry =
    269         base::Time::FromDoubleT(dynamic_spki_hashes_expiry);
    270 
    271     double sts_observed;
    272     double pkp_observed;
    273     if (parsed->GetDouble(kStsObserved, &sts_observed)) {
    274       domain_state.sts.last_observed = base::Time::FromDoubleT(sts_observed);
    275     } else if (parsed->GetDouble(kCreated, &sts_observed)) {
    276       // kCreated is a legacy synonym for both kStsObserved and kPkpObserved.
    277       domain_state.sts.last_observed = base::Time::FromDoubleT(sts_observed);
    278     } else {
    279       // We're migrating an old entry with no observation date. Make sure we
    280       // write the new date back in a reasonable time frame.
    281       dirtied = true;
    282       domain_state.sts.last_observed = base::Time::Now();
    283     }
    284     if (parsed->GetDouble(kPkpObserved, &pkp_observed)) {
    285       domain_state.pkp.last_observed = base::Time::FromDoubleT(pkp_observed);
    286     } else if (parsed->GetDouble(kCreated, &pkp_observed)) {
    287       domain_state.pkp.last_observed = base::Time::FromDoubleT(pkp_observed);
    288     } else {
    289       dirtied = true;
    290       domain_state.pkp.last_observed = base::Time::Now();
    291     }
    292 
    293     if (domain_state.sts.expiry <= current_time &&
    294         domain_state.pkp.expiry <= current_time) {
    295       // Make sure we dirty the state if we drop an entry.
    296       dirtied = true;
    297       continue;
    298     }
    299 
    300     std::string hashed = ExternalStringToHashedDomain(i.key());
    301     if (hashed.empty()) {
    302       dirtied = true;
    303       continue;
    304     }
    305 
    306     state->AddOrUpdateEnabledHosts(hashed, domain_state);
    307   }
    308 
    309   *dirty = dirtied;
    310   return true;
    311 }
    312 
    313 void TransportSecurityPersister::CompleteLoad(const std::string& state) {
    314   DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
    315 
    316   if (state.empty())
    317     return;
    318 
    319   bool dirty = false;
    320   if (!LoadEntries(state, &dirty)) {
    321     LOG(ERROR) << "Failed to deserialize state: " << state;
    322     return;
    323   }
    324   if (dirty)
    325     StateIsDirty(transport_security_state_);
    326 }
    327 
    328 }  // namespace net
    329