Home | History | Annotate | Download | only in safe_browsing
      1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/browser/safe_browsing/bloom_filter.h"
      6 
      7 #include "base/metrics/histogram.h"
      8 #include "base/rand_util.h"
      9 #include "net/base/file_stream.h"
     10 #include "net/base/net_errors.h"
     11 
     12 namespace {
     13 
     14 // The Jenkins 96 bit mix function:
     15 // http://www.concentric.net/~Ttwang/tech/inthash.htm
     16 uint32 HashMix(BloomFilter::HashKey hash_key, uint32 c) {
     17   uint32 a = static_cast<uint32>(hash_key)       & 0xFFFFFFFF;
     18   uint32 b = static_cast<uint32>(hash_key >> 32) & 0xFFFFFFFF;
     19 
     20   a -= (b + c);  a ^= (c >> 13);
     21   b -= (c + a);  b ^= (a << 8);
     22   c -= (a + b);  c ^= (b >> 13);
     23   a -= (b + c);  a ^= (c >> 12);
     24   b -= (c + a);  b ^= (a << 16);
     25   c -= (a + b);  c ^= (b >> 5);
     26   a -= (b + c);  a ^= (c >> 3);
     27   b -= (c + a);  b ^= (a << 10);
     28   c -= (a + b);  c ^= (b >> 15);
     29 
     30   return c;
     31 }
     32 
     33 }  // namespace
     34 
     35 // static
     36 int BloomFilter::FilterSizeForKeyCount(int key_count) {
     37   const int default_min = BloomFilter::kBloomFilterMinSize;
     38   const int number_of_keys = std::max(key_count, default_min);
     39   return std::min(number_of_keys * BloomFilter::kBloomFilterSizeRatio,
     40                   BloomFilter::kBloomFilterMaxSize * 8);
     41 }
     42 
     43 // static
     44 void BloomFilter::RecordFailure(FailureType failure_type) {
     45   UMA_HISTOGRAM_ENUMERATION("SB2.BloomFailure", failure_type,
     46                             FAILURE_FILTER_MAX);
     47 }
     48 
     49 BloomFilter::BloomFilter(int bit_size) {
     50   for (int i = 0; i < kNumHashKeys; ++i)
     51     hash_keys_.push_back(base::RandUint64());
     52 
     53   // Round up to the next boundary which fits bit_size.
     54   byte_size_ = (bit_size + 7) / 8;
     55   bit_size_ = byte_size_ * 8;
     56   DCHECK_LE(bit_size, bit_size_);  // strictly more bits.
     57   data_.reset(new char[byte_size_]);
     58   memset(data_.get(), 0, byte_size_);
     59 }
     60 
     61 BloomFilter::BloomFilter(char* data, int size, const HashKeys& keys)
     62     : hash_keys_(keys) {
     63   byte_size_ = size;
     64   bit_size_ = byte_size_ * 8;
     65   data_.reset(data);
     66 }
     67 
     68 BloomFilter::~BloomFilter() {
     69 }
     70 
     71 void BloomFilter::Insert(SBPrefix hash) {
     72   uint32 hash_uint32 = static_cast<uint32>(hash);
     73   for (size_t i = 0; i < hash_keys_.size(); ++i) {
     74     uint32 index = HashMix(hash_keys_[i], hash_uint32) % bit_size_;
     75     data_[index / 8] |= 1 << (index % 8);
     76   }
     77 }
     78 
     79 bool BloomFilter::Exists(SBPrefix hash) const {
     80   uint32 hash_uint32 = static_cast<uint32>(hash);
     81   for (size_t i = 0; i < hash_keys_.size(); ++i) {
     82     uint32 index = HashMix(hash_keys_[i], hash_uint32) % bit_size_;
     83     if (!(data_[index / 8] & (1 << (index % 8))))
     84       return false;
     85   }
     86   return true;
     87 }
     88 
     89 // static.
     90 BloomFilter* BloomFilter::LoadFile(const FilePath& filter_name) {
     91   net::FileStream filter;
     92 
     93   if (filter.Open(filter_name,
     94                   base::PLATFORM_FILE_OPEN |
     95                   base::PLATFORM_FILE_READ) != net::OK) {
     96     RecordFailure(FAILURE_FILTER_READ_OPEN);
     97     return NULL;
     98   }
     99 
    100   // Make sure we have a file version that we can understand.
    101   int file_version;
    102   int bytes_read = filter.Read(reinterpret_cast<char*>(&file_version),
    103                                sizeof(file_version), NULL);
    104   if (bytes_read != sizeof(file_version) || file_version != kFileVersion) {
    105     RecordFailure(FAILURE_FILTER_READ_VERSION);
    106     return NULL;
    107   }
    108 
    109   // Get all the random hash keys.
    110   int num_keys;
    111   bytes_read = filter.Read(reinterpret_cast<char*>(&num_keys),
    112                            sizeof(num_keys), NULL);
    113   if (bytes_read != sizeof(num_keys) ||
    114       num_keys < 1 || num_keys > kNumHashKeys) {
    115     RecordFailure(FAILURE_FILTER_READ_NUM_KEYS);
    116     return NULL;
    117   }
    118 
    119   HashKeys hash_keys;
    120   for (int i = 0; i < num_keys; ++i) {
    121     HashKey key;
    122     bytes_read = filter.Read(reinterpret_cast<char*>(&key), sizeof(key), NULL);
    123     if (bytes_read != sizeof(key)) {
    124       RecordFailure(FAILURE_FILTER_READ_KEY);
    125       return NULL;
    126     }
    127     hash_keys.push_back(key);
    128   }
    129 
    130   // Read in the filter data, with sanity checks on min and max sizes.
    131   int64 remaining64 = filter.Available();
    132   if (remaining64 < kBloomFilterMinSize) {
    133     RecordFailure(FAILURE_FILTER_READ_DATA_MINSIZE);
    134     return NULL;
    135   } else if (remaining64 > kBloomFilterMaxSize) {
    136     RecordFailure(FAILURE_FILTER_READ_DATA_MAXSIZE);
    137     return NULL;
    138   }
    139 
    140   int byte_size = static_cast<int>(remaining64);
    141   scoped_array<char> data(new char[byte_size]);
    142   bytes_read = filter.Read(data.get(), byte_size, NULL);
    143   if (bytes_read < byte_size) {
    144     RecordFailure(FAILURE_FILTER_READ_DATA_SHORT);
    145     return NULL;
    146   } else if (bytes_read != byte_size) {
    147     RecordFailure(FAILURE_FILTER_READ_DATA);
    148     return NULL;
    149   }
    150 
    151   // We've read everything okay, commit the data.
    152   return new BloomFilter(data.release(), byte_size, hash_keys);
    153 }
    154 
    155 bool BloomFilter::WriteFile(const FilePath& filter_name) const {
    156   net::FileStream filter;
    157 
    158   if (filter.Open(filter_name,
    159                   base::PLATFORM_FILE_WRITE |
    160                   base::PLATFORM_FILE_CREATE_ALWAYS) != net::OK)
    161     return false;
    162 
    163   // Write the version information.
    164   int version = kFileVersion;
    165   int bytes_written = filter.Write(reinterpret_cast<char*>(&version),
    166                                    sizeof(version), NULL);
    167   if (bytes_written != sizeof(version))
    168     return false;
    169 
    170   // Write the number of random hash keys.
    171   int num_keys = static_cast<int>(hash_keys_.size());
    172   bytes_written = filter.Write(reinterpret_cast<char*>(&num_keys),
    173                                sizeof(num_keys), NULL);
    174   if (bytes_written != sizeof(num_keys))
    175     return false;
    176 
    177   for (int i = 0; i < num_keys; ++i) {
    178     bytes_written = filter.Write(reinterpret_cast<const char*>(&hash_keys_[i]),
    179                                  sizeof(hash_keys_[i]), NULL);
    180     if (bytes_written != sizeof(hash_keys_[i]))
    181       return false;
    182   }
    183 
    184   // Write the filter data.
    185   bytes_written = filter.Write(data_.get(), byte_size_, NULL);
    186   if (bytes_written != byte_size_)
    187     return false;
    188 
    189   return true;
    190 }
    191