1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "ValidateHal" 18 19 #include "ValidateHal.h" 20 #include "NeuralNetworks.h" 21 #include "OperationsUtils.h" 22 #include "Tracing.h" 23 #include "Utils.h" 24 25 #include <android-base/logging.h> 26 27 namespace android { 28 namespace nn { 29 30 template <class T_Model> 31 struct ModelToHalVersion; 32 template <> 33 struct ModelToHalVersion<V1_0::Model> { 34 static constexpr HalVersion version = HalVersion::V1_0; 35 }; 36 template <> 37 struct ModelToHalVersion<V1_1::Model> { 38 static constexpr HalVersion version = HalVersion::V1_1; 39 }; 40 template <> 41 struct ModelToHalVersion<V1_2::Model> { 42 static constexpr HalVersion version = HalVersion::V1_2; 43 }; 44 45 class MemoryAccessVerifier { 46 public: 47 MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools) 48 : mPoolCount(pools.size()), mPoolSizes(mPoolCount) { 49 for (size_t i = 0; i < mPoolCount; i++) { 50 mPoolSizes[i] = pools[i].size(); 51 } 52 } 53 bool validate(const DataLocation& location) { 54 if (location.poolIndex >= mPoolCount) { 55 LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount; 56 return false; 57 } 58 const size_t size = mPoolSizes[location.poolIndex]; 59 // Do the addition using size_t to avoid potential wrap-around problems. 60 if (static_cast<size_t>(location.offset) + location.length > size) { 61 LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset " 62 << location.offset << " and length " << location.length 63 << " exceeds pool size of " << size; 64 return false; 65 } 66 return true; 67 } 68 69 private: 70 size_t mPoolCount; 71 std::vector<size_t> mPoolSizes; 72 }; 73 74 static bool validateOperandExtraParams(const V1_2::Operand& operand, uint32_t index) { 75 switch (operand.type) { 76 case OperandType::FLOAT32: 77 case OperandType::INT32: 78 case OperandType::UINT32: 79 case OperandType::BOOL: 80 case OperandType::TENSOR_FLOAT32: 81 case OperandType::TENSOR_FLOAT16: 82 case OperandType::TENSOR_INT32: 83 case OperandType::TENSOR_QUANT8_ASYMM: 84 case OperandType::TENSOR_QUANT8_SYMM: 85 case OperandType::TENSOR_QUANT16_ASYMM: 86 case OperandType::TENSOR_QUANT16_SYMM: 87 case OperandType::TENSOR_BOOL8: { 88 NN_RET_CHECK(operand.extraParams.getDiscriminator() == 89 V1_2::Operand::ExtraParams::hidl_discriminator::none) 90 << "Operand " << index << ": Operand of type " 91 << getOperandTypeName(operand.type) 92 << " has incorrect extraParams: " << toString(operand.extraParams); 93 } break; 94 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: { 95 NN_RET_CHECK(operand.extraParams.getDiscriminator() == 96 V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant) 97 << "Operand " << index << ": Operand of type " 98 << getOperandTypeName(operand.type) << " without a Channel Quantization params"; 99 auto& channelQuant = operand.extraParams.channelQuant(); 100 101 size_t count = operand.dimensions.size(); 102 NN_RET_CHECK_LT(channelQuant.channelDim, count) 103 << "Operand " << index << ": Operand of type " 104 << getOperandTypeName(operand.type) 105 << " with an invalid channelQuant.channelDim " << channelQuant.channelDim 106 << ", must be valid dimension index in range [0, " << count << ")"; 107 uint32_t expected = operand.dimensions[channelQuant.channelDim]; 108 NN_RET_CHECK_EQ(channelQuant.scales.size(), expected) 109 << "Operand " << index << ": Operand of type " 110 << getOperandTypeName(operand.type) << " with a wrong-sized scales, " 111 << "expected " << expected << " was " << channelQuant.scales.size(); 112 NN_RET_CHECK_NE(expected, 0) 113 << "Operand " << index << ": Operand of type " 114 << getOperandTypeName(operand.type) << " channel dimension " 115 << channelQuant.channelDim << " is underspecified (can't be 0)"; 116 for (uint32_t i = 0; i < expected; ++i) { 117 NN_RET_CHECK_GT(channelQuant.scales[i], .0f) 118 << "Operand " << index << ": Operand of type " 119 << getOperandTypeName(operand.type) << " with a negative value in scales[" 120 << i << "]=" << channelQuant.scales[i]; 121 } 122 } break; 123 default: { 124 if (isExtensionOperandType(operand.type)) { 125 NN_RET_CHECK(operand.extraParams.getDiscriminator() == 126 V1_2::Operand::ExtraParams::hidl_discriminator::extension || 127 operand.extraParams.getDiscriminator() == 128 V1_2::Operand::ExtraParams::hidl_discriminator::none) 129 << "Operand " << index << ": Extension operand of type " 130 << getOperandTypeName(operand.type) 131 << " has incorrect extraParams: " << toString(operand.extraParams); 132 } 133 // No validation for OEM types. 134 } break; 135 } 136 return true; 137 } 138 139 template <typename VersionedOperand> 140 static bool validateOperands(const hidl_vec<VersionedOperand>& operands, 141 const hidl_vec<uint8_t>& operandValues, 142 const hidl_vec<hidl_memory>& pools, bool allowUnspecifiedRank) { 143 uint32_t index = 0; 144 MemoryAccessVerifier poolVerifier(pools); 145 for (auto& versionedOperand : operands) { 146 if (!validOperandType(versionedOperand.type)) { 147 LOG(ERROR) << "Operand is not supported by this version: " 148 << toString(versionedOperand.type); 149 return false; 150 } 151 // Once we are sure the operand is supported by its version, it is safe 152 // to convert it to the latest version for the rest of the validations. 153 V1_2::Operand operand = convertToV1_2(versionedOperand); 154 // Validate type and dimensions. 155 switch (operand.type) { 156 case OperandType::FLOAT16: 157 case OperandType::FLOAT32: 158 case OperandType::INT32: 159 case OperandType::UINT32: 160 case OperandType::BOOL: 161 case OperandType::OEM: { 162 size_t count = operand.dimensions.size(); 163 if (count != 0) { 164 LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank " 165 << count; 166 return false; 167 } 168 break; 169 } 170 case OperandType::TENSOR_FLOAT16: 171 case OperandType::TENSOR_FLOAT32: 172 case OperandType::TENSOR_INT32: 173 case OperandType::TENSOR_QUANT8_ASYMM: 174 case OperandType::TENSOR_QUANT8_SYMM: 175 case OperandType::TENSOR_QUANT16_ASYMM: 176 case OperandType::TENSOR_QUANT16_SYMM: 177 case OperandType::TENSOR_BOOL8: 178 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: 179 case OperandType::TENSOR_OEM_BYTE: { 180 if ((!allowUnspecifiedRank || operand.lifetime == OperandLifeTime::CONSTANT_COPY || 181 operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE) && 182 operand.dimensions.size() == 0) { 183 LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0"; 184 return false; 185 } 186 break; 187 } 188 default: { 189 if (!isExtensionOperandType(operand.type)) { 190 LOG(ERROR) << "Operand " << index << ": Invalid operand type " 191 << toString(operand.type); 192 return false; 193 } 194 } break; 195 } 196 197 // TODO Validate the numberOfConsumers. 198 // TODO Since we have to validate it, there was no point in including it. For the next 199 // release, consider removing unless we have an additional process in system space 200 // that creates this value. In that case, it would not have to be validated. 201 202 // Validate the scale. 203 switch (operand.type) { 204 case OperandType::FLOAT16: 205 case OperandType::FLOAT32: 206 case OperandType::INT32: 207 case OperandType::UINT32: 208 case OperandType::BOOL: 209 case OperandType::TENSOR_FLOAT16: 210 case OperandType::TENSOR_FLOAT32: 211 case OperandType::TENSOR_BOOL8: 212 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: 213 if (operand.scale != 0.f) { 214 LOG(ERROR) << "Operand " << index << ": Operand of type " 215 << getOperandTypeName(operand.type) << " with a non-zero scale (" 216 << operand.scale << ")"; 217 return false; 218 } 219 break; 220 case OperandType::TENSOR_INT32: 221 // TENSOR_INT32 may be used with or without scale, depending on the operation. 222 if (operand.scale < 0.f) { 223 LOG(ERROR) << "Operand " << index << ": Operand of type " 224 << getOperandTypeName(operand.type) << " with a negative scale"; 225 return false; 226 } 227 break; 228 case OperandType::TENSOR_QUANT8_ASYMM: 229 case OperandType::TENSOR_QUANT8_SYMM: 230 case OperandType::TENSOR_QUANT16_ASYMM: 231 case OperandType::TENSOR_QUANT16_SYMM: 232 if (operand.scale <= 0.f) { 233 LOG(ERROR) << "Operand " << index << ": Operand of type " 234 << getOperandTypeName(operand.type) << " with a non-positive scale"; 235 return false; 236 } 237 break; 238 default: 239 if (isExtensionOperandType(operand.type) && operand.scale != 0.f) { 240 LOG(ERROR) << "Operand " << index << ": Operand of type " 241 << getOperandTypeName(operand.type) << " with a non-zero scale (" 242 << operand.scale << ")"; 243 return false; 244 } 245 // No validation for OEM types. 246 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale. 247 break; 248 } 249 250 // Validate the zeroPoint. 251 switch (operand.type) { 252 case OperandType::FLOAT16: 253 case OperandType::FLOAT32: 254 case OperandType::INT32: 255 case OperandType::UINT32: 256 case OperandType::BOOL: 257 case OperandType::TENSOR_FLOAT16: 258 case OperandType::TENSOR_FLOAT32: 259 case OperandType::TENSOR_INT32: 260 case OperandType::TENSOR_BOOL8: 261 case OperandType::TENSOR_QUANT8_SYMM: 262 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: 263 if (operand.zeroPoint != 0) { 264 LOG(ERROR) << "Operand " << index << ": Operand of type " 265 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint " 266 << operand.zeroPoint; 267 return false; 268 } 269 break; 270 case OperandType::TENSOR_QUANT8_ASYMM: 271 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) { 272 LOG(ERROR) << "Operand " << index << ": Operand of type " 273 << getOperandTypeName(operand.type) << " with an invalid zeroPoint " 274 << operand.zeroPoint << ", must be in range [0, 255]"; 275 return false; 276 } 277 break; 278 case OperandType::TENSOR_QUANT16_ASYMM: 279 if (operand.zeroPoint < 0 || operand.zeroPoint > 65535) { 280 LOG(ERROR) << "Operand " << index << ": Operand of type " 281 << getOperandTypeName(operand.type) << " with an invalid zeroPoint " 282 << operand.zeroPoint << ", must be in range [0, 65535]"; 283 return false; 284 } 285 break; 286 case OperandType::TENSOR_QUANT16_SYMM: 287 if (operand.zeroPoint != 0) { 288 LOG(ERROR) << "Operand " << index << ": Operand of type " 289 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint " 290 << operand.zeroPoint; 291 return false; 292 } 293 break; 294 default: 295 if (isExtensionOperandType(operand.type) && operand.zeroPoint != 0) { 296 LOG(ERROR) << "Operand " << index << ": Operand of type " 297 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint " 298 << operand.zeroPoint; 299 return false; 300 } 301 // No validation for OEM types. 302 break; 303 } 304 305 NN_RET_CHECK(validateOperandExtraParams(operand, index)); 306 307 // Validate the lifetime and the location. 308 const DataLocation& location = operand.location; 309 switch (operand.lifetime) { 310 case OperandLifeTime::CONSTANT_COPY: 311 if (location.poolIndex != 0) { 312 LOG(ERROR) << "Operand " << index 313 << ": CONSTANT_COPY with a non-zero poolIndex " 314 << location.poolIndex; 315 return false; 316 } 317 // Do the addition using size_t to avoid potential wrap-around problems. 318 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) { 319 LOG(ERROR) << "Operand " << index 320 << ": OperandValue location out of range. Starts at " 321 << location.offset << ", length " << location.length << ", max " 322 << operandValues.size(); 323 return false; 324 } 325 break; 326 case OperandLifeTime::CONSTANT_REFERENCE: 327 if (!poolVerifier.validate(location)) { 328 return false; 329 } 330 break; 331 case OperandLifeTime::TEMPORARY_VARIABLE: 332 case OperandLifeTime::MODEL_INPUT: 333 case OperandLifeTime::MODEL_OUTPUT: 334 case OperandLifeTime::NO_VALUE: 335 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) { 336 LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex " 337 << location.poolIndex << ", offset " << location.offset 338 << ", or length " << location.length << " for operand of lifetime " 339 << toString(operand.lifetime); 340 return false; 341 } 342 break; 343 default: 344 LOG(ERROR) << "Operand " << index << ": Invalid lifetime " 345 << toString(operand.lifetime); 346 return false; 347 } 348 349 // For constants, validate that the length is as expected. The other lifetimes 350 // expect the length to be 0. Don't validate for OEM types. 351 if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE || 352 operand.lifetime == OperandLifeTime::CONSTANT_COPY) { 353 if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM && 354 operand.type != OperandType::TENSOR_OEM_BYTE) { 355 uint32_t expectedLength = nonExtensionOperandSizeOfData(operand); 356 if (location.length != expectedLength) { 357 LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand) 358 << " expected a size of " << expectedLength << " but got " 359 << location.length; 360 return false; 361 } 362 } 363 } 364 365 index++; 366 } 367 return true; 368 } 369 370 static HalVersion getHalVersion(const V1_0::Operation&) { 371 return HalVersion::V1_0; 372 } 373 374 static HalVersion getHalVersion(const V1_1::Operation&) { 375 return HalVersion::V1_1; 376 } 377 378 static HalVersion getHalVersion(const V1_2::Operation&) { 379 return HalVersion::V1_2; 380 } 381 382 template <typename VersionedOperation> 383 static bool validateOperations(const hidl_vec<VersionedOperation>& operations, 384 const hidl_vec<Operand>& operands) { 385 const size_t operandCount = operands.size(); 386 // This vector keeps track of whether there's an operation that writes to 387 // each operand. It is used to validate that temporary variables and 388 // model outputs will be written to. 389 std::vector<bool> writtenTo(operandCount, false); 390 for (auto& op : operations) { 391 // TODO Validate the shapes and any known values. This is currently 392 // done in CpuExecutor but should be done here for all drivers. 393 int error = validateOperation( 394 static_cast<int32_t>(op.type), op.inputs.size(), 395 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(), 396 op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op)); 397 if (error != ANEURALNETWORKS_NO_ERROR) { 398 LOG(ERROR) << "Invalid operation " << toString(op.type); 399 return false; 400 } 401 402 for (uint32_t i : op.outputs) { 403 const Operand& operand = operands[i]; 404 if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE && 405 operand.lifetime != OperandLifeTime::MODEL_OUTPUT) { 406 LOG(ERROR) << "Writing to an operand with incompatible lifetime " 407 << toString(operand.lifetime); 408 return false; 409 } 410 411 // Check that we only write once to an operand. 412 if (writtenTo[i]) { 413 LOG(ERROR) << "Operand " << i << " written a second time"; 414 return false; 415 } 416 writtenTo[i] = true; 417 } 418 } 419 for (size_t i = 0; i < operandCount; i++) { 420 if (!writtenTo[i]) { 421 const Operand& operand = operands[i]; 422 if (operand.lifetime == OperandLifeTime::TEMPORARY_VARIABLE || 423 operand.lifetime == OperandLifeTime::MODEL_OUTPUT) { 424 LOG(ERROR) << "Operand " << i << " with lifetime " << toString(operand.lifetime) 425 << " is not being written to."; 426 return false; 427 } 428 } 429 } 430 // TODO More whole graph verifications are possible, for example that an 431 // operand is not use as input & output for the same op, and more 432 // generally that it is acyclic. 433 return true; 434 } 435 436 bool validatePool(const hidl_memory& pool, HalVersion ver) { 437 const auto& name = pool.name(); 438 if (name != "ashmem" && name != "mmap_fd" && 439 ((ver < HalVersion::V1_2) || 440 (name != "hardware_buffer_blob" && name != "hardware_buffer"))) { 441 LOG(ERROR) << "Unsupported memory type " << name; 442 return false; 443 } 444 if (pool.handle() == nullptr) { 445 LOG(ERROR) << "Memory of type " << name << " is null"; 446 return false; 447 } 448 return true; 449 } 450 451 static bool validatePools(const hidl_vec<hidl_memory>& pools, HalVersion ver) { 452 return std::all_of(pools.begin(), pools.end(), 453 [ver](const hidl_memory& pool) { return validatePool(pool, ver); }); 454 } 455 456 static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes, 457 const hidl_vec<Operand>& operands, OperandLifeTime lifetime) { 458 const size_t operandCount = operands.size(); 459 for (uint32_t i : indexes) { 460 if (i >= operandCount) { 461 LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount; 462 return false; 463 } 464 const Operand& operand = operands[i]; 465 if (operand.lifetime != lifetime) { 466 LOG(ERROR) << "Model input or output has lifetime of " << toString(operand.lifetime) 467 << " instead of the expected " << toString(lifetime); 468 return false; 469 } 470 } 471 472 std::vector<uint32_t> sortedIndexes = indexes; 473 std::sort(sortedIndexes.begin(), sortedIndexes.end()); 474 auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end()); 475 if (adjacentI != sortedIndexes.end()) { 476 LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI; 477 return false; 478 } 479 return true; 480 } 481 482 template <class T_Model> 483 bool validateModel(const T_Model& model) { 484 NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel"); 485 HalVersion version = ModelToHalVersion<T_Model>::version; 486 if (model.operations.size() == 0 || model.operands.size() == 0) { 487 LOG(ERROR) << "Invalid empty model."; 488 return false; 489 } 490 // We only need versioned operands for their validation. For all the other 491 // validations we can use operands upcasted to the latest version. 492 const hidl_vec<Operand> latestVersionOperands = convertToV1_2(model.operands); 493 return (validateOperands(model.operands, model.operandValues, model.pools, 494 /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) && 495 validateOperations(model.operations, latestVersionOperands) && 496 validateModelInputOutputs(model.inputIndexes, latestVersionOperands, 497 OperandLifeTime::MODEL_INPUT) && 498 validateModelInputOutputs(model.outputIndexes, latestVersionOperands, 499 OperandLifeTime::MODEL_OUTPUT) && 500 validatePools(model.pools, version)); 501 } 502 503 template bool validateModel<V1_0::Model>(const V1_0::Model& model); 504 template bool validateModel<V1_1::Model>(const V1_1::Model& model); 505 template bool validateModel<V1_2::Model>(const V1_2::Model& model); 506 507 // Validates the arguments of a request. type is either "input" or "output" and is used 508 // for printing error messages. The operandIndexes is the appropriate array of input 509 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs. 510 static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments, 511 const hidl_vec<uint32_t>& operandIndexes, 512 const hidl_vec<Operand>& operands, 513 const hidl_vec<hidl_memory>& pools, bool allowUnspecified, 514 const char* type) { 515 MemoryAccessVerifier poolVerifier(pools); 516 // The request should specify as many arguments as were described in the model. 517 const size_t requestArgumentCount = requestArguments.size(); 518 if (requestArgumentCount != operandIndexes.size()) { 519 LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type 520 << "s but the model has " << operandIndexes.size(); 521 return false; 522 } 523 for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount; 524 requestArgumentIndex++) { 525 const RequestArgument& requestArgument = requestArguments[requestArgumentIndex]; 526 const DataLocation& location = requestArgument.location; 527 // Get the operand index for this argument. We extract it from the list 528 // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs. 529 // We assume in this function that the model has been validated already. 530 const uint32_t operandIndex = operandIndexes[requestArgumentIndex]; 531 const Operand& operand = operands[operandIndex]; 532 if (requestArgument.hasNoValue) { 533 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 || 534 requestArgument.dimensions.size() != 0) { 535 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex 536 << " has no value yet has details."; 537 return false; 538 } 539 } else { 540 // Validate the location. 541 if (!poolVerifier.validate(location)) { 542 return false; 543 } 544 // If the argument specified a dimension, validate it. 545 uint32_t rank = requestArgument.dimensions.size(); 546 if (rank == 0) { 547 if (!allowUnspecified) { 548 // Validate that all the dimensions are specified in the model. 549 for (size_t i = 0; i < operand.dimensions.size(); i++) { 550 if (operand.dimensions[i] == 0) { 551 LOG(ERROR) << "Model has dimension " << i 552 << " set to 0 but the request does specify the dimension."; 553 return false; 554 } 555 } 556 } 557 } else { 558 if (rank != operand.dimensions.size()) { 559 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex 560 << " has number of dimensions (" << rank 561 << ") different than the model's (" << operand.dimensions.size() 562 << ")"; 563 return false; 564 } 565 for (size_t i = 0; i < rank; i++) { 566 if (requestArgument.dimensions[i] != operand.dimensions[i] && 567 operand.dimensions[i] != 0) { 568 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex 569 << " has dimension " << i << " of " 570 << requestArgument.dimensions[i] 571 << " different than the model's " << operand.dimensions[i]; 572 return false; 573 } 574 if (requestArgument.dimensions[i] == 0 && !allowUnspecified) { 575 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex 576 << " has dimension " << i << " of zero"; 577 return false; 578 } 579 } 580 } 581 } 582 } 583 return true; 584 } 585 586 template <class T_Model> 587 bool validateRequest(const Request& request, const T_Model& model) { 588 HalVersion version = ModelToHalVersion<T_Model>::version; 589 return (validateRequestArguments(request.inputs, model.inputIndexes, 590 convertToV1_2(model.operands), request.pools, 591 /*allowUnspecified=*/false, "input") && 592 validateRequestArguments(request.outputs, model.outputIndexes, 593 convertToV1_2(model.operands), request.pools, 594 /*allowUnspecified=*/version >= HalVersion::V1_2, "output") && 595 validatePools(request.pools, version)); 596 } 597 598 template bool validateRequest<V1_0::Model>(const Request& request, const V1_0::Model& model); 599 template bool validateRequest<V1_1::Model>(const Request& request, const V1_1::Model& model); 600 template bool validateRequest<V1_2::Model>(const Request& request, const V1_2::Model& model); 601 602 bool validateExecutionPreference(ExecutionPreference preference) { 603 return preference == ExecutionPreference::LOW_POWER || 604 preference == ExecutionPreference::FAST_SINGLE_ANSWER || 605 preference == ExecutionPreference::SUSTAINED_SPEED; 606 } 607 608 bool validOperandType(V1_0::OperandType operandType) { 609 switch (operandType) { 610 case V1_0::OperandType::FLOAT32: 611 case V1_0::OperandType::INT32: 612 case V1_0::OperandType::UINT32: 613 case V1_0::OperandType::TENSOR_FLOAT32: 614 case V1_0::OperandType::TENSOR_INT32: 615 case V1_0::OperandType::TENSOR_QUANT8_ASYMM: 616 case V1_0::OperandType::OEM: 617 case V1_0::OperandType::TENSOR_OEM_BYTE: 618 return true; 619 default: 620 return false; 621 } 622 } 623 624 bool validOperandType(V1_2::OperandType operandType) { 625 switch (operandType) { 626 case V1_2::OperandType::FLOAT16: 627 case V1_2::OperandType::FLOAT32: 628 case V1_2::OperandType::INT32: 629 case V1_2::OperandType::UINT32: 630 case V1_2::OperandType::BOOL: 631 case V1_2::OperandType::TENSOR_FLOAT16: 632 case V1_2::OperandType::TENSOR_FLOAT32: 633 case V1_2::OperandType::TENSOR_INT32: 634 case V1_2::OperandType::TENSOR_QUANT8_ASYMM: 635 case V1_2::OperandType::TENSOR_QUANT8_SYMM: 636 case V1_2::OperandType::TENSOR_QUANT16_ASYMM: 637 case V1_2::OperandType::TENSOR_QUANT16_SYMM: 638 case V1_2::OperandType::TENSOR_BOOL8: 639 case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: 640 case V1_2::OperandType::OEM: 641 case V1_2::OperandType::TENSOR_OEM_BYTE: 642 return true; 643 default: 644 return isExtensionOperandType(operandType); 645 } 646 } 647 648 } // namespace nn 649 } // namespace android 650