Home | History | Annotate | Download | only in functional
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
     18 
     19 #include "VtsHalNeuralnetworks.h"
     20 
     21 #include "Callbacks.h"
     22 
     23 namespace android {
     24 namespace hardware {
     25 namespace neuralnetworks {
     26 namespace V1_0 {
     27 namespace vts {
     28 namespace functional {
     29 
     30 using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback;
     31 using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
     32 
     33 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
     34 
     35 static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
     36                                            const V1_0::Model& model) {
     37     SCOPED_TRACE(message + " [getSupportedOperations]");
     38 
     39     Return<void> ret =
     40         device->getSupportedOperations(model, [&](ErrorStatus status, const hidl_vec<bool>&) {
     41             EXPECT_EQ(ErrorStatus::INVALID_ARGUMENT, status);
     42         });
     43     EXPECT_TRUE(ret.isOk());
     44 }
     45 
     46 static void validatePrepareModel(const sp<IDevice>& device, const std::string& message,
     47                                  const V1_0::Model& model) {
     48     SCOPED_TRACE(message + " [prepareModel]");
     49 
     50     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
     51     ASSERT_NE(nullptr, preparedModelCallback.get());
     52     Return<ErrorStatus> prepareLaunchStatus = device->prepareModel(model, preparedModelCallback);
     53     ASSERT_TRUE(prepareLaunchStatus.isOk());
     54     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus));
     55 
     56     preparedModelCallback->wait();
     57     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
     58     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
     59     sp<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
     60     ASSERT_EQ(nullptr, preparedModel.get());
     61 }
     62 
     63 // Primary validation function. This function will take a valid model, apply a
     64 // mutation to it to invalidate the model, then pass it to interface calls that
     65 // use the model. Note that the model here is passed by value, and any mutation
     66 // to the model does not leave this function.
     67 static void validate(const sp<IDevice>& device, const std::string& message, V1_0::Model model,
     68                      const std::function<void(Model*)>& mutation) {
     69     mutation(&model);
     70     validateGetSupportedOperations(device, message, model);
     71     validatePrepareModel(device, message, model);
     72 }
     73 
     74 // Delete element from hidl_vec. hidl_vec doesn't support a "remove" operation,
     75 // so this is efficiently accomplished by moving the element to the end and
     76 // resizing the hidl_vec to one less.
     77 template <typename Type>
     78 static void hidl_vec_removeAt(hidl_vec<Type>* vec, uint32_t index) {
     79     if (vec) {
     80         std::rotate(vec->begin() + index, vec->begin() + index + 1, vec->end());
     81         vec->resize(vec->size() - 1);
     82     }
     83 }
     84 
     85 template <typename Type>
     86 static uint32_t hidl_vec_push_back(hidl_vec<Type>* vec, const Type& value) {
     87     // assume vec is valid
     88     const uint32_t index = vec->size();
     89     vec->resize(index + 1);
     90     (*vec)[index] = value;
     91     return index;
     92 }
     93 
     94 static uint32_t addOperand(Model* model) {
     95     return hidl_vec_push_back(&model->operands,
     96                               {
     97                                   .type = OperandType::INT32,
     98                                   .dimensions = {},
     99                                   .numberOfConsumers = 0,
    100                                   .scale = 0.0f,
    101                                   .zeroPoint = 0,
    102                                   .lifetime = OperandLifeTime::MODEL_INPUT,
    103                                   .location = {.poolIndex = 0, .offset = 0, .length = 0},
    104                               });
    105 }
    106 
    107 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
    108     uint32_t index = addOperand(model);
    109     model->operands[index].numberOfConsumers = 1;
    110     model->operands[index].lifetime = lifetime;
    111     return index;
    112 }
    113 
    114 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
    115 
    116 static const int32_t invalidOperandTypes[] = {
    117     static_cast<int32_t>(OperandType::FLOAT32) - 1,              // lower bound fundamental
    118     static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) + 1,  // upper bound fundamental
    119     static_cast<int32_t>(OperandType::OEM) - 1,                  // lower bound OEM
    120     static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) + 1,      // upper bound OEM
    121 };
    122 
    123 static void mutateOperandTypeTest(const sp<IDevice>& device, const V1_0::Model& model) {
    124     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
    125         for (int32_t invalidOperandType : invalidOperandTypes) {
    126             const std::string message = "mutateOperandTypeTest: operand " +
    127                                         std::to_string(operand) + " set to value " +
    128                                         std::to_string(invalidOperandType);
    129             validate(device, message, model, [operand, invalidOperandType](Model* model) {
    130                 model->operands[operand].type = static_cast<OperandType>(invalidOperandType);
    131             });
    132         }
    133     }
    134 }
    135 
    136 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
    137 
    138 static uint32_t getInvalidRank(OperandType type) {
    139     switch (type) {
    140         case OperandType::FLOAT32:
    141         case OperandType::INT32:
    142         case OperandType::UINT32:
    143             return 1;
    144         case OperandType::TENSOR_FLOAT32:
    145         case OperandType::TENSOR_INT32:
    146         case OperandType::TENSOR_QUANT8_ASYMM:
    147             return 0;
    148         default:
    149             return 0;
    150     }
    151 }
    152 
    153 static void mutateOperandRankTest(const sp<IDevice>& device, const V1_0::Model& model) {
    154     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
    155         const uint32_t invalidRank = getInvalidRank(model.operands[operand].type);
    156         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
    157                                     " has rank of " + std::to_string(invalidRank);
    158         validate(device, message, model, [operand, invalidRank](Model* model) {
    159             model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
    160         });
    161     }
    162 }
    163 
    164 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
    165 
    166 static float getInvalidScale(OperandType type) {
    167     switch (type) {
    168         case OperandType::FLOAT32:
    169         case OperandType::INT32:
    170         case OperandType::UINT32:
    171         case OperandType::TENSOR_FLOAT32:
    172             return 1.0f;
    173         case OperandType::TENSOR_INT32:
    174             return -1.0f;
    175         case OperandType::TENSOR_QUANT8_ASYMM:
    176             return 0.0f;
    177         default:
    178             return 0.0f;
    179     }
    180 }
    181 
    182 static void mutateOperandScaleTest(const sp<IDevice>& device, const V1_0::Model& model) {
    183     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
    184         const float invalidScale = getInvalidScale(model.operands[operand].type);
    185         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
    186                                     " has scale of " + std::to_string(invalidScale);
    187         validate(device, message, model, [operand, invalidScale](Model* model) {
    188             model->operands[operand].scale = invalidScale;
    189         });
    190     }
    191 }
    192 
    193 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
    194 
    195 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
    196     switch (type) {
    197         case OperandType::FLOAT32:
    198         case OperandType::INT32:
    199         case OperandType::UINT32:
    200         case OperandType::TENSOR_FLOAT32:
    201         case OperandType::TENSOR_INT32:
    202             return {1};
    203         case OperandType::TENSOR_QUANT8_ASYMM:
    204             return {-1, 256};
    205         default:
    206             return {};
    207     }
    208 }
    209 
    210 static void mutateOperandZeroPointTest(const sp<IDevice>& device, const V1_0::Model& model) {
    211     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
    212         const std::vector<int32_t> invalidZeroPoints =
    213             getInvalidZeroPoints(model.operands[operand].type);
    214         for (int32_t invalidZeroPoint : invalidZeroPoints) {
    215             const std::string message = "mutateOperandZeroPointTest: operand " +
    216                                         std::to_string(operand) + " has zero point of " +
    217                                         std::to_string(invalidZeroPoint);
    218             validate(device, message, model, [operand, invalidZeroPoint](Model* model) {
    219                 model->operands[operand].zeroPoint = invalidZeroPoint;
    220             });
    221         }
    222     }
    223 }
    224 
    225 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
    226 
    227 // TODO: Operand::lifetime
    228 // TODO: Operand::location
    229 
    230 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
    231 
    232 static void mutateOperand(Operand* operand, OperandType type) {
    233     Operand newOperand = *operand;
    234     newOperand.type = type;
    235     switch (type) {
    236         case OperandType::FLOAT32:
    237         case OperandType::INT32:
    238         case OperandType::UINT32:
    239             newOperand.dimensions = hidl_vec<uint32_t>();
    240             newOperand.scale = 0.0f;
    241             newOperand.zeroPoint = 0;
    242             break;
    243         case OperandType::TENSOR_FLOAT32:
    244             newOperand.dimensions =
    245                 operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
    246             newOperand.scale = 0.0f;
    247             newOperand.zeroPoint = 0;
    248             break;
    249         case OperandType::TENSOR_INT32:
    250             newOperand.dimensions =
    251                 operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
    252             newOperand.zeroPoint = 0;
    253             break;
    254         case OperandType::TENSOR_QUANT8_ASYMM:
    255             newOperand.dimensions =
    256                 operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
    257             newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
    258             break;
    259         case OperandType::OEM:
    260         case OperandType::TENSOR_OEM_BYTE:
    261         default:
    262             break;
    263     }
    264     *operand = newOperand;
    265 }
    266 
    267 static bool mutateOperationOperandTypeSkip(size_t operand, const V1_0::Model& model) {
    268     // LSH_PROJECTION's second argument is allowed to have any type. This is the
    269     // only operation that currently has a type that can be anything independent
    270     // from any other type. Changing the operand type to any other type will
    271     // result in a valid model for LSH_PROJECTION. If this is the case, skip the
    272     // test.
    273     for (const Operation& operation : model.operations) {
    274         if (operation.type == OperationType::LSH_PROJECTION && operand == operation.inputs[1]) {
    275             return true;
    276         }
    277     }
    278     return false;
    279 }
    280 
    281 static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const V1_0::Model& model) {
    282     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
    283         if (mutateOperationOperandTypeSkip(operand, model)) {
    284             continue;
    285         }
    286         for (OperandType invalidOperandType : hidl_enum_range<OperandType>{}) {
    287             // Do not test OEM types
    288             if (invalidOperandType == model.operands[operand].type ||
    289                 invalidOperandType == OperandType::OEM ||
    290                 invalidOperandType == OperandType::TENSOR_OEM_BYTE) {
    291                 continue;
    292             }
    293             const std::string message = "mutateOperationOperandTypeTest: operand " +
    294                                         std::to_string(operand) + " set to type " +
    295                                         toString(invalidOperandType);
    296             validate(device, message, model, [operand, invalidOperandType](Model* model) {
    297                 mutateOperand(&model->operands[operand], invalidOperandType);
    298             });
    299         }
    300     }
    301 }
    302 
    303 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
    304 
    305 static const int32_t invalidOperationTypes[] = {
    306     static_cast<int32_t>(OperationType::ADD) - 1,            // lower bound fundamental
    307     static_cast<int32_t>(OperationType::TANH) + 1,           // upper bound fundamental
    308     static_cast<int32_t>(OperationType::OEM_OPERATION) - 1,  // lower bound OEM
    309     static_cast<int32_t>(OperationType::OEM_OPERATION) + 1,  // upper bound OEM
    310 };
    311 
    312 static void mutateOperationTypeTest(const sp<IDevice>& device, const V1_0::Model& model) {
    313     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
    314         for (int32_t invalidOperationType : invalidOperationTypes) {
    315             const std::string message = "mutateOperationTypeTest: operation " +
    316                                         std::to_string(operation) + " set to value " +
    317                                         std::to_string(invalidOperationType);
    318             validate(device, message, model, [operation, invalidOperationType](Model* model) {
    319                 model->operations[operation].type =
    320                     static_cast<OperationType>(invalidOperationType);
    321             });
    322         }
    323     }
    324 }
    325 
    326 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
    327 
    328 static void mutateOperationInputOperandIndexTest(const sp<IDevice>& device,
    329                                                  const V1_0::Model& model) {
    330     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
    331         const uint32_t invalidOperand = model.operands.size();
    332         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
    333             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
    334                                         std::to_string(operation) + " input " +
    335                                         std::to_string(input);
    336             validate(device, message, model, [operation, input, invalidOperand](Model* model) {
    337                 model->operations[operation].inputs[input] = invalidOperand;
    338             });
    339         }
    340     }
    341 }
    342 
    343 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
    344 
    345 static void mutateOperationOutputOperandIndexTest(const sp<IDevice>& device,
    346                                                   const V1_0::Model& model) {
    347     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
    348         const uint32_t invalidOperand = model.operands.size();
    349         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
    350             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
    351                                         std::to_string(operation) + " output " +
    352                                         std::to_string(output);
    353             validate(device, message, model, [operation, output, invalidOperand](Model* model) {
    354                 model->operations[operation].outputs[output] = invalidOperand;
    355             });
    356         }
    357     }
    358 }
    359 
    360 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
    361 
    362 static void removeValueAndDecrementGreaterValues(hidl_vec<uint32_t>* vec, uint32_t value) {
    363     if (vec) {
    364         // remove elements matching "value"
    365         auto last = std::remove(vec->begin(), vec->end(), value);
    366         vec->resize(std::distance(vec->begin(), last));
    367 
    368         // decrement elements exceeding "value"
    369         std::transform(vec->begin(), vec->end(), vec->begin(),
    370                        [value](uint32_t v) { return v > value ? v-- : v; });
    371     }
    372 }
    373 
    374 static void removeOperand(Model* model, uint32_t index) {
    375     hidl_vec_removeAt(&model->operands, index);
    376     for (Operation& operation : model->operations) {
    377         removeValueAndDecrementGreaterValues(&operation.inputs, index);
    378         removeValueAndDecrementGreaterValues(&operation.outputs, index);
    379     }
    380     removeValueAndDecrementGreaterValues(&model->inputIndexes, index);
    381     removeValueAndDecrementGreaterValues(&model->outputIndexes, index);
    382 }
    383 
    384 static void removeOperandTest(const sp<IDevice>& device, const V1_0::Model& model) {
    385     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
    386         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
    387         validate(device, message, model,
    388                  [operand](Model* model) { removeOperand(model, operand); });
    389     }
    390 }
    391 
    392 ///////////////////////// REMOVE OPERATION /////////////////////////
    393 
    394 static void removeOperation(Model* model, uint32_t index) {
    395     for (uint32_t operand : model->operations[index].inputs) {
    396         model->operands[operand].numberOfConsumers--;
    397     }
    398     hidl_vec_removeAt(&model->operations, index);
    399 }
    400 
    401 static void removeOperationTest(const sp<IDevice>& device, const V1_0::Model& model) {
    402     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
    403         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
    404         validate(device, message, model,
    405                  [operation](Model* model) { removeOperation(model, operation); });
    406     }
    407 }
    408 
    409 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
    410 
    411 static void removeOperationInputTest(const sp<IDevice>& device, const V1_0::Model& model) {
    412     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
    413         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
    414             const V1_0::Operation& op = model.operations[operation];
    415             // CONCATENATION has at least 2 inputs, with the last element being
    416             // INT32. Skip this test if removing one of CONCATENATION's
    417             // inputs still produces a valid model.
    418             if (op.type == V1_0::OperationType::CONCATENATION && op.inputs.size() > 2 &&
    419                 input != op.inputs.size() - 1) {
    420                 continue;
    421             }
    422             const std::string message = "removeOperationInputTest: operation " +
    423                                         std::to_string(operation) + ", input " +
    424                                         std::to_string(input);
    425             validate(device, message, model, [operation, input](Model* model) {
    426                 uint32_t operand = model->operations[operation].inputs[input];
    427                 model->operands[operand].numberOfConsumers--;
    428                 hidl_vec_removeAt(&model->operations[operation].inputs, input);
    429             });
    430         }
    431     }
    432 }
    433 
    434 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
    435 
    436 static void removeOperationOutputTest(const sp<IDevice>& device, const V1_0::Model& model) {
    437     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
    438         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
    439             const std::string message = "removeOperationOutputTest: operation " +
    440                                         std::to_string(operation) + ", output " +
    441                                         std::to_string(output);
    442             validate(device, message, model, [operation, output](Model* model) {
    443                 hidl_vec_removeAt(&model->operations[operation].outputs, output);
    444             });
    445         }
    446     }
    447 }
    448 
    449 ///////////////////////// MODEL VALIDATION /////////////////////////
    450 
    451 // TODO: remove model input
    452 // TODO: remove model output
    453 // TODO: add unused operation
    454 
    455 ///////////////////////// ADD OPERATION INPUT /////////////////////////
    456 
    457 static void addOperationInputTest(const sp<IDevice>& device, const V1_0::Model& model) {
    458     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
    459         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
    460         validate(device, message, model, [operation](Model* model) {
    461             uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT);
    462             hidl_vec_push_back(&model->operations[operation].inputs, index);
    463             hidl_vec_push_back(&model->inputIndexes, index);
    464         });
    465     }
    466 }
    467 
    468 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
    469 
    470 static void addOperationOutputTest(const sp<IDevice>& device, const V1_0::Model& model) {
    471     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
    472         const std::string message =
    473             "addOperationOutputTest: operation " + std::to_string(operation);
    474         validate(device, message, model, [operation](Model* model) {
    475             uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT);
    476             hidl_vec_push_back(&model->operations[operation].outputs, index);
    477             hidl_vec_push_back(&model->outputIndexes, index);
    478         });
    479     }
    480 }
    481 
    482 ////////////////////////// ENTRY POINT //////////////////////////////
    483 
    484 void ValidationTest::validateModel(const V1_0::Model& model) {
    485     mutateOperandTypeTest(device, model);
    486     mutateOperandRankTest(device, model);
    487     mutateOperandScaleTest(device, model);
    488     mutateOperandZeroPointTest(device, model);
    489     mutateOperationOperandTypeTest(device, model);
    490     mutateOperationTypeTest(device, model);
    491     mutateOperationInputOperandIndexTest(device, model);
    492     mutateOperationOutputOperandIndexTest(device, model);
    493     removeOperandTest(device, model);
    494     removeOperationTest(device, model);
    495     removeOperationInputTest(device, model);
    496     removeOperationOutputTest(device, model);
    497     addOperationInputTest(device, model);
    498     addOperationOutputTest(device, model);
    499 }
    500 
    501 }  // namespace functional
    502 }  // namespace vts
    503 }  // namespace V1_0
    504 }  // namespace neuralnetworks
    505 }  // namespace hardware
    506 }  // namespace android
    507