1 // Copyright (c) 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/quic/crypto/strike_register.h" 6 7 #include <limits> 8 9 #include "base/logging.h" 10 11 using std::make_pair; 12 using std::max; 13 using std::min; 14 using std::pair; 15 using std::set; 16 using std::vector; 17 18 namespace net { 19 20 namespace { 21 22 uint32 GetInitialHorizon(uint32 current_time_internal, 23 uint32 window_secs, 24 StrikeRegister::StartupType startup) { 25 if (startup == StrikeRegister::DENY_REQUESTS_AT_STARTUP) { 26 // The horizon is initially set |window_secs| into the future because, if 27 // we just crashed, then we may have accepted nonces in the span 28 // [current_time...current_time+window_secs] and so we conservatively 29 // reject the whole timespan unless |startup| tells us otherwise. 30 return current_time_internal + window_secs + 1; 31 } else { // startup == StrikeRegister::NO_STARTUP_PERIOD_NEEDED 32 // The orbit can be assumed to be globally unique. Use a horizon 33 // in the past. 34 return 0; 35 } 36 } 37 38 } // namespace 39 40 // static 41 const uint32 StrikeRegister::kExternalNodeSize = 24; 42 // static 43 const uint32 StrikeRegister::kNil = (1u << 31) | 1; 44 // static 45 const uint32 StrikeRegister::kExternalFlag = 1 << 23; 46 47 // InternalNode represents a non-leaf node in the critbit tree. See the comment 48 // in the .h file for details. 49 class StrikeRegister::InternalNode { 50 public: 51 void SetChild(unsigned direction, uint32 child) { 52 data_[direction] = (data_[direction] & 0xff) | (child << 8); 53 } 54 55 void SetCritByte(uint8 critbyte) { 56 data_[0] = (data_[0] & 0xffffff00) | critbyte; 57 } 58 59 void SetOtherBits(uint8 otherbits) { 60 data_[1] = (data_[1] & 0xffffff00) | otherbits; 61 } 62 63 void SetNextPtr(uint32 next) { data_[0] = next; } 64 65 uint32 next() const { return data_[0]; } 66 67 uint32 child(unsigned n) const { return data_[n] >> 8; } 68 69 uint8 critbyte() const { return data_[0]; } 70 71 uint8 otherbits() const { return data_[1]; } 72 73 // These bytes are organised thus: 74 // <24 bits> left child 75 // <8 bits> crit-byte 76 // <24 bits> right child 77 // <8 bits> other-bits 78 uint32 data_[2]; 79 }; 80 81 // kCreationTimeFromInternalEpoch contains the number of seconds between the 82 // start of the internal epoch and the creation time. This allows us 83 // to consider times that are before the creation time. 84 static const uint32 kCreationTimeFromInternalEpoch = 63115200; // 2 years. 85 86 void StrikeRegister::ValidateStrikeRegisterConfig(unsigned max_entries) { 87 // We only have 23 bits of index available. 88 CHECK_LT(max_entries, 1u << 23); 89 CHECK_GT(max_entries, 1u); // There must be at least two entries. 90 CHECK_EQ(sizeof(InternalNode), 8u); // in case of compiler changes. 91 } 92 93 StrikeRegister::StrikeRegister(unsigned max_entries, 94 uint32 current_time, 95 uint32 window_secs, 96 const uint8 orbit[8], 97 StartupType startup) 98 : max_entries_(max_entries), 99 window_secs_(window_secs), 100 internal_epoch_(current_time > kCreationTimeFromInternalEpoch 101 ? current_time - kCreationTimeFromInternalEpoch 102 : 0), 103 horizon_(GetInitialHorizon( 104 ExternalTimeToInternal(current_time), window_secs, startup)) { 105 memcpy(orbit_, orbit, sizeof(orbit_)); 106 107 ValidateStrikeRegisterConfig(max_entries); 108 internal_nodes_ = new InternalNode[max_entries]; 109 external_nodes_.reset(new uint8[kExternalNodeSize * max_entries]); 110 111 Reset(); 112 } 113 114 StrikeRegister::~StrikeRegister() { delete[] internal_nodes_; } 115 116 void StrikeRegister::Reset() { 117 // Thread a free list through all of the internal nodes. 118 internal_node_free_head_ = 0; 119 for (unsigned i = 0; i < max_entries_ - 1; i++) 120 internal_nodes_[i].SetNextPtr(i + 1); 121 internal_nodes_[max_entries_ - 1].SetNextPtr(kNil); 122 123 // Also thread a free list through the external nodes. 124 external_node_free_head_ = 0; 125 for (unsigned i = 0; i < max_entries_ - 1; i++) 126 external_node_next_ptr(i) = i + 1; 127 external_node_next_ptr(max_entries_ - 1) = kNil; 128 129 // This is the root of the tree. 130 internal_node_head_ = kNil; 131 } 132 133 InsertStatus StrikeRegister::Insert(const uint8 nonce[32], 134 uint32 current_time_external) { 135 // Make space for the insertion if the strike register is full. 136 while (external_node_free_head_ == kNil || 137 internal_node_free_head_ == kNil) { 138 DropOldestNode(); 139 } 140 141 const uint32 current_time = ExternalTimeToInternal(current_time_external); 142 143 // Check to see if the orbit is correct. 144 if (memcmp(nonce + sizeof(current_time), orbit_, sizeof(orbit_))) { 145 return NONCE_INVALID_ORBIT_FAILURE; 146 } 147 148 const uint32 nonce_time = ExternalTimeToInternal(TimeFromBytes(nonce)); 149 150 // Check that the timestamp is in the valid range. 151 pair<uint32, uint32> valid_range = 152 StrikeRegister::GetValidRange(current_time); 153 if (nonce_time < valid_range.first || nonce_time > valid_range.second) { 154 return NONCE_INVALID_TIME_FAILURE; 155 } 156 157 // We strip the orbit out of the nonce. 158 uint8 value[24]; 159 memcpy(value, nonce, sizeof(nonce_time)); 160 memcpy(value + sizeof(nonce_time), 161 nonce + sizeof(nonce_time) + sizeof(orbit_), 162 sizeof(value) - sizeof(nonce_time)); 163 164 // Find the best match to |value| in the crit-bit tree. The best match is 165 // simply the value which /could/ match |value|, if any does, so we still 166 // need a memcmp to check. 167 uint32 best_match_index = BestMatch(value); 168 if (best_match_index == kNil) { 169 // Empty tree. Just insert the new value at the root. 170 uint32 index = GetFreeExternalNode(); 171 memcpy(external_node(index), value, sizeof(value)); 172 internal_node_head_ = (index | kExternalFlag) << 8; 173 DCHECK_LE(horizon_, nonce_time); 174 return NONCE_OK; 175 } 176 177 const uint8* best_match = external_node(best_match_index); 178 if (memcmp(best_match, value, sizeof(value)) == 0) { 179 // We found the value in the tree. 180 return NONCE_NOT_UNIQUE_FAILURE; 181 } 182 183 // We are going to insert a new entry into the tree, so get the nodes now. 184 uint32 internal_node_index = GetFreeInternalNode(); 185 uint32 external_node_index = GetFreeExternalNode(); 186 187 // If we just evicted the best match, then we have to try and match again. 188 // We know that we didn't just empty the tree because we require that 189 // max_entries_ >= 2. Also, we know that it doesn't match because, if it 190 // did, it would have been returned previously. 191 if (external_node_index == best_match_index) { 192 best_match_index = BestMatch(value); 193 best_match = external_node(best_match_index); 194 } 195 196 // Now we need to find the first bit where we differ from |best_match|. 197 unsigned differing_byte; 198 uint8 new_other_bits; 199 for (differing_byte = 0; differing_byte < sizeof(value); differing_byte++) { 200 new_other_bits = value[differing_byte] ^ best_match[differing_byte]; 201 if (new_other_bits) { 202 break; 203 } 204 } 205 206 // Once we have the XOR the of first differing byte in new_other_bits we need 207 // to find the most significant differing bit. We could do this with a simple 208 // for loop, testing bits 7..0. Instead we fold the bits so that we end up 209 // with a byte where all the bits below the most significant one, are set. 210 new_other_bits |= new_other_bits >> 1; 211 new_other_bits |= new_other_bits >> 2; 212 new_other_bits |= new_other_bits >> 4; 213 // Now this bit trick results in all the bits set, except the original 214 // most-significant one. 215 new_other_bits = (new_other_bits & ~(new_other_bits >> 1)) ^ 255; 216 217 // Consider the effect of ORing against |new_other_bits|. If |value| did not 218 // have the critical bit set, the result is the same as |new_other_bits|. If 219 // it did, the result is all ones. 220 221 unsigned newdirection; 222 if ((new_other_bits | value[differing_byte]) == 0xff) { 223 newdirection = 1; 224 } else { 225 newdirection = 0; 226 } 227 228 memcpy(external_node(external_node_index), value, sizeof(value)); 229 InternalNode* inode = &internal_nodes_[internal_node_index]; 230 231 inode->SetChild(newdirection, external_node_index | kExternalFlag); 232 inode->SetCritByte(differing_byte); 233 inode->SetOtherBits(new_other_bits); 234 235 // |where_index| is a pointer to the uint32 which needs to be updated in 236 // order to insert the new internal node into the tree. The internal nodes 237 // store the child indexes in the top 24-bits of a 32-bit word and, to keep 238 // the code simple, we define that |internal_node_head_| is organised the 239 // same way. 240 DCHECK_EQ(internal_node_head_ & 0xff, 0u); 241 uint32* where_index = &internal_node_head_; 242 while (((*where_index >> 8) & kExternalFlag) == 0) { 243 InternalNode* node = &internal_nodes_[*where_index >> 8]; 244 if (node->critbyte() > differing_byte) { 245 break; 246 } 247 if (node->critbyte() == differing_byte && 248 node->otherbits() > new_other_bits) { 249 break; 250 } 251 if (node->critbyte() == differing_byte && 252 node->otherbits() == new_other_bits) { 253 CHECK(false); 254 } 255 256 uint8 c = value[node->critbyte()]; 257 const int direction = 258 (1 + static_cast<unsigned>(node->otherbits() | c)) >> 8; 259 where_index = &node->data_[direction]; 260 } 261 262 inode->SetChild(newdirection ^ 1, *where_index >> 8); 263 *where_index = (*where_index & 0xff) | (internal_node_index << 8); 264 265 DCHECK_LE(horizon_, nonce_time); 266 return NONCE_OK; 267 } 268 269 const uint8* StrikeRegister::orbit() const { 270 return orbit_; 271 } 272 273 uint32 StrikeRegister::GetCurrentValidWindowSecs( 274 uint32 current_time_external) const { 275 uint32 current_time = ExternalTimeToInternal(current_time_external); 276 pair<uint32, uint32> valid_range = StrikeRegister::GetValidRange( 277 current_time); 278 if (valid_range.second >= valid_range.first) { 279 return valid_range.second - current_time + 1; 280 } else { 281 return 0; 282 } 283 } 284 285 void StrikeRegister::Validate() { 286 set<uint32> free_internal_nodes; 287 for (uint32 i = internal_node_free_head_; i != kNil; 288 i = internal_nodes_[i].next()) { 289 CHECK_LT(i, max_entries_); 290 CHECK_EQ(free_internal_nodes.count(i), 0u); 291 free_internal_nodes.insert(i); 292 } 293 294 set<uint32> free_external_nodes; 295 for (uint32 i = external_node_free_head_; i != kNil; 296 i = external_node_next_ptr(i)) { 297 CHECK_LT(i, max_entries_); 298 CHECK_EQ(free_external_nodes.count(i), 0u); 299 free_external_nodes.insert(i); 300 } 301 302 set<uint32> used_external_nodes; 303 set<uint32> used_internal_nodes; 304 305 if (internal_node_head_ != kNil && 306 ((internal_node_head_ >> 8) & kExternalFlag) == 0) { 307 vector<pair<unsigned, bool> > bits; 308 ValidateTree(internal_node_head_ >> 8, -1, bits, free_internal_nodes, 309 free_external_nodes, &used_internal_nodes, 310 &used_external_nodes); 311 } 312 } 313 314 // static 315 uint32 StrikeRegister::TimeFromBytes(const uint8 d[4]) { 316 return static_cast<uint32>(d[0]) << 24 | 317 static_cast<uint32>(d[1]) << 16 | 318 static_cast<uint32>(d[2]) << 8 | 319 static_cast<uint32>(d[3]); 320 } 321 322 pair<uint32, uint32> StrikeRegister::GetValidRange( 323 uint32 current_time_internal) const { 324 if (current_time_internal < horizon_) { 325 // Empty valid range. 326 return make_pair(std::numeric_limits<uint32>::max(), 0); 327 } 328 329 uint32 lower_bound; 330 if (current_time_internal >= window_secs_) { 331 lower_bound = max(horizon_, current_time_internal - window_secs_); 332 } else { 333 lower_bound = horizon_; 334 } 335 336 // Also limit the upper range based on horizon_. This makes the 337 // strike register reject inserts that are far in the future and 338 // would consume strike register resources for a long time. This 339 // allows the strike server to degrade optimally in cases where the 340 // insert rate exceeds |max_entries_ / (2 * window_secs_)| entries 341 // per second. 342 uint32 upper_bound = 343 current_time_internal + min(current_time_internal - horizon_, 344 window_secs_); 345 346 return make_pair(lower_bound, upper_bound); 347 } 348 349 uint32 StrikeRegister::ExternalTimeToInternal(uint32 external_time) const { 350 return external_time - internal_epoch_; 351 } 352 353 uint32 StrikeRegister::BestMatch(const uint8 v[24]) const { 354 if (internal_node_head_ == kNil) { 355 return kNil; 356 } 357 358 uint32 next = internal_node_head_ >> 8; 359 while ((next & kExternalFlag) == 0) { 360 InternalNode* node = &internal_nodes_[next]; 361 uint8 b = v[node->critbyte()]; 362 unsigned direction = 363 (1 + static_cast<unsigned>(node->otherbits() | b)) >> 8; 364 next = node->child(direction); 365 } 366 367 return next & ~kExternalFlag; 368 } 369 370 uint32& StrikeRegister::external_node_next_ptr(unsigned i) { 371 return *reinterpret_cast<uint32*>(&external_nodes_[i * kExternalNodeSize]); 372 } 373 374 uint8* StrikeRegister::external_node(unsigned i) { 375 return &external_nodes_[i * kExternalNodeSize]; 376 } 377 378 uint32 StrikeRegister::GetFreeExternalNode() { 379 uint32 index = external_node_free_head_; 380 DCHECK(index != kNil); 381 external_node_free_head_ = external_node_next_ptr(index); 382 return index; 383 } 384 385 uint32 StrikeRegister::GetFreeInternalNode() { 386 uint32 index = internal_node_free_head_; 387 DCHECK(index != kNil); 388 internal_node_free_head_ = internal_nodes_[index].next(); 389 return index; 390 } 391 392 void StrikeRegister::DropOldestNode() { 393 // DropOldestNode should never be called on an empty tree. 394 DCHECK(internal_node_head_ != kNil); 395 396 // An internal node in a crit-bit tree always has exactly two children. 397 // This means that, if we are removing an external node (which is one of 398 // those children), then we also need to remove an internal node. In order 399 // to do that we keep pointers to the parent (wherep) and grandparent 400 // (whereq) when walking down the tree. 401 402 uint32 p = internal_node_head_ >> 8, *wherep = &internal_node_head_, 403 *whereq = NULL; 404 while ((p & kExternalFlag) == 0) { 405 whereq = wherep; 406 InternalNode* inode = &internal_nodes_[p]; 407 // We always go left, towards the smallest element, exploiting the fact 408 // that the timestamp is big-endian and at the start of the value. 409 wherep = &inode->data_[0]; 410 p = (*wherep) >> 8; 411 } 412 413 const uint32 ext_index = p & ~kExternalFlag; 414 const uint8* ext_node = external_node(ext_index); 415 uint32 new_horizon = ExternalTimeToInternal(TimeFromBytes(ext_node)) + 1; 416 DCHECK_LE(horizon_, new_horizon); 417 horizon_ = new_horizon; 418 419 if (!whereq) { 420 // We are removing the last element in a tree. 421 internal_node_head_ = kNil; 422 FreeExternalNode(ext_index); 423 return; 424 } 425 426 // |wherep| points to the left child pointer in the parent so we can add 427 // one and dereference to get the right child. 428 const uint32 other_child = wherep[1]; 429 FreeInternalNode((*whereq) >> 8); 430 *whereq = (*whereq & 0xff) | (other_child & 0xffffff00); 431 FreeExternalNode(ext_index); 432 } 433 434 void StrikeRegister::FreeExternalNode(uint32 index) { 435 external_node_next_ptr(index) = external_node_free_head_; 436 external_node_free_head_ = index; 437 } 438 439 void StrikeRegister::FreeInternalNode(uint32 index) { 440 internal_nodes_[index].SetNextPtr(internal_node_free_head_); 441 internal_node_free_head_ = index; 442 } 443 444 void StrikeRegister::ValidateTree( 445 uint32 internal_node, 446 int last_bit, 447 const vector<pair<unsigned, bool> >& bits, 448 const set<uint32>& free_internal_nodes, 449 const set<uint32>& free_external_nodes, 450 set<uint32>* used_internal_nodes, 451 set<uint32>* used_external_nodes) { 452 CHECK_LT(internal_node, max_entries_); 453 const InternalNode* i = &internal_nodes_[internal_node]; 454 unsigned bit = 0; 455 switch (i->otherbits()) { 456 case 0xff & ~(1 << 7): 457 bit = 0; 458 break; 459 case 0xff & ~(1 << 6): 460 bit = 1; 461 break; 462 case 0xff & ~(1 << 5): 463 bit = 2; 464 break; 465 case 0xff & ~(1 << 4): 466 bit = 3; 467 break; 468 case 0xff & ~(1 << 3): 469 bit = 4; 470 break; 471 case 0xff & ~(1 << 2): 472 bit = 5; 473 break; 474 case 0xff & ~(1 << 1): 475 bit = 6; 476 break; 477 case 0xff & ~1: 478 bit = 7; 479 break; 480 default: 481 CHECK(false); 482 } 483 484 bit += 8 * i->critbyte(); 485 if (last_bit > -1) { 486 CHECK_GT(bit, static_cast<unsigned>(last_bit)); 487 } 488 489 CHECK_EQ(free_internal_nodes.count(internal_node), 0u); 490 491 for (unsigned child = 0; child < 2; child++) { 492 if (i->child(child) & kExternalFlag) { 493 uint32 ext = i->child(child) & ~kExternalFlag; 494 CHECK_EQ(free_external_nodes.count(ext), 0u); 495 CHECK_EQ(used_external_nodes->count(ext), 0u); 496 used_external_nodes->insert(ext); 497 const uint8* bytes = external_node(ext); 498 for (vector<pair<unsigned, bool> >::const_iterator i = bits.begin(); 499 i != bits.end(); i++) { 500 unsigned byte = i->first / 8; 501 DCHECK_LE(byte, 0xffu); 502 unsigned bit = i->first % 8; 503 static const uint8 kMasks[8] = 504 {0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01}; 505 CHECK_EQ((bytes[byte] & kMasks[bit]) != 0, i->second); 506 } 507 } else { 508 uint32 inter = i->child(child); 509 vector<pair<unsigned, bool> > new_bits(bits); 510 new_bits.push_back(pair<unsigned, bool>(bit, child != 0)); 511 CHECK_EQ(free_internal_nodes.count(inter), 0u); 512 CHECK_EQ(used_internal_nodes->count(inter), 0u); 513 used_internal_nodes->insert(inter); 514 ValidateTree(inter, bit, bits, free_internal_nodes, free_external_nodes, 515 used_internal_nodes, used_external_nodes); 516 } 517 } 518 } 519 520 } // namespace net 521