Home | History | Annotate | Download | only in ssl
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/ssl/default_channel_id_store.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/message_loop/message_loop.h"
      9 #include "base/metrics/histogram.h"
     10 #include "net/base/net_errors.h"
     11 
     12 namespace net {
     13 
     14 // --------------------------------------------------------------------------
     15 // Task
     16 class DefaultChannelIDStore::Task {
     17  public:
     18   virtual ~Task();
     19 
     20   // Runs the task and invokes the client callback on the thread that
     21   // originally constructed the task.
     22   virtual void Run(DefaultChannelIDStore* store) = 0;
     23 
     24  protected:
     25   void InvokeCallback(base::Closure callback) const;
     26 };
     27 
     28 DefaultChannelIDStore::Task::~Task() {
     29 }
     30 
     31 void DefaultChannelIDStore::Task::InvokeCallback(
     32     base::Closure callback) const {
     33   if (!callback.is_null())
     34     callback.Run();
     35 }
     36 
     37 // --------------------------------------------------------------------------
     38 // GetChannelIDTask
     39 class DefaultChannelIDStore::GetChannelIDTask
     40     : public DefaultChannelIDStore::Task {
     41  public:
     42   GetChannelIDTask(const std::string& server_identifier,
     43                    const GetChannelIDCallback& callback);
     44   virtual ~GetChannelIDTask();
     45   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
     46 
     47  private:
     48   std::string server_identifier_;
     49   GetChannelIDCallback callback_;
     50 };
     51 
     52 DefaultChannelIDStore::GetChannelIDTask::GetChannelIDTask(
     53     const std::string& server_identifier,
     54     const GetChannelIDCallback& callback)
     55     : server_identifier_(server_identifier),
     56       callback_(callback) {
     57 }
     58 
     59 DefaultChannelIDStore::GetChannelIDTask::~GetChannelIDTask() {
     60 }
     61 
     62 void DefaultChannelIDStore::GetChannelIDTask::Run(
     63     DefaultChannelIDStore* store) {
     64   base::Time expiration_time;
     65   std::string private_key_result;
     66   std::string cert_result;
     67   int err = store->GetChannelID(
     68       server_identifier_, &expiration_time, &private_key_result,
     69       &cert_result, GetChannelIDCallback());
     70   DCHECK(err != ERR_IO_PENDING);
     71 
     72   InvokeCallback(base::Bind(callback_, err, server_identifier_,
     73                             expiration_time, private_key_result, cert_result));
     74 }
     75 
     76 // --------------------------------------------------------------------------
     77 // SetChannelIDTask
     78 class DefaultChannelIDStore::SetChannelIDTask
     79     : public DefaultChannelIDStore::Task {
     80  public:
     81   SetChannelIDTask(const std::string& server_identifier,
     82                    base::Time creation_time,
     83                    base::Time expiration_time,
     84                    const std::string& private_key,
     85                    const std::string& cert);
     86   virtual ~SetChannelIDTask();
     87   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
     88 
     89  private:
     90   std::string server_identifier_;
     91   base::Time creation_time_;
     92   base::Time expiration_time_;
     93   std::string private_key_;
     94   std::string cert_;
     95 };
     96 
     97 DefaultChannelIDStore::SetChannelIDTask::SetChannelIDTask(
     98     const std::string& server_identifier,
     99     base::Time creation_time,
    100     base::Time expiration_time,
    101     const std::string& private_key,
    102     const std::string& cert)
    103     : server_identifier_(server_identifier),
    104       creation_time_(creation_time),
    105       expiration_time_(expiration_time),
    106       private_key_(private_key),
    107       cert_(cert) {
    108 }
    109 
    110 DefaultChannelIDStore::SetChannelIDTask::~SetChannelIDTask() {
    111 }
    112 
    113 void DefaultChannelIDStore::SetChannelIDTask::Run(
    114     DefaultChannelIDStore* store) {
    115   store->SyncSetChannelID(server_identifier_, creation_time_,
    116                           expiration_time_, private_key_, cert_);
    117 }
    118 
    119 // --------------------------------------------------------------------------
    120 // DeleteChannelIDTask
    121 class DefaultChannelIDStore::DeleteChannelIDTask
    122     : public DefaultChannelIDStore::Task {
    123  public:
    124   DeleteChannelIDTask(const std::string& server_identifier,
    125                       const base::Closure& callback);
    126   virtual ~DeleteChannelIDTask();
    127   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
    128 
    129  private:
    130   std::string server_identifier_;
    131   base::Closure callback_;
    132 };
    133 
    134 DefaultChannelIDStore::DeleteChannelIDTask::
    135     DeleteChannelIDTask(
    136         const std::string& server_identifier,
    137         const base::Closure& callback)
    138         : server_identifier_(server_identifier),
    139           callback_(callback) {
    140 }
    141 
    142 DefaultChannelIDStore::DeleteChannelIDTask::
    143     ~DeleteChannelIDTask() {
    144 }
    145 
    146 void DefaultChannelIDStore::DeleteChannelIDTask::Run(
    147     DefaultChannelIDStore* store) {
    148   store->SyncDeleteChannelID(server_identifier_);
    149 
    150   InvokeCallback(callback_);
    151 }
    152 
    153 // --------------------------------------------------------------------------
    154 // DeleteAllCreatedBetweenTask
    155 class DefaultChannelIDStore::DeleteAllCreatedBetweenTask
    156     : public DefaultChannelIDStore::Task {
    157  public:
    158   DeleteAllCreatedBetweenTask(base::Time delete_begin,
    159                               base::Time delete_end,
    160                               const base::Closure& callback);
    161   virtual ~DeleteAllCreatedBetweenTask();
    162   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
    163 
    164  private:
    165   base::Time delete_begin_;
    166   base::Time delete_end_;
    167   base::Closure callback_;
    168 };
    169 
    170 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
    171     DeleteAllCreatedBetweenTask(
    172         base::Time delete_begin,
    173         base::Time delete_end,
    174         const base::Closure& callback)
    175         : delete_begin_(delete_begin),
    176           delete_end_(delete_end),
    177           callback_(callback) {
    178 }
    179 
    180 DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
    181     ~DeleteAllCreatedBetweenTask() {
    182 }
    183 
    184 void DefaultChannelIDStore::DeleteAllCreatedBetweenTask::Run(
    185     DefaultChannelIDStore* store) {
    186   store->SyncDeleteAllCreatedBetween(delete_begin_, delete_end_);
    187 
    188   InvokeCallback(callback_);
    189 }
    190 
    191 // --------------------------------------------------------------------------
    192 // GetAllChannelIDsTask
    193 class DefaultChannelIDStore::GetAllChannelIDsTask
    194     : public DefaultChannelIDStore::Task {
    195  public:
    196   explicit GetAllChannelIDsTask(const GetChannelIDListCallback& callback);
    197   virtual ~GetAllChannelIDsTask();
    198   virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
    199 
    200  private:
    201   std::string server_identifier_;
    202   GetChannelIDListCallback callback_;
    203 };
    204 
    205 DefaultChannelIDStore::GetAllChannelIDsTask::
    206     GetAllChannelIDsTask(const GetChannelIDListCallback& callback)
    207         : callback_(callback) {
    208 }
    209 
    210 DefaultChannelIDStore::GetAllChannelIDsTask::
    211     ~GetAllChannelIDsTask() {
    212 }
    213 
    214 void DefaultChannelIDStore::GetAllChannelIDsTask::Run(
    215     DefaultChannelIDStore* store) {
    216   ChannelIDList cert_list;
    217   store->SyncGetAllChannelIDs(&cert_list);
    218 
    219   InvokeCallback(base::Bind(callback_, cert_list));
    220 }
    221 
    222 // --------------------------------------------------------------------------
    223 // DefaultChannelIDStore
    224 
    225 DefaultChannelIDStore::DefaultChannelIDStore(
    226     PersistentStore* store)
    227     : initialized_(false),
    228       loaded_(false),
    229       store_(store),
    230       weak_ptr_factory_(this) {}
    231 
    232 int DefaultChannelIDStore::GetChannelID(
    233     const std::string& server_identifier,
    234     base::Time* expiration_time,
    235     std::string* private_key_result,
    236     std::string* cert_result,
    237     const GetChannelIDCallback& callback) {
    238   DCHECK(CalledOnValidThread());
    239   InitIfNecessary();
    240 
    241   if (!loaded_) {
    242     EnqueueTask(scoped_ptr<Task>(
    243         new GetChannelIDTask(server_identifier, callback)));
    244     return ERR_IO_PENDING;
    245   }
    246 
    247   ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
    248 
    249   if (it == channel_ids_.end())
    250     return ERR_FILE_NOT_FOUND;
    251 
    252   ChannelID* channel_id = it->second;
    253   *expiration_time = channel_id->expiration_time();
    254   *private_key_result = channel_id->private_key();
    255   *cert_result = channel_id->cert();
    256 
    257   return OK;
    258 }
    259 
    260 void DefaultChannelIDStore::SetChannelID(
    261     const std::string& server_identifier,
    262     base::Time creation_time,
    263     base::Time expiration_time,
    264     const std::string& private_key,
    265     const std::string& cert) {
    266   RunOrEnqueueTask(scoped_ptr<Task>(new SetChannelIDTask(
    267       server_identifier, creation_time, expiration_time, private_key,
    268       cert)));
    269 }
    270 
    271 void DefaultChannelIDStore::DeleteChannelID(
    272     const std::string& server_identifier,
    273     const base::Closure& callback) {
    274   RunOrEnqueueTask(scoped_ptr<Task>(
    275       new DeleteChannelIDTask(server_identifier, callback)));
    276 }
    277 
    278 void DefaultChannelIDStore::DeleteAllCreatedBetween(
    279     base::Time delete_begin,
    280     base::Time delete_end,
    281     const base::Closure& callback) {
    282   RunOrEnqueueTask(scoped_ptr<Task>(
    283       new DeleteAllCreatedBetweenTask(delete_begin, delete_end, callback)));
    284 }
    285 
    286 void DefaultChannelIDStore::DeleteAll(
    287     const base::Closure& callback) {
    288   DeleteAllCreatedBetween(base::Time(), base::Time(), callback);
    289 }
    290 
    291 void DefaultChannelIDStore::GetAllChannelIDs(
    292     const GetChannelIDListCallback& callback) {
    293   RunOrEnqueueTask(scoped_ptr<Task>(new GetAllChannelIDsTask(callback)));
    294 }
    295 
    296 int DefaultChannelIDStore::GetChannelIDCount() {
    297   DCHECK(CalledOnValidThread());
    298 
    299   return channel_ids_.size();
    300 }
    301 
    302 void DefaultChannelIDStore::SetForceKeepSessionState() {
    303   DCHECK(CalledOnValidThread());
    304   InitIfNecessary();
    305 
    306   if (store_.get())
    307     store_->SetForceKeepSessionState();
    308 }
    309 
    310 DefaultChannelIDStore::~DefaultChannelIDStore() {
    311   DeleteAllInMemory();
    312 }
    313 
    314 void DefaultChannelIDStore::DeleteAllInMemory() {
    315   DCHECK(CalledOnValidThread());
    316 
    317   for (ChannelIDMap::iterator it = channel_ids_.begin();
    318        it != channel_ids_.end(); ++it) {
    319     delete it->second;
    320   }
    321   channel_ids_.clear();
    322 }
    323 
    324 void DefaultChannelIDStore::InitStore() {
    325   DCHECK(CalledOnValidThread());
    326   DCHECK(store_.get()) << "Store must exist to initialize";
    327   DCHECK(!loaded_);
    328 
    329   store_->Load(base::Bind(&DefaultChannelIDStore::OnLoaded,
    330                           weak_ptr_factory_.GetWeakPtr()));
    331 }
    332 
    333 void DefaultChannelIDStore::OnLoaded(
    334     scoped_ptr<ScopedVector<ChannelID> > channel_ids) {
    335   DCHECK(CalledOnValidThread());
    336 
    337   for (std::vector<ChannelID*>::const_iterator it = channel_ids->begin();
    338        it != channel_ids->end(); ++it) {
    339     DCHECK(channel_ids_.find((*it)->server_identifier()) ==
    340            channel_ids_.end());
    341     channel_ids_[(*it)->server_identifier()] = *it;
    342   }
    343   channel_ids->weak_clear();
    344 
    345   loaded_ = true;
    346 
    347   base::TimeDelta wait_time;
    348   if (!waiting_tasks_.empty())
    349     wait_time = base::TimeTicks::Now() - waiting_tasks_start_time_;
    350   DVLOG(1) << "Task delay " << wait_time.InMilliseconds();
    351   UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.TaskMaxWaitTime",
    352                              wait_time,
    353                              base::TimeDelta::FromMilliseconds(1),
    354                              base::TimeDelta::FromMinutes(1),
    355                              50);
    356   UMA_HISTOGRAM_COUNTS_100("DomainBoundCerts.TaskWaitCount",
    357                            waiting_tasks_.size());
    358 
    359 
    360   for (ScopedVector<Task>::iterator i = waiting_tasks_.begin();
    361        i != waiting_tasks_.end(); ++i)
    362     (*i)->Run(this);
    363   waiting_tasks_.clear();
    364 }
    365 
    366 void DefaultChannelIDStore::SyncSetChannelID(
    367     const std::string& server_identifier,
    368     base::Time creation_time,
    369     base::Time expiration_time,
    370     const std::string& private_key,
    371     const std::string& cert) {
    372   DCHECK(CalledOnValidThread());
    373   DCHECK(loaded_);
    374 
    375   InternalDeleteChannelID(server_identifier);
    376   InternalInsertChannelID(
    377       server_identifier,
    378       new ChannelID(
    379           server_identifier, creation_time, expiration_time, private_key,
    380           cert));
    381 }
    382 
    383 void DefaultChannelIDStore::SyncDeleteChannelID(
    384     const std::string& server_identifier) {
    385   DCHECK(CalledOnValidThread());
    386   DCHECK(loaded_);
    387   InternalDeleteChannelID(server_identifier);
    388 }
    389 
    390 void DefaultChannelIDStore::SyncDeleteAllCreatedBetween(
    391     base::Time delete_begin,
    392     base::Time delete_end) {
    393   DCHECK(CalledOnValidThread());
    394   DCHECK(loaded_);
    395   for (ChannelIDMap::iterator it = channel_ids_.begin();
    396        it != channel_ids_.end();) {
    397     ChannelIDMap::iterator cur = it;
    398     ++it;
    399     ChannelID* channel_id = cur->second;
    400     if ((delete_begin.is_null() ||
    401          channel_id->creation_time() >= delete_begin) &&
    402         (delete_end.is_null() || channel_id->creation_time() < delete_end)) {
    403       if (store_.get())
    404         store_->DeleteChannelID(*channel_id);
    405       delete channel_id;
    406       channel_ids_.erase(cur);
    407     }
    408   }
    409 }
    410 
    411 void DefaultChannelIDStore::SyncGetAllChannelIDs(
    412     ChannelIDList* channel_id_list) {
    413   DCHECK(CalledOnValidThread());
    414   DCHECK(loaded_);
    415   for (ChannelIDMap::iterator it = channel_ids_.begin();
    416        it != channel_ids_.end(); ++it)
    417     channel_id_list->push_back(*it->second);
    418 }
    419 
    420 void DefaultChannelIDStore::EnqueueTask(scoped_ptr<Task> task) {
    421   DCHECK(CalledOnValidThread());
    422   DCHECK(!loaded_);
    423   if (waiting_tasks_.empty())
    424     waiting_tasks_start_time_ = base::TimeTicks::Now();
    425   waiting_tasks_.push_back(task.release());
    426 }
    427 
    428 void DefaultChannelIDStore::RunOrEnqueueTask(scoped_ptr<Task> task) {
    429   DCHECK(CalledOnValidThread());
    430   InitIfNecessary();
    431 
    432   if (!loaded_) {
    433     EnqueueTask(task.Pass());
    434     return;
    435   }
    436 
    437   task->Run(this);
    438 }
    439 
    440 void DefaultChannelIDStore::InternalDeleteChannelID(
    441     const std::string& server_identifier) {
    442   DCHECK(CalledOnValidThread());
    443   DCHECK(loaded_);
    444 
    445   ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
    446   if (it == channel_ids_.end())
    447     return;  // There is nothing to delete.
    448 
    449   ChannelID* channel_id = it->second;
    450   if (store_.get())
    451     store_->DeleteChannelID(*channel_id);
    452   channel_ids_.erase(it);
    453   delete channel_id;
    454 }
    455 
    456 void DefaultChannelIDStore::InternalInsertChannelID(
    457     const std::string& server_identifier,
    458     ChannelID* channel_id) {
    459   DCHECK(CalledOnValidThread());
    460   DCHECK(loaded_);
    461 
    462   if (store_.get())
    463     store_->AddChannelID(*channel_id);
    464   channel_ids_[server_identifier] = channel_id;
    465 }
    466 
    467 DefaultChannelIDStore::PersistentStore::PersistentStore() {}
    468 
    469 DefaultChannelIDStore::PersistentStore::~PersistentStore() {}
    470 
    471 }  // namespace net
    472