Home | History | Annotate | Download | only in fuzzing
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "RandomVariable.h"
     18 
     19 #include <algorithm>
     20 #include <memory>
     21 #include <set>
     22 #include <string>
     23 #include <unordered_set>
     24 #include <vector>
     25 
     26 #include "RandomGraphGeneratorUtils.h"
     27 
     28 namespace android {
     29 namespace nn {
     30 namespace fuzzing_test {
     31 
     32 unsigned int RandomVariableBase::globalIndex = 0;
     33 int RandomVariable::defaultValue = 10;
     34 
     35 RandomVariableBase::RandomVariableBase(int value)
     36     : index(globalIndex++),
     37       type(RandomVariableType::CONST),
     38       range(value),
     39       value(value),
     40       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
     41 
     42 RandomVariableBase::RandomVariableBase(int lower, int upper)
     43     : index(globalIndex++),
     44       type(RandomVariableType::FREE),
     45       range(lower, upper),
     46       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
     47 
     48 RandomVariableBase::RandomVariableBase(const std::vector<int>& choices)
     49     : index(globalIndex++),
     50       type(RandomVariableType::FREE),
     51       range(choices),
     52       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
     53 
     54 RandomVariableBase::RandomVariableBase(const RandomVariableNode& lhs, const RandomVariableNode& rhs,
     55                                        const std::shared_ptr<const IRandomVariableOp>& op)
     56     : index(globalIndex++),
     57       type(RandomVariableType::OP),
     58       range(op->getInitRange(lhs->range, rhs == nullptr ? RandomVariableRange(0) : rhs->range)),
     59       op(op),
     60       parent1(lhs),
     61       parent2(rhs),
     62       timestamp(RandomVariableNetwork::get()->getGlobalTime()) {}
     63 
     64 void RandomVariableRange::setRange(int lower, int upper) {
     65     // kInvalidValue indicates unlimited bound.
     66     auto head = lower == kInvalidValue ? mChoices.begin()
     67                                        : std::lower_bound(mChoices.begin(), mChoices.end(), lower);
     68     auto tail = upper == kInvalidValue ? mChoices.end()
     69                                        : std::upper_bound(mChoices.begin(), mChoices.end(), upper);
     70     NN_FUZZER_CHECK(head <= tail) << "Invalid range!";
     71     if (head != mChoices.begin() || tail != mChoices.end()) {
     72         mChoices = std::vector<int>(head, tail);
     73     }
     74 }
     75 
     76 int RandomVariableRange::toConst() {
     77     if (mChoices.size() > 1) mChoices = {getRandomChoice(mChoices)};
     78     return mChoices[0];
     79 }
     80 
     81 RandomVariableRange operator&(const RandomVariableRange& lhs, const RandomVariableRange& rhs) {
     82     std::vector<int> result(lhs.size() + rhs.size());
     83     auto it = std::set_intersection(lhs.mChoices.begin(), lhs.mChoices.end(), rhs.mChoices.begin(),
     84                                     rhs.mChoices.end(), result.begin());
     85     result.resize(it - result.begin());
     86     return RandomVariableRange(std::move(result));
     87 }
     88 
     89 void RandomVariableBase::freeze() {
     90     if (type == RandomVariableType::CONST) return;
     91     value = range.toConst();
     92     type = RandomVariableType::CONST;
     93 }
     94 
     95 int RandomVariableBase::getValue() const {
     96     switch (type) {
     97         case RandomVariableType::CONST:
     98             return value;
     99         case RandomVariableType::OP:
    100             return op->eval(parent1->getValue(), parent2 == nullptr ? 0 : parent2->getValue());
    101         default:
    102             NN_FUZZER_CHECK(false) << "Invalid type when getting value of var" << index;
    103             return 0;
    104     }
    105 }
    106 
    107 void RandomVariableBase::updateTimestamp() {
    108     timestamp = RandomVariableNetwork::get()->getGlobalTime();
    109     NN_FUZZER_LOG << "Update timestamp of var" << index << " to " << timestamp;
    110 }
    111 
    112 RandomVariable::RandomVariable(int value) : mVar(new RandomVariableBase(value)) {
    113     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
    114     RandomVariableNetwork::get()->add(mVar);
    115 }
    116 RandomVariable::RandomVariable(int lower, int upper) : mVar(new RandomVariableBase(lower, upper)) {
    117     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
    118     RandomVariableNetwork::get()->add(mVar);
    119 }
    120 RandomVariable::RandomVariable(const std::vector<int>& choices)
    121     : mVar(new RandomVariableBase(choices)) {
    122     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
    123     RandomVariableNetwork::get()->add(mVar);
    124 }
    125 RandomVariable::RandomVariable(RandomVariableType type)
    126     : mVar(new RandomVariableBase(1, defaultValue)) {
    127     NN_FUZZER_CHECK(type == RandomVariableType::FREE);
    128     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
    129     RandomVariableNetwork::get()->add(mVar);
    130 }
    131 RandomVariable::RandomVariable(const RandomVariable& lhs, const RandomVariable& rhs,
    132                                const std::shared_ptr<const IRandomVariableOp>& op)
    133     : mVar(new RandomVariableBase(lhs.get(), rhs.get(), op)) {
    134     // Make a copy if the parent is CONST. This will resolve the fake dependency problem.
    135     if (mVar->parent1->type == RandomVariableType::CONST) {
    136         mVar->parent1 = RandomVariable(mVar->parent1->value).get();
    137     }
    138     if (mVar->parent2 != nullptr && mVar->parent2->type == RandomVariableType::CONST) {
    139         mVar->parent2 = RandomVariable(mVar->parent2->value).get();
    140     }
    141     mVar->parent1->children.push_back(mVar);
    142     if (mVar->parent2 != nullptr) mVar->parent2->children.push_back(mVar);
    143     RandomVariableNetwork::get()->add(mVar);
    144     NN_FUZZER_LOG << "New RandomVariable " << toString(mVar);
    145 }
    146 
    147 void RandomVariable::setRange(int lower, int upper) {
    148     NN_FUZZER_CHECK(mVar != nullptr) << "setRange() on nullptr";
    149     NN_FUZZER_LOG << "Set range [" << lower << ", " << upper << "] on var" << mVar->index;
    150     size_t oldSize = mVar->range.size();
    151     mVar->range.setRange(lower, upper);
    152     // Only update the timestamp if the range is *indeed* narrowed down.
    153     if (mVar->range.size() != oldSize) mVar->updateTimestamp();
    154 }
    155 
    156 RandomVariableRange IRandomVariableOp::getInitRange(const RandomVariableRange& lhs,
    157                                                     const RandomVariableRange& rhs) const {
    158     std::set<int> st;
    159     for (auto i : lhs.getChoices()) {
    160         for (auto j : rhs.getChoices()) {
    161             int res = this->eval(i, j);
    162             if (res > kMaxValue || res < -kMaxValue) continue;
    163             st.insert(res);
    164         }
    165     }
    166     return RandomVariableRange(st);
    167 }
    168 
    169 // Check if the range contains exactly all values in [min, max].
    170 static inline bool isContinuous(const std::set<int>* range) {
    171     return (*(range->rbegin()) - *(range->begin()) + 1) == static_cast<int>(range->size());
    172 }
    173 
    174 // Fill the set with a range of values specified by [lower, upper].
    175 static inline void fillRange(std::set<int>* range, int lower, int upper) {
    176     for (int i = lower; i <= upper; i++) range->insert(i);
    177 }
    178 
    179 // The slowest algorithm: iterate through every combinations of parents and save the valid pairs.
    180 void IRandomVariableOp::eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
    181                              const std::set<int>* childIn, std::set<int>* parent1Out,
    182                              std::set<int>* parent2Out, std::set<int>* childOut) const {
    183     // Avoid the binary search if the child is a closed range.
    184     bool isChildInContinuous = isContinuous(childIn);
    185     std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
    186     for (auto i : *parent1In) {
    187         bool valid = false;
    188         for (auto j : *parent2In) {
    189             int res = this->eval(i, j);
    190             // Avoid the binary search if obviously out of range.
    191             if (res > child.second || res < child.first) continue;
    192             if (isChildInContinuous || childIn->find(res) != childIn->end()) {
    193                 parent2Out->insert(j);
    194                 childOut->insert(res);
    195                 valid = true;
    196             }
    197         }
    198         if (valid) parent1Out->insert(i);
    199     }
    200 }
    201 
    202 // A helper template to make a class into a Singleton.
    203 template <class T>
    204 class Singleton : public T {
    205    public:
    206     static const std::shared_ptr<const T>& get() {
    207         static std::shared_ptr<const T> instance(new T);
    208         return instance;
    209     }
    210 };
    211 
    212 // A set of operations that only compute on a single input value.
    213 class IUnaryOp : public IRandomVariableOp {
    214    public:
    215     using IRandomVariableOp::eval;
    216     virtual int eval(int val) const = 0;
    217     virtual int eval(int lhs, int) const override { return eval(lhs); }
    218     // The slowest algorithm: iterate through every value of the parent and save the valid one.
    219     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
    220                       const std::set<int>* childIn, std::set<int>* parent1Out,
    221                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
    222         NN_FUZZER_CHECK(parent2In == nullptr);
    223         NN_FUZZER_CHECK(parent2Out == nullptr);
    224         bool isChildInContinuous = isContinuous(childIn);
    225         std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
    226         for (auto i : *parent1In) {
    227             int res = this->eval(i);
    228             if (res > child.second || res < child.first) continue;
    229             if (isChildInContinuous || childIn->find(res) != childIn->end()) {
    230                 parent1Out->insert(i);
    231                 childOut->insert(res);
    232             }
    233         }
    234     }
    235 };
    236 
    237 // A set of operations that only check conditional constraints.
    238 class IConstraintOp : public IRandomVariableOp {
    239    public:
    240     using IRandomVariableOp::eval;
    241     virtual bool check(int lhs, int rhs) const = 0;
    242     virtual int eval(int lhs, int rhs) const override {
    243         return check(lhs, rhs) ? 0 : kInvalidValue;
    244     }
    245     // The range for a constraint op is always {0}.
    246     virtual RandomVariableRange getInitRange(const RandomVariableRange&,
    247                                              const RandomVariableRange&) const override {
    248         return RandomVariableRange(0);
    249     }
    250     // The slowest algorithm:
    251     // iterate through every combinations of parents and save the valid pairs.
    252     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
    253                       const std::set<int>*, std::set<int>* parent1Out, std::set<int>* parent2Out,
    254                       std::set<int>* childOut) const override {
    255         for (auto i : *parent1In) {
    256             bool valid = false;
    257             for (auto j : *parent2In) {
    258                 if (this->check(i, j)) {
    259                     parent2Out->insert(j);
    260                     valid = true;
    261                 }
    262             }
    263             if (valid) parent1Out->insert(i);
    264         }
    265         if (!parent1Out->empty()) childOut->insert(0);
    266     }
    267 };
    268 
    269 class Addition : public IRandomVariableOp {
    270    public:
    271     virtual int eval(int lhs, int rhs) const override { return lhs + rhs; }
    272     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
    273                                              const RandomVariableRange& rhs) const override {
    274         return RandomVariableRange(lhs.min() + rhs.min(), lhs.max() + rhs.max());
    275     }
    276     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
    277                       const std::set<int>* childIn, std::set<int>* parent1Out,
    278                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
    279         if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) {
    280             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
    281                                     childOut);
    282         } else {
    283             // For parents and child with close range, the out range can be computed directly
    284             // without iterations.
    285             std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()};
    286             std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()};
    287             std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
    288 
    289             // From ranges for parent, evalute range for child.
    290             // [a, b] + [c, d] -> [a + c, b + d]
    291             fillRange(childOut, std::max(child.first, parent1.first + parent2.first),
    292                       std::min(child.second, parent1.second + parent2.second));
    293 
    294             // From ranges for child and one parent, evalute range for another parent.
    295             // [a, b] - [c, d] -> [a - d, b - c]
    296             fillRange(parent1Out, std::max(parent1.first, child.first - parent2.second),
    297                       std::min(parent1.second, child.second - parent2.first));
    298             fillRange(parent2Out, std::max(parent2.first, child.first - parent1.second),
    299                       std::min(parent2.second, child.second - parent1.first));
    300         }
    301     }
    302     virtual const char* getName() const override { return "ADD"; }
    303 };
    304 
    305 class Subtraction : public IRandomVariableOp {
    306    public:
    307     virtual int eval(int lhs, int rhs) const override { return lhs - rhs; }
    308     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
    309                                              const RandomVariableRange& rhs) const override {
    310         return RandomVariableRange(lhs.min() - rhs.max(), lhs.max() - rhs.min());
    311     }
    312     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
    313                       const std::set<int>* childIn, std::set<int>* parent1Out,
    314                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
    315         if (!isContinuous(parent1In) || !isContinuous(parent2In) || !isContinuous(childIn)) {
    316             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
    317                                     childOut);
    318         } else {
    319             // Similar algorithm as Addition.
    320             std::pair<int, int> parent1 = {*parent1In->begin(), *parent1In->rbegin()};
    321             std::pair<int, int> parent2 = {*parent2In->begin(), *parent2In->rbegin()};
    322             std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
    323             fillRange(childOut, std::max(child.first, parent1.first - parent2.second),
    324                       std::min(child.second, parent1.second - parent2.first));
    325             fillRange(parent1Out, std::max(parent1.first, child.first + parent2.first),
    326                       std::min(parent1.second, child.second + parent2.second));
    327             fillRange(parent2Out, std::max(parent2.first, parent1.first - child.second),
    328                       std::min(parent2.second, parent1.second - child.first));
    329         }
    330     }
    331     virtual const char* getName() const override { return "SUB"; }
    332 };
    333 
    334 class Multiplication : public IRandomVariableOp {
    335    public:
    336     virtual int eval(int lhs, int rhs) const override { return lhs * rhs; }
    337     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
    338                                              const RandomVariableRange& rhs) const override {
    339         if (lhs.min() < 0 || rhs.min() < 0) {
    340             return IRandomVariableOp::getInitRange(lhs, rhs);
    341         } else {
    342             int lower = std::min(lhs.min() * rhs.min(), kMaxValue);
    343             int upper = std::min(lhs.max() * rhs.max(), kMaxValue);
    344             return RandomVariableRange(lower, upper);
    345         }
    346     }
    347     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
    348                       const std::set<int>* childIn, std::set<int>* parent1Out,
    349                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
    350         if (*parent1In->begin() < 0 || *parent2In->begin() < 0 || *childIn->begin() < 0) {
    351             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
    352                                     childOut);
    353         } else {
    354             bool isChildInContinuous = isContinuous(childIn);
    355             std::pair<int, int> child = {*childIn->begin(), *childIn->rbegin()};
    356             for (auto i : *parent1In) {
    357                 bool valid = false;
    358                 for (auto j : *parent2In) {
    359                     int res = this->eval(i, j);
    360                     // Since MUL increases monotonically with one value, break the loop if the
    361                     // result is larger than the limit.
    362                     if (res > child.second) break;
    363                     if (res < child.first) continue;
    364                     if (isChildInContinuous || childIn->find(res) != childIn->end()) {
    365                         valid = true;
    366                         parent2Out->insert(j);
    367                         childOut->insert(res);
    368                     }
    369                 }
    370                 if (valid) parent1Out->insert(i);
    371             }
    372         }
    373     }
    374     virtual const char* getName() const override { return "MUL"; }
    375 };
    376 
    377 class Division : public IRandomVariableOp {
    378    public:
    379     virtual int eval(int lhs, int rhs) const override {
    380         return rhs == 0 ? kInvalidValue : lhs / rhs;
    381     }
    382     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
    383                                              const RandomVariableRange& rhs) const override {
    384         if (lhs.min() < 0 || rhs.min() <= 0) {
    385             return IRandomVariableOp::getInitRange(lhs, rhs);
    386         } else {
    387             return RandomVariableRange(lhs.min() / rhs.max(), lhs.max() / rhs.min());
    388         }
    389     }
    390     virtual const char* getName() const override { return "DIV"; }
    391 };
    392 
    393 class ExactDivision : public Division {
    394    public:
    395     virtual int eval(int lhs, int rhs) const override {
    396         return (rhs == 0 || lhs % rhs != 0) ? kInvalidValue : lhs / rhs;
    397     }
    398     virtual const char* getName() const override { return "EXACT_DIV"; }
    399 };
    400 
    401 class Modulo : public IRandomVariableOp {
    402    public:
    403     virtual int eval(int lhs, int rhs) const override {
    404         return rhs == 0 ? kInvalidValue : lhs % rhs;
    405     }
    406     virtual RandomVariableRange getInitRange(const RandomVariableRange&,
    407                                              const RandomVariableRange& rhs) const override {
    408         return RandomVariableRange(0, rhs.max());
    409     }
    410     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
    411                       const std::set<int>* childIn, std::set<int>* parent1Out,
    412                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
    413         if (*childIn->begin() != 0 || childIn->size() != 1u) {
    414             IRandomVariableOp::eval(parent1In, parent2In, childIn, parent1Out, parent2Out,
    415                                     childOut);
    416         } else {
    417             // For the special case that child is a const 0, it would be faster if the range for
    418             // parents are evaluated separately.
    419 
    420             // Evalute parent1 directly.
    421             for (auto i : *parent1In) {
    422                 for (auto j : *parent2In) {
    423                     if (i % j == 0) {
    424                         parent1Out->insert(i);
    425                         break;
    426                     }
    427                 }
    428             }
    429             // Evalute parent2, see if a multiple of parent2 value can be found in parent1.
    430             int parent1Max = *parent1In->rbegin();
    431             for (auto i : *parent2In) {
    432                 int jMax = parent1Max / i;
    433                 for (int j = 1; j <= jMax; j++) {
    434                     if (parent1In->find(i * j) != parent1In->end()) {
    435                         parent2Out->insert(i);
    436                         break;
    437                     }
    438                 }
    439             }
    440             if (!parent1Out->empty()) childOut->insert(0);
    441         }
    442     }
    443     virtual const char* getName() const override { return "MOD"; }
    444 };
    445 
    446 class Maximum : public IRandomVariableOp {
    447    public:
    448     virtual int eval(int lhs, int rhs) const override { return std::max(lhs, rhs); }
    449     virtual const char* getName() const override { return "MAX"; }
    450 };
    451 
    452 class Minimum : public IRandomVariableOp {
    453    public:
    454     virtual int eval(int lhs, int rhs) const override { return std::min(lhs, rhs); }
    455     virtual const char* getName() const override { return "MIN"; }
    456 };
    457 
    458 class Square : public IUnaryOp {
    459    public:
    460     virtual int eval(int val) const override { return val * val; }
    461     virtual const char* getName() const override { return "SQUARE"; }
    462 };
    463 
    464 class UnaryEqual : public IUnaryOp {
    465    public:
    466     virtual int eval(int val) const override { return val; }
    467     virtual const char* getName() const override { return "UNARY_EQUAL"; }
    468 };
    469 
    470 class Equal : public IConstraintOp {
    471    public:
    472     virtual bool check(int lhs, int rhs) const override { return lhs == rhs; }
    473     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
    474                       const std::set<int>* childIn, std::set<int>* parent1Out,
    475                       std::set<int>* parent2Out, std::set<int>* childOut) const override {
    476         NN_FUZZER_CHECK(childIn->size() == 1u && *childIn->begin() == 0);
    477         // The intersection of two sets can be found in O(n).
    478         std::set_intersection(parent1In->begin(), parent1In->end(), parent2In->begin(),
    479                               parent2In->end(), std::inserter(*parent1Out, parent1Out->begin()));
    480         *parent2Out = *parent1Out;
    481         childOut->insert(0);
    482     }
    483     virtual const char* getName() const override { return "EQUAL"; }
    484 };
    485 
    486 class GreaterThan : public IConstraintOp {
    487    public:
    488     virtual bool check(int lhs, int rhs) const override { return lhs > rhs; }
    489     virtual const char* getName() const override { return "GREATER_THAN"; }
    490 };
    491 
    492 class GreaterEqual : public IConstraintOp {
    493    public:
    494     virtual bool check(int lhs, int rhs) const override { return lhs >= rhs; }
    495     virtual const char* getName() const override { return "GREATER_EQUAL"; }
    496 };
    497 
    498 class FloatMultiplication : public IUnaryOp {
    499    public:
    500     FloatMultiplication(float multiplicand) : mMultiplicand(multiplicand) {}
    501     virtual int eval(int val) const override {
    502         return static_cast<int>(std::floor(static_cast<float>(val) * mMultiplicand));
    503     }
    504     virtual const char* getName() const override { return "MUL_FLOAT"; }
    505 
    506    private:
    507     float mMultiplicand;
    508 };
    509 
    510 // Arithmetic operators and methods on RandomVariables will create OP RandomVariableNodes.
    511 // Since there must be at most one edge between two RandomVariableNodes, we have to do something
    512 // special when both sides are refering to the same node.
    513 
    514 RandomVariable operator+(const RandomVariable& lhs, const RandomVariable& rhs) {
    515     return lhs.get() == rhs.get() ? RandomVariable(lhs, 2, Singleton<Multiplication>::get())
    516                                   : RandomVariable(lhs, rhs, Singleton<Addition>::get());
    517 }
    518 RandomVariable operator-(const RandomVariable& lhs, const RandomVariable& rhs) {
    519     return lhs.get() == rhs.get() ? RandomVariable(0)
    520                                   : RandomVariable(lhs, rhs, Singleton<Subtraction>::get());
    521 }
    522 RandomVariable operator*(const RandomVariable& lhs, const RandomVariable& rhs) {
    523     return lhs.get() == rhs.get() ? RandomVariable(lhs, RandomVariable(), Singleton<Square>::get())
    524                                   : RandomVariable(lhs, rhs, Singleton<Multiplication>::get());
    525 }
    526 RandomVariable operator*(const RandomVariable& lhs, const float& rhs) {
    527     return RandomVariable(lhs, RandomVariable(), std::make_shared<FloatMultiplication>(rhs));
    528 }
    529 RandomVariable operator/(const RandomVariable& lhs, const RandomVariable& rhs) {
    530     return lhs.get() == rhs.get() ? RandomVariable(1)
    531                                   : RandomVariable(lhs, rhs, Singleton<Division>::get());
    532 }
    533 RandomVariable operator%(const RandomVariable& lhs, const RandomVariable& rhs) {
    534     return lhs.get() == rhs.get() ? RandomVariable(0)
    535                                   : RandomVariable(lhs, rhs, Singleton<Modulo>::get());
    536 }
    537 RandomVariable max(const RandomVariable& lhs, const RandomVariable& rhs) {
    538     return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Maximum>::get());
    539 }
    540 RandomVariable min(const RandomVariable& lhs, const RandomVariable& rhs) {
    541     return lhs.get() == rhs.get() ? lhs : RandomVariable(lhs, rhs, Singleton<Minimum>::get());
    542 }
    543 
    544 RandomVariable RandomVariable::exactDiv(const RandomVariable& other) {
    545     return mVar == other.get() ? RandomVariable(1)
    546                                : RandomVariable(*this, other, Singleton<ExactDivision>::get());
    547 }
    548 
    549 RandomVariable RandomVariable::setEqual(const RandomVariable& other) const {
    550     RandomVariableNode node1 = mVar, node2 = other.get();
    551     NN_FUZZER_LOG << "Set equality of var" << node1->index << " and var" << node2->index;
    552 
    553     // Do not setEqual on the same pair twice.
    554     if (node1 == node2 || (node1->op == Singleton<UnaryEqual>::get() && node1->parent1 == node2) ||
    555         (node2->op == Singleton<UnaryEqual>::get() && node2->parent1 == node1)) {
    556         NN_FUZZER_LOG << "Already equal. Return.";
    557         return RandomVariable();
    558     }
    559 
    560     // If possible, always try UnaryEqual first to reduce the search space.
    561     // UnaryEqual can be used if node B is FREE and is evaluated later than node A.
    562     // TODO: Reduce code duplication.
    563     if (RandomVariableNetwork::get()->isSubordinate(node1, node2)) {
    564         NN_FUZZER_LOG << "  Make var" << node2->index << " a child of var" << node1->index;
    565         node2->type = RandomVariableType::OP;
    566         node2->parent1 = node1;
    567         node2->op = Singleton<UnaryEqual>::get();
    568         node1->children.push_back(node2);
    569         RandomVariableNetwork::get()->join(node1, node2);
    570         node1->updateTimestamp();
    571         return other;
    572     }
    573     if (RandomVariableNetwork::get()->isSubordinate(node2, node1)) {
    574         NN_FUZZER_LOG << "  Make var" << node1->index << " a child of var" << node2->index;
    575         node1->type = RandomVariableType::OP;
    576         node1->parent1 = node2;
    577         node1->op = Singleton<UnaryEqual>::get();
    578         node2->children.push_back(node1);
    579         RandomVariableNetwork::get()->join(node2, node1);
    580         node1->updateTimestamp();
    581         return *this;
    582     }
    583     return RandomVariable(*this, other, Singleton<Equal>::get());
    584 }
    585 
    586 RandomVariable RandomVariable::setGreaterThan(const RandomVariable& other) const {
    587     NN_FUZZER_CHECK(mVar != other.get());
    588     return RandomVariable(*this, other, Singleton<GreaterThan>::get());
    589 }
    590 RandomVariable RandomVariable::setGreaterEqual(const RandomVariable& other) const {
    591     return mVar == other.get() ? *this
    592                                : RandomVariable(*this, other, Singleton<GreaterEqual>::get());
    593 }
    594 
    595 void DisjointNetwork::add(const RandomVariableNode& var) {
    596     // Find the subnet index of the parents and decide the index for var.
    597     int ind1 = var->parent1 == nullptr ? -1 : mIndexMap[var->parent1];
    598     int ind2 = var->parent2 == nullptr ? -1 : mIndexMap[var->parent2];
    599     int ind = join(ind1, ind2);
    600     // If no parent, put it into a new subnet component.
    601     if (ind == -1) ind = mNextIndex++;
    602     NN_FUZZER_LOG << "Add RandomVariable var" << var->index << " to network #" << ind;
    603     mIndexMap[var] = ind;
    604     mEvalOrderMap[ind].push_back(var);
    605 }
    606 
    607 int DisjointNetwork::join(int ind1, int ind2) {
    608     if (ind1 == -1) return ind2;
    609     if (ind2 == -1) return ind1;
    610     if (ind1 == ind2) return ind1;
    611     NN_FUZZER_LOG << "Join network #" << ind1 << " and #" << ind2;
    612     auto &order1 = mEvalOrderMap[ind1], &order2 = mEvalOrderMap[ind2];
    613     // Append every node in ind2 to the end of ind1
    614     for (const auto& var : order2) {
    615         order1.push_back(var);
    616         mIndexMap[var] = ind1;
    617     }
    618     // Remove ind2 from mEvalOrderMap.
    619     mEvalOrderMap.erase(mEvalOrderMap.find(ind2));
    620     return ind1;
    621 }
    622 
    623 RandomVariableNetwork* RandomVariableNetwork::get() {
    624     static RandomVariableNetwork instance;
    625     return &instance;
    626 }
    627 
    628 void RandomVariableNetwork::initialize(int defaultValue) {
    629     RandomVariableBase::globalIndex = 0;
    630     RandomVariable::defaultValue = defaultValue;
    631     mIndexMap.clear();
    632     mEvalOrderMap.clear();
    633     mDimProd.clear();
    634     mNextIndex = 0;
    635     mGlobalTime = 0;
    636     mTimestamp = -1;
    637 }
    638 
    639 bool RandomVariableNetwork::isSubordinate(const RandomVariableNode& node1,
    640                                           const RandomVariableNode& node2) {
    641     if (node2->type != RandomVariableType::FREE) return false;
    642     int ind1 = mIndexMap[node1];
    643     // node2 is of a different subnet.
    644     if (ind1 != mIndexMap[node2]) return true;
    645     for (const auto& node : mEvalOrderMap[ind1]) {
    646         if (node == node2) return false;
    647         // node2 is of the same subnet but evaluated later than node1.
    648         if (node == node1) return true;
    649     }
    650     NN_FUZZER_CHECK(false) << "Code executed in non-reachable region.";
    651     return false;
    652 }
    653 
    654 struct EvalInfo {
    655     // The RandomVariableNode that this EvalInfo is associated with.
    656     // var->value is the current value during evaluation.
    657     RandomVariableNode var;
    658 
    659     // The RandomVariable value is staged when a valid combination is found.
    660     std::set<int> staging;
    661 
    662     // The staging values are committed after a subnet evaluation.
    663     std::set<int> committed;
    664 
    665     // Keeps track of the latest timestamp that committed is updated.
    666     int timestamp;
    667 
    668     // For evalSubnetWithLocalNetwork.
    669     RandomVariableType originalType;
    670 
    671     // Should only invoke eval on OP RandomVariable.
    672     bool eval() {
    673         NN_FUZZER_CHECK(var->type == RandomVariableType::OP);
    674         var->value = var->op->eval(var->parent1->value,
    675                                    var->parent2 == nullptr ? 0 : var->parent2->value);
    676         if (var->value == kInvalidValue) return false;
    677         return committed.find(var->value) != committed.end();
    678     }
    679     void stage() { staging.insert(var->value); }
    680     void commit() {
    681         // Only update committed and timestamp if the range is *indeed* changed.
    682         if (staging.size() != committed.size()) {
    683             committed = std::move(staging);
    684             timestamp = RandomVariableNetwork::get()->getGlobalTime();
    685         }
    686         staging.clear();
    687     }
    688     void updateRange() {
    689         // Only update range and timestamp if the range is *indeed* changed.
    690         if (committed.size() != var->range.size()) {
    691             var->range = RandomVariableRange(committed);
    692             var->timestamp = timestamp;
    693         }
    694         committed.clear();
    695     }
    696 
    697     EvalInfo(const RandomVariableNode& var)
    698         : var(var),
    699           committed(var->range.getChoices().begin(), var->range.getChoices().end()),
    700           timestamp(var->timestamp) {}
    701 };
    702 using EvalContext = std::unordered_map<RandomVariableNode, EvalInfo>;
    703 
    704 // For logging only.
    705 inline std::string toString(const RandomVariableNode& var, EvalContext* context) {
    706     std::stringstream ss;
    707     ss << "var" << var->index << " = ";
    708     const auto& committed = context->at(var).committed;
    709     switch (var->type) {
    710         case RandomVariableType::FREE:
    711             ss << "FREE ["
    712                << joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end())) << "]";
    713             break;
    714         case RandomVariableType::CONST:
    715             ss << "CONST " << toString(var->value);
    716             break;
    717         case RandomVariableType::OP:
    718             ss << "var" << var->parent1->index << " " << var->op->getName();
    719             if (var->parent2 != nullptr) ss << " var" << var->parent2->index;
    720             ss << ", [" << joinStr(", ", 20, std::vector<int>(committed.begin(), committed.end()))
    721                << "]";
    722             break;
    723         default:
    724             NN_FUZZER_CHECK(false);
    725     }
    726     ss << ", timestamp = " << context->at(var).timestamp;
    727     return ss.str();
    728 }
    729 
    730 // Check if the subnet needs to be re-evaluated by comparing the timestamps.
    731 static inline bool needEvaluate(const EvaluationOrder& evalOrder, int subnetTime,
    732                                 EvalContext* context = nullptr) {
    733     for (const auto& var : evalOrder) {
    734         int timestamp = context == nullptr ? var->timestamp : context->at(var).timestamp;
    735         // If we find a node that has been modified since last evaluation, the subnet needs to be
    736         // re-evaluated.
    737         if (timestamp > subnetTime) return true;
    738     }
    739     return false;
    740 }
    741 
    742 // Helper function to evaluate the subnet recursively.
    743 // Iterate through all combinations of FREE RandomVariables choices.
    744 static void evalSubnetHelper(const EvaluationOrder& evalOrder, EvalContext* context, size_t i = 0) {
    745     if (i == evalOrder.size()) {
    746         // Reach the end of the evaluation, find a valid combination.
    747         for (auto& var : evalOrder) context->at(var).stage();
    748         return;
    749     }
    750     const auto& var = evalOrder[i];
    751     if (var->type == RandomVariableType::FREE) {
    752         // For FREE RandomVariable, iterate through all valid choices.
    753         for (int val : context->at(var).committed) {
    754             var->value = val;
    755             evalSubnetHelper(evalOrder, context, i + 1);
    756         }
    757         return;
    758     } else if (var->type == RandomVariableType::OP) {
    759         // For OP RandomVariable, evaluate from parents and terminate if the result is invalid.
    760         if (!context->at(var).eval()) return;
    761     }
    762     evalSubnetHelper(evalOrder, context, i + 1);
    763 }
    764 
    765 // Check if the subnet has only one single OP RandomVariable.
    766 static inline bool isSingleOpSubnet(const EvaluationOrder& evalOrder) {
    767     int numOp = 0;
    768     for (const auto& var : evalOrder) {
    769         if (var->type == RandomVariableType::OP) numOp++;
    770         if (numOp > 1) return false;
    771     }
    772     return numOp != 0;
    773 }
    774 
    775 // Evaluate with a potentially faster approach provided by IRandomVariableOp.
    776 static inline void evalSubnetSingleOpHelper(const EvaluationOrder& evalOrder,
    777                                             EvalContext* context) {
    778     NN_FUZZER_LOG << "Identified as single op subnet";
    779     const auto& var = evalOrder.back();
    780     NN_FUZZER_CHECK(var->type == RandomVariableType::OP);
    781     var->op->eval(&context->at(var->parent1).committed,
    782                   var->parent2 == nullptr ? nullptr : &context->at(var->parent2).committed,
    783                   &context->at(var).committed, &context->at(var->parent1).staging,
    784                   var->parent2 == nullptr ? nullptr : &context->at(var->parent2).staging,
    785                   &context->at(var).staging);
    786 }
    787 
    788 // Check if the number of combinations of FREE RandomVariables exceeds the limit.
    789 static inline uint64_t getNumCombinations(const EvaluationOrder& evalOrder,
    790                                           EvalContext* context = nullptr) {
    791     constexpr uint64_t kLimit = 1e8;
    792     uint64_t numCombinations = 1;
    793     for (const auto& var : evalOrder) {
    794         if (var->type == RandomVariableType::FREE) {
    795             size_t size =
    796                     context == nullptr ? var->range.size() : context->at(var).committed.size();
    797             numCombinations *= size;
    798             // To prevent overflow.
    799             if (numCombinations > kLimit) return kLimit;
    800         }
    801     }
    802     return numCombinations;
    803 }
    804 
    805 // Evaluate the subnet recursively. Will return fail if the number of combinations of FREE
    806 // RandomVariable exceeds the threshold kMaxNumCombinations.
    807 static bool evalSubnetWithBruteForce(const EvaluationOrder& evalOrder, EvalContext* context) {
    808     constexpr uint64_t kMaxNumCombinations = 1e7;
    809     NN_FUZZER_LOG << "Evaluate with brute force";
    810     if (isSingleOpSubnet(evalOrder)) {
    811         // If the network only have one single OP, dispatch to a faster evaluation.
    812         evalSubnetSingleOpHelper(evalOrder, context);
    813     } else {
    814         if (getNumCombinations(evalOrder, context) > kMaxNumCombinations) {
    815             NN_FUZZER_LOG << "Terminate the evaluation because of large search range";
    816             std::cout << "[          ]   Terminate the evaluation because of large search range"
    817                       << std::endl;
    818             return false;
    819         }
    820         evalSubnetHelper(evalOrder, context);
    821     }
    822     for (auto& var : evalOrder) {
    823         if (context->at(var).staging.empty()) {
    824             NN_FUZZER_LOG << "Evaluation failed at " << toString(var, context);
    825             return false;
    826         }
    827         context->at(var).commit();
    828     }
    829     return true;
    830 }
    831 
    832 struct LocalNetwork {
    833     EvaluationOrder evalOrder;
    834     std::vector<RandomVariableNode> bridgeNodes;
    835     int timestamp = 0;
    836 
    837     bool eval(EvalContext* context) {
    838         NN_FUZZER_LOG << "Evaluate local network with timestamp = " << timestamp;
    839         // Temporarily treat bridge nodes as FREE RandomVariables.
    840         for (const auto& var : bridgeNodes) {
    841             context->at(var).originalType = var->type;
    842             var->type = RandomVariableType::FREE;
    843         }
    844         for (const auto& var : evalOrder) {
    845             context->at(var).staging.clear();
    846             NN_FUZZER_LOG << "  - " << toString(var, context);
    847         }
    848         bool success = evalSubnetWithBruteForce(evalOrder, context);
    849         // Reset the RandomVariable types for bridge nodes.
    850         for (const auto& var : bridgeNodes) var->type = context->at(var).originalType;
    851         return success;
    852     }
    853 };
    854 
    855 // Partition the network further into LocalNetworks based on the result from bridge annotation
    856 // algorithm.
    857 class GraphPartitioner : public DisjointNetwork {
    858    public:
    859     GraphPartitioner() = default;
    860 
    861     std::vector<LocalNetwork> partition(const EvaluationOrder& evalOrder, int timestamp) {
    862         annotateBridge(evalOrder);
    863         for (const auto& var : evalOrder) add(var);
    864         return get(timestamp);
    865     }
    866 
    867    private:
    868     GraphPartitioner(const GraphPartitioner&) = delete;
    869     GraphPartitioner& operator=(const GraphPartitioner&) = delete;
    870 
    871     // Find the parent-child relationship between var1 and var2, and reset the bridge.
    872     void setBridgeFlag(const RandomVariableNode& var1, const RandomVariableNode& var2) {
    873         if (var1->parent1 == var2) {
    874             mBridgeInfo[var1].isParent1Bridge = true;
    875         } else if (var1->parent2 == var2) {
    876             mBridgeInfo[var1].isParent2Bridge = true;
    877         } else {
    878             setBridgeFlag(var2, var1);
    879         }
    880     }
    881 
    882     // Annoate the bridges with DFS -- an edge [u, v] is a bridge if none of u's ancestor is
    883     // reachable from a node in the subtree of b. The complexity is O(V + E).
    884     // discoveryTime: The timestamp a node is visited
    885     // lowTime: The min discovery time of all reachable nodes from the subtree of the node.
    886     void annotateBridgeHelper(const RandomVariableNode& var, int* time) {
    887         mBridgeInfo[var].visited = true;
    888         mBridgeInfo[var].discoveryTime = mBridgeInfo[var].lowTime = (*time)++;
    889 
    890         // The algorithm operates on undirected graph. First find all adjacent nodes.
    891         auto adj = var->children;
    892         if (var->parent1 != nullptr) adj.push_back(var->parent1);
    893         if (var->parent2 != nullptr) adj.push_back(var->parent2);
    894 
    895         for (const auto& child : adj) {
    896             if (mBridgeInfo.find(child) == mBridgeInfo.end()) continue;
    897             if (!mBridgeInfo[child].visited) {
    898                 mBridgeInfo[child].parent = var;
    899                 annotateBridgeHelper(child, time);
    900 
    901                 // If none of nodes in the subtree of child is connected to any ancestors of var,
    902                 // then it is a bridge.
    903                 mBridgeInfo[var].lowTime =
    904                         std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].lowTime);
    905                 if (mBridgeInfo[child].lowTime > mBridgeInfo[var].discoveryTime)
    906                     setBridgeFlag(var, child);
    907             } else if (mBridgeInfo[var].parent != child) {
    908                 mBridgeInfo[var].lowTime =
    909                         std::min(mBridgeInfo[var].lowTime, mBridgeInfo[child].discoveryTime);
    910             }
    911         }
    912     }
    913 
    914     // Find all bridges in the subnet with DFS.
    915     void annotateBridge(const EvaluationOrder& evalOrder) {
    916         for (const auto& var : evalOrder) mBridgeInfo[var];
    917         int time = 0;
    918         for (const auto& var : evalOrder) {
    919             if (!mBridgeInfo[var].visited) annotateBridgeHelper(var, &time);
    920         }
    921     }
    922 
    923     // Re-partition the network by treating bridges as no edge.
    924     void add(const RandomVariableNode& var) {
    925         auto parent1 = var->parent1;
    926         auto parent2 = var->parent2;
    927         if (mBridgeInfo[var].isParent1Bridge) var->parent1 = nullptr;
    928         if (mBridgeInfo[var].isParent2Bridge) var->parent2 = nullptr;
    929         DisjointNetwork::add(var);
    930         var->parent1 = parent1;
    931         var->parent2 = parent2;
    932     }
    933 
    934     // Add bridge nodes to the local network and remove single node subnet.
    935     std::vector<LocalNetwork> get(int timestamp) {
    936         std::vector<LocalNetwork> res;
    937         for (auto& pair : mEvalOrderMap) {
    938             // We do not need to evaluate subnet with only a single node.
    939             if (pair.second.size() == 1 && pair.second[0]->parent1 == nullptr) continue;
    940             res.emplace_back();
    941             for (const auto& var : pair.second) {
    942                 if (mBridgeInfo[var].isParent1Bridge) {
    943                     res.back().evalOrder.push_back(var->parent1);
    944                     res.back().bridgeNodes.push_back(var->parent1);
    945                 }
    946                 if (mBridgeInfo[var].isParent2Bridge) {
    947                     res.back().evalOrder.push_back(var->parent2);
    948                     res.back().bridgeNodes.push_back(var->parent2);
    949                 }
    950                 res.back().evalOrder.push_back(var);
    951             }
    952             res.back().timestamp = timestamp;
    953         }
    954         return res;
    955     }
    956 
    957     // For bridge discovery algorithm.
    958     struct BridgeInfo {
    959         bool isParent1Bridge = false;
    960         bool isParent2Bridge = false;
    961         int discoveryTime = 0;
    962         int lowTime = 0;
    963         bool visited = false;
    964         std::shared_ptr<RandomVariableBase> parent = nullptr;
    965     };
    966     std::unordered_map<RandomVariableNode, BridgeInfo> mBridgeInfo;
    967 };
    968 
    969 // Evaluate subnets repeatedly until converge.
    970 // Class T_Subnet must have member evalOrder, timestamp, and member function eval.
    971 template <class T_Subnet>
    972 inline bool evalSubnetsRepeatedly(std::vector<T_Subnet>* subnets, EvalContext* context) {
    973     bool terminate = false;
    974     while (!terminate) {
    975         terminate = true;
    976         for (auto& subnet : *subnets) {
    977             if (needEvaluate(subnet.evalOrder, subnet.timestamp, context)) {
    978                 if (!subnet.eval(context)) return false;
    979                 subnet.timestamp = RandomVariableNetwork::get()->getGlobalTime();
    980                 terminate = false;
    981             }
    982         }
    983     }
    984     return true;
    985 }
    986 
    987 // Evaluate the subnet by first partitioning it further into LocalNetworks.
    988 static bool evalSubnetWithLocalNetwork(const EvaluationOrder& evalOrder, int timestamp,
    989                                        EvalContext* context) {
    990     NN_FUZZER_LOG << "Evaluate with local network";
    991     auto localNetworks = GraphPartitioner().partition(evalOrder, timestamp);
    992     return evalSubnetsRepeatedly(&localNetworks, context);
    993 }
    994 
    995 struct LeafNetwork {
    996     EvaluationOrder evalOrder;
    997     int timestamp = 0;
    998     LeafNetwork(const RandomVariableNode& var, int timestamp) : timestamp(timestamp) {
    999         std::set<RandomVariableNode> visited;
   1000         constructorHelper(var, &visited);
   1001     }
   1002     // Construct the leaf network by recursively including parent nodes.
   1003     void constructorHelper(const RandomVariableNode& var, std::set<RandomVariableNode>* visited) {
   1004         if (var == nullptr || visited->find(var) != visited->end()) return;
   1005         constructorHelper(var->parent1, visited);
   1006         constructorHelper(var->parent2, visited);
   1007         visited->insert(var);
   1008         evalOrder.push_back(var);
   1009     }
   1010     bool eval(EvalContext* context) {
   1011         return evalSubnetWithLocalNetwork(evalOrder, timestamp, context);
   1012     }
   1013 };
   1014 
   1015 // Evaluate the subnet by leaf network.
   1016 // NOTE: This algorithm will only produce correct result for *most* of the time (> 99%).
   1017 //       The random graph generator is expected to retry if it fails.
   1018 static bool evalSubnetWithLeafNetwork(const EvaluationOrder& evalOrder, int timestamp,
   1019                                       EvalContext* context) {
   1020     NN_FUZZER_LOG << "Evaluate with leaf network";
   1021     // Construct leaf networks.
   1022     std::vector<LeafNetwork> leafNetworks;
   1023     for (const auto& var : evalOrder) {
   1024         if (var->children.empty()) {
   1025             NN_FUZZER_LOG << "Found leaf " << toString(var, context);
   1026             leafNetworks.emplace_back(var, timestamp);
   1027         }
   1028     }
   1029     return evalSubnetsRepeatedly(&leafNetworks, context);
   1030 }
   1031 
   1032 void RandomVariableNetwork::addDimensionProd(const std::vector<RandomVariable>& dims) {
   1033     if (dims.size() <= 1) return;
   1034     EvaluationOrder order;
   1035     for (const auto& dim : dims) order.push_back(dim.get());
   1036     mDimProd.push_back(order);
   1037 }
   1038 
   1039 bool enforceDimProd(const std::vector<EvaluationOrder>& mDimProd,
   1040                     const std::unordered_map<RandomVariableNode, int>& indexMap,
   1041                     EvalContext* context, std::unordered_set<int>* dirtySubnets) {
   1042     for (auto& evalOrder : mDimProd) {
   1043         NN_FUZZER_LOG << "  Dimension product network size = " << evalOrder.size();
   1044         // Initialize EvalInfo of each RandomVariable.
   1045         for (auto& var : evalOrder) {
   1046             if (context->find(var) == context->end()) context->emplace(var, var);
   1047             NN_FUZZER_LOG << "  - " << toString(var, context);
   1048         }
   1049 
   1050         // Enforce the product of the dimension values below kMaxValue:
   1051         // max(dimA) = kMaxValue / (min(dimB) * min(dimC) * ...)
   1052         int prod = 1;
   1053         for (const auto& var : evalOrder) prod *= (*context->at(var).committed.begin());
   1054         for (auto& var : evalOrder) {
   1055             auto& committed = context->at(var).committed;
   1056             int maxValue = kMaxValue / (prod / *committed.begin());
   1057             auto it = committed.upper_bound(maxValue);
   1058             // var has empty range -> no solution.
   1059             if (it == committed.begin()) return false;
   1060             // The range is not modified -> continue.
   1061             if (it == committed.end()) continue;
   1062             // The range is modified -> the subnet of var is dirty, i.e. needs re-evaluation.
   1063             committed.erase(it, committed.end());
   1064             context->at(var).timestamp = RandomVariableNetwork::get()->getGlobalTime();
   1065             dirtySubnets->insert(indexMap.at(var));
   1066         }
   1067     }
   1068     return true;
   1069 }
   1070 
   1071 bool RandomVariableNetwork::evalRange() {
   1072     constexpr uint64_t kMaxNumCombinationsWithBruteForce = 500;
   1073     constexpr uint64_t kMaxNumCombinationsWithLocalNetwork = 1e5;
   1074     NN_FUZZER_LOG << "Evaluate on " << mEvalOrderMap.size() << " sub-networks";
   1075     EvalContext context;
   1076     std::unordered_set<int> dirtySubnets;  // Which subnets needs evaluation.
   1077     for (auto& pair : mEvalOrderMap) {
   1078         const auto& evalOrder = pair.second;
   1079         // Decide whether needs evaluation by timestamp -- if no range has changed after the last
   1080         // evaluation, then the subnet does not need re-evaluation.
   1081         if (evalOrder.size() == 1 || !needEvaluate(evalOrder, mTimestamp)) continue;
   1082         dirtySubnets.insert(pair.first);
   1083     }
   1084     if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false;
   1085 
   1086     // Repeat until the ranges converge.
   1087     while (!dirtySubnets.empty()) {
   1088         for (int ind : dirtySubnets) {
   1089             const auto& evalOrder = mEvalOrderMap[ind];
   1090             NN_FUZZER_LOG << "  Sub-network #" << ind << " size = " << evalOrder.size();
   1091 
   1092             // Initialize EvalInfo of each RandomVariable.
   1093             for (auto& var : evalOrder) {
   1094                 if (context.find(var) == context.end()) context.emplace(var, var);
   1095                 NN_FUZZER_LOG << "  - " << toString(var, &context);
   1096             }
   1097 
   1098             // Dispatch to different algorithm according to search range.
   1099             bool success;
   1100             uint64_t numCombinations = getNumCombinations(evalOrder);
   1101             if (numCombinations <= kMaxNumCombinationsWithBruteForce) {
   1102                 success = evalSubnetWithBruteForce(evalOrder, &context);
   1103             } else if (numCombinations <= kMaxNumCombinationsWithLocalNetwork) {
   1104                 success = evalSubnetWithLocalNetwork(evalOrder, mTimestamp, &context);
   1105             } else {
   1106                 success = evalSubnetWithLeafNetwork(evalOrder, mTimestamp, &context);
   1107             }
   1108             if (!success) return false;
   1109         }
   1110         dirtySubnets.clear();
   1111         if (!enforceDimProd(mDimProd, mIndexMap, &context, &dirtySubnets)) return false;
   1112     }
   1113     // A successful evaluation, update RandomVariables from EvalContext.
   1114     for (auto& pair : context) pair.second.updateRange();
   1115     mTimestamp = getGlobalTime();
   1116     NN_FUZZER_LOG << "Finish range evaluation";
   1117     return true;
   1118 }
   1119 
   1120 static void unsetEqual(const RandomVariableNode& node) {
   1121     if (node == nullptr) return;
   1122     NN_FUZZER_LOG << "Unset equality of var" << node->index;
   1123     RandomVariableNode parent1 = node->parent1, parent2 = node->parent2;
   1124     parent1->children.erase(std::find(parent1->children.begin(), parent1->children.end(), node));
   1125     node->parent1 = nullptr;
   1126     if (parent2 != nullptr) {
   1127         // For Equal.
   1128         parent2->children.erase(
   1129                 std::find(parent2->children.begin(), parent2->children.end(), node));
   1130         node->parent2 = nullptr;
   1131     } else {
   1132         // For UnaryEqual.
   1133         node->type = RandomVariableType::FREE;
   1134         node->op = nullptr;
   1135     }
   1136 }
   1137 
   1138 // A class to revert all the changes made to RandomVariableNetwork since the Reverter object is
   1139 // constructed. Only used when setEqualIfCompatible results in incompatible.
   1140 class RandomVariableNetwork::Reverter {
   1141    public:
   1142     // Take a snapshot of RandomVariableNetwork when Reverter is constructed.
   1143     Reverter() : mSnapshot(*RandomVariableNetwork::get()) {}
   1144     // Add constraint (Equal) nodes to the reverter.
   1145     void addNode(const RandomVariableNode& node) { mEqualNodes.push_back(node); }
   1146     void revert() {
   1147         NN_FUZZER_LOG << "Revert RandomVariableNetwork";
   1148         // Release the constraints.
   1149         for (const auto& node : mEqualNodes) unsetEqual(node);
   1150         // Reset all member variables.
   1151         *RandomVariableNetwork::get() = std::move(mSnapshot);
   1152     }
   1153 
   1154    private:
   1155     Reverter(const Reverter&) = delete;
   1156     Reverter& operator=(const Reverter&) = delete;
   1157     RandomVariableNetwork mSnapshot;
   1158     std::vector<RandomVariableNode> mEqualNodes;
   1159 };
   1160 
   1161 bool RandomVariableNetwork::setEqualIfCompatible(const std::vector<RandomVariable>& lhs,
   1162                                                  const std::vector<RandomVariable>& rhs) {
   1163     NN_FUZZER_LOG << "Check compatibility of {" << joinStr(", ", lhs) << "} and {"
   1164                   << joinStr(", ", rhs) << "}";
   1165     if (lhs.size() != rhs.size()) return false;
   1166     Reverter reverter;
   1167     bool result = true;
   1168     for (size_t i = 0; i < lhs.size(); i++) {
   1169         auto node = lhs[i].setEqual(rhs[i]).get();
   1170         reverter.addNode(node);
   1171         // Early terminate if there is no common choice between two ranges.
   1172         if (node != nullptr && node->range.empty()) result = false;
   1173     }
   1174     result = result && evalRange();
   1175     if (!result) reverter.revert();
   1176     NN_FUZZER_LOG << "setEqualIfCompatible: " << (result ? "[COMPATIBLE]" : "[INCOMPATIBLE]");
   1177     return result;
   1178 }
   1179 
   1180 bool RandomVariableNetwork::freeze() {
   1181     NN_FUZZER_LOG << "Freeze the random network";
   1182     if (!evalRange()) return false;
   1183     for (const auto& pair : mEvalOrderMap) {
   1184         // Find all FREE RandomVariables in the subnet.
   1185         std::vector<RandomVariableNode> nodes;
   1186         for (const auto& var : pair.second) {
   1187             if (var->type == RandomVariableType::FREE) nodes.push_back(var);
   1188         }
   1189         // Randomly shuffle the order, this is for a more uniform randomness.
   1190         randomShuffle(&nodes);
   1191         // An inefficient algorithm that does freeze -> re-evaluate for every FREE RandomVariable.
   1192         // TODO: Might be able to optimize this.
   1193         for (const auto& var : nodes) {
   1194             size_t size = var->range.size();
   1195             NN_FUZZER_LOG << "Freeze " << toString(var);
   1196             var->freeze();
   1197             NN_FUZZER_LOG << "  " << toString(var);
   1198             // There is no need to re-evaluate if the FREE RandomVariable have only one choice.
   1199             if (size > 1) {
   1200                 var->updateTimestamp();
   1201                 if (!evalRange()) {
   1202                     NN_FUZZER_LOG << "Freeze failed at " << toString(var);
   1203                     return false;
   1204                 }
   1205             }
   1206         }
   1207     }
   1208     NN_FUZZER_LOG << "Finish freezing the random network";
   1209     return true;
   1210 }
   1211 
   1212 }  // namespace fuzzing_test
   1213 }  // namespace nn
   1214 }  // namespace android
   1215