Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "ValidateHal"
     18 
     19 #include "ValidateHal.h"
     20 #include "NeuralNetworks.h"
     21 #include "OperationsUtils.h"
     22 #include "Tracing.h"
     23 #include "Utils.h"
     24 
     25 #include <android-base/logging.h>
     26 
     27 namespace android {
     28 namespace nn {
     29 
     30 template <class T_Model>
     31 struct ModelToHalVersion;
     32 template <>
     33 struct ModelToHalVersion<V1_0::Model> {
     34     static constexpr HalVersion version = HalVersion::V1_0;
     35 };
     36 template <>
     37 struct ModelToHalVersion<V1_1::Model> {
     38     static constexpr HalVersion version = HalVersion::V1_1;
     39 };
     40 template <>
     41 struct ModelToHalVersion<V1_2::Model> {
     42     static constexpr HalVersion version = HalVersion::V1_2;
     43 };
     44 
     45 class MemoryAccessVerifier {
     46 public:
     47     MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
     48         : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
     49         for (size_t i = 0; i < mPoolCount; i++) {
     50             mPoolSizes[i] = pools[i].size();
     51         }
     52     }
     53     bool validate(const DataLocation& location) {
     54         if (location.poolIndex >= mPoolCount) {
     55             LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
     56             return false;
     57         }
     58         const size_t size = mPoolSizes[location.poolIndex];
     59         // Do the addition using size_t to avoid potential wrap-around problems.
     60         if (static_cast<size_t>(location.offset) + location.length > size) {
     61             LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
     62                        << location.offset << " and length " << location.length
     63                        << " exceeds pool size of " << size;
     64             return false;
     65         }
     66         return true;
     67     }
     68 
     69 private:
     70     size_t mPoolCount;
     71     std::vector<size_t> mPoolSizes;
     72 };
     73 
     74 static bool validateOperandExtraParams(const V1_2::Operand& operand, uint32_t index) {
     75     switch (operand.type) {
     76         case OperandType::FLOAT32:
     77         case OperandType::INT32:
     78         case OperandType::UINT32:
     79         case OperandType::BOOL:
     80         case OperandType::TENSOR_FLOAT32:
     81         case OperandType::TENSOR_FLOAT16:
     82         case OperandType::TENSOR_INT32:
     83         case OperandType::TENSOR_QUANT8_ASYMM:
     84         case OperandType::TENSOR_QUANT8_SYMM:
     85         case OperandType::TENSOR_QUANT16_ASYMM:
     86         case OperandType::TENSOR_QUANT16_SYMM:
     87         case OperandType::TENSOR_BOOL8: {
     88             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
     89                          V1_2::Operand::ExtraParams::hidl_discriminator::none)
     90                     << "Operand " << index << ": Operand of type "
     91                     << getOperandTypeName(operand.type)
     92                     << " has incorrect extraParams: " << toString(operand.extraParams);
     93         } break;
     94         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
     95             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
     96                          V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant)
     97                     << "Operand " << index << ": Operand of type "
     98                     << getOperandTypeName(operand.type) << " without a Channel Quantization params";
     99             auto& channelQuant = operand.extraParams.channelQuant();
    100 
    101             size_t count = operand.dimensions.size();
    102             NN_RET_CHECK_LT(channelQuant.channelDim, count)
    103                     << "Operand " << index << ": Operand of type "
    104                     << getOperandTypeName(operand.type)
    105                     << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
    106                     << ", must be valid dimension index in range [0, " << count << ")";
    107             uint32_t expected = operand.dimensions[channelQuant.channelDim];
    108             NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
    109                     << "Operand " << index << ": Operand of type "
    110                     << getOperandTypeName(operand.type) << " with a wrong-sized scales, "
    111                     << "expected " << expected << " was " << channelQuant.scales.size();
    112             NN_RET_CHECK_NE(expected, 0)
    113                     << "Operand " << index << ": Operand of type "
    114                     << getOperandTypeName(operand.type) << " channel dimension "
    115                     << channelQuant.channelDim << " is underspecified (can't be 0)";
    116             for (uint32_t i = 0; i < expected; ++i) {
    117                 NN_RET_CHECK_GT(channelQuant.scales[i], .0f)
    118                         << "Operand " << index << ": Operand of type "
    119                         << getOperandTypeName(operand.type) << " with a negative value in scales["
    120                         << i << "]=" << channelQuant.scales[i];
    121             }
    122         } break;
    123         default: {
    124             if (isExtensionOperandType(operand.type)) {
    125                 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
    126                                      V1_2::Operand::ExtraParams::hidl_discriminator::extension ||
    127                              operand.extraParams.getDiscriminator() ==
    128                                      V1_2::Operand::ExtraParams::hidl_discriminator::none)
    129                         << "Operand " << index << ": Extension operand of type "
    130                         << getOperandTypeName(operand.type)
    131                         << " has incorrect extraParams: " << toString(operand.extraParams);
    132             }
    133             // No validation for OEM types.
    134         } break;
    135     }
    136     return true;
    137 }
    138 
    139 template <typename VersionedOperand>
    140 static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
    141                              const hidl_vec<uint8_t>& operandValues,
    142                              const hidl_vec<hidl_memory>& pools, bool allowUnspecifiedRank) {
    143     uint32_t index = 0;
    144     MemoryAccessVerifier poolVerifier(pools);
    145     for (auto& versionedOperand : operands) {
    146         if (!validOperandType(versionedOperand.type)) {
    147             LOG(ERROR) << "Operand is not supported by this version: "
    148                        << toString(versionedOperand.type);
    149             return false;
    150         }
    151         // Once we are sure the operand is supported by its version, it is safe
    152         // to convert it to the latest version for the rest of the validations.
    153         V1_2::Operand operand = convertToV1_2(versionedOperand);
    154         // Validate type and dimensions.
    155         switch (operand.type) {
    156             case OperandType::FLOAT16:
    157             case OperandType::FLOAT32:
    158             case OperandType::INT32:
    159             case OperandType::UINT32:
    160             case OperandType::BOOL:
    161             case OperandType::OEM: {
    162                 size_t count = operand.dimensions.size();
    163                 if (count != 0) {
    164                     LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
    165                                << count;
    166                     return false;
    167                 }
    168                 break;
    169             }
    170             case OperandType::TENSOR_FLOAT16:
    171             case OperandType::TENSOR_FLOAT32:
    172             case OperandType::TENSOR_INT32:
    173             case OperandType::TENSOR_QUANT8_ASYMM:
    174             case OperandType::TENSOR_QUANT8_SYMM:
    175             case OperandType::TENSOR_QUANT16_ASYMM:
    176             case OperandType::TENSOR_QUANT16_SYMM:
    177             case OperandType::TENSOR_BOOL8:
    178             case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
    179             case OperandType::TENSOR_OEM_BYTE: {
    180                 if ((!allowUnspecifiedRank || operand.lifetime == OperandLifeTime::CONSTANT_COPY ||
    181                      operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE) &&
    182                     operand.dimensions.size() == 0) {
    183                     LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
    184                     return false;
    185                 }
    186                 break;
    187             }
    188             default: {
    189                 if (!isExtensionOperandType(operand.type)) {
    190                     LOG(ERROR) << "Operand " << index << ": Invalid operand type "
    191                                << toString(operand.type);
    192                     return false;
    193                 }
    194             } break;
    195         }
    196 
    197         // TODO Validate the numberOfConsumers.
    198         // TODO Since we have to validate it, there was no point in including it. For the next
    199         // release, consider removing unless we have an additional process in system space
    200         // that creates this value. In that case, it would not have to be validated.
    201 
    202         // Validate the scale.
    203         switch (operand.type) {
    204             case OperandType::FLOAT16:
    205             case OperandType::FLOAT32:
    206             case OperandType::INT32:
    207             case OperandType::UINT32:
    208             case OperandType::BOOL:
    209             case OperandType::TENSOR_FLOAT16:
    210             case OperandType::TENSOR_FLOAT32:
    211             case OperandType::TENSOR_BOOL8:
    212             case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
    213                 if (operand.scale != 0.f) {
    214                     LOG(ERROR) << "Operand " << index << ": Operand of type "
    215                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
    216                                << operand.scale << ")";
    217                     return false;
    218                 }
    219                 break;
    220             case OperandType::TENSOR_INT32:
    221                 // TENSOR_INT32 may be used with or without scale, depending on the operation.
    222                 if (operand.scale < 0.f) {
    223                     LOG(ERROR) << "Operand " << index << ": Operand of type "
    224                                << getOperandTypeName(operand.type) << " with a negative scale";
    225                     return false;
    226                 }
    227                 break;
    228             case OperandType::TENSOR_QUANT8_ASYMM:
    229             case OperandType::TENSOR_QUANT8_SYMM:
    230             case OperandType::TENSOR_QUANT16_ASYMM:
    231             case OperandType::TENSOR_QUANT16_SYMM:
    232                 if (operand.scale <= 0.f) {
    233                     LOG(ERROR) << "Operand " << index << ": Operand of type "
    234                                << getOperandTypeName(operand.type) << " with a non-positive scale";
    235                     return false;
    236                 }
    237                 break;
    238             default:
    239                 if (isExtensionOperandType(operand.type) && operand.scale != 0.f) {
    240                     LOG(ERROR) << "Operand " << index << ": Operand of type "
    241                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
    242                                << operand.scale << ")";
    243                     return false;
    244                 }
    245                 // No validation for OEM types.
    246                 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
    247                 break;
    248         }
    249 
    250         // Validate the zeroPoint.
    251         switch (operand.type) {
    252             case OperandType::FLOAT16:
    253             case OperandType::FLOAT32:
    254             case OperandType::INT32:
    255             case OperandType::UINT32:
    256             case OperandType::BOOL:
    257             case OperandType::TENSOR_FLOAT16:
    258             case OperandType::TENSOR_FLOAT32:
    259             case OperandType::TENSOR_INT32:
    260             case OperandType::TENSOR_BOOL8:
    261             case OperandType::TENSOR_QUANT8_SYMM:
    262             case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
    263                 if (operand.zeroPoint != 0) {
    264                     LOG(ERROR) << "Operand " << index << ": Operand of type "
    265                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
    266                                << operand.zeroPoint;
    267                     return false;
    268                 }
    269                 break;
    270             case OperandType::TENSOR_QUANT8_ASYMM:
    271                 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
    272                     LOG(ERROR) << "Operand " << index << ": Operand of type "
    273                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
    274                                << operand.zeroPoint << ", must be in range [0, 255]";
    275                     return false;
    276                 }
    277                 break;
    278             case OperandType::TENSOR_QUANT16_ASYMM:
    279                 if (operand.zeroPoint < 0 || operand.zeroPoint > 65535) {
    280                     LOG(ERROR) << "Operand " << index << ": Operand of type "
    281                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
    282                                << operand.zeroPoint << ", must be in range [0, 65535]";
    283                     return false;
    284                 }
    285                 break;
    286             case OperandType::TENSOR_QUANT16_SYMM:
    287                 if (operand.zeroPoint != 0) {
    288                     LOG(ERROR) << "Operand " << index << ": Operand of type "
    289                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
    290                                << operand.zeroPoint;
    291                     return false;
    292                 }
    293                 break;
    294             default:
    295                 if (isExtensionOperandType(operand.type) && operand.zeroPoint != 0) {
    296                     LOG(ERROR) << "Operand " << index << ": Operand of type "
    297                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
    298                                << operand.zeroPoint;
    299                     return false;
    300                 }
    301                 // No validation for OEM types.
    302                 break;
    303         }
    304 
    305         NN_RET_CHECK(validateOperandExtraParams(operand, index));
    306 
    307         // Validate the lifetime and the location.
    308         const DataLocation& location = operand.location;
    309         switch (operand.lifetime) {
    310             case OperandLifeTime::CONSTANT_COPY:
    311                 if (location.poolIndex != 0) {
    312                     LOG(ERROR) << "Operand " << index
    313                                << ": CONSTANT_COPY with a non-zero poolIndex "
    314                                << location.poolIndex;
    315                     return false;
    316                 }
    317                 // Do the addition using size_t to avoid potential wrap-around problems.
    318                 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
    319                     LOG(ERROR) << "Operand " << index
    320                                << ": OperandValue location out of range.  Starts at "
    321                                << location.offset << ", length " << location.length << ", max "
    322                                << operandValues.size();
    323                     return false;
    324                 }
    325                 break;
    326             case OperandLifeTime::CONSTANT_REFERENCE:
    327                 if (!poolVerifier.validate(location)) {
    328                     return false;
    329                 }
    330                 break;
    331             case OperandLifeTime::TEMPORARY_VARIABLE:
    332             case OperandLifeTime::MODEL_INPUT:
    333             case OperandLifeTime::MODEL_OUTPUT:
    334             case OperandLifeTime::NO_VALUE:
    335                 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
    336                     LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
    337                                << location.poolIndex << ", offset " << location.offset
    338                                << ", or length " << location.length << " for operand of lifetime "
    339                                << toString(operand.lifetime);
    340                     return false;
    341                 }
    342                 break;
    343             default:
    344                 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
    345                            << toString(operand.lifetime);
    346                 return false;
    347         }
    348 
    349         // For constants, validate that the length is as expected. The other lifetimes
    350         // expect the length to be 0. Don't validate for OEM types.
    351         if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
    352             operand.lifetime == OperandLifeTime::CONSTANT_COPY) {
    353             if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM &&
    354                 operand.type != OperandType::TENSOR_OEM_BYTE) {
    355                 uint32_t expectedLength = nonExtensionOperandSizeOfData(operand);
    356                 if (location.length != expectedLength) {
    357                     LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
    358                                << " expected a size of " << expectedLength << " but got "
    359                                << location.length;
    360                     return false;
    361                 }
    362             }
    363         }
    364 
    365         index++;
    366     }
    367     return true;
    368 }
    369 
    370 static HalVersion getHalVersion(const V1_0::Operation&) {
    371     return HalVersion::V1_0;
    372 }
    373 
    374 static HalVersion getHalVersion(const V1_1::Operation&) {
    375     return HalVersion::V1_1;
    376 }
    377 
    378 static HalVersion getHalVersion(const V1_2::Operation&) {
    379     return HalVersion::V1_2;
    380 }
    381 
    382 template <typename VersionedOperation>
    383 static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
    384                                const hidl_vec<Operand>& operands) {
    385     const size_t operandCount = operands.size();
    386     // This vector keeps track of whether there's an operation that writes to
    387     // each operand. It is used to validate that temporary variables and
    388     // model outputs will be written to.
    389     std::vector<bool> writtenTo(operandCount, false);
    390     for (auto& op : operations) {
    391         // TODO Validate the shapes and any known values. This is currently
    392         // done in CpuExecutor but should be done here for all drivers.
    393         int error = validateOperation(
    394                 static_cast<int32_t>(op.type), op.inputs.size(),
    395                 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
    396                 op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op));
    397         if (error != ANEURALNETWORKS_NO_ERROR) {
    398             LOG(ERROR) << "Invalid operation " << toString(op.type);
    399             return false;
    400         }
    401 
    402         for (uint32_t i : op.outputs) {
    403             const Operand& operand = operands[i];
    404             if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE &&
    405                 operand.lifetime != OperandLifeTime::MODEL_OUTPUT) {
    406                 LOG(ERROR) << "Writing to an operand with incompatible lifetime "
    407                            << toString(operand.lifetime);
    408                 return false;
    409             }
    410 
    411             // Check that we only write once to an operand.
    412             if (writtenTo[i]) {
    413                 LOG(ERROR) << "Operand " << i << " written a second time";
    414                 return false;
    415             }
    416             writtenTo[i] = true;
    417         }
    418     }
    419     for (size_t i = 0; i < operandCount; i++) {
    420         if (!writtenTo[i]) {
    421             const Operand& operand = operands[i];
    422             if (operand.lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
    423                 operand.lifetime == OperandLifeTime::MODEL_OUTPUT) {
    424                 LOG(ERROR) << "Operand " << i << " with lifetime " << toString(operand.lifetime)
    425                            << " is not being written to.";
    426                 return false;
    427             }
    428         }
    429     }
    430     // TODO More whole graph verifications are possible, for example that an
    431     // operand is not use as input & output for the same op, and more
    432     // generally that it is acyclic.
    433     return true;
    434 }
    435 
    436 bool validatePool(const hidl_memory& pool, HalVersion ver) {
    437     const auto& name = pool.name();
    438     if (name != "ashmem" && name != "mmap_fd" &&
    439         ((ver < HalVersion::V1_2) ||
    440          (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
    441         LOG(ERROR) << "Unsupported memory type " << name;
    442         return false;
    443     }
    444     if (pool.handle() == nullptr) {
    445         LOG(ERROR) << "Memory of type " << name << " is null";
    446         return false;
    447     }
    448     return true;
    449 }
    450 
    451 static bool validatePools(const hidl_vec<hidl_memory>& pools, HalVersion ver) {
    452     return std::all_of(pools.begin(), pools.end(),
    453                        [ver](const hidl_memory& pool) { return validatePool(pool, ver); });
    454 }
    455 
    456 static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
    457                                       const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
    458     const size_t operandCount = operands.size();
    459     for (uint32_t i : indexes) {
    460         if (i >= operandCount) {
    461             LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
    462             return false;
    463         }
    464         const Operand& operand = operands[i];
    465         if (operand.lifetime != lifetime) {
    466             LOG(ERROR) << "Model input or output has lifetime of " << toString(operand.lifetime)
    467                        << " instead of the expected " << toString(lifetime);
    468             return false;
    469         }
    470     }
    471 
    472     std::vector<uint32_t> sortedIndexes = indexes;
    473     std::sort(sortedIndexes.begin(), sortedIndexes.end());
    474     auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
    475     if (adjacentI != sortedIndexes.end()) {
    476         LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
    477         return false;
    478     }
    479     return true;
    480 }
    481 
    482 template <class T_Model>
    483 bool validateModel(const T_Model& model) {
    484     NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
    485     HalVersion version = ModelToHalVersion<T_Model>::version;
    486     if (model.operations.size() == 0 || model.operands.size() == 0) {
    487         LOG(ERROR) << "Invalid empty model.";
    488         return false;
    489     }
    490     // We only need versioned operands for their validation. For all the other
    491     // validations we can use operands upcasted to the latest version.
    492     const hidl_vec<Operand> latestVersionOperands = convertToV1_2(model.operands);
    493     return (validateOperands(model.operands, model.operandValues, model.pools,
    494                              /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
    495             validateOperations(model.operations, latestVersionOperands) &&
    496             validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
    497                                       OperandLifeTime::MODEL_INPUT) &&
    498             validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
    499                                       OperandLifeTime::MODEL_OUTPUT) &&
    500             validatePools(model.pools, version));
    501 }
    502 
    503 template bool validateModel<V1_0::Model>(const V1_0::Model& model);
    504 template bool validateModel<V1_1::Model>(const V1_1::Model& model);
    505 template bool validateModel<V1_2::Model>(const V1_2::Model& model);
    506 
    507 // Validates the arguments of a request. type is either "input" or "output" and is used
    508 // for printing error messages. The operandIndexes is the appropriate array of input
    509 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
    510 static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments,
    511                                      const hidl_vec<uint32_t>& operandIndexes,
    512                                      const hidl_vec<Operand>& operands,
    513                                      const hidl_vec<hidl_memory>& pools, bool allowUnspecified,
    514                                      const char* type) {
    515     MemoryAccessVerifier poolVerifier(pools);
    516     // The request should specify as many arguments as were described in the model.
    517     const size_t requestArgumentCount = requestArguments.size();
    518     if (requestArgumentCount != operandIndexes.size()) {
    519         LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
    520                    << "s but the model has " << operandIndexes.size();
    521         return false;
    522     }
    523     for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
    524          requestArgumentIndex++) {
    525         const RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
    526         const DataLocation& location = requestArgument.location;
    527         // Get the operand index for this argument. We extract it from the list
    528         // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
    529         // We assume in this function that the model has been validated already.
    530         const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
    531         const Operand& operand = operands[operandIndex];
    532         if (requestArgument.hasNoValue) {
    533             if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
    534                 requestArgument.dimensions.size() != 0) {
    535                 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
    536                            << " has no value yet has details.";
    537                 return false;
    538             }
    539         } else {
    540             // Validate the location.
    541             if (!poolVerifier.validate(location)) {
    542                 return false;
    543             }
    544             // If the argument specified a dimension, validate it.
    545             uint32_t rank = requestArgument.dimensions.size();
    546             if (rank == 0) {
    547                 if (!allowUnspecified) {
    548                     // Validate that all the dimensions are specified in the model.
    549                     for (size_t i = 0; i < operand.dimensions.size(); i++) {
    550                         if (operand.dimensions[i] == 0) {
    551                             LOG(ERROR) << "Model has dimension " << i
    552                                        << " set to 0 but the request does specify the dimension.";
    553                             return false;
    554                         }
    555                     }
    556                 }
    557             } else {
    558                 if (rank != operand.dimensions.size()) {
    559                     LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
    560                                << " has number of dimensions (" << rank
    561                                << ") different than the model's (" << operand.dimensions.size()
    562                                << ")";
    563                     return false;
    564                 }
    565                 for (size_t i = 0; i < rank; i++) {
    566                     if (requestArgument.dimensions[i] != operand.dimensions[i] &&
    567                         operand.dimensions[i] != 0) {
    568                         LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
    569                                    << " has dimension " << i << " of "
    570                                    << requestArgument.dimensions[i]
    571                                    << " different than the model's " << operand.dimensions[i];
    572                         return false;
    573                     }
    574                     if (requestArgument.dimensions[i] == 0 && !allowUnspecified) {
    575                         LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
    576                                    << " has dimension " << i << " of zero";
    577                         return false;
    578                     }
    579                 }
    580             }
    581         }
    582     }
    583     return true;
    584 }
    585 
    586 template <class T_Model>
    587 bool validateRequest(const Request& request, const T_Model& model) {
    588     HalVersion version = ModelToHalVersion<T_Model>::version;
    589     return (validateRequestArguments(request.inputs, model.inputIndexes,
    590                                      convertToV1_2(model.operands), request.pools,
    591                                      /*allowUnspecified=*/false, "input") &&
    592             validateRequestArguments(request.outputs, model.outputIndexes,
    593                                      convertToV1_2(model.operands), request.pools,
    594                                      /*allowUnspecified=*/version >= HalVersion::V1_2, "output") &&
    595             validatePools(request.pools, version));
    596 }
    597 
    598 template bool validateRequest<V1_0::Model>(const Request& request, const V1_0::Model& model);
    599 template bool validateRequest<V1_1::Model>(const Request& request, const V1_1::Model& model);
    600 template bool validateRequest<V1_2::Model>(const Request& request, const V1_2::Model& model);
    601 
    602 bool validateExecutionPreference(ExecutionPreference preference) {
    603     return preference == ExecutionPreference::LOW_POWER ||
    604            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
    605            preference == ExecutionPreference::SUSTAINED_SPEED;
    606 }
    607 
    608 bool validOperandType(V1_0::OperandType operandType) {
    609     switch (operandType) {
    610         case V1_0::OperandType::FLOAT32:
    611         case V1_0::OperandType::INT32:
    612         case V1_0::OperandType::UINT32:
    613         case V1_0::OperandType::TENSOR_FLOAT32:
    614         case V1_0::OperandType::TENSOR_INT32:
    615         case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
    616         case V1_0::OperandType::OEM:
    617         case V1_0::OperandType::TENSOR_OEM_BYTE:
    618             return true;
    619         default:
    620             return false;
    621     }
    622 }
    623 
    624 bool validOperandType(V1_2::OperandType operandType) {
    625     switch (operandType) {
    626         case V1_2::OperandType::FLOAT16:
    627         case V1_2::OperandType::FLOAT32:
    628         case V1_2::OperandType::INT32:
    629         case V1_2::OperandType::UINT32:
    630         case V1_2::OperandType::BOOL:
    631         case V1_2::OperandType::TENSOR_FLOAT16:
    632         case V1_2::OperandType::TENSOR_FLOAT32:
    633         case V1_2::OperandType::TENSOR_INT32:
    634         case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
    635         case V1_2::OperandType::TENSOR_QUANT8_SYMM:
    636         case V1_2::OperandType::TENSOR_QUANT16_ASYMM:
    637         case V1_2::OperandType::TENSOR_QUANT16_SYMM:
    638         case V1_2::OperandType::TENSOR_BOOL8:
    639         case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
    640         case V1_2::OperandType::OEM:
    641         case V1_2::OperandType::TENSOR_OEM_BYTE:
    642             return true;
    643         default:
    644             return isExtensionOperandType(operandType);
    645     }
    646 }
    647 
    648 }  // namespace nn
    649 }  // namespace android
    650