1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "RandomVariable.h" 18 19 #include <algorithm> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <unordered_set> 24 #include <vector> 25 26 #include "RandomGraphGeneratorUtils.h" 27 28 namespace android { 29 namespace nn { 30 namespace fuzzing_test { 31 32 unsigned int RandomVariableBase::globalIndex = 0; 33 int RandomVariable::defaultValue = 10; 34 35 RandomVariableBase::RandomVariableBase(int value) 36 : index(globalIndex++), 37 type(RandomVariableType::CONST), 38 range(value), 39 value(value), 40 timestamp(RandomVariableNetwork::get()->getGlobalTime()) {} 41 42 RandomVariableBase::RandomVariableBase(int lower, int upper) 43 : index(globalIndex++), 44 type(RandomVariableType::FREE), 45 range(lower, upper), 46 timestamp(RandomVariableNetwork::get()->getGlobalTime()) {} 47 48 RandomVariableBase::RandomVariableBase(const std::vector<int>& choices) 49 : index(globalIndex++), 50 type(RandomVariableType::FREE), 51 range(choices), 52 timestamp(RandomVariableNetwork::get()->getGlobalTime()) {} 53 54 RandomVariableBase::RandomVariableBase(const RandomVariableNode& lhs, const RandomVariableNode& rhs, 55 const std::shared_ptr<const IRandomVariableOp>& op) 56 : index(globalIndex++), 57 type(RandomVariableType::OP), 58 range(op->getInitRange(lhs->range, rhs == nullptr ? RandomVariableRange(0) : rhs->range)), 59 op(op), 60 parent1(lhs), 61 parent2(rhs), 62 timestamp(RandomVariableNetwork::get()->getGlobalTime()) {} 63 64 void RandomVariableRange::setRange(int lower, int upper) { 65 // kInvalidValue indicates unlimited bound. 66 auto head = lower == kInvalidValue ? mChoices.begin() 67 : std::lower_bound(mChoices.begin(), mChoices.end(), lower); 68 auto tail = upper == kInvalidValue ? mChoices.end() 69 : std::upper_bound(mChoices.begin(), mChoices.end(), upper); 70 NN_FUZZER_CHECK(head <= tail) << "Invalid range!"; 71 if (head != mChoices.begin() || tail != mChoices.end()) { 72 mChoices = std::vector<int>(head, tail); 73 } 74 } 75 76 int RandomVariableRange::toConst() { 77 if (mChoices.size() > 1) mChoices = {getRandomChoice(mChoices)}; 78 return mChoices[0]; 79 } 80 81 RandomVariableRange operator&(const RandomVariableRange& lhs, const RandomVariableRange& rhs) { 82 std::vector<int> result(lhs.size() + rhs.size()); 83 auto it = std::set_intersection(lhs.mChoices.begin(), lhs.mChoices.end(), rhs.mChoices.begin(), 84 rhs.mChoices.end(), result.begin()); 85 result.resize(it - result.begin()); 86 return RandomVariableRange(std::move(result)); 87 } 88 89 void RandomVariableBase::freeze() { 90 if (type == RandomVariableType::CONST) return; 91 value = range.toConst(); 92 type = RandomVariableType::CONST; 93 } 94 95 int RandomVariableBase::getValue() const { 96 switch (type) { 97 case RandomVariableType::CONST: 98 return value; 99 case RandomVariableType::OP: 100 return op->eval(parent1->getValue(), parent2 == nullptr ? 0 : parent2->getValue()); 101 default: 102 NN_FUZZER_CHECK(false) << "Invalid type when getting value of var" << index; 103 return 0; 104 } 105 } 106 107 void RandomVariableBase::updateTimestamp() { 108 timestamp = RandomVariableNetwork::get()->getGlobalTime(); 109 NN_FUZZER_LOG << "Update timestamp of var" << index << " to " << timestamp; 110 } 111 112 RandomVariable::RandomVariable(int value) : mVar(new RandomVariableBase(value)) { 113 NN_FUZZER_LOG << "New RandomVariable " << toString(mVar); 114 RandomVariableNetwork::get()->add(mVar); 115 } 116 RandomVariable::RandomVariable(int lower, int upper) : mVar(new RandomVariableBase(lower, upper)) { 117 NN_FUZZER_LOG << "New RandomVariable " << toString(mVar); 118 RandomVariableNetwork::get()->add(mVar); 119 } 120 RandomVariable::RandomVariable(const std::vector<int>& choices) 121 : mVar(new RandomVariableBase(choices)) { 122 NN_FUZZER_LOG << "New RandomVariable " << toString(mVar); 123 RandomVariableNetwork::get()->add(mVar); 124 } 125 RandomVariable::RandomVariable(RandomVariableType type) 126 : mVar(new RandomVariableBase(1, defaultValue)) { 127 NN_FUZZER_CHECK(type == RandomVariableType::FREE); 128 NN_FUZZER_LOG << "New RandomVariable " << toString(mVar); 129 RandomVariableNetwork::get()->add(mVar); 130 } 131 RandomVariable::RandomVariable(const RandomVariable& lhs, const RandomVariable& rhs, 132 const std::shared_ptr<const IRandomVariableOp>& op) 133 : mVar(new RandomVariableBase(lhs.get(), rhs.get(), op)) { 134 // Make a copy if the parent is CONST. This will resolve the fake dependency problem. 135 if (mVar->parent1->type == RandomVariableType::CONST) { 136 mVar->parent1 = RandomVariable(mVar->parent1->value).get(); 137 } 138 if (mVar->parent2 != nullptr && mVar->parent2->type == RandomVariableType::CONST) { 139 mVar->parent2 = RandomVariable(mVar->parent2->value).get(); 140 } 141 mVar->parent1->children.push_back(mVar); 142 if (mVar->parent2 != nullptr) mVar->parent2->children.push_back(mVar); 143 RandomVariableNetwork::get()->add(mVar); 144 NN_FUZZER_LOG << "New RandomVariable " << toString(mVar); 145 } 146 147 void RandomVariable::setRange(int lower, int upper) { 148 NN_FUZZER_CHECK(mVar != nullptr) << "setRange() on nullptr"; 149 NN_FUZZER_LOG << "Set range [" << lower << ", " << upper << "] on var" << mVar->index; 150 size_t oldSize = mVar->range.size(); 151 mVar->range.setRange(lower, upper); 152 // Only update the timestamp if the range is *indeed* narrowed down. 153 if (mVar->range.size() != oldSize) mVar->updateTimestamp(); 154 } 155 156 RandomVariableRange IRandomVariableOp::getInitRange(const RandomVariableRange& lhs, 157 const RandomVariableRange& rhs) const { 158 std::set<int> st; 159 for (auto i : lhs.getChoices()) { 160 for (auto j : rhs.getChoices()) { 161 int res = this->eval(i, j); 162 if (res > kMaxValue || res < -kMaxValue) continue; 163 st.insert(res); 164 } 165 } 166 return RandomVariableRange(st); 167 } 168 169 // Check if the range contains exactly all values in [min, max]. 170 static inline bool isContinuous(const std::set<int>* range) { 171 return (*(range->rbegin()) - *(range->begin()) + 1) == static_cast<int>(range->size()); 172 } 173 174 // Fill the set with a range of values specified by [lower, upper]. 175 static inline void fillRange(std::set<int>* range, int lower, int upper) { 176 for (int i = lower; i <= upper; i++) range->insert(i); 177 } 178 179 // The slowest algorithm: iterate through every combinations of parents and save the valid pairs. 180 void IRandomVariableOp::eval(const std::set<int>* parent1In, const std::set<int>* parent2In, 181 const std::set<int>* childIn, std::set<int>* parent1Out, 182 std::set<int>* parent2Out, std::set<int>* childOut) const { 183 // Avoid the binary search if the child is a closed range. 184 bool isChildInContinuous = isContinuous(childIn); 185 std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()}; 186 for (auto i : *parent1In) { 187 bool valid = false; 188 for (auto j : *parent2In) { 189 int res = this->eval(i, j); 190 // Avoid the binary search if obviously out of range. 191 if (res > child.second || res < child.first) continue; 192 if (isChildInContinuous || childIn->find(res) != childIn->end()) { 193 parent2Out->insert(j); 194 childOut->insert(res); 195 valid = true; 196 } 197 } 198 if (valid) parent1Out->insert(i); 199 } 200 } 201 202 // A helper template to make a class into a Singleton. 203 template <class T> 204 class Singleton : public T { 205 public: 206 static const std::shared_ptr<const T>& get() { 207 static std::shared_ptr<const T> instance(new T); 208 return instance; 209 } 210 }; 211 212 // A set of operations that only compute on a single input value. 213 class IUnaryOp : public IRandomVariableOp { 214 public: 215 using IRandomVariableOp::eval; 216 virtual int eval(int val) const = 0; 217 virtual int eval(int lhs, int) const override { return eval(lhs); } 218 // The slowest algorithm: iterate through every value of the parent and save the valid one. 219 virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In, 220 const std::set<int>* childIn, std::set<int>* parent1Out, 221 std::set<int>* parent2Out, std::set<int>* childOut) const override { 222 NN_FUZZER_CHECK(parent2In == nullptr); 223 NN_FUZZER_CHECK(parent2Out == nullptr); 224 bool isChildInContinuous = isContinuous(childIn); 225 std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()}; 226 for (auto i : *parent1In) { 227 int res = this->eval(i); 228 if (res > child.second || res < child.first) continue; 229 if (isChildInContinuous || childIn->find(res) != childIn->end()) { 230 parent1Out->insert(i); 231 childOut->insert(res); 232 } 233 } 234 } 235 }; 236 237 // A set of operations that only check conditional constraints. 238 class IConstraintOp : public IRandomVariableOp { 239 public: 240 using IRandomVariableOp::eval; 241 virtual bool check(int lhs, int rhs) const = 0; 242 virtual int eval(int lhs, int rhs) const override { 243 return check(lhs, rhs) ? 0 : kInvalidValue; 244 } 245 // The range for a constraint op is always {0}. 246 virtual RandomVariableRange getInitRange(const RandomVariableRange&, 247 const RandomVariableRange&) const override { 248 return RandomVariableRange(0); 249 } 250 // The slowest algorithm: 251 // iterate through every combinations of parents and save the valid pairs. 252 virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In, 253 const std::set<int>*, std::set<int>* parent1Out, std::set<int>* parent2Out, 254 std::set<int>* childOut) const override { 255 for (auto i : *parent1In) { 256 bool valid = false; 257 for (auto j : *parent2In) { 258 if (this->check(i, j)) { 259 parent2Out->insert(j); 260 valid = true; 261 } 262 } 263 if (valid) parent1Out->insert(i); 264 } 265 if (!parent1Out->empty()) childOut->insert(0); 266 } 267 }; 268 269 class Addition : public IRandomVariableOp { 270 public: 271 virtual int eval(int lhs, int rhs) const override { return lhs + rhs; } 272 virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs, 273 const RandomVariableRange& rhs) const override { 274 return RandomVariableRange(lhs.min() + rhs.min(), lhs.max() + rhs.max()); 275 } 276 virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In, 277 const std::set<int>* childIn, std::set<int>* parent1Out, 278 std::set<int>* parent2Out, std::set<int>* childOut) const override { 279 if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) { 280 IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out, 281 childOut); 282 } else { 283 // For parents and child with close range, the out range can be computed directly 284 // without iterations. 285 std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()}; 286 std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()}; 287 std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()}; 288 289 // From ranges for parent, evalute range for child. 290 // [a, b] + [c, d] -> [a + c, b + d] 291 fillRange(childOut, std::max(child.first, parent1.first + parent2.first), 292 std::min(child.second, parent1.second + parent2.second)); 293 294 // From ranges for child and one parent, evalute range for another parent. 295 // [a, b] - [c, d] -> [a - d, b - c] 296 fillRange(parent1Out, std::max(parent1.first, child.first - parent2.second), 297 std::min(parent1.second, child.second - parent2.first)); 298 fillRange(parent2Out, std::max(parent2.first, child.first - parent1.second), 299 std::min(parent2.second, child.second - parent1.first)); 300 } 301 } 302 virtual const char* getName() const override { return "ADD"; } 303 }; 304 305 class Subtraction : public IRandomVariableOp { 306 public: 307 virtual int eval(int lhs, int rhs) const override { return lhs - rhs; } 308 virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs, 309 const RandomVariableRange& rhs) const override { 310 return RandomVariableRange(lhs.min() - rhs.max(), lhs.max() - rhs.min()); 311 } 312 virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In, 313 const std::set<int>* childIn, std::set<int>* parent1Out, 314 std::set<int>* parent2Out, std::set<int>* childOut) const override { 315 if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) { 316 IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out, 317 childOut); 318 } else { 319 // Similar algorithm as Addition. 320 std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()}; 321 std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()}; 322 std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()}; 323 fillRange(childOut, std::max(child.first, parent1.first - parent2.second), 324 std::min(child.second, parent1.second - parent2.first)); 325 fillRange(parent1Out, std::max(parent1.first, child.first + parent2.first), 326 std::min(parent1.second, child.second + parent2.second)); 327 fillRange(parent2Out, std::max(parent2.first, parent1.first - child.second), 328 std::min(parent2.second, parent1.second - child.first)); 329 } 330 } 331 virtual const char* getName() const override { return "SUB"; } 332 }; 333 334 class Multiplication : public IRandomVariableOp { 335 public: 336 virtual int eval(int lhs, int rhs) const override { return lhs * rhs; } 337 virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs, 338 const RandomVariableRange& rhs) const override { 339 if (lhs.min() < 0 || rhs.min() < 0) { 340 return IRandomVariableOp::getInitRange(lhs, rhs); 341 } else { 342 int lower = std::min(lhs.min() * rhs.min(), kMaxValue); 343 int upper = std::min(lhs.max() * rhs.max(), kMaxValue); 344 return RandomVariableRange(lower, upper); 345 } 346 } 347 virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In, 348 const std::set<int>* childIn, std::set<int>* parent1Out, 349 std::set<int>* parent2Out, std::set<int>* childOut) const override { 350 if (*parent1In->begin() < 0 || *parent2In->begin() < 0 || *childIn->begin() < 0) { 351 IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out, 352 childOut); 353 } else { 354 bool isChildInContinuous = isContinuous(childIn); 355 std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()}; 356 for (auto i : *parent1In) { 357 bool valid = false; 358 for (auto j : *parent2In) { 359 int res = this->eval(i, j); 360 // Since MUL increases monotonically with one value, break the loop if the 361 // result is larger than the limit. 362 if (res > child.second) break; 363 if (res < child.first) continue; 364 if (isChildInContinuous || childIn->find(res) != childIn->end()) { 365 valid = true; 366 parent2Out->insert(j); 367 childOut->insert(res); 368 } 369 } 370 if (valid) parent1Out->insert(i); 371 } 372 } 373 } 374 virtual const char* getName() const override { return "MUL"; } 375 }; 376 377 class Division : public IRandomVariableOp { 378 public: 379 virtual int eval(int lhs, int rhs) const override { 380 return rhs == 0 ? kInvalidValue : lhs / rhs; 381 } 382 virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs, 383 const RandomVariableRange& rhs) const override { 384 if (lhs.min() < 0 || rhs.min() <= 0) { 385 return IRandomVariableOp::getInitRange(lhs, rhs); 386 } else { 387 return RandomVariableRange(lhs.min() / rhs.max(), lhs.max() / rhs.min()); 388 } 389 } 390 virtual const char* getName() const override { return "DIV"; } 391 }; 392 393 class ExactDivision : public Division { 394 public: 395 virtual int eval(int lhs, int rhs) const override { 396 return (rhs == 0 || lhs % rhs != 0) ? kInvalidValue : lhs / rhs; 397 } 398 virtual const char* getName() const override { return "EXACT_DIV"; } 399 }; 400 401 class Modulo : public IRandomVariableOp { 402 public: 403 virtual int eval(int lhs, int rhs) const override { 404 return rhs == 0 ? kInvalidValue : lhs % rhs; 405 } 406 virtual RandomVariableRange getInitRange(const RandomVariableRange&, 407 const RandomVariableRange& rhs) const override { 408 return RandomVariableRange(0, rhs.max()); 409 } 410 virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In, 411 const std::set<int>* childIn, std::set<int>* parent1Out, 412 std::set<int>* parent2Out, std::set<int>* childOut) const override { 413 if (*childIn->begin() != 0 || childIn->size() != 1u) { 414 IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out, 415 childOut); 416 } else { 417 // For the special case that child is a const 0, it would be faster if the range for 418 // parents are evaluated separately. 419 420 // Evalute parent1 directly. 421 for (auto i : *parent1In) { 422 for (auto j : *parent2In) { 423 if (i % j == 0) { 424 parent1Out->insert(i); 425 break; 426 } 427 } 428 } 429 // Evalute parent2, see if a multiple of parent2 value can be found in parent1. 430 int parent1Max = *parent1In->rbegin(); 431 for (auto i : *parent2In) { 432 int jMax = parent1Max / i; 433 for (int j = 1; j <= jMax; j++) { 434 if (parent1In->find(i * j) != parent1In->end()) { 435 parent2Out->insert(i); 436 break; 437 } 438 } 439 } 440 if (!parent1Out->empty()) childOut->insert(0); 441 } 442 } 443 virtual const char* getName() const override { return "MOD"; } 444 }; 445 446 class Maximum : public IRandomVariableOp { 447 public: 448 virtual int eval(int lhs, int rhs) const override { return std::max(lhs, rhs); } 449 virtual const char* getName() const override { return "MAX"; } 450 }; 451 452 class Minimum : public IRandomVariableOp { 453 public: 454 virtual int eval(int lhs, int rhs) const override { return std::min(lhs, rhs); } 455 virtual const char* getName() const override { return "MIN"; } 456 }; 457 458 class Square : public IUnaryOp { 459 public: 460 virtual int eval(int val) const override { return val * val; } 461 virtual const char* getName() const override { return "SQUARE"; } 462 }; 463 464 class UnaryEqual : public IUnaryOp { 465 public: 466 virtual int eval(int val) const override { return val; } 467 virtual const char* getName() const override { return "UNARY_EQUAL"; } 468 }; 469 470 class Equal : public IConstraintOp { 471 public: 472 virtual bool check(int lhs, int rhs) const override { return lhs == rhs; } 473 virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In, 474 const std::set<int>* childIn, std::set<int>* parent1Out, 475 std::set<int>* parent2Out, std::set<int>* childOut) const override { 476 NN_FUZZER_CHECK(childIn->size() == 1u && *childIn->begin() == 0); 477 // The intersection of two sets can be found in O(n). 478 std::set_intersection(parent1In->begin(), parent1In->end(), parent2In->begin(), 479 parent2In->end(), std::inserter(*parent1Out, parent1Out->begin())); 480 *parent2Out = *parent1Out; 481 childOut->insert(0); 482 } 483 virtual const char* getName() const override { return "EQUAL"; } 484 }; 485 486 class GreaterThan : public IConstraintOp { 487 public: 488 virtual bool check(int lhs, int rhs) const override { return lhs > rhs; } 489 virtual const char* getName() const override { return "GREATER_THAN"; } 490 }; 491 492 class GreaterEqual : public IConstraintOp { 493 public: 494 virtual bool check(int lhs, int rhs) const override { return lhs >= rhs; } 495 virtual const char* getName() const override { return "GREATER_EQUAL"; } 496 }; 497 498 class FloatMultiplication : public IUnaryOp { 499 public: 500 FloatMultiplication(float multiplicand) : mMultiplicand(multiplicand) {} 501 virtual int eval(int val) const override { 502 return static_cast<int>(std::floor(static_cast<float>(val) * mMultiplicand)); 503 } 504 virtual const char* getName() const override { return "MUL_FLOAT"; } 505 506 private: 507 float mMultiplicand; 508 }; 509 510 // Arithmetic operators and methods on RandomVariables will create OP RandomVariableNodes. 511 // Since there must be at most one edge between two RandomVariableNodes, we have to do something 512 // special when both sides are refering to the same node. 513 514 RandomVariable operator+(const RandomVariable& lhs, const RandomVariable& rhs) { 515 return lhs.get() == rhs.get() ? RandomVariable(lhs, 2, Singleton<Multiplication>::get()) 516 : RandomVariable(lhs, rhs, Singleton<Addition>::get()); 517 } 518 RandomVariable operator-(const RandomVariable& lhs, const RandomVariable& rhs) { 519 return lhs.get() == rhs.get() ? RandomVariable(0) 520 : RandomVariable(lhs, rhs, Singleton<Subtraction>::get()); 521 } 522 RandomVariable operator*(const RandomVariable& lhs, const RandomVariable& rhs) { 523 return lhs.get() == rhs.get() ? RandomVariable(lhs, RandomVariable(), Singleton<Square>::get()) 524 : RandomVariable(lhs, rhs, Singleton<Multiplication>::get()); 525 } 526 RandomVariable operator*(const RandomVariable& lhs, const float& rhs) { 527 return RandomVariable(lhs, RandomVariable(), std::make_shared<FloatMultiplication>(rhs)); 528 } 529 RandomVariable operator/(const RandomVariable& lhs, const RandomVariable& rhs) { 530 return lhs.get() == rhs.get() ? RandomVariable(1) 531 : RandomVariable(lhs, rhs, Singleton<Division>::get()); 532 } 533 RandomVariable operator%(const RandomVariable& lhs, const RandomVariable& rhs) { 534 return lhs.get() == rhs.get() ? RandomVariable(0) 535 : RandomVariable(lhs, rhs, Singleton<Modulo>::get()); 536 } 537 RandomVariable max(const RandomVariable& lhs, const RandomVariable& rhs) { 538 return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Maximum>::get()); 539 } 540 RandomVariable min(const RandomVariable& lhs, const RandomVariable& rhs) { 541 return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Minimum>::get()); 542 } 543 544 RandomVariable RandomVariable::exactDiv(const RandomVariable& other) { 545 return mVar == other.get() ? RandomVariable(1) 546 : RandomVariable(*this, other, Singleton<ExactDivision>::get()); 547 } 548 549 RandomVariable RandomVariable::setEqual(const RandomVariable& other) const { 550 RandomVariableNode node1 = mVar, node2 = other.get(); 551 NN_FUZZER_LOG << "Set equality of var" << node1->index << " and var" << node2->index; 552 553 // Do not setEqual on the same pair twice. 554 if (node1 == node2 || (node1->op == Singleton<UnaryEqual>::get() && node1->parent1 == node2) || 555 (node2->op == Singleton<UnaryEqual>::get() && node2->parent1 == node1)) { 556 NN_FUZZER_LOG << "Already equal. Return."; 557 return RandomVariable(); 558 } 559 560 // If possible, always try UnaryEqual first to reduce the search space. 561 // UnaryEqual can be used if node B is FREE and is evaluated later than node A. 562 // TODO: Reduce code duplication. 563 if (RandomVariableNetwork::get()->isSubordinate(node1, node2)) { 564 NN_FUZZER_LOG << " Make var" << node2->index << " a child of var" << node1->index; 565 node2->type = RandomVariableType::OP; 566 node2->parent1 = node1; 567 node2->op = Singleton<UnaryEqual>::get(); 568 node1->children.push_back(node2); 569 RandomVariableNetwork::get()->join(node1, node2); 570 node1->updateTimestamp(); 571 return other; 572 } 573 if (RandomVariableNetwork::get()->isSubordinate(node2, node1)) { 574 NN_FUZZER_LOG << " Make var" << node1->index << " a child of var" << node2->index; 575 node1->type = RandomVariableType::OP; 576 node1->parent1 = node2; 577 node1->op = Singleton<UnaryEqual>::get(); 578 node2->children.push_back(node1); 579 RandomVariableNetwork::get()->join(node2, node1); 580 node1->updateTimestamp(); 581 return *this; 582 } 583 return RandomVariable(*this, other, Singleton<Equal>::get()); 584 } 585 586 RandomVariable RandomVariable::setGreaterThan(const RandomVariable& other) const { 587 NN_FUZZER_CHECK(mVar != other.get()); 588 return RandomVariable(*this, other, Singleton<GreaterThan>::get()); 589 } 590 RandomVariable RandomVariable::setGreaterEqual(const RandomVariable& other) const { 591 return mVar == other.get() ? *this 592 : RandomVariable(*this, other, Singleton<GreaterEqual>::get()); 593 } 594 595 void DisjointNetwork::add(const RandomVariableNode& var) { 596 // Find the subnet index of the parents and decide the index for var. 597 int ind1 = var->parent1 == nullptr ? -1 : mIndexMap[var->parent1]; 598 int ind2 = var->parent2 == nullptr ? -1 : mIndexMap[var->parent2]; 599 int ind = join(ind1, ind2); 600 // If no parent, put it into a new subnet component. 601 if (ind == -1) ind = mNextIndex++; 602 NN_FUZZER_LOG << "Add RandomVariable var" << var->index << " to network #" << ind; 603 mIndexMap[var] = ind; 604 mEvalOrderMap[ind].push_back(var); 605 } 606 607 int DisjointNetwork::join(int ind1, int ind2) { 608 if (ind1 == -1) return ind2; 609 if (ind2 == -1) return ind1; 610 if (ind1 == ind2) return ind1; 611 NN_FUZZER_LOG << "Join network #" << ind1 << " and #" << ind2; 612 auto &order1 = mEvalOrderMap[ind1], &order2 = mEvalOrderMap[ind2]; 613 // Append every node in ind2 to the end of ind1 614 for (const auto& var : order2) { 615 order1.push_back(var); 616 mIndexMap[var] = ind1; 617 } 618 // Remove ind2 from mEvalOrderMap. 619 mEvalOrderMap.erase(mEvalOrderMap.find(ind2)); 620 return ind1; 621 } 622 623 RandomVariableNetwork* RandomVariableNetwork::get() { 624 static RandomVariableNetwork instance; 625 return &instance; 626 } 627 628 void RandomVariableNetwork::initialize(int defaultValue) { 629 RandomVariableBase::globalIndex = 0; 630 RandomVariable::defaultValue = defaultValue; 631 mIndexMap.clear(); 632 mEvalOrderMap.clear(); 633 mDimProd.clear(); 634 mNextIndex = 0; 635 mGlobalTime = 0; 636 mTimestamp = -1; 637 } 638 639 bool RandomVariableNetwork::isSubordinate(const RandomVariableNode& node1, 640 const RandomVariableNode& node2) { 641 if (node2->type != RandomVariableType::FREE) return false; 642 int ind1 = mIndexMap[node1]; 643 // node2 is of a different subnet. 644 if (ind1 != mIndexMap[node2]) return true; 645 for (const auto& node : mEvalOrderMap[ind1]) { 646 if (node == node2) return false; 647 // node2 is of the same subnet but evaluated later than node1. 648 if (node == node1) return true; 649 } 650 NN_FUZZER_CHECK(false) << "Code executed in non-reachable region."; 651 return false; 652 } 653 654 struct EvalInfo { 655 // The RandomVariableNode that this EvalInfo is associated with. 656 // var->value is the current value during evaluation. 657 RandomVariableNode var; 658 659 // The RandomVariable value is staged when a valid combination is found. 660 std::set<int> staging; 661 662 // The staging values are committed after a subnet evaluation. 663 std::set<int> committed; 664 665 // Keeps track of the latest timestamp that committed is updated. 666 int timestamp; 667 668 // For evalSubnetWithLocalNetwork. 669 RandomVariableType originalType; 670 671 // Should only invoke eval on OP RandomVariable. 672 bool eval() { 673 NN_FUZZER_CHECK(var->type == RandomVariableType::OP); 674 var->value = var->op->eval(var->parent1->value, 675 var->parent2 == nullptr ? 0 : var->parent2->value); 676 if (var->value == kInvalidValue) return false; 677 return committed.find(var->value) != committed.end(); 678 } 679 void stage() { staging.insert(var->value); } 680 void commit() { 681 // Only update committed and timestamp if the range is *indeed* changed. 682 if (staging.size() != committed.size()) { 683 committed = std::move(staging); 684 timestamp = RandomVariableNetwork::get()->getGlobalTime(); 685 } 686 staging.clear(); 687 } 688 void updateRange() { 689 // Only update range and timestamp if the range is *indeed* changed. 690 if (committed.size() != var->range.size()) { 691 var->range = RandomVariableRange(committed); 692 var->timestamp = timestamp; 693 } 694 committed.clear(); 695 } 696 697 EvalInfo(const RandomVariableNode& var) 698 : var(var), 699 committed(var->range.getChoices().begin(), var->range.getChoices().end()), 700 timestamp(var->timestamp) {} 701 }; 702 using EvalContext = std::unordered_map<RandomVariableNode, EvalInfo>; 703 704 // For logging only. 705 inline std::string toString(const RandomVariableNode& var, EvalContext* context) { 706 std::stringstream ss; 707 ss << "var" << var->index << " = "; 708 const auto& committed = context->at(var).committed; 709 switch (var->type) { 710 case RandomVariableType::FREE: 711 ss << "FREE [" 712 << joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end())) << "]"; 713 break; 714 case RandomVariableType::CONST: 715 ss << "CONST " << toString(var->value); 716 break; 717 case RandomVariableType::OP: 718 ss << "var" << var->parent1->index << " " << var->op->getName(); 719 if (var->parent2 != nullptr) ss << " var" << var->parent2->index; 720 ss << ", [" << joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end())) 721 << "]"; 722 break; 723 default: 724 NN_FUZZER_CHECK(false); 725 } 726 ss << ", timestamp = " << context->at(var).timestamp; 727 return ss.str(); 728 } 729 730 // Check if the subnet needs to be re-evaluated by comparing the timestamps. 731 static inline bool needEvaluate(const EvaluationOrder& evalOrder, int subnetTime, 732 EvalContext* context = nullptr) { 733 for (const auto& var : evalOrder) { 734 int timestamp = context == nullptr ? var->timestamp : context->at(var).timestamp; 735 // If we find a node that has been modified since last evaluation, the subnet needs to be 736 // re-evaluated. 737 if (timestamp > subnetTime) return true; 738 } 739 return false; 740 } 741 742 // Helper function to evaluate the subnet recursively. 743 // Iterate through all combinations of FREE RandomVariables choices. 744 static void evalSubnetHelper(const EvaluationOrder& evalOrder, EvalContext* context, size_t i = 0) { 745 if (i == evalOrder.size()) { 746 // Reach the end of the evaluation, find a valid combination. 747 for (auto& var : evalOrder) context->at(var).stage(); 748 return; 749 } 750 const auto& var = evalOrder[i]; 751 if (var->type == RandomVariableType::FREE) { 752 // For FREE RandomVariable, iterate through all valid choices. 753 for (int val : context->at(var).committed) { 754 var->value = val; 755 evalSubnetHelper(evalOrder, context, i + 1); 756 } 757 return; 758 } else if (var->type == RandomVariableType::OP) { 759 // For OP RandomVariable, evaluate from parents and terminate if the result is invalid. 760 if (!context->at(var).eval()) return; 761 } 762 evalSubnetHelper(evalOrder, context, i + 1); 763 } 764 765 // Check if the subnet has only one single OP RandomVariable. 766 static inline bool isSingleOpSubnet(const EvaluationOrder& evalOrder) { 767 int numOp = 0; 768 for (const auto& var : evalOrder) { 769 if (var->type == RandomVariableType::OP) numOp++; 770 if (numOp > 1) return false; 771 } 772 return numOp != 0; 773 } 774 775 // Evaluate with a potentially faster approach provided by IRandomVariableOp. 776 static inline void evalSubnetSingleOpHelper(const EvaluationOrder& evalOrder, 777 EvalContext* context) { 778 NN_FUZZER_LOG << "Identified as single op subnet"; 779 const auto& var = evalOrder.back(); 780 NN_FUZZER_CHECK(var->type == RandomVariableType::OP); 781 var->op->eval(&context->at(var->parent1).committed, 782 var->parent2 == nullptr ? nullptr : &context->at(var->parent2).committed, 783 &context->at(var).committed, &context->at(var->parent1).staging, 784 var->parent2 == nullptr ? nullptr : &context->at(var->parent2).staging, 785 &context->at(var).staging); 786 } 787 788 // Check if the number of combinations of FREE RandomVariables exceeds the limit. 789 static inline uint64_t getNumCombinations(const EvaluationOrder& evalOrder, 790 EvalContext* context = nullptr) { 791 constexpr uint64_t kLimit = 1e8; 792 uint64_t numCombinations = 1; 793 for (const auto& var : evalOrder) { 794 if (var->type == RandomVariableType::FREE) { 795 size_t size = 796 context == nullptr ? var->range.size() : context->at(var).committed.size(); 797 numCombinations *= size; 798 // To prevent overflow. 799 if (numCombinations > kLimit) return kLimit; 800 } 801 } 802 return numCombinations; 803 } 804 805 // Evaluate the subnet recursively. Will return fail if the number of combinations of FREE 806 // RandomVariable exceeds the threshold kMaxNumCombinations. 807 static bool evalSubnetWithBruteForce(const EvaluationOrder& evalOrder, EvalContext* context) { 808 constexpr uint64_t kMaxNumCombinations = 1e7; 809 NN_FUZZER_LOG << "Evaluate with brute force"; 810 if (isSingleOpSubnet(evalOrder)) { 811 // If the network only have one single OP, dispatch to a faster evaluation. 812 evalSubnetSingleOpHelper(evalOrder, context); 813 } else { 814 if (getNumCombinations(evalOrder, context) > kMaxNumCombinations) { 815 NN_FUZZER_LOG << "Terminate the evaluation because of large search range"; 816 std::cout << "[ ] Terminate the evaluation because of large search range" 817 << std::endl; 818 return false; 819 } 820 evalSubnetHelper(evalOrder, context); 821 } 822 for (auto& var : evalOrder) { 823 if (context->at(var).staging.empty()) { 824 NN_FUZZER_LOG << "Evaluation failed at " << toString(var, context); 825 return false; 826 } 827 context->at(var).commit(); 828 } 829 return true; 830 } 831 832 struct LocalNetwork { 833 EvaluationOrder evalOrder; 834 std::vector<RandomVariableNode> bridgeNodes; 835 int timestamp = 0; 836 837 bool eval(EvalContext* context) { 838 NN_FUZZER_LOG << "Evaluate local network with timestamp = " << timestamp; 839 // Temporarily treat bridge nodes as FREE RandomVariables. 840 for (const auto& var : bridgeNodes) { 841 context->at(var).originalType = var->type; 842 var->type = RandomVariableType::FREE; 843 } 844 for (const auto& var : evalOrder) { 845 context->at(var).staging.clear(); 846 NN_FUZZER_LOG << " - " << toString(var, context); 847 } 848 bool success = evalSubnetWithBruteForce(evalOrder, context); 849 // Reset the RandomVariable types for bridge nodes. 850 for (const auto& var : bridgeNodes) var->type = context->at(var).originalType; 851 return success; 852 } 853 }; 854 855 // Partition the network further into LocalNetworks based on the result from bridge annotation 856 // algorithm. 857 class GraphPartitioner : public DisjointNetwork { 858 public: 859 GraphPartitioner() = default; 860 861 std::vector<LocalNetwork> partition(const EvaluationOrder& evalOrder, int timestamp) { 862 annotateBridge(evalOrder); 863 for (const auto& var : evalOrder) add(var); 864 return get(timestamp); 865 } 866 867 private: 868 GraphPartitioner(const GraphPartitioner&) = delete; 869 GraphPartitioner& operator=(const GraphPartitioner&) = delete; 870 871 // Find the parent-child relationship between var1 and var2, and reset the bridge. 872 void setBridgeFlag(const RandomVariableNode& var1, const RandomVariableNode& var2) { 873 if (var1->parent1 == var2) { 874 mBridgeInfo[var1].isParent1Bridge = true; 875 } else if (var1->parent2 == var2) { 876 mBridgeInfo[var1].isParent2Bridge = true; 877 } else { 878 setBridgeFlag(var2, var1); 879 } 880 } 881 882 // Annoate the bridges with DFS -- an edge [u, v] is a bridge if none of u's ancestor is 883 // reachable from a node in the subtree of b. The complexity is O(V + E). 884 // discoveryTime: The timestamp a node is visited 885 // lowTime: The min discovery time of all reachable nodes from the subtree of the node. 886 void annotateBridgeHelper(const RandomVariableNode& var, int* time) { 887 mBridgeInfo[var].visited = true; 888 mBridgeInfo[var].discoveryTime = mBridgeInfo[var].lowTime = (*time)++; 889 890 // The algorithm operates on undirected graph. First find all adjacent nodes. 891 auto adj = var->children; 892 if (var->parent1 != nullptr) adj.push_back(var->parent1); 893 if (var->parent2 != nullptr) adj.push_back(var->parent2); 894 895 for (const auto& child : adj) { 896 if (mBridgeInfo.find(child) == mBridgeInfo.end()) continue; 897 if (!mBridgeInfo[child].visited) { 898 mBridgeInfo[child].parent = var; 899 annotateBridgeHelper(child, time); 900 901 // If none of nodes in the subtree of child is connected to any ancestors of var, 902 // then it is a bridge. 903 mBridgeInfo[var].lowTime = 904 std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].lowTime); 905 if (mBridgeInfo[child].lowTime > mBridgeInfo[var].discoveryTime) 906 setBridgeFlag(var, child); 907 } else if (mBridgeInfo[var].parent != child) { 908 mBridgeInfo[var].lowTime = 909 std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].discoveryTime); 910 } 911 } 912 } 913 914 // Find all bridges in the subnet with DFS. 915 void annotateBridge(const EvaluationOrder& evalOrder) { 916 for (const auto& var : evalOrder) mBridgeInfo[var]; 917 int time = 0; 918 for (const auto& var : evalOrder) { 919 if (!mBridgeInfo[var].visited) annotateBridgeHelper(var, &time); 920 } 921 } 922 923 // Re-partition the network by treating bridges as no edge. 924 void add(const RandomVariableNode& var) { 925 auto parent1 = var->parent1; 926 auto parent2 = var->parent2; 927 if (mBridgeInfo[var].isParent1Bridge) var->parent1 = nullptr; 928 if (mBridgeInfo[var].isParent2Bridge) var->parent2 = nullptr; 929 DisjointNetwork::add(var); 930 var->parent1 = parent1; 931 var->parent2 = parent2; 932 } 933 934 // Add bridge nodes to the local network and remove single node subnet. 935 std::vector<LocalNetwork> get(int timestamp) { 936 std::vector<LocalNetwork> res; 937 for (auto& pair : mEvalOrderMap) { 938 // We do not need to evaluate subnet with only a single node. 939 if (pair.second.size() == 1 && pair.second[0]->parent1 == nullptr) continue; 940 res.emplace_back(); 941 for (const auto& var : pair.second) { 942 if (mBridgeInfo[var].isParent1Bridge) { 943 res.back().evalOrder.push_back(var->parent1); 944 res.back().bridgeNodes.push_back(var->parent1); 945 } 946 if (mBridgeInfo[var].isParent2Bridge) { 947 res.back().evalOrder.push_back(var->parent2); 948 res.back().bridgeNodes.push_back(var->parent2); 949 } 950 res.back().evalOrder.push_back(var); 951 } 952 res.back().timestamp = timestamp; 953 } 954 return res; 955 } 956 957 // For bridge discovery algorithm. 958 struct BridgeInfo { 959 bool isParent1Bridge = false; 960 bool isParent2Bridge = false; 961 int discoveryTime = 0; 962 int lowTime = 0; 963 bool visited = false; 964 std::shared_ptr<RandomVariableBase> parent = nullptr; 965 }; 966 std::unordered_map<RandomVariableNode, BridgeInfo> mBridgeInfo; 967 }; 968 969 // Evaluate subnets repeatedly until converge. 970 // Class T_Subnet must have member evalOrder, timestamp, and member function eval. 971 template <class T_Subnet> 972 inline bool evalSubnetsRepeatedly(std::vector<T_Subnet>* subnets, EvalContext* context) { 973 bool terminate = false; 974 while (!terminate) { 975 terminate = true; 976 for (auto& subnet : *subnets) { 977 if (needEvaluate(subnet.evalOrder, subnet.timestamp, context)) { 978 if (!subnet.eval(context)) return false; 979 subnet.timestamp = RandomVariableNetwork::get()->getGlobalTime(); 980 terminate = false; 981 } 982 } 983 } 984 return true; 985 } 986 987 // Evaluate the subnet by first partitioning it further into LocalNetworks. 988 static bool evalSubnetWithLocalNetwork(const EvaluationOrder& evalOrder, int timestamp, 989 EvalContext* context) { 990 NN_FUZZER_LOG << "Evaluate with local network"; 991 auto localNetworks = GraphPartitioner().partition(evalOrder, timestamp); 992 return evalSubnetsRepeatedly(&localNetworks, context); 993 } 994 995 struct LeafNetwork { 996 EvaluationOrder evalOrder; 997 int timestamp = 0; 998 LeafNetwork(const RandomVariableNode& var, int timestamp) : timestamp(timestamp) { 999 std::set<RandomVariableNode> visited; 1000 constructorHelper(var, &visited); 1001 } 1002 // Construct the leaf network by recursively including parent nodes. 1003 void constructorHelper(const RandomVariableNode& var, std::set<RandomVariableNode>* visited) { 1004 if (var == nullptr || visited->find(var) != visited->end()) return; 1005 constructorHelper(var->parent1, visited); 1006 constructorHelper(var->parent2, visited); 1007 visited->insert(var); 1008 evalOrder.push_back(var); 1009 } 1010 bool eval(EvalContext* context) { 1011 return evalSubnetWithLocalNetwork(evalOrder, timestamp, context); 1012 } 1013 }; 1014 1015 // Evaluate the subnet by leaf network. 1016 // NOTE: This algorithm will only produce correct result for *most* of the time (> 99%). 1017 // The random graph generator is expected to retry if it fails. 1018 static bool evalSubnetWithLeafNetwork(const EvaluationOrder& evalOrder, int timestamp, 1019 EvalContext* context) { 1020 NN_FUZZER_LOG << "Evaluate with leaf network"; 1021 // Construct leaf networks. 1022 std::vector<LeafNetwork> leafNetworks; 1023 for (const auto& var : evalOrder) { 1024 if (var->children.empty()) { 1025 NN_FUZZER_LOG << "Found leaf " << toString(var, context); 1026 leafNetworks.emplace_back(var, timestamp); 1027 } 1028 } 1029 return evalSubnetsRepeatedly(&leafNetworks, context); 1030 } 1031 1032 void RandomVariableNetwork::addDimensionProd(const std::vector<RandomVariable>& dims) { 1033 if (dims.size() <= 1) return; 1034 EvaluationOrder order; 1035 for (const auto& dim : dims) order.push_back(dim.get()); 1036 mDimProd.push_back(order); 1037 } 1038 1039 bool enforceDimProd(const std::vector<EvaluationOrder>& mDimProd, 1040 const std::unordered_map<RandomVariableNode, int>& indexMap, 1041 EvalContext* context, std::unordered_set<int>* dirtySubnets) { 1042 for (auto& evalOrder : mDimProd) { 1043 NN_FUZZER_LOG << " Dimension product network size = " << evalOrder.size(); 1044 // Initialize EvalInfo of each RandomVariable. 1045 for (auto& var : evalOrder) { 1046 if (context->find(var) == context->end()) context->emplace(var, var); 1047 NN_FUZZER_LOG << " - " << toString(var, context); 1048 } 1049 1050 // Enforce the product of the dimension values below kMaxValue: 1051 // max(dimA) = kMaxValue / (min(dimB) * min(dimC) * ...) 1052 int prod = 1; 1053 for (const auto& var : evalOrder) prod *= (*context->at(var).committed.begin()); 1054 for (auto& var : evalOrder) { 1055 auto& committed = context->at(var).committed; 1056 int maxValue = kMaxValue / (prod / *committed.begin()); 1057 auto it = committed.upper_bound(maxValue); 1058 // var has empty range -> no solution. 1059 if (it == committed.begin()) return false; 1060 // The range is not modified -> continue. 1061 if (it == committed.end()) continue; 1062 // The range is modified -> the subnet of var is dirty, i.e. needs re-evaluation. 1063 committed.erase(it, committed.end()); 1064 context->at(var).timestamp = RandomVariableNetwork::get()->getGlobalTime(); 1065 dirtySubnets->insert(indexMap.at(var)); 1066 } 1067 } 1068 return true; 1069 } 1070 1071 bool RandomVariableNetwork::evalRange() { 1072 constexpr uint64_t kMaxNumCombinationsWithBruteForce = 500; 1073 constexpr uint64_t kMaxNumCombinationsWithLocalNetwork = 1e5; 1074 NN_FUZZER_LOG << "Evaluate on " << mEvalOrderMap.size() << " sub-networks"; 1075 EvalContext context; 1076 std::unordered_set<int> dirtySubnets; // Which subnets needs evaluation. 1077 for (auto& pair : mEvalOrderMap) { 1078 const auto& evalOrder = pair.second; 1079 // Decide whether needs evaluation by timestamp -- if no range has changed after the last 1080 // evaluation, then the subnet does not need re-evaluation. 1081 if (evalOrder.size() == 1 || !needEvaluate(evalOrder, mTimestamp)) continue; 1082 dirtySubnets.insert(pair.first); 1083 } 1084 if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false; 1085 1086 // Repeat until the ranges converge. 1087 while (!dirtySubnets.empty()) { 1088 for (int ind : dirtySubnets) { 1089 const auto& evalOrder = mEvalOrderMap[ind]; 1090 NN_FUZZER_LOG << " Sub-network #" << ind << " size = " << evalOrder.size(); 1091 1092 // Initialize EvalInfo of each RandomVariable. 1093 for (auto& var : evalOrder) { 1094 if (context.find(var) == context.end()) context.emplace(var, var); 1095 NN_FUZZER_LOG << " - " << toString(var, &context); 1096 } 1097 1098 // Dispatch to different algorithm according to search range. 1099 bool success; 1100 uint64_t numCombinations = getNumCombinations(evalOrder); 1101 if (numCombinations <= kMaxNumCombinationsWithBruteForce) { 1102 success = evalSubnetWithBruteForce(evalOrder, &context); 1103 } else if (numCombinations <= kMaxNumCombinationsWithLocalNetwork) { 1104 success = evalSubnetWithLocalNetwork(evalOrder, mTimestamp, &context); 1105 } else { 1106 success = evalSubnetWithLeafNetwork(evalOrder, mTimestamp, &context); 1107 } 1108 if (!success) return false; 1109 } 1110 dirtySubnets.clear(); 1111 if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false; 1112 } 1113 // A successful evaluation, update RandomVariables from EvalContext. 1114 for (auto& pair : context) pair.second.updateRange(); 1115 mTimestamp = getGlobalTime(); 1116 NN_FUZZER_LOG << "Finish range evaluation"; 1117 return true; 1118 } 1119 1120 static void unsetEqual(const RandomVariableNode& node) { 1121 if (node == nullptr) return; 1122 NN_FUZZER_LOG << "Unset equality of var" << node->index; 1123 RandomVariableNode parent1 = node->parent1, parent2 = node->parent2; 1124 parent1->children.erase(std::find(parent1->children.begin(), parent1->children.end(), node)); 1125 node->parent1 = nullptr; 1126 if (parent2 != nullptr) { 1127 // For Equal. 1128 parent2->children.erase( 1129 std::find(parent2->children.begin(), parent2->children.end(), node)); 1130 node->parent2 = nullptr; 1131 } else { 1132 // For UnaryEqual. 1133 node->type = RandomVariableType::FREE; 1134 node->op = nullptr; 1135 } 1136 } 1137 1138 // A class to revert all the changes made to RandomVariableNetwork since the Reverter object is 1139 // constructed. Only used when setEqualIfCompatible results in incompatible. 1140 class RandomVariableNetwork::Reverter { 1141 public: 1142 // Take a snapshot of RandomVariableNetwork when Reverter is constructed. 1143 Reverter() : mSnapshot(*RandomVariableNetwork::get()) {} 1144 // Add constraint (Equal) nodes to the reverter. 1145 void addNode(const RandomVariableNode& node) { mEqualNodes.push_back(node); } 1146 void revert() { 1147 NN_FUZZER_LOG << "Revert RandomVariableNetwork"; 1148 // Release the constraints. 1149 for (const auto& node : mEqualNodes) unsetEqual(node); 1150 // Reset all member variables. 1151 *RandomVariableNetwork::get() = std::move(mSnapshot); 1152 } 1153 1154 private: 1155 Reverter(const Reverter&) = delete; 1156 Reverter& operator=(const Reverter&) = delete; 1157 RandomVariableNetwork mSnapshot; 1158 std::vector<RandomVariableNode> mEqualNodes; 1159 }; 1160 1161 bool RandomVariableNetwork::setEqualIfCompatible(const std::vector<RandomVariable>& lhs, 1162 const std::vector<RandomVariable>& rhs) { 1163 NN_FUZZER_LOG << "Check compatibility of {" << joinStr(", ", lhs) << "} and {" 1164 << joinStr(", ", rhs) << "}"; 1165 if (lhs.size() != rhs.size()) return false; 1166 Reverter reverter; 1167 bool result = true; 1168 for (size_t i = 0; i < lhs.size(); i++) { 1169 auto node = lhs[i].setEqual(rhs[i]).get(); 1170 reverter.addNode(node); 1171 // Early terminate if there is no common choice between two ranges. 1172 if (node != nullptr && node->range.empty()) result = false; 1173 } 1174 result = result && evalRange(); 1175 if (!result) reverter.revert(); 1176 NN_FUZZER_LOG << "setEqualIfCompatible: " << (result ? "[COMPATIBLE]" : "[INCOMPATIBLE]"); 1177 return result; 1178 } 1179 1180 bool RandomVariableNetwork::freeze() { 1181 NN_FUZZER_LOG << "Freeze the random network"; 1182 if (!evalRange()) return false; 1183 for (const auto& pair : mEvalOrderMap) { 1184 // Find all FREE RandomVariables in the subnet. 1185 std::vector<RandomVariableNode> nodes; 1186 for (const auto& var : pair.second) { 1187 if (var->type == RandomVariableType::FREE) nodes.push_back(var); 1188 } 1189 // Randomly shuffle the order, this is for a more uniform randomness. 1190 randomShuffle(&nodes); 1191 // An inefficient algorithm that does freeze -> re-evaluate for every FREE RandomVariable. 1192 // TODO: Might be able to optimize this. 1193 for (const auto& var : nodes) { 1194 size_t size = var->range.size(); 1195 NN_FUZZER_LOG << "Freeze " << toString(var); 1196 var->freeze(); 1197 NN_FUZZER_LOG << " " << toString(var); 1198 // There is no need to re-evaluate if the FREE RandomVariable have only one choice. 1199 if (size > 1) { 1200 var->updateTimestamp(); 1201 if (!evalRange()) { 1202 NN_FUZZER_LOG << "Freeze failed at " << toString(var); 1203 return false; 1204 } 1205 } 1206 } 1207 } 1208 NN_FUZZER_LOG << "Finish freezing the random network"; 1209 return true; 1210 } 1211 1212 } // namespace fuzzing_test 1213 } // namespace nn 1214 } // namespace android 1215