Home | History | Annotate | Download | only in test
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "ExecutionPlan.h"
     18 #include "HalInterfaces.h"
     19 #include "Manager.h"
     20 #include "ModelBuilder.h"
     21 #include "NeuralNetworks.h"
     22 #include "NeuralNetworksWrapper.h"
     23 #include "Utils.h"
     24 
     25 #include <gtest/gtest.h>
     26 
     27 #include <map>
     28 #include <queue>
     29 
     30 // Uncomment the following line to generate some debugging output that
     31 // may be useful when analyzing failures:
     32 //
     33 // #define VERBOSE VERBOSE
     34 
     35 // These tests do whitebox testing of the graph partitioning
     36 // algorithm.  It is "whitebox" in the sense that we're not evaluating
     37 // whether a particular partitioning is legal, or "good enough"
     38 // according to some metric, but whether it exactly matches the
     39 // expected behavior of the current partitioning algorithm.
     40 //
     41 // A key part of the current partitioning algorithm is to determine
     42 // which device among the available devices should be the one to
     43 // execute a particular operation from the graph.  This determination
     44 // is made "locally" -- i.e., it does not depend on the graph
     45 // topology, only on the properties of the operation in question.
     46 // IDevice::getSupportedOperations() indicates which operations in a
     47 // graph can be executed on a device, and IDevice::getCapabilities()
     48 // indicates how "good" that device is for executing particular kinds
     49 // of operations.  For each operation, the partitioning algorithm
     50 // picks the "best" device that is capable of executing that
     51 // operation; if no device can do so, then the algorithm picks the
     52 // cpu.
     53 //
     54 // As part of this testing approach, we want to make it easy to
     55 // specify which operations in a test graph can be executed on which
     56 // devices.  We accomplish this with an abstraction: There are eight
     57 // different kinds of operations (each of which has two inputs and one
     58 // output), and when we instantiate a device for testing purposes, we
     59 // specify what subset of those eight kinds of operations the device
     60 // is able to execute.
     61 //
     62 // The eight kinds of operations are represented in the graph as ADD
     63 // or MUL with a particular activation function -- two opcodes times
     64 // four activation functions means eight available operation kinds.
     65 // This is a low-level representation detail -- when we specify the
     66 // behavior of the device or build a graph, we do so in terms of
     67 // operation encodings 0..7.
     68 //
     69 // In order to determine whether or not a partitioning matches the
     70 // expected partitioning, we check the number of partitions, check
     71 // which device each partition targets, and compare each partition's
     72 // subgraph, model inputs, model outputs, submodel inputs, and
     73 // submodel outputs against what is expected.  In order to perform
     74 // that comparison, we build a model to compare against a partition's
     75 // submodel and run a graph comparison algorithm on it.  The graph
     76 // comparison and the inputs and outputs comparisons are syntactic
     77 // rather than semantic comparisons -- they don't allow for
     78 // reorderings of inputs and outputs.  Because of this, we need to
     79 // know exactly how the partitioning algorithm orders inputs and
     80 // outputs in order to construct the models and operand lists to
     81 // compare against.  Here are some relevant behaviors of the
     82 // partitioning algorithm:
     83 //
     84 // - It builds a subgraph by walking operations in forward topological
     85 //   order, and adding each operation's input operands and output
     86 //   operands in index order (input followed by output) when that
     87 //   operation is added.  (It does not add an input that has already
     88 //   been added.)
     89 // - It finds model inputs, model outputs, and submodel inputs in
     90 //   the order the corresponding operands were added to the subgraph
     91 //   (see ExecutionStep methods getModelInputs(), getModelOutputs(),
     92 //   getSubModelInputs()).
     93 // - It finds submodel outputs in numerical order of corresponding
     94 //   operand number in the original model (see ExecutionStep method
     95 //   getSubModelOutputs()).
     96 // - When it calls identifyInputsAndOutputs() on the submodel, it
     97 //   passes inputs from getModelInputs() in order followed by submodel
     98 //   inputs from getSubModelInputs() in order; and it passes outputs
     99 //   from getModelOutputs() in order followed by submodel outputs from
    100 //   getSubModelOutputs() in order.
    101 //
    102 // TODO: Maybe the logic for comparing a partition to an expected
    103 //       model should be changed to tolerate reorderings of inputs and
    104 //       outputs, so that when we build models and lists to compare
    105 //       against, we don't need to worry about input and output
    106 //       orderings.  But is there a way to do this that still lets us
    107 //       verify that we have the correct relationships between
    108 //       an (original) model's inputs and outputs and each submodel's
    109 //       inputs and outputs, as well as the correct relationship
    110 //       between submodel inputs and outputs across partitions?
    111 
    112 namespace {
    113 
    114 using Device = ::android::nn::Device;
    115 using ExecutePreference = ::android::nn::wrapper::ExecutePreference;
    116 using ExecutionPlan = ::android::nn::ExecutionPlan;
    117 using ExecutionStep = ::android::nn::ExecutionStep;
    118 using HidlModel = ::android::hardware::neuralnetworks::V1_0::Model;
    119 using ModelBuilder = ::android::nn::ModelBuilder;
    120 using WrapperModel = ::android::nn::wrapper::Model;
    121 using WrapperOperandType = ::android::nn::wrapper::OperandType;
    122 using WrapperType = ::android::nn::wrapper::Type;
    123 
    124 template <typename T> using sp = ::android::sp<T>;
    125 
    126 // We employ an operation numbering scheme:
    127 // - 0..FuseCode-1 = ADD with the appropriate activation function
    128 // - FuseCode..2*FuseCode-1 = MUL with the appropriate activation function
    129 const uint32_t kNumFuseCodes = 4;
    130 const uint32_t kBadOperation = ~0;
    131 
    132 // Look up the operation with the specified index in a graph, and
    133 // return the operation encoding -- 0..7; or, if for some reason this
    134 // is not one of the encoded operations, then return kBadOperation.
    135 uint32_t lookupOperation(std::function<const Operation&(uint32_t)> getOperation,
    136                          std::function<const Operand&(uint32_t)> getOperand,
    137                          std::function<const uint8_t*(uint32_t)> getValue,
    138                          uint32_t operationIndex) {
    139     const Operation& operation = getOperation(operationIndex);
    140     switch (operation.type) {
    141         case OperationType::ADD:
    142         case OperationType::MUL: {
    143             // input2 is the fused activation function
    144             const Operand& input2 = getOperand(operation.inputs[2]);
    145             if ((input2.type == OperandType::INT32) &&
    146                 (input2.lifetime == OperandLifeTime::CONSTANT_COPY)) {
    147                 int32_t value;
    148                 memcpy(&value,
    149                        getValue(input2.location.offset),
    150                        input2.location.length);
    151                 if (operation.type == OperationType::MUL) {
    152                     value += kNumFuseCodes;
    153                 }
    154                 return value;
    155             }
    156             break;
    157         }
    158         default:
    159             break;
    160     }
    161     return kBadOperation;
    162 }
    163 
    164 uint32_t lookupOperation(const HidlModel& model, uint32_t operationIndex) {
    165     return lookupOperation(
    166         [&model](uint32_t index) -> const Operation& {
    167             return model.operations[index];
    168         },
    169         [&model](uint32_t index) -> const Operand& {
    170             return model.operands[index];
    171         },
    172         [&model](uint32_t offset) {return &model.operandValues[offset];},
    173         operationIndex);
    174 }
    175 
    176 #ifdef VERBOSE
    177 // This is a debugging utility function
    178 void dump(const char* name, const ModelBuilder* model) {
    179     HidlModel hidlModel;
    180     model->setHidlModel(&hidlModel);
    181     std::cout << name << ": " << toString(hidlModel) << std::endl;
    182     std::cout << "inputs: " << toString(hidlModel.inputIndexes) << std::endl;
    183     std::cout << "outputs: " << toString(hidlModel.outputIndexes) << std::endl;
    184     for (size_t i = 0, e = hidlModel.operations.size(); i < e; i++) {
    185         std::cout << "operation[" << i << "]: " << toString(hidlModel.operations[i]) << std::endl;
    186     }
    187 }
    188 #endif
    189 
    190 // This is an IDevice for testing purposes.  It only has two
    191 // interesting properties, both of which are specified as constructor
    192 // arguments: device capabilities, and which subset of operation kinds
    193 // (0..7) does the device support.  The subset is represented with a
    194 // bitmask, in which operation kind K corresponds to the bit (1 << K).
    195 class PartitioningIDevice : public IDevice {
    196 public:
    197     PartitioningIDevice(Capabilities capabilities, uint32_t operationMask) :
    198             mCapabilities(capabilities), mOperationMask(operationMask) {}
    199     ~PartitioningIDevice() override {}
    200 
    201     Return<ErrorStatus> prepareModel(const HidlModel&,
    202                                      const sp<IPreparedModelCallback>& cb) override {
    203         cb->notify(ErrorStatus::NONE, nullptr);
    204         return ErrorStatus::NONE;
    205     }
    206     Return<DeviceStatus> getStatus() override {
    207         return DeviceStatus::AVAILABLE;
    208     }
    209 
    210     Return<void> getCapabilities(getCapabilities_cb cb) override {
    211         cb(ErrorStatus::NONE, mCapabilities);
    212         return Void();
    213     }
    214     Return<void> getSupportedOperations(const HidlModel& model,
    215                                         getSupportedOperations_cb cb) override {
    216         if (!android::nn::validateModel(model)) {
    217             cb(ErrorStatus::INVALID_ARGUMENT, std::vector<bool>());
    218             return Void();
    219         }
    220 
    221         const size_t count = model.operations.size();
    222         std::vector<bool> supported(count);
    223         for (size_t i = 0; i < count; i++) {
    224             supported[i] = false;
    225             uint32_t operation = lookupOperation(model, i);
    226             if ((operation != kBadOperation) && (mOperationMask & (1 << operation))) {
    227                 supported[i] = true;
    228             }
    229         }
    230         cb(ErrorStatus::NONE, supported);
    231         return Void();
    232     }
    233 private:
    234     Capabilities mCapabilities;
    235     uint32_t mOperationMask;
    236 };
    237 
    238 // This class adds some simple abstractions and utilities on top of
    239 // ::android::nn::wrapper::Model.  For example, it provides methods
    240 // that work in terms of operation kind (0..7); and because we care
    241 // about graph topology rather than details of operand types and
    242 // values, it greatly simplifies the process of creating operands.
    243 class PartitioningModel : public WrapperModel {
    244 public:
    245     // Create a tensor operand of the specified type, and return the
    246     // corresponding operand index.
    247     uint32_t addFloatOperand() {
    248         static const WrapperOperandType type(WrapperType::TENSOR_FLOAT32, { 1 });
    249         return addOperand(&type);
    250     }
    251     uint32_t addQuantOperand() {
    252         static const WrapperOperandType type(WrapperType::TENSOR_QUANT8_ASYMM, { 1 });
    253         return addOperand(&type);
    254     }
    255 
    256     // Create an operation with two inputs and one output, specifying
    257     // the operation kind (0..7) and the input operand indexes.
    258     // Returns the output operand index.
    259     uint32_t addOperation2To1(uint32_t operation, const uint32_t input0, const uint32_t input1) {
    260         ANeuralNetworksOperationType type =
    261                 (operation < kNumFuseCodes ? ANEURALNETWORKS_ADD : ANEURALNETWORKS_MUL);
    262         int32_t fuseCode = (operation < kNumFuseCodes ? operation : operation - kNumFuseCodes);
    263         uint32_t input2 = addIntOperand(fuseCode);
    264         uint32_t output = addOperandOfSameType(input0);
    265         addOperation(type, { input0, input1, input2 }, { output });
    266         return output;
    267     }
    268 
    269     // Run the partitioning algorithm to create an ExecutionPlan.
    270     int partitionTheWork(const std::vector<std::shared_ptr<Device>>& devices,
    271                          ExecutePreference preference, ExecutionPlan* plan) {
    272         return reinterpret_cast<ModelBuilder*>(getHandle())->partitionTheWork(
    273             devices, static_cast<uint32_t>(preference), plan);
    274     }
    275 
    276 #ifdef VERBOSE
    277     // This is a debugging utility function.
    278     void dump(const char* name) const {
    279         const ModelBuilder* mb = reinterpret_cast<const ModelBuilder*>(getHandle());
    280         ::dump(name, mb);
    281     }
    282 #endif
    283 
    284 private:
    285 
    286     // Create a scalar integer operand of the specified value, and
    287     // return the corresponding operand index.
    288     uint32_t addIntOperand(int32_t value) {
    289         static const WrapperOperandType type(WrapperType::INT32, { });
    290         uint32_t operand = addOperand(&type);
    291         setOperandValue(operand, &value, sizeof(value));
    292         return operand;
    293     }
    294 
    295     // Create an operand of the same type as the specified operand,
    296     // and return the operand index of the new operand.
    297     uint32_t addOperandOfSameType(uint32_t operand) {
    298         const Operand& operandStruct =
    299                 reinterpret_cast<const ModelBuilder*>(getHandle())->getOperand(operand);
    300         WrapperOperandType type(static_cast<WrapperType>(operandStruct.type), { 1 });
    301         return addOperand(&type);
    302     }
    303 };
    304 
    305 #ifdef VERBOSE
    306 #define RETURN_TRUE()                                                          \
    307     {                                                                          \
    308         std::cerr << "returning true from " << __LINE__ << std::endl;          \
    309         return true;                                                           \
    310     }
    311 #else
    312 #define RETURN_TRUE()                                                          \
    313     {                                                                          \
    314         return true;                                                           \
    315     }
    316 #endif
    317 #ifdef VERBOSE
    318 #define RETURN_FALSE(MESSAGE)                                                  \
    319     {                                                                          \
    320         std::cerr << "returning false from " << __LINE__ MESSAGE << std::endl; \
    321         return false;                                                          \
    322     }
    323 #else
    324 #define RETURN_FALSE(MESSAGE)                                                  \
    325     {                                                                          \
    326         return false;                                                          \
    327     }
    328 #endif
    329 
    330 class PartitioningTest : public ::testing::Test {
    331 protected:
    332     // workaround for private types in ExecutionStep
    333     using RemapVectorType = decltype(static_cast<ExecutionStep*>(nullptr)->getModelInputs());
    334     using SubModelOutputSetType = decltype(static_cast<ExecutionStep*>(nullptr)->getSubModelOutputs());
    335 
    336     virtual void SetUp() {
    337     }
    338 
    339     // From a vector of triples (tuples), each of the form (name,
    340     // capabilities, bitmask of supported operation kinds), create a
    341     // vector of Devices.
    342     static std::vector<std::shared_ptr<Device>>
    343     makeDevices(std::vector<std::tuple<std::string, Capabilities, uint32_t>> specifications) {
    344         std::vector<std::shared_ptr<Device>> devices;
    345         for (const auto& specification : specifications) {
    346             devices.push_back(std::make_shared<Device>(
    347                 std::get<0>(specification),
    348                 new PartitioningIDevice(std::get<1>(specification), std::get<2>(specification))));
    349             devices.back()->initialize();
    350         }
    351         return devices;
    352     }
    353 
    354     /*-- Graph comparision ----------------------------------------------------------------*/
    355 
    356     // An operand with certain values for its lifetime does not have a
    357     // defining operation in the graph.  For the purposes of the graph
    358     // comparison algorithm, we encode the "defining operation" index of
    359     // such an operand as follows:
    360     // - NO_VALUE       kPseudoDefiningOperationNoValue
    361     // - MODEL_INPUT    kPseudoDefiningOperationModelInput0 + (position in list of inputs)
    362     // - CONSTANT_COPY  kPseudoDefiningOperationConstantCopy0 + (constant value)
    363     //                    Note: For the graphs we build in this test, we
    364     //                          only expect to see 4-byte constants within
    365     //                          a very restricted range, so we only make
    366     //                          room for such constants in our encoding
    367     //                          space.
    368     // We do not expect to see CONSTANT_REFERENCE, and so we do not handle
    369     // it.
    370     //
    371     // The encoding is intended to be relatively human readable; it is not
    372     // designed to represent some optimal balance of ranges for the items
    373     // within its scope (actual operations, inputs, constants).
    374 
    375     enum PseudoDefiningOperationEncodings : uint32_t {
    376         kPseudoDefiningOperationModelInput0   = 0x80000000U,
    377         kPseudoDefiningOperationConstantCopy0 = 0x90000000U,
    378         kPseudoDefiningOperationNoValue       = 0xeeeeeeeeU,
    379 
    380         // lowest value for special encoding
    381         kPseudoDefiningOperationBase          = 0x80000000U,
    382 
    383         // range of encoded input or constant
    384         kPseudoDefiningOperationRange         = 0x10000000U,
    385     };
    386 
    387     // Build a map from operand to defining operation.
    388     // TODO: Replace map with vector?
    389     void buildDefinitionMap(const ModelBuilder* model,
    390                             std::map<uint32_t, uint32_t>* defMap) {
    391         // actual definitions
    392         ASSERT_LT(model->operationCount(), kPseudoDefiningOperationBase);
    393         for (uint32_t i = 0, e = model->operationCount(); i < e; i++) {
    394             const Operation& operation = model->getOperation(i);
    395             for (uint32_t output : operation.outputs) {
    396                 (*defMap)[output] = i;
    397             }
    398         }
    399         // inputs
    400         ASSERT_LT(model->inputCount(), kPseudoDefiningOperationRange);
    401         for (uint32_t i = 0, e = model->inputCount(); i < e; i++) {
    402             (*defMap)[model->getInputOperandIndex(i)] = kPseudoDefiningOperationModelInput0 + i;
    403         }
    404         // look for NO_VALUE and CONSTANT_COPY
    405         for (uint32_t i = 0, e = model->operandCount(); i < e; i++) {
    406             const Operand& operand = model->getOperand(i);
    407             switch (operand.lifetime) {
    408                 case OperandLifeTime::NO_VALUE:
    409                     (*defMap)[i] = kPseudoDefiningOperationNoValue;
    410                     break;
    411                 case OperandLifeTime::CONSTANT_COPY: {
    412                     ASSERT_EQ(operand.location.length, sizeof(uint32_t));
    413                     uint32_t value;
    414                     memcpy(&value, model->getPointerToOperandValue(operand.location.offset), sizeof(uint32_t));
    415                     ASSERT_LT(value, kPseudoDefiningOperationNoValue);
    416                     (*defMap)[i] = kPseudoDefiningOperationConstantCopy0 + value;
    417                     break;
    418                 }
    419                 case OperandLifeTime::TEMPORARY_VARIABLE:
    420                 case OperandLifeTime::MODEL_INPUT:
    421                 case OperandLifeTime::MODEL_OUTPUT:
    422                     // already handled
    423                     break;
    424                 default:
    425                     FAIL();
    426                     break;
    427             }
    428         }
    429         // sanity check
    430         ASSERT_EQ(model->operandCount(), defMap->size());
    431     }
    432 
    433 #ifdef VERBOSE
    434     void dump(const char* name, const std::map<uint32_t, uint32_t>* aMap) {
    435         auto writeNum = [](uint32_t num) {
    436             if (num >= kPseudoDefiningOperationBase) {
    437                 std::cout << "0x" << std::hex << num << std::dec;
    438             } else {
    439                 std::cout << num;
    440             }
    441         };
    442 
    443         std::cout << name << ": { ";
    444         bool gotOne = false;
    445         for (const auto& entry : *aMap) {
    446             if (gotOne) {
    447                 std::cout << ", ";
    448             } else {
    449                 gotOne = true;
    450             }
    451             std::cout << "(";
    452             writeNum(entry.first);
    453             std::cout << ", ";
    454             writeNum(entry.second);
    455             std::cout << ")";
    456         }
    457         std::cout << " }" << std::endl;
    458     }
    459 #endif
    460 
    461     bool compare(const Operand& operandA, const Operand& operandB) {
    462         if (operandA.type != operandB.type ||
    463             operandA.dimensions != operandB.dimensions ||
    464             operandA.numberOfConsumers != operandB.numberOfConsumers ||
    465             operandA.scale != operandB.scale ||
    466             operandA.zeroPoint != operandB.zeroPoint) {
    467             return false;
    468         }
    469         return true;
    470     }
    471 
    472     // Compare two graphs.  We ignore operand and operation indexes (i.e.,
    473     // two nodes can be the same even if they are numbered differently)
    474     // but we also ignore semantics (e.g., even if an operation kind is
    475     // such that the operand is commutative, we still pay attention to the
    476     // order of its input operands).
    477     //
    478     // The comparison algorithm works by walking modelA from outputs
    479     // towards inputs, along the edge from each operand to its
    480     // defining operation, and then along the edges to the operation's
    481     // input operands.  At each step along the way, we try to match up
    482     // operands and operations from modelA with equivalent operands
    483     // and operations from modelB.
    484     //
    485     // We start by assuming that modelA's outputs and modelB's outputs
    486     // match positionally (e.g., modelA's first output operand is
    487     // equivalent to modelB's first output operand).  Once we've
    488     // discovered two equivalent operands (such as those outputs), we
    489     // place them in a work queue.  We repeatedly pull operands off
    490     // the queue and compare their defining operations and those
    491     // operations' input operands, to discover more pairs of
    492     // equivalent operands.  If we ever find operations that do not
    493     // match (e.g., because operation kind differs), or operands that
    494     // do not match (e.g., because operand type differs); or if we
    495     // ever find a conflict (we've already decided that operand A's
    496     // equivalent operand is B0, but it looks like we need its
    497     // equivalent operand to be B1); then the graphs compare unequal.
    498     // Otherwise, we'll eventually exhaust the work queue, and
    499     // conclude that the graphs compare equal.
    500     bool compare(const ModelBuilder* modelA, const ModelBuilder* modelB) {
    501 #ifdef VERBOSE
    502         ::dump("compare(A)", modelA);
    503         ::dump("compare(B)", modelB);
    504 #endif
    505 
    506         if (modelA->operandCount()   != modelB->operandCount()   ||
    507             modelA->operationCount() != modelB->operationCount() ||
    508             modelA->inputCount()     != modelB->inputCount()     ||
    509             modelA->outputCount()    != modelB->outputCount()) {
    510             RETURN_FALSE();
    511         }
    512 
    513         // Maps from operand index to index of defining operation.
    514         std::map<uint32_t, uint32_t> defsA, defsB;
    515         buildDefinitionMap(modelA, &defsA);
    516         buildDefinitionMap(modelB, &defsB);
    517         if (HasFatalFailure()) return false;
    518 
    519         // Maps from operand index in modelA to equivalent operand index
    520         // in modelB; and from operation index in modelA to equivalent
    521         // operation index in modelB.
    522         std::map<uint32_t, uint32_t> equivalentOperandsAToB;
    523         std::map<uint32_t, uint32_t> equivalentOperationsAToB;
    524 
    525         // Queue of operand indexes from modelA, each of whose defining
    526         // operations are to be checked for equivalence with modelB.
    527         std::queue<uint32_t> workQueueOperandsA;
    528 
    529         // Seed operand equivalence map and work queue from model outputs.
    530         for (uint32_t i = 0, e = modelA->outputCount(); i < e; i++) {
    531             uint32_t outputA = modelA->getOutputOperandIndex(i);
    532             uint32_t outputB = modelB->getOutputOperandIndex(i);
    533             if (!compare(modelA->getOperand(outputA), modelB->getOperand(outputB))) {
    534                 RETURN_FALSE();
    535             }
    536             equivalentOperandsAToB[outputA] = outputB;
    537             workQueueOperandsA.push(outputA);
    538         }
    539 
    540 #ifdef VERBOSE
    541         dump("defsA", &defsA);
    542         dump("defsB", &defsB);
    543 #endif
    544 
    545         // Process the queue.
    546         uint32_t pseudoDefinitionCount = 0;
    547         while (!workQueueOperandsA.empty()) {
    548 #ifdef VERBOSE
    549             dump("equivalentOperandsAToB", &equivalentOperandsAToB);
    550             dump("equivalentOperationsAToB", &equivalentOperationsAToB);
    551 #endif
    552             uint32_t operandIndexA = workQueueOperandsA.front();
    553 #ifdef VERBOSE
    554             std::cout << "operandIndexA: " << operandIndexA << std::endl;
    555 #endif
    556             workQueueOperandsA.pop();
    557             uint32_t operandIndexB = equivalentOperandsAToB.at(operandIndexA);
    558 
    559             uint32_t operationIndexA = defsA.at(operandIndexA);
    560             uint32_t operationIndexB = defsB.at(operandIndexB);
    561             auto it = equivalentOperationsAToB.find(operationIndexA);
    562             if (it != equivalentOperationsAToB.end()) {
    563                 if (it->second != operationIndexB) {
    564                     RETURN_FALSE();
    565                 }
    566                 continue;
    567             }
    568 
    569             // We haven't identified an equivalent operation for
    570             // operationIndexA.
    571 
    572             if ((operationIndexA >= kPseudoDefiningOperationBase) !=
    573                 (operationIndexB >= kPseudoDefiningOperationBase)) {
    574                 RETURN_FALSE();
    575             }
    576             // Either both operands have pseudo-definitions, or neither
    577             // does.
    578             if (operationIndexA >= kPseudoDefiningOperationBase) {
    579                 // Both operands have pseudo-definitions.
    580                 if (operationIndexA != operationIndexB) {
    581                     RETURN_FALSE();
    582                 }
    583                 equivalentOperationsAToB[operationIndexA] = operationIndexB;
    584                 ++pseudoDefinitionCount;
    585                 continue;
    586             }
    587 
    588             // If we get here, neither operation A nor operation B is a
    589             // pseudo-definition.
    590 
    591             const Operation& operationA = modelA->getOperation(operationIndexA);
    592             const Operation& operationB = modelB->getOperation(operationIndexB);
    593             if (operationA.type != operationB.type ||
    594                 operationA.inputs.size() != operationB.inputs.size() ||
    595                 operationA.outputs.size() != operationB.outputs.size()) {
    596                 RETURN_FALSE();
    597             }
    598             equivalentOperationsAToB[operationIndexA] = operationIndexB;
    599             for (uint32_t i = 0, e = operationA.inputs.size(); i < e; i++) {
    600                 uint32_t inputA = operationA.inputs[i];
    601                 uint32_t inputB = operationB.inputs[i];
    602                 auto it = equivalentOperandsAToB.find(inputA);
    603                 if (it != equivalentOperandsAToB.end()) {
    604                     if (it->second != inputB) {
    605                         RETURN_FALSE();
    606                     }
    607                     continue;
    608                 }
    609                 // We haven't identified an equivalent operand for inputA.
    610                 if (!compare(modelA->getOperand(inputA), modelB->getOperand(inputB))) {
    611                     RETURN_FALSE();
    612                 }
    613                 equivalentOperandsAToB[inputA] = inputB;
    614                 workQueueOperandsA.push(inputA);
    615             }
    616         }
    617 
    618         // Sanity check
    619         if (modelA->operandCount() != defsA.size() ||
    620             modelA->operandCount() != defsB.size() ||
    621             modelA->operandCount() != equivalentOperandsAToB.size() ||
    622             modelA->operationCount() + pseudoDefinitionCount != equivalentOperationsAToB.size()) {
    623             RETURN_FALSE();
    624         }
    625 
    626         RETURN_TRUE();
    627     }
    628 
    629     /*-------------------------------------------------------------------------------------*/
    630 
    631     bool compare(std::shared_ptr<const ExecutionStep> step,
    632                  const WrapperModel* model, std::shared_ptr<Device> device) {
    633         return (step->getDevice() == device) &&
    634                 compare(step->getSubModel().get(),
    635                         reinterpret_cast<const ModelBuilder*>(model->getHandle()));
    636     }
    637 };
    638 
    639 TEST_F(PartitioningTest, SimpleModel) {
    640     PartitioningModel model;
    641     uint32_t opnd0 = model.addFloatOperand();
    642     uint32_t opnd1 = model.addFloatOperand();
    643     uint32_t opnd2 = model.addOperation2To1(0, opnd0, opnd1);
    644     uint32_t opnd3 = model.addFloatOperand();
    645     uint32_t opnd4 = model.addOperation2To1(1, opnd2, opnd3);
    646     model.identifyInputsAndOutputs({ opnd0, opnd1, opnd3 }, { opnd4 });
    647     model.finish();
    648     ASSERT_TRUE(model.isValid());
    649 
    650     // Simple partition (two devices are each capable of everything, one is the best).
    651     const auto devicesA = makeDevices(
    652         {
    653             {"bad", { .float32Performance = { .execTime = 1.5, .powerUsage = 1.5 },
    654                             .quantized8Performance = { .execTime = 1.5, .powerUsage = 1.5 } }, ~0},
    655             {"good", { .float32Performance = { .execTime = 0.5, .powerUsage = 0.5 },
    656                             .quantized8Performance = { .execTime = 0.5, .powerUsage = 0.5 } }, ~0}
    657         });
    658     ExecutionPlan planA;
    659     ASSERT_EQ(model.partitionTheWork(devicesA, ExecutePreference::PREFER_LOW_POWER, &planA),
    660               ANEURALNETWORKS_NO_ERROR);
    661     ASSERT_EQ(planA.forTest_getKind(), ExecutionPlan::Kind::SIMPLE);
    662     ASSERT_EQ(planA.forTest_simpleGetDevice()->getName(), "good");
    663 
    664     // Compound partition (two devices, each is capable of one of the
    665     // two operations).  We could do more extensive checking here --
    666     // for example, verify that each step within the plan has the
    667     // correct (model and submodel)x(inputs and outputs).
    668     const auto devicesB = makeDevices(
    669         {
    670             {"0", { .float32Performance = { .execTime = 1.5, .powerUsage = 1.5 },
    671                             .quantized8Performance = { .execTime = 1.5, .powerUsage = 1.5 } }, 1<<0},
    672             {"1", { .float32Performance = { .execTime = 0.5, .powerUsage = 0.5 },
    673                             .quantized8Performance = { .execTime = 0.5, .powerUsage = 0.5 } }, 1<<1}
    674         });
    675     ExecutionPlan planB;
    676     ASSERT_EQ(model.partitionTheWork(devicesB, ExecutePreference::PREFER_LOW_POWER, &planB),
    677               ANEURALNETWORKS_NO_ERROR);
    678     ASSERT_EQ(planB.forTest_getKind(), ExecutionPlan::Kind::COMPOUND);
    679     const auto& stepsB = planB.forTest_compoundGetSteps();
    680     ASSERT_EQ(stepsB.size(), size_t(2));
    681     {
    682         // Build a model to compare against the submodel from stepsB[0].
    683         PartitioningModel modelB0;
    684         uint32_t b0Opnd0 = modelB0.addFloatOperand();
    685         uint32_t b0Opnd1 = modelB0.addFloatOperand();
    686         uint32_t b0Opnd2 = modelB0.addOperation2To1(0, b0Opnd0, b0Opnd1);
    687         modelB0.identifyInputsAndOutputs({ b0Opnd0, b0Opnd1 }, { b0Opnd2 });
    688         modelB0.finish();
    689         ASSERT_TRUE(modelB0.isValid());
    690         ASSERT_NO_FATAL_FAILURE(ASSERT_TRUE(compare(stepsB[0], &modelB0, devicesB[0])));
    691         ASSERT_EQ(stepsB[0]->getModelInputs(),
    692                   (RemapVectorType{ { opnd0, b0Opnd0 }, { opnd1, b0Opnd1 } }));
    693         ASSERT_EQ(stepsB[0]->getModelOutputs(),
    694                   (RemapVectorType{}));
    695         ASSERT_EQ(stepsB[0]->getSubModelInputs(),
    696                   (RemapVectorType{}));
    697         ASSERT_EQ(stepsB[0]->getSubModelOutputs(),
    698                   (SubModelOutputSetType{ { opnd2, b0Opnd2 } }));
    699     }
    700     {
    701         // Build a model to compare against the submodel from stepsB[1].
    702         PartitioningModel modelB1;
    703         uint32_t b1Opnd2 = modelB1.addFloatOperand();
    704         uint32_t b1Opnd3 = modelB1.addFloatOperand();
    705         uint32_t b1Opnd4 = modelB1.addOperation2To1(1, b1Opnd2, b1Opnd3);
    706         // Note: In the partitioning algorithm, submodel inputs follow
    707         // model inputs.  In the original model "model", opnd2 is not
    708         // an input; so in the submodel "modelB1", the corresponding
    709         // input b1Opnd2 is a submodel input, and must follow the
    710         // model input b1Opnd3.
    711         modelB1.identifyInputsAndOutputs({ b1Opnd3, b1Opnd2 }, { b1Opnd4 });
    712         modelB1.finish();
    713         ASSERT_TRUE(modelB1.isValid());
    714         ASSERT_NO_FATAL_FAILURE(ASSERT_TRUE(compare(stepsB[1], &modelB1, devicesB[1])));
    715         ASSERT_EQ(stepsB[1]->getModelInputs(),
    716                   (RemapVectorType{ { opnd3, b1Opnd3 } }));
    717         ASSERT_EQ(stepsB[1]->getModelOutputs(),
    718                   (RemapVectorType{ { opnd4, b1Opnd4 } }));
    719         ASSERT_EQ(stepsB[1]->getSubModelInputs(),
    720                   (RemapVectorType{ { opnd2, b1Opnd2 } }));
    721         ASSERT_EQ(stepsB[1]->getSubModelOutputs(),
    722                   (SubModelOutputSetType{}));
    723     }
    724 }
    725 
    726 TEST_F(PartitioningTest, Cpu) {
    727     // Here's a model where some operations execute only on the Cpu.
    728     // To make things interesting, we produce three partitions --
    729     // device, cpu, same-device.
    730 
    731     static const uint32_t kCpuOp = 1;
    732     static const uint32_t kDevOp = 2;
    733 
    734     const auto devices = makeDevices(
    735         {
    736             {"1", { .float32Performance = { .execTime = 0.5, .powerUsage = 0.5 },
    737                     .quantized8Performance = { .execTime = 0.5, .powerUsage = 0.5 } }, 1<<kDevOp}
    738         });
    739 
    740     PartitioningModel model;
    741 
    742     uint32_t opnd0 = model.addFloatOperand();
    743     uint32_t opnd1 = model.addFloatOperand();
    744 
    745     uint32_t opnd2 = model.addOperation2To1(kDevOp, opnd0, opnd1);
    746     uint32_t opnd3 = model.addOperation2To1(kDevOp, opnd0, opnd2);
    747 
    748     uint32_t opnd4 = model.addOperation2To1(kCpuOp, opnd0, opnd3);
    749     uint32_t opnd5 = model.addOperation2To1(kCpuOp, opnd2, opnd4);
    750 
    751     uint32_t opnd6 = model.addFloatOperand();
    752 
    753     uint32_t opnd7 = model.addOperation2To1(kDevOp, opnd3, opnd5);
    754     uint32_t opnd8 = model.addOperation2To1(kDevOp, opnd6, opnd7);
    755 
    756     model.identifyInputsAndOutputs({ opnd0, opnd1, opnd6 }, { opnd4, opnd8 });
    757     model.finish();
    758     ASSERT_TRUE(model.isValid());
    759 
    760     ExecutionPlan plan;
    761     ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan),
    762               ANEURALNETWORKS_NO_ERROR);
    763     ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::COMPOUND);
    764     const auto& steps = plan.forTest_compoundGetSteps();
    765     ASSERT_EQ(steps.size(), size_t(3));
    766     {
    767         const auto& step0 = steps[0];
    768 
    769         // Build a model to compare against the submodel from steps[0].
    770         PartitioningModel model0;
    771         uint32_t m0Opnd0 = model0.addFloatOperand();
    772         uint32_t m0Opnd1 = model0.addFloatOperand();
    773         uint32_t m0Opnd2 = model0.addOperation2To1(kDevOp, m0Opnd0, m0Opnd1);
    774         uint32_t m0Opnd3 = model0.addOperation2To1(kDevOp, m0Opnd0, m0Opnd2);
    775         model0.identifyInputsAndOutputs({ m0Opnd0, m0Opnd1 }, { m0Opnd2, m0Opnd3 });
    776         model0.finish();
    777         ASSERT_TRUE(model0.isValid());
    778         ASSERT_NO_FATAL_FAILURE(ASSERT_TRUE(compare(step0, &model0, devices[0])));
    779         ASSERT_EQ(step0->getModelInputs(),
    780                   (RemapVectorType{ { opnd0, m0Opnd0 }, { opnd1, m0Opnd1 } }));
    781         ASSERT_EQ(step0->getModelOutputs(),
    782                   (RemapVectorType{}));
    783         ASSERT_EQ(step0->getSubModelInputs(),
    784                   (RemapVectorType{}));
    785         ASSERT_EQ(step0->getSubModelOutputs(),
    786                   (SubModelOutputSetType{ { opnd2, m0Opnd2 }, { opnd3, m0Opnd3 } }));
    787     }
    788     {
    789         const auto& step1 = steps[1];
    790 
    791         // Build a model to compare against the submodel from steps[1].
    792         PartitioningModel model1;
    793         uint32_t m1Opnd0 = model1.addFloatOperand();
    794         uint32_t m1Opnd3 = model1.addFloatOperand();
    795         uint32_t m1Opnd4 = model1.addOperation2To1(kCpuOp, m1Opnd0, m1Opnd3);
    796         uint32_t m1Opnd2 = model1.addFloatOperand();
    797         uint32_t m1Opnd5 = model1.addOperation2To1(kCpuOp, m1Opnd2, m1Opnd4);
    798         model1.identifyInputsAndOutputs({ m1Opnd0, m1Opnd3, m1Opnd2 }, { m1Opnd4, m1Opnd5 });
    799         model1.finish();
    800         ASSERT_TRUE(model1.isValid());
    801         ASSERT_NO_FATAL_FAILURE(ASSERT_TRUE(compare(step1, &model1, nullptr)));
    802         ASSERT_EQ(step1->getModelInputs(),
    803                   (RemapVectorType{ { opnd0, m1Opnd0 } }));
    804         ASSERT_EQ(step1->getModelOutputs(),
    805                   (RemapVectorType{ { opnd4, m1Opnd4 } }));
    806         ASSERT_EQ(step1->getSubModelInputs(),
    807                   (RemapVectorType{ { opnd3, m1Opnd3 }, { opnd2, m1Opnd2 } }));
    808         ASSERT_EQ(step1->getSubModelOutputs(),
    809                   (SubModelOutputSetType{ { opnd5, m1Opnd5 } }));
    810     }
    811     {
    812         const auto& step2 = steps[2];
    813 
    814         // Build a model to compare against the submodel from steps[2].
    815         PartitioningModel model2;
    816         uint32_t m2Opnd3 = model2.addFloatOperand();
    817         uint32_t m2Opnd5 = model2.addFloatOperand();
    818         uint32_t m2Opnd7 = model2.addOperation2To1(kDevOp, m2Opnd3, m2Opnd5);
    819         uint32_t m2Opnd6 = model2.addFloatOperand();
    820         uint32_t m2Opnd8 = model2.addOperation2To1(kDevOp, m2Opnd6, m2Opnd7);
    821         model2.identifyInputsAndOutputs({ m2Opnd6, m2Opnd3, m2Opnd5 }, { m2Opnd8 });
    822         model2.finish();
    823         ASSERT_TRUE(model2.isValid());
    824         ASSERT_NO_FATAL_FAILURE(ASSERT_TRUE(compare(step2, &model2, devices[0])));
    825         ASSERT_EQ(step2->getModelInputs(),
    826                   (RemapVectorType{ { opnd6, m2Opnd6 } }));
    827         ASSERT_EQ(step2->getModelOutputs(),
    828                   (RemapVectorType{ { opnd8, m2Opnd8 } }));
    829         ASSERT_EQ(step2->getSubModelInputs(),
    830                   (RemapVectorType{ { opnd3, m2Opnd3 }, { opnd5, m2Opnd5 } }));
    831         ASSERT_EQ(step2->getSubModelOutputs(),
    832                   (SubModelOutputSetType{}));
    833     }
    834 }
    835 
    836 }  // namespace
    837