1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "neuralnetworks_hidl_hal_test" 18 19 #include "VtsHalNeuralnetworks.h" 20 21 #include "Callbacks.h" 22 23 namespace android { 24 namespace hardware { 25 namespace neuralnetworks { 26 namespace V1_0 { 27 namespace vts { 28 namespace functional { 29 30 using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback; 31 using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback; 32 33 ///////////////////////// UTILITY FUNCTIONS ///////////////////////// 34 35 static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message, 36 const V1_0::Model& model) { 37 SCOPED_TRACE(message + " [getSupportedOperations]"); 38 39 Return<void> ret = 40 device->getSupportedOperations(model, [&](ErrorStatus status, const hidl_vec<bool>&) { 41 EXPECT_EQ(ErrorStatus::INVALID_ARGUMENT, status); 42 }); 43 EXPECT_TRUE(ret.isOk()); 44 } 45 46 static void validatePrepareModel(const sp<IDevice>& device, const std::string& message, 47 const V1_0::Model& model) { 48 SCOPED_TRACE(message + " [prepareModel]"); 49 50 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback(); 51 ASSERT_NE(nullptr, preparedModelCallback.get()); 52 Return<ErrorStatus> prepareLaunchStatus = device->prepareModel(model, preparedModelCallback); 53 ASSERT_TRUE(prepareLaunchStatus.isOk()); 54 ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus)); 55 56 preparedModelCallback->wait(); 57 ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus(); 58 ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus); 59 sp<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel(); 60 ASSERT_EQ(nullptr, preparedModel.get()); 61 } 62 63 // Primary validation function. This function will take a valid model, apply a 64 // mutation to it to invalidate the model, then pass it to interface calls that 65 // use the model. Note that the model here is passed by value, and any mutation 66 // to the model does not leave this function. 67 static void validate(const sp<IDevice>& device, const std::string& message, V1_0::Model model, 68 const std::function<void(Model*)>& mutation) { 69 mutation(&model); 70 validateGetSupportedOperations(device, message, model); 71 validatePrepareModel(device, message, model); 72 } 73 74 // Delete element from hidl_vec. hidl_vec doesn't support a "remove" operation, 75 // so this is efficiently accomplished by moving the element to the end and 76 // resizing the hidl_vec to one less. 77 template <typename Type> 78 static void hidl_vec_removeAt(hidl_vec<Type>* vec, uint32_t index) { 79 if (vec) { 80 std::rotate(vec->begin() + index, vec->begin() + index + 1, vec->end()); 81 vec->resize(vec->size() - 1); 82 } 83 } 84 85 template <typename Type> 86 static uint32_t hidl_vec_push_back(hidl_vec<Type>* vec, const Type& value) { 87 // assume vec is valid 88 const uint32_t index = vec->size(); 89 vec->resize(index + 1); 90 (*vec)[index] = value; 91 return index; 92 } 93 94 static uint32_t addOperand(Model* model) { 95 return hidl_vec_push_back(&model->operands, 96 { 97 .type = OperandType::INT32, 98 .dimensions = {}, 99 .numberOfConsumers = 0, 100 .scale = 0.0f, 101 .zeroPoint = 0, 102 .lifetime = OperandLifeTime::MODEL_INPUT, 103 .location = {.poolIndex = 0, .offset = 0, .length = 0}, 104 }); 105 } 106 107 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) { 108 uint32_t index = addOperand(model); 109 model->operands[index].numberOfConsumers = 1; 110 model->operands[index].lifetime = lifetime; 111 return index; 112 } 113 114 ///////////////////////// VALIDATE MODEL OPERAND TYPE ///////////////////////// 115 116 static const int32_t invalidOperandTypes[] = { 117 static_cast<int32_t>(OperandType::FLOAT32) - 1, // lower bound fundamental 118 static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) + 1, // upper bound fundamental 119 static_cast<int32_t>(OperandType::OEM) - 1, // lower bound OEM 120 static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) + 1, // upper bound OEM 121 }; 122 123 static void mutateOperandTypeTest(const sp<IDevice>& device, const V1_0::Model& model) { 124 for (size_t operand = 0; operand < model.operands.size(); ++operand) { 125 for (int32_t invalidOperandType : invalidOperandTypes) { 126 const std::string message = "mutateOperandTypeTest: operand " + 127 std::to_string(operand) + " set to value " + 128 std::to_string(invalidOperandType); 129 validate(device, message, model, [operand, invalidOperandType](Model* model) { 130 model->operands[operand].type = static_cast<OperandType>(invalidOperandType); 131 }); 132 } 133 } 134 } 135 136 ///////////////////////// VALIDATE OPERAND RANK ///////////////////////// 137 138 static uint32_t getInvalidRank(OperandType type) { 139 switch (type) { 140 case OperandType::FLOAT32: 141 case OperandType::INT32: 142 case OperandType::UINT32: 143 return 1; 144 case OperandType::TENSOR_FLOAT32: 145 case OperandType::TENSOR_INT32: 146 case OperandType::TENSOR_QUANT8_ASYMM: 147 return 0; 148 default: 149 return 0; 150 } 151 } 152 153 static void mutateOperandRankTest(const sp<IDevice>& device, const V1_0::Model& model) { 154 for (size_t operand = 0; operand < model.operands.size(); ++operand) { 155 const uint32_t invalidRank = getInvalidRank(model.operands[operand].type); 156 const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) + 157 " has rank of " + std::to_string(invalidRank); 158 validate(device, message, model, [operand, invalidRank](Model* model) { 159 model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0); 160 }); 161 } 162 } 163 164 ///////////////////////// VALIDATE OPERAND SCALE ///////////////////////// 165 166 static float getInvalidScale(OperandType type) { 167 switch (type) { 168 case OperandType::FLOAT32: 169 case OperandType::INT32: 170 case OperandType::UINT32: 171 case OperandType::TENSOR_FLOAT32: 172 return 1.0f; 173 case OperandType::TENSOR_INT32: 174 return -1.0f; 175 case OperandType::TENSOR_QUANT8_ASYMM: 176 return 0.0f; 177 default: 178 return 0.0f; 179 } 180 } 181 182 static void mutateOperandScaleTest(const sp<IDevice>& device, const V1_0::Model& model) { 183 for (size_t operand = 0; operand < model.operands.size(); ++operand) { 184 const float invalidScale = getInvalidScale(model.operands[operand].type); 185 const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) + 186 " has scale of " + std::to_string(invalidScale); 187 validate(device, message, model, [operand, invalidScale](Model* model) { 188 model->operands[operand].scale = invalidScale; 189 }); 190 } 191 } 192 193 ///////////////////////// VALIDATE OPERAND ZERO POINT ///////////////////////// 194 195 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) { 196 switch (type) { 197 case OperandType::FLOAT32: 198 case OperandType::INT32: 199 case OperandType::UINT32: 200 case OperandType::TENSOR_FLOAT32: 201 case OperandType::TENSOR_INT32: 202 return {1}; 203 case OperandType::TENSOR_QUANT8_ASYMM: 204 return {-1, 256}; 205 default: 206 return {}; 207 } 208 } 209 210 static void mutateOperandZeroPointTest(const sp<IDevice>& device, const V1_0::Model& model) { 211 for (size_t operand = 0; operand < model.operands.size(); ++operand) { 212 const std::vector<int32_t> invalidZeroPoints = 213 getInvalidZeroPoints(model.operands[operand].type); 214 for (int32_t invalidZeroPoint : invalidZeroPoints) { 215 const std::string message = "mutateOperandZeroPointTest: operand " + 216 std::to_string(operand) + " has zero point of " + 217 std::to_string(invalidZeroPoint); 218 validate(device, message, model, [operand, invalidZeroPoint](Model* model) { 219 model->operands[operand].zeroPoint = invalidZeroPoint; 220 }); 221 } 222 } 223 } 224 225 ///////////////////////// VALIDATE EXTRA ??? ///////////////////////// 226 227 // TODO: Operand::lifetime 228 // TODO: Operand::location 229 230 ///////////////////////// VALIDATE OPERATION OPERAND TYPE ///////////////////////// 231 232 static void mutateOperand(Operand* operand, OperandType type) { 233 Operand newOperand = *operand; 234 newOperand.type = type; 235 switch (type) { 236 case OperandType::FLOAT32: 237 case OperandType::INT32: 238 case OperandType::UINT32: 239 newOperand.dimensions = hidl_vec<uint32_t>(); 240 newOperand.scale = 0.0f; 241 newOperand.zeroPoint = 0; 242 break; 243 case OperandType::TENSOR_FLOAT32: 244 newOperand.dimensions = 245 operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1}); 246 newOperand.scale = 0.0f; 247 newOperand.zeroPoint = 0; 248 break; 249 case OperandType::TENSOR_INT32: 250 newOperand.dimensions = 251 operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1}); 252 newOperand.zeroPoint = 0; 253 break; 254 case OperandType::TENSOR_QUANT8_ASYMM: 255 newOperand.dimensions = 256 operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1}); 257 newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f; 258 break; 259 case OperandType::OEM: 260 case OperandType::TENSOR_OEM_BYTE: 261 default: 262 break; 263 } 264 *operand = newOperand; 265 } 266 267 static bool mutateOperationOperandTypeSkip(size_t operand, const V1_0::Model& model) { 268 // LSH_PROJECTION's second argument is allowed to have any type. This is the 269 // only operation that currently has a type that can be anything independent 270 // from any other type. Changing the operand type to any other type will 271 // result in a valid model for LSH_PROJECTION. If this is the case, skip the 272 // test. 273 for (const Operation& operation : model.operations) { 274 if (operation.type == OperationType::LSH_PROJECTION && operand == operation.inputs[1]) { 275 return true; 276 } 277 } 278 return false; 279 } 280 281 static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const V1_0::Model& model) { 282 for (size_t operand = 0; operand < model.operands.size(); ++operand) { 283 if (mutateOperationOperandTypeSkip(operand, model)) { 284 continue; 285 } 286 for (OperandType invalidOperandType : hidl_enum_range<OperandType>{}) { 287 // Do not test OEM types 288 if (invalidOperandType == model.operands[operand].type || 289 invalidOperandType == OperandType::OEM || 290 invalidOperandType == OperandType::TENSOR_OEM_BYTE) { 291 continue; 292 } 293 const std::string message = "mutateOperationOperandTypeTest: operand " + 294 std::to_string(operand) + " set to type " + 295 toString(invalidOperandType); 296 validate(device, message, model, [operand, invalidOperandType](Model* model) { 297 mutateOperand(&model->operands[operand], invalidOperandType); 298 }); 299 } 300 } 301 } 302 303 ///////////////////////// VALIDATE MODEL OPERATION TYPE ///////////////////////// 304 305 static const int32_t invalidOperationTypes[] = { 306 static_cast<int32_t>(OperationType::ADD) - 1, // lower bound fundamental 307 static_cast<int32_t>(OperationType::TANH) + 1, // upper bound fundamental 308 static_cast<int32_t>(OperationType::OEM_OPERATION) - 1, // lower bound OEM 309 static_cast<int32_t>(OperationType::OEM_OPERATION) + 1, // upper bound OEM 310 }; 311 312 static void mutateOperationTypeTest(const sp<IDevice>& device, const V1_0::Model& model) { 313 for (size_t operation = 0; operation < model.operations.size(); ++operation) { 314 for (int32_t invalidOperationType : invalidOperationTypes) { 315 const std::string message = "mutateOperationTypeTest: operation " + 316 std::to_string(operation) + " set to value " + 317 std::to_string(invalidOperationType); 318 validate(device, message, model, [operation, invalidOperationType](Model* model) { 319 model->operations[operation].type = 320 static_cast<OperationType>(invalidOperationType); 321 }); 322 } 323 } 324 } 325 326 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX ///////////////////////// 327 328 static void mutateOperationInputOperandIndexTest(const sp<IDevice>& device, 329 const V1_0::Model& model) { 330 for (size_t operation = 0; operation < model.operations.size(); ++operation) { 331 const uint32_t invalidOperand = model.operands.size(); 332 for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) { 333 const std::string message = "mutateOperationInputOperandIndexTest: operation " + 334 std::to_string(operation) + " input " + 335 std::to_string(input); 336 validate(device, message, model, [operation, input, invalidOperand](Model* model) { 337 model->operations[operation].inputs[input] = invalidOperand; 338 }); 339 } 340 } 341 } 342 343 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX ///////////////////////// 344 345 static void mutateOperationOutputOperandIndexTest(const sp<IDevice>& device, 346 const V1_0::Model& model) { 347 for (size_t operation = 0; operation < model.operations.size(); ++operation) { 348 const uint32_t invalidOperand = model.operands.size(); 349 for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) { 350 const std::string message = "mutateOperationOutputOperandIndexTest: operation " + 351 std::to_string(operation) + " output " + 352 std::to_string(output); 353 validate(device, message, model, [operation, output, invalidOperand](Model* model) { 354 model->operations[operation].outputs[output] = invalidOperand; 355 }); 356 } 357 } 358 } 359 360 ///////////////////////// REMOVE OPERAND FROM EVERYTHING ///////////////////////// 361 362 static void removeValueAndDecrementGreaterValues(hidl_vec<uint32_t>* vec, uint32_t value) { 363 if (vec) { 364 // remove elements matching "value" 365 auto last = std::remove(vec->begin(), vec->end(), value); 366 vec->resize(std::distance(vec->begin(), last)); 367 368 // decrement elements exceeding "value" 369 std::transform(vec->begin(), vec->end(), vec->begin(), 370 [value](uint32_t v) { return v > value ? v-- : v; }); 371 } 372 } 373 374 static void removeOperand(Model* model, uint32_t index) { 375 hidl_vec_removeAt(&model->operands, index); 376 for (Operation& operation : model->operations) { 377 removeValueAndDecrementGreaterValues(&operation.inputs, index); 378 removeValueAndDecrementGreaterValues(&operation.outputs, index); 379 } 380 removeValueAndDecrementGreaterValues(&model->inputIndexes, index); 381 removeValueAndDecrementGreaterValues(&model->outputIndexes, index); 382 } 383 384 static void removeOperandTest(const sp<IDevice>& device, const V1_0::Model& model) { 385 for (size_t operand = 0; operand < model.operands.size(); ++operand) { 386 const std::string message = "removeOperandTest: operand " + std::to_string(operand); 387 validate(device, message, model, 388 [operand](Model* model) { removeOperand(model, operand); }); 389 } 390 } 391 392 ///////////////////////// REMOVE OPERATION ///////////////////////// 393 394 static void removeOperation(Model* model, uint32_t index) { 395 for (uint32_t operand : model->operations[index].inputs) { 396 model->operands[operand].numberOfConsumers--; 397 } 398 hidl_vec_removeAt(&model->operations, index); 399 } 400 401 static void removeOperationTest(const sp<IDevice>& device, const V1_0::Model& model) { 402 for (size_t operation = 0; operation < model.operations.size(); ++operation) { 403 const std::string message = "removeOperationTest: operation " + std::to_string(operation); 404 validate(device, message, model, 405 [operation](Model* model) { removeOperation(model, operation); }); 406 } 407 } 408 409 ///////////////////////// REMOVE OPERATION INPUT ///////////////////////// 410 411 static void removeOperationInputTest(const sp<IDevice>& device, const V1_0::Model& model) { 412 for (size_t operation = 0; operation < model.operations.size(); ++operation) { 413 for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) { 414 const V1_0::Operation& op = model.operations[operation]; 415 // CONCATENATION has at least 2 inputs, with the last element being 416 // INT32. Skip this test if removing one of CONCATENATION's 417 // inputs still produces a valid model. 418 if (op.type == V1_0::OperationType::CONCATENATION && op.inputs.size() > 2 && 419 input != op.inputs.size() - 1) { 420 continue; 421 } 422 const std::string message = "removeOperationInputTest: operation " + 423 std::to_string(operation) + ", input " + 424 std::to_string(input); 425 validate(device, message, model, [operation, input](Model* model) { 426 uint32_t operand = model->operations[operation].inputs[input]; 427 model->operands[operand].numberOfConsumers--; 428 hidl_vec_removeAt(&model->operations[operation].inputs, input); 429 }); 430 } 431 } 432 } 433 434 ///////////////////////// REMOVE OPERATION OUTPUT ///////////////////////// 435 436 static void removeOperationOutputTest(const sp<IDevice>& device, const V1_0::Model& model) { 437 for (size_t operation = 0; operation < model.operations.size(); ++operation) { 438 for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) { 439 const std::string message = "removeOperationOutputTest: operation " + 440 std::to_string(operation) + ", output " + 441 std::to_string(output); 442 validate(device, message, model, [operation, output](Model* model) { 443 hidl_vec_removeAt(&model->operations[operation].outputs, output); 444 }); 445 } 446 } 447 } 448 449 ///////////////////////// MODEL VALIDATION ///////////////////////// 450 451 // TODO: remove model input 452 // TODO: remove model output 453 // TODO: add unused operation 454 455 ///////////////////////// ADD OPERATION INPUT ///////////////////////// 456 457 static void addOperationInputTest(const sp<IDevice>& device, const V1_0::Model& model) { 458 for (size_t operation = 0; operation < model.operations.size(); ++operation) { 459 const std::string message = "addOperationInputTest: operation " + std::to_string(operation); 460 validate(device, message, model, [operation](Model* model) { 461 uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT); 462 hidl_vec_push_back(&model->operations[operation].inputs, index); 463 hidl_vec_push_back(&model->inputIndexes, index); 464 }); 465 } 466 } 467 468 ///////////////////////// ADD OPERATION OUTPUT ///////////////////////// 469 470 static void addOperationOutputTest(const sp<IDevice>& device, const V1_0::Model& model) { 471 for (size_t operation = 0; operation < model.operations.size(); ++operation) { 472 const std::string message = 473 "addOperationOutputTest: operation " + std::to_string(operation); 474 validate(device, message, model, [operation](Model* model) { 475 uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT); 476 hidl_vec_push_back(&model->operations[operation].outputs, index); 477 hidl_vec_push_back(&model->outputIndexes, index); 478 }); 479 } 480 } 481 482 ////////////////////////// ENTRY POINT ////////////////////////////// 483 484 void ValidationTest::validateModel(const V1_0::Model& model) { 485 mutateOperandTypeTest(device, model); 486 mutateOperandRankTest(device, model); 487 mutateOperandScaleTest(device, model); 488 mutateOperandZeroPointTest(device, model); 489 mutateOperationOperandTypeTest(device, model); 490 mutateOperationTypeTest(device, model); 491 mutateOperationInputOperandIndexTest(device, model); 492 mutateOperationOutputOperandIndexTest(device, model); 493 removeOperandTest(device, model); 494 removeOperationTest(device, model); 495 removeOperationInputTest(device, model); 496 removeOperationOutputTest(device, model); 497 addOperationInputTest(device, model); 498 addOperationOutputTest(device, model); 499 } 500 501 } // namespace functional 502 } // namespace vts 503 } // namespace V1_0 504 } // namespace neuralnetworks 505 } // namespace hardware 506 } // namespace android 507