Home | History | Annotate | Download | only in functional
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
     18 
     19 #include <android-base/logging.h>
     20 #include <android/hidl/memory/1.0/IMemory.h>
     21 #include <ftw.h>
     22 #include <gtest/gtest.h>
     23 #include <hidlmemory/mapping.h>
     24 #include <unistd.h>
     25 
     26 #include <cstdio>
     27 #include <cstdlib>
     28 #include <random>
     29 
     30 #include "Callbacks.h"
     31 #include "GeneratedTestHarness.h"
     32 #include "TestHarness.h"
     33 #include "Utils.h"
     34 #include "VtsHalNeuralnetworks.h"
     35 
     36 namespace android {
     37 namespace hardware {
     38 namespace neuralnetworks {
     39 namespace V1_2 {
     40 namespace vts {
     41 namespace functional {
     42 
     43 using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback;
     44 using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCallback;
     45 using ::android::nn::allocateSharedMemory;
     46 using ::test_helper::MixedTypedExample;
     47 
     48 namespace float32_model {
     49 
     50 // In frameworks/ml/nn/runtime/test/generated/, creates a hidl model of float32 mobilenet.
     51 #include "examples/mobilenet_224_gender_basic_fixed.example.cpp"
     52 #include "vts_models/mobilenet_224_gender_basic_fixed.model.cpp"
     53 
     54 // Prevent the compiler from complaining about an otherwise unused function.
     55 [[maybe_unused]] auto dummy_createTestModel = createTestModel_dynamic_output_shape;
     56 [[maybe_unused]] auto dummy_get_examples = get_examples_dynamic_output_shape;
     57 
     58 // MixedTypedExample is defined in frameworks/ml/nn/tools/test_generator/include/TestHarness.h.
     59 // This function assumes the operation is always ADD.
     60 std::vector<MixedTypedExample> getLargeModelExamples(uint32_t len) {
     61     float outputValue = 1.0f + static_cast<float>(len);
     62     return {{.operands = {
     63                      // Input
     64                      {.operandDimensions = {{0, {1}}}, .float32Operands = {{0, {1.0f}}}},
     65                      // Output
     66                      {.operandDimensions = {{0, {1}}}, .float32Operands = {{0, {outputValue}}}}}}};
     67 }
     68 
     69 }  // namespace float32_model
     70 
     71 namespace quant8_model {
     72 
     73 // In frameworks/ml/nn/runtime/test/generated/, creates a hidl model of quant8 mobilenet.
     74 #include "examples/mobilenet_quantized.example.cpp"
     75 #include "vts_models/mobilenet_quantized.model.cpp"
     76 
     77 // Prevent the compiler from complaining about an otherwise unused function.
     78 [[maybe_unused]] auto dummy_createTestModel = createTestModel_dynamic_output_shape;
     79 [[maybe_unused]] auto dummy_get_examples = get_examples_dynamic_output_shape;
     80 
     81 // MixedTypedExample is defined in frameworks/ml/nn/tools/test_generator/include/TestHarness.h.
     82 // This function assumes the operation is always ADD.
     83 std::vector<MixedTypedExample> getLargeModelExamples(uint32_t len) {
     84     uint8_t outputValue = 1 + static_cast<uint8_t>(len);
     85     return {{.operands = {// Input
     86                           {.operandDimensions = {{0, {1}}}, .quant8AsymmOperands = {{0, {1}}}},
     87                           // Output
     88                           {.operandDimensions = {{0, {1}}},
     89                            .quant8AsymmOperands = {{0, {outputValue}}}}}}};
     90 }
     91 
     92 }  // namespace quant8_model
     93 
     94 namespace {
     95 
     96 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
     97 
     98 // Creates cache handles based on provided file groups.
     99 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
    100 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
    101                         const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
    102     handles->resize(fileGroups.size());
    103     for (uint32_t i = 0; i < fileGroups.size(); i++) {
    104         std::vector<int> fds;
    105         for (const auto& file : fileGroups[i]) {
    106             int fd;
    107             if (mode[i] == AccessMode::READ_ONLY) {
    108                 fd = open(file.c_str(), O_RDONLY);
    109             } else if (mode[i] == AccessMode::WRITE_ONLY) {
    110                 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
    111             } else if (mode[i] == AccessMode::READ_WRITE) {
    112                 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
    113             } else {
    114                 FAIL();
    115             }
    116             ASSERT_GE(fd, 0);
    117             fds.push_back(fd);
    118         }
    119         native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
    120         ASSERT_NE(cacheNativeHandle, nullptr);
    121         std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
    122         (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
    123     }
    124 }
    125 
    126 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
    127                         hidl_vec<hidl_handle>* handles) {
    128     createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
    129 }
    130 
    131 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
    132 // For simplicity, activation scalar is shared. The second operand is not shared
    133 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
    134 // data locations in cache.
    135 //
    136 //                --------- activation --------
    137 //                                         
    138 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
    139 //                                         
    140 //               [1]    [1]    [1]           [1]
    141 //
    142 // This function assumes the operation is either ADD or MUL.
    143 template <typename CppType, OperandType operandType>
    144 Model createLargeTestModelImpl(OperationType op, uint32_t len) {
    145     EXPECT_TRUE(op == OperationType::ADD || op == OperationType::MUL);
    146 
    147     // Model operations and operands.
    148     std::vector<Operation> operations(len);
    149     std::vector<Operand> operands(len * 2 + 2);
    150 
    151     // The constant buffer pool. This contains the activation scalar, followed by the
    152     // per-operation constant operands.
    153     std::vector<uint8_t> operandValues(sizeof(int32_t) + len * sizeof(CppType));
    154 
    155     // The activation scalar, value = 0.
    156     operands[0] = {
    157             .type = OperandType::INT32,
    158             .dimensions = {},
    159             .numberOfConsumers = len,
    160             .scale = 0.0f,
    161             .zeroPoint = 0,
    162             .lifetime = OperandLifeTime::CONSTANT_COPY,
    163             .location = {.poolIndex = 0, .offset = 0, .length = sizeof(int32_t)},
    164     };
    165     memset(operandValues.data(), 0, sizeof(int32_t));
    166 
    167     // The buffer value of the constant second operand. The logical value is always 1.0f.
    168     CppType bufferValue;
    169     // The scale of the first and second operand.
    170     float scale1, scale2;
    171     if (operandType == OperandType::TENSOR_FLOAT32) {
    172         bufferValue = 1.0f;
    173         scale1 = 0.0f;
    174         scale2 = 0.0f;
    175     } else if (op == OperationType::ADD) {
    176         bufferValue = 1;
    177         scale1 = 1.0f;
    178         scale2 = 1.0f;
    179     } else {
    180         // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
    181         // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
    182         bufferValue = 2;
    183         scale1 = 1.0f;
    184         scale2 = 0.5f;
    185     }
    186 
    187     for (uint32_t i = 0; i < len; i++) {
    188         const uint32_t firstInputIndex = i * 2 + 1;
    189         const uint32_t secondInputIndex = firstInputIndex + 1;
    190         const uint32_t outputIndex = secondInputIndex + 1;
    191 
    192         // The first operation input.
    193         operands[firstInputIndex] = {
    194                 .type = operandType,
    195                 .dimensions = {1},
    196                 .numberOfConsumers = 1,
    197                 .scale = scale1,
    198                 .zeroPoint = 0,
    199                 .lifetime = (i == 0 ? OperandLifeTime::MODEL_INPUT
    200                                     : OperandLifeTime::TEMPORARY_VARIABLE),
    201                 .location = {},
    202         };
    203 
    204         // The second operation input, value = 1.
    205         operands[secondInputIndex] = {
    206                 .type = operandType,
    207                 .dimensions = {1},
    208                 .numberOfConsumers = 1,
    209                 .scale = scale2,
    210                 .zeroPoint = 0,
    211                 .lifetime = OperandLifeTime::CONSTANT_COPY,
    212                 .location = {.poolIndex = 0,
    213                              .offset = static_cast<uint32_t>(i * sizeof(CppType) + sizeof(int32_t)),
    214                              .length = sizeof(CppType)},
    215         };
    216         memcpy(operandValues.data() + sizeof(int32_t) + i * sizeof(CppType), &bufferValue,
    217                sizeof(CppType));
    218 
    219         // The operation. All operations share the same activation scalar.
    220         // The output operand is created as an input in the next iteration of the loop, in the case
    221         // of all but the last member of the chain; and after the loop as a model output, in the
    222         // case of the last member of the chain.
    223         operations[i] = {
    224                 .type = op,
    225                 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
    226                 .outputs = {outputIndex},
    227         };
    228     }
    229 
    230     // The model output.
    231     operands.back() = {
    232             .type = operandType,
    233             .dimensions = {1},
    234             .numberOfConsumers = 0,
    235             .scale = scale1,
    236             .zeroPoint = 0,
    237             .lifetime = OperandLifeTime::MODEL_OUTPUT,
    238             .location = {},
    239     };
    240 
    241     const std::vector<uint32_t> inputIndexes = {1};
    242     const std::vector<uint32_t> outputIndexes = {len * 2 + 1};
    243     const std::vector<hidl_memory> pools = {};
    244 
    245     return {
    246             .operands = operands,
    247             .operations = operations,
    248             .inputIndexes = inputIndexes,
    249             .outputIndexes = outputIndexes,
    250             .operandValues = operandValues,
    251             .pools = pools,
    252     };
    253 }
    254 
    255 }  // namespace
    256 
    257 // Tag for the compilation caching tests.
    258 class CompilationCachingTestBase : public NeuralnetworksHidlTest {
    259   protected:
    260     CompilationCachingTestBase(OperandType type) : kOperandType(type) {}
    261 
    262     void SetUp() override {
    263         NeuralnetworksHidlTest::SetUp();
    264         ASSERT_NE(device.get(), nullptr);
    265 
    266         // Create cache directory. The cache directory and a temporary cache file is always created
    267         // to test the behavior of prepareModelFromCache, even when caching is not supported.
    268         char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
    269         char* cacheDir = mkdtemp(cacheDirTemp);
    270         ASSERT_NE(cacheDir, nullptr);
    271         mCacheDir = cacheDir;
    272         mCacheDir.push_back('/');
    273 
    274         Return<void> ret = device->getNumberOfCacheFilesNeeded(
    275                 [this](ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
    276                     EXPECT_EQ(ErrorStatus::NONE, status);
    277                     mNumModelCache = numModelCache;
    278                     mNumDataCache = numDataCache;
    279                 });
    280         EXPECT_TRUE(ret.isOk());
    281         mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
    282 
    283         // Create empty cache files.
    284         mTmpCache = mCacheDir + "tmp";
    285         for (uint32_t i = 0; i < mNumModelCache; i++) {
    286             mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
    287         }
    288         for (uint32_t i = 0; i < mNumDataCache; i++) {
    289             mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
    290         }
    291         // Dummy handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
    292         hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
    293         createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
    294         createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
    295         createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
    296 
    297         if (!mIsCachingSupported) {
    298             LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
    299                          "support compilation caching.";
    300             std::cout << "[          ]   Early termination of test because vendor service does not "
    301                          "support compilation caching."
    302                       << std::endl;
    303         }
    304     }
    305 
    306     void TearDown() override {
    307         // If the test passes, remove the tmp directory.  Otherwise, keep it for debugging purposes.
    308         if (!::testing::Test::HasFailure()) {
    309             // Recursively remove the cache directory specified by mCacheDir.
    310             auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
    311                 return remove(entry);
    312             };
    313             nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
    314         }
    315         NeuralnetworksHidlTest::TearDown();
    316     }
    317 
    318     // Model and examples creators. According to kOperandType, the following methods will return
    319     // either float32 model/examples or the quant8 variant.
    320     Model createTestModel() {
    321         if (kOperandType == OperandType::TENSOR_FLOAT32) {
    322             return float32_model::createTestModel();
    323         } else {
    324             return quant8_model::createTestModel();
    325         }
    326     }
    327 
    328     std::vector<MixedTypedExample> get_examples() {
    329         if (kOperandType == OperandType::TENSOR_FLOAT32) {
    330             return float32_model::get_examples();
    331         } else {
    332             return quant8_model::get_examples();
    333         }
    334     }
    335 
    336     Model createLargeTestModel(OperationType op, uint32_t len) {
    337         if (kOperandType == OperandType::TENSOR_FLOAT32) {
    338             return createLargeTestModelImpl<float, OperandType::TENSOR_FLOAT32>(op, len);
    339         } else {
    340             return createLargeTestModelImpl<uint8_t, OperandType::TENSOR_QUANT8_ASYMM>(op, len);
    341         }
    342     }
    343 
    344     std::vector<MixedTypedExample> getLargeModelExamples(uint32_t len) {
    345         if (kOperandType == OperandType::TENSOR_FLOAT32) {
    346             return float32_model::getLargeModelExamples(len);
    347         } else {
    348             return quant8_model::getLargeModelExamples(len);
    349         }
    350     }
    351 
    352     // See if the service can handle the model.
    353     bool isModelFullySupported(const V1_2::Model& model) {
    354         bool fullySupportsModel = false;
    355         Return<void> supportedCall = device->getSupportedOperations_1_2(
    356                 model,
    357                 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
    358                     ASSERT_EQ(ErrorStatus::NONE, status);
    359                     ASSERT_EQ(supported.size(), model.operations.size());
    360                     fullySupportsModel = std::all_of(supported.begin(), supported.end(),
    361                                                      [](bool valid) { return valid; });
    362                 });
    363         EXPECT_TRUE(supportedCall.isOk());
    364         return fullySupportsModel;
    365     }
    366 
    367     void saveModelToCache(const V1_2::Model& model, const hidl_vec<hidl_handle>& modelCache,
    368                           const hidl_vec<hidl_handle>& dataCache,
    369                           sp<IPreparedModel>* preparedModel = nullptr) {
    370         if (preparedModel != nullptr) *preparedModel = nullptr;
    371 
    372         // Launch prepare model.
    373         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
    374         ASSERT_NE(nullptr, preparedModelCallback.get());
    375         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
    376         Return<ErrorStatus> prepareLaunchStatus =
    377                 device->prepareModel_1_2(model, ExecutionPreference::FAST_SINGLE_ANSWER, modelCache,
    378                                          dataCache, cacheToken, preparedModelCallback);
    379         ASSERT_TRUE(prepareLaunchStatus.isOk());
    380         ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
    381 
    382         // Retrieve prepared model.
    383         preparedModelCallback->wait();
    384         ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
    385         if (preparedModel != nullptr) {
    386             *preparedModel =
    387                     V1_2::IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
    388                             .withDefault(nullptr);
    389         }
    390     }
    391 
    392     bool checkEarlyTermination(ErrorStatus status) {
    393         if (status == ErrorStatus::GENERAL_FAILURE) {
    394             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
    395                          "save the prepared model that it does not support.";
    396             std::cout << "[          ]   Early termination of test because vendor service cannot "
    397                          "save the prepared model that it does not support."
    398                       << std::endl;
    399             return true;
    400         }
    401         return false;
    402     }
    403 
    404     bool checkEarlyTermination(const V1_2::Model& model) {
    405         if (!isModelFullySupported(model)) {
    406             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
    407                          "prepare model that it does not support.";
    408             std::cout << "[          ]   Early termination of test because vendor service cannot "
    409                          "prepare model that it does not support."
    410                       << std::endl;
    411             return true;
    412         }
    413         return false;
    414     }
    415 
    416     void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
    417                                const hidl_vec<hidl_handle>& dataCache,
    418                                sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
    419         // Launch prepare model from cache.
    420         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
    421         ASSERT_NE(nullptr, preparedModelCallback.get());
    422         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
    423         Return<ErrorStatus> prepareLaunchStatus = device->prepareModelFromCache(
    424                 modelCache, dataCache, cacheToken, preparedModelCallback);
    425         ASSERT_TRUE(prepareLaunchStatus.isOk());
    426         if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
    427             *preparedModel = nullptr;
    428             *status = static_cast<ErrorStatus>(prepareLaunchStatus);
    429             return;
    430         }
    431 
    432         // Retrieve prepared model.
    433         preparedModelCallback->wait();
    434         *status = preparedModelCallback->getStatus();
    435         *preparedModel = V1_2::IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
    436                                  .withDefault(nullptr);
    437     }
    438 
    439     // Absolute path to the temporary cache directory.
    440     std::string mCacheDir;
    441 
    442     // Groups of file paths for model and data cache in the tmp cache directory, initialized with
    443     // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
    444     // and the inner vector is for fds held by each handle.
    445     std::vector<std::vector<std::string>> mModelCache;
    446     std::vector<std::vector<std::string>> mDataCache;
    447 
    448     // A separate temporary file path in the tmp cache directory.
    449     std::string mTmpCache;
    450 
    451     uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
    452     uint32_t mNumModelCache;
    453     uint32_t mNumDataCache;
    454     uint32_t mIsCachingSupported;
    455 
    456     // The primary data type of the testModel.
    457     const OperandType kOperandType;
    458 };
    459 
    460 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
    461 // pass running with float32 models and the second pass running with quant8 models.
    462 class CompilationCachingTest : public CompilationCachingTestBase,
    463                                public ::testing::WithParamInterface<OperandType> {
    464   protected:
    465     CompilationCachingTest() : CompilationCachingTestBase(GetParam()) {}
    466 };
    467 
    468 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
    469     // Create test HIDL model and compile.
    470     const Model testModel = createTestModel();
    471     if (checkEarlyTermination(testModel)) return;
    472     sp<IPreparedModel> preparedModel = nullptr;
    473 
    474     // Save the compilation to cache.
    475     {
    476         hidl_vec<hidl_handle> modelCache, dataCache;
    477         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    478         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    479         saveModelToCache(testModel, modelCache, dataCache);
    480     }
    481 
    482     // Retrieve preparedModel from cache.
    483     {
    484         preparedModel = nullptr;
    485         ErrorStatus status;
    486         hidl_vec<hidl_handle> modelCache, dataCache;
    487         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    488         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    489         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    490         if (!mIsCachingSupported) {
    491             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    492             ASSERT_EQ(preparedModel, nullptr);
    493             return;
    494         } else if (checkEarlyTermination(status)) {
    495             ASSERT_EQ(preparedModel, nullptr);
    496             return;
    497         } else {
    498             ASSERT_EQ(status, ErrorStatus::NONE);
    499             ASSERT_NE(preparedModel, nullptr);
    500         }
    501     }
    502 
    503     // Execute and verify results.
    504     generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; }, get_examples(),
    505                                            testModel.relaxComputationFloat32toFloat16,
    506                                            /*testDynamicOutputShape=*/false);
    507 }
    508 
    509 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
    510     // Create test HIDL model and compile.
    511     const Model testModel = createTestModel();
    512     if (checkEarlyTermination(testModel)) return;
    513     sp<IPreparedModel> preparedModel = nullptr;
    514 
    515     // Save the compilation to cache.
    516     {
    517         hidl_vec<hidl_handle> modelCache, dataCache;
    518         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    519         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    520         uint8_t dummyBytes[] = {0, 0};
    521         // Write a dummy integer to the cache.
    522         // The driver should be able to handle non-empty cache and non-zero fd offset.
    523         for (uint32_t i = 0; i < modelCache.size(); i++) {
    524             ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &dummyBytes,
    525                             sizeof(dummyBytes)),
    526                       sizeof(dummyBytes));
    527         }
    528         for (uint32_t i = 0; i < dataCache.size(); i++) {
    529             ASSERT_EQ(
    530                     write(dataCache[i].getNativeHandle()->data[0], &dummyBytes, sizeof(dummyBytes)),
    531                     sizeof(dummyBytes));
    532         }
    533         saveModelToCache(testModel, modelCache, dataCache);
    534     }
    535 
    536     // Retrieve preparedModel from cache.
    537     {
    538         preparedModel = nullptr;
    539         ErrorStatus status;
    540         hidl_vec<hidl_handle> modelCache, dataCache;
    541         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    542         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    543         uint8_t dummyByte = 0;
    544         // Advance the offset of each handle by one byte.
    545         // The driver should be able to handle non-zero fd offset.
    546         for (uint32_t i = 0; i < modelCache.size(); i++) {
    547             ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
    548         }
    549         for (uint32_t i = 0; i < dataCache.size(); i++) {
    550             ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
    551         }
    552         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    553         if (!mIsCachingSupported) {
    554             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    555             ASSERT_EQ(preparedModel, nullptr);
    556             return;
    557         } else if (checkEarlyTermination(status)) {
    558             ASSERT_EQ(preparedModel, nullptr);
    559             return;
    560         } else {
    561             ASSERT_EQ(status, ErrorStatus::NONE);
    562             ASSERT_NE(preparedModel, nullptr);
    563         }
    564     }
    565 
    566     // Execute and verify results.
    567     generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; }, get_examples(),
    568                                            testModel.relaxComputationFloat32toFloat16,
    569                                            /*testDynamicOutputShape=*/false);
    570 }
    571 
    572 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
    573     // Create test HIDL model and compile.
    574     const Model testModel = createTestModel();
    575     if (checkEarlyTermination(testModel)) return;
    576 
    577     // Test with number of model cache files greater than mNumModelCache.
    578     {
    579         hidl_vec<hidl_handle> modelCache, dataCache;
    580         // Pass an additional cache file for model cache.
    581         mModelCache.push_back({mTmpCache});
    582         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    583         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    584         mModelCache.pop_back();
    585         sp<IPreparedModel> preparedModel = nullptr;
    586         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    587         ASSERT_NE(preparedModel, nullptr);
    588         // Execute and verify results.
    589         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
    590                                                get_examples(),
    591                                                testModel.relaxComputationFloat32toFloat16,
    592                                                /*testDynamicOutputShape=*/false);
    593         // Check if prepareModelFromCache fails.
    594         preparedModel = nullptr;
    595         ErrorStatus status;
    596         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    597         if (status != ErrorStatus::INVALID_ARGUMENT) {
    598             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    599         }
    600         ASSERT_EQ(preparedModel, nullptr);
    601     }
    602 
    603     // Test with number of model cache files smaller than mNumModelCache.
    604     if (mModelCache.size() > 0) {
    605         hidl_vec<hidl_handle> modelCache, dataCache;
    606         // Pop out the last cache file.
    607         auto tmp = mModelCache.back();
    608         mModelCache.pop_back();
    609         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    610         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    611         mModelCache.push_back(tmp);
    612         sp<IPreparedModel> preparedModel = nullptr;
    613         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    614         ASSERT_NE(preparedModel, nullptr);
    615         // Execute and verify results.
    616         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
    617                                                get_examples(),
    618                                                testModel.relaxComputationFloat32toFloat16,
    619                                                /*testDynamicOutputShape=*/false);
    620         // Check if prepareModelFromCache fails.
    621         preparedModel = nullptr;
    622         ErrorStatus status;
    623         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    624         if (status != ErrorStatus::INVALID_ARGUMENT) {
    625             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    626         }
    627         ASSERT_EQ(preparedModel, nullptr);
    628     }
    629 
    630     // Test with number of data cache files greater than mNumDataCache.
    631     {
    632         hidl_vec<hidl_handle> modelCache, dataCache;
    633         // Pass an additional cache file for data cache.
    634         mDataCache.push_back({mTmpCache});
    635         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    636         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    637         mDataCache.pop_back();
    638         sp<IPreparedModel> preparedModel = nullptr;
    639         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    640         ASSERT_NE(preparedModel, nullptr);
    641         // Execute and verify results.
    642         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
    643                                                get_examples(),
    644                                                testModel.relaxComputationFloat32toFloat16,
    645                                                /*testDynamicOutputShape=*/false);
    646         // Check if prepareModelFromCache fails.
    647         preparedModel = nullptr;
    648         ErrorStatus status;
    649         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    650         if (status != ErrorStatus::INVALID_ARGUMENT) {
    651             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    652         }
    653         ASSERT_EQ(preparedModel, nullptr);
    654     }
    655 
    656     // Test with number of data cache files smaller than mNumDataCache.
    657     if (mDataCache.size() > 0) {
    658         hidl_vec<hidl_handle> modelCache, dataCache;
    659         // Pop out the last cache file.
    660         auto tmp = mDataCache.back();
    661         mDataCache.pop_back();
    662         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    663         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    664         mDataCache.push_back(tmp);
    665         sp<IPreparedModel> preparedModel = nullptr;
    666         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    667         ASSERT_NE(preparedModel, nullptr);
    668         // Execute and verify results.
    669         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
    670                                                get_examples(),
    671                                                testModel.relaxComputationFloat32toFloat16,
    672                                                /*testDynamicOutputShape=*/false);
    673         // Check if prepareModelFromCache fails.
    674         preparedModel = nullptr;
    675         ErrorStatus status;
    676         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    677         if (status != ErrorStatus::INVALID_ARGUMENT) {
    678             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    679         }
    680         ASSERT_EQ(preparedModel, nullptr);
    681     }
    682 }
    683 
    684 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
    685     // Create test HIDL model and compile.
    686     const Model testModel = createTestModel();
    687     if (checkEarlyTermination(testModel)) return;
    688 
    689     // Save the compilation to cache.
    690     {
    691         hidl_vec<hidl_handle> modelCache, dataCache;
    692         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    693         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    694         saveModelToCache(testModel, modelCache, dataCache);
    695     }
    696 
    697     // Test with number of model cache files greater than mNumModelCache.
    698     {
    699         sp<IPreparedModel> preparedModel = nullptr;
    700         ErrorStatus status;
    701         hidl_vec<hidl_handle> modelCache, dataCache;
    702         mModelCache.push_back({mTmpCache});
    703         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    704         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    705         mModelCache.pop_back();
    706         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    707         if (status != ErrorStatus::GENERAL_FAILURE) {
    708             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
    709         }
    710         ASSERT_EQ(preparedModel, nullptr);
    711     }
    712 
    713     // Test with number of model cache files smaller than mNumModelCache.
    714     if (mModelCache.size() > 0) {
    715         sp<IPreparedModel> preparedModel = nullptr;
    716         ErrorStatus status;
    717         hidl_vec<hidl_handle> modelCache, dataCache;
    718         auto tmp = mModelCache.back();
    719         mModelCache.pop_back();
    720         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    721         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    722         mModelCache.push_back(tmp);
    723         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    724         if (status != ErrorStatus::GENERAL_FAILURE) {
    725             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
    726         }
    727         ASSERT_EQ(preparedModel, nullptr);
    728     }
    729 
    730     // Test with number of data cache files greater than mNumDataCache.
    731     {
    732         sp<IPreparedModel> preparedModel = nullptr;
    733         ErrorStatus status;
    734         hidl_vec<hidl_handle> modelCache, dataCache;
    735         mDataCache.push_back({mTmpCache});
    736         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    737         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    738         mDataCache.pop_back();
    739         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    740         if (status != ErrorStatus::GENERAL_FAILURE) {
    741             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
    742         }
    743         ASSERT_EQ(preparedModel, nullptr);
    744     }
    745 
    746     // Test with number of data cache files smaller than mNumDataCache.
    747     if (mDataCache.size() > 0) {
    748         sp<IPreparedModel> preparedModel = nullptr;
    749         ErrorStatus status;
    750         hidl_vec<hidl_handle> modelCache, dataCache;
    751         auto tmp = mDataCache.back();
    752         mDataCache.pop_back();
    753         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    754         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    755         mDataCache.push_back(tmp);
    756         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    757         if (status != ErrorStatus::GENERAL_FAILURE) {
    758             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
    759         }
    760         ASSERT_EQ(preparedModel, nullptr);
    761     }
    762 }
    763 
    764 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
    765     // Create test HIDL model and compile.
    766     const Model testModel = createTestModel();
    767     if (checkEarlyTermination(testModel)) return;
    768 
    769     // Go through each handle in model cache, test with NumFd greater than 1.
    770     for (uint32_t i = 0; i < mNumModelCache; i++) {
    771         hidl_vec<hidl_handle> modelCache, dataCache;
    772         // Pass an invalid number of fds for handle i.
    773         mModelCache[i].push_back(mTmpCache);
    774         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    775         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    776         mModelCache[i].pop_back();
    777         sp<IPreparedModel> preparedModel = nullptr;
    778         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    779         ASSERT_NE(preparedModel, nullptr);
    780         // Execute and verify results.
    781         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
    782                                                get_examples(),
    783                                                testModel.relaxComputationFloat32toFloat16,
    784                                                /*testDynamicOutputShape=*/false);
    785         // Check if prepareModelFromCache fails.
    786         preparedModel = nullptr;
    787         ErrorStatus status;
    788         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    789         if (status != ErrorStatus::INVALID_ARGUMENT) {
    790             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    791         }
    792         ASSERT_EQ(preparedModel, nullptr);
    793     }
    794 
    795     // Go through each handle in model cache, test with NumFd equal to 0.
    796     for (uint32_t i = 0; i < mNumModelCache; i++) {
    797         hidl_vec<hidl_handle> modelCache, dataCache;
    798         // Pass an invalid number of fds for handle i.
    799         auto tmp = mModelCache[i].back();
    800         mModelCache[i].pop_back();
    801         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    802         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    803         mModelCache[i].push_back(tmp);
    804         sp<IPreparedModel> preparedModel = nullptr;
    805         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    806         ASSERT_NE(preparedModel, nullptr);
    807         // Execute and verify results.
    808         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
    809                                                get_examples(),
    810                                                testModel.relaxComputationFloat32toFloat16,
    811                                                /*testDynamicOutputShape=*/false);
    812         // Check if prepareModelFromCache fails.
    813         preparedModel = nullptr;
    814         ErrorStatus status;
    815         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    816         if (status != ErrorStatus::INVALID_ARGUMENT) {
    817             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    818         }
    819         ASSERT_EQ(preparedModel, nullptr);
    820     }
    821 
    822     // Go through each handle in data cache, test with NumFd greater than 1.
    823     for (uint32_t i = 0; i < mNumDataCache; i++) {
    824         hidl_vec<hidl_handle> modelCache, dataCache;
    825         // Pass an invalid number of fds for handle i.
    826         mDataCache[i].push_back(mTmpCache);
    827         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    828         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    829         mDataCache[i].pop_back();
    830         sp<IPreparedModel> preparedModel = nullptr;
    831         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    832         ASSERT_NE(preparedModel, nullptr);
    833         // Execute and verify results.
    834         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
    835                                                get_examples(),
    836                                                testModel.relaxComputationFloat32toFloat16,
    837                                                /*testDynamicOutputShape=*/false);
    838         // Check if prepareModelFromCache fails.
    839         preparedModel = nullptr;
    840         ErrorStatus status;
    841         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    842         if (status != ErrorStatus::INVALID_ARGUMENT) {
    843             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    844         }
    845         ASSERT_EQ(preparedModel, nullptr);
    846     }
    847 
    848     // Go through each handle in data cache, test with NumFd equal to 0.
    849     for (uint32_t i = 0; i < mNumDataCache; i++) {
    850         hidl_vec<hidl_handle> modelCache, dataCache;
    851         // Pass an invalid number of fds for handle i.
    852         auto tmp = mDataCache[i].back();
    853         mDataCache[i].pop_back();
    854         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    855         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    856         mDataCache[i].push_back(tmp);
    857         sp<IPreparedModel> preparedModel = nullptr;
    858         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    859         ASSERT_NE(preparedModel, nullptr);
    860         // Execute and verify results.
    861         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
    862                                                get_examples(),
    863                                                testModel.relaxComputationFloat32toFloat16,
    864                                                /*testDynamicOutputShape=*/false);
    865         // Check if prepareModelFromCache fails.
    866         preparedModel = nullptr;
    867         ErrorStatus status;
    868         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    869         if (status != ErrorStatus::INVALID_ARGUMENT) {
    870             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    871         }
    872         ASSERT_EQ(preparedModel, nullptr);
    873     }
    874 }
    875 
    876 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
    877     // Create test HIDL model and compile.
    878     const Model testModel = createTestModel();
    879     if (checkEarlyTermination(testModel)) return;
    880 
    881     // Save the compilation to cache.
    882     {
    883         hidl_vec<hidl_handle> modelCache, dataCache;
    884         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    885         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    886         saveModelToCache(testModel, modelCache, dataCache);
    887     }
    888 
    889     // Go through each handle in model cache, test with NumFd greater than 1.
    890     for (uint32_t i = 0; i < mNumModelCache; i++) {
    891         sp<IPreparedModel> preparedModel = nullptr;
    892         ErrorStatus status;
    893         hidl_vec<hidl_handle> modelCache, dataCache;
    894         mModelCache[i].push_back(mTmpCache);
    895         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    896         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    897         mModelCache[i].pop_back();
    898         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    899         if (status != ErrorStatus::GENERAL_FAILURE) {
    900             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
    901         }
    902         ASSERT_EQ(preparedModel, nullptr);
    903     }
    904 
    905     // Go through each handle in model cache, test with NumFd equal to 0.
    906     for (uint32_t i = 0; i < mNumModelCache; i++) {
    907         sp<IPreparedModel> preparedModel = nullptr;
    908         ErrorStatus status;
    909         hidl_vec<hidl_handle> modelCache, dataCache;
    910         auto tmp = mModelCache[i].back();
    911         mModelCache[i].pop_back();
    912         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    913         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    914         mModelCache[i].push_back(tmp);
    915         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    916         if (status != ErrorStatus::GENERAL_FAILURE) {
    917             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
    918         }
    919         ASSERT_EQ(preparedModel, nullptr);
    920     }
    921 
    922     // Go through each handle in data cache, test with NumFd greater than 1.
    923     for (uint32_t i = 0; i < mNumDataCache; i++) {
    924         sp<IPreparedModel> preparedModel = nullptr;
    925         ErrorStatus status;
    926         hidl_vec<hidl_handle> modelCache, dataCache;
    927         mDataCache[i].push_back(mTmpCache);
    928         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    929         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    930         mDataCache[i].pop_back();
    931         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    932         if (status != ErrorStatus::GENERAL_FAILURE) {
    933             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
    934         }
    935         ASSERT_EQ(preparedModel, nullptr);
    936     }
    937 
    938     // Go through each handle in data cache, test with NumFd equal to 0.
    939     for (uint32_t i = 0; i < mNumDataCache; i++) {
    940         sp<IPreparedModel> preparedModel = nullptr;
    941         ErrorStatus status;
    942         hidl_vec<hidl_handle> modelCache, dataCache;
    943         auto tmp = mDataCache[i].back();
    944         mDataCache[i].pop_back();
    945         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
    946         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
    947         mDataCache[i].push_back(tmp);
    948         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    949         if (status != ErrorStatus::GENERAL_FAILURE) {
    950             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
    951         }
    952         ASSERT_EQ(preparedModel, nullptr);
    953     }
    954 }
    955 
    956 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
    957     // Create test HIDL model and compile.
    958     const Model testModel = createTestModel();
    959     if (checkEarlyTermination(testModel)) return;
    960     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
    961     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
    962 
    963     // Go through each handle in model cache, test with invalid access mode.
    964     for (uint32_t i = 0; i < mNumModelCache; i++) {
    965         hidl_vec<hidl_handle> modelCache, dataCache;
    966         modelCacheMode[i] = AccessMode::READ_ONLY;
    967         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
    968         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
    969         modelCacheMode[i] = AccessMode::READ_WRITE;
    970         sp<IPreparedModel> preparedModel = nullptr;
    971         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    972         ASSERT_NE(preparedModel, nullptr);
    973         // Execute and verify results.
    974         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
    975                                                get_examples(),
    976                                                testModel.relaxComputationFloat32toFloat16,
    977                                                /*testDynamicOutputShape=*/false);
    978         // Check if prepareModelFromCache fails.
    979         preparedModel = nullptr;
    980         ErrorStatus status;
    981         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
    982         if (status != ErrorStatus::INVALID_ARGUMENT) {
    983             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
    984         }
    985         ASSERT_EQ(preparedModel, nullptr);
    986     }
    987 
    988     // Go through each handle in data cache, test with invalid access mode.
    989     for (uint32_t i = 0; i < mNumDataCache; i++) {
    990         hidl_vec<hidl_handle> modelCache, dataCache;
    991         dataCacheMode[i] = AccessMode::READ_ONLY;
    992         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
    993         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
    994         dataCacheMode[i] = AccessMode::READ_WRITE;
    995         sp<IPreparedModel> preparedModel = nullptr;
    996         saveModelToCache(testModel, modelCache, dataCache, &preparedModel);
    997         ASSERT_NE(preparedModel, nullptr);
    998         // Execute and verify results.
    999         generated_tests::EvaluatePreparedModel(preparedModel, [](int) { return false; },
   1000                                                get_examples(),
   1001                                                testModel.relaxComputationFloat32toFloat16,
   1002                                                /*testDynamicOutputShape=*/false);
   1003         // Check if prepareModelFromCache fails.
   1004         preparedModel = nullptr;
   1005         ErrorStatus status;
   1006         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
   1007         if (status != ErrorStatus::INVALID_ARGUMENT) {
   1008             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
   1009         }
   1010         ASSERT_EQ(preparedModel, nullptr);
   1011     }
   1012 }
   1013 
   1014 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
   1015     // Create test HIDL model and compile.
   1016     const Model testModel = createTestModel();
   1017     if (checkEarlyTermination(testModel)) return;
   1018     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
   1019     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
   1020 
   1021     // Save the compilation to cache.
   1022     {
   1023         hidl_vec<hidl_handle> modelCache, dataCache;
   1024         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
   1025         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1026         saveModelToCache(testModel, modelCache, dataCache);
   1027     }
   1028 
   1029     // Go through each handle in model cache, test with invalid access mode.
   1030     for (uint32_t i = 0; i < mNumModelCache; i++) {
   1031         sp<IPreparedModel> preparedModel = nullptr;
   1032         ErrorStatus status;
   1033         hidl_vec<hidl_handle> modelCache, dataCache;
   1034         modelCacheMode[i] = AccessMode::WRITE_ONLY;
   1035         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
   1036         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
   1037         modelCacheMode[i] = AccessMode::READ_WRITE;
   1038         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
   1039         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
   1040         ASSERT_EQ(preparedModel, nullptr);
   1041     }
   1042 
   1043     // Go through each handle in data cache, test with invalid access mode.
   1044     for (uint32_t i = 0; i < mNumDataCache; i++) {
   1045         sp<IPreparedModel> preparedModel = nullptr;
   1046         ErrorStatus status;
   1047         hidl_vec<hidl_handle> modelCache, dataCache;
   1048         dataCacheMode[i] = AccessMode::WRITE_ONLY;
   1049         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
   1050         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
   1051         dataCacheMode[i] = AccessMode::READ_WRITE;
   1052         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
   1053         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
   1054         ASSERT_EQ(preparedModel, nullptr);
   1055     }
   1056 }
   1057 
   1058 // Copy file contents between file groups.
   1059 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
   1060 // The outer vector sizes must match and the inner vectors must have size = 1.
   1061 static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
   1062                            const std::vector<std::vector<std::string>>& to) {
   1063     constexpr size_t kBufferSize = 1000000;
   1064     uint8_t buffer[kBufferSize];
   1065 
   1066     ASSERT_EQ(from.size(), to.size());
   1067     for (uint32_t i = 0; i < from.size(); i++) {
   1068         ASSERT_EQ(from[i].size(), 1u);
   1069         ASSERT_EQ(to[i].size(), 1u);
   1070         int fromFd = open(from[i][0].c_str(), O_RDONLY);
   1071         int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
   1072         ASSERT_GE(fromFd, 0);
   1073         ASSERT_GE(toFd, 0);
   1074 
   1075         ssize_t readBytes;
   1076         while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
   1077             ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
   1078         }
   1079         ASSERT_GE(readBytes, 0);
   1080 
   1081         close(fromFd);
   1082         close(toFd);
   1083     }
   1084 }
   1085 
   1086 // Number of operations in the large test model.
   1087 constexpr uint32_t kLargeModelSize = 100;
   1088 constexpr uint32_t kNumIterationsTOCTOU = 100;
   1089 
   1090 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
   1091     if (!mIsCachingSupported) return;
   1092 
   1093     // Create test models and check if fully supported by the service.
   1094     const Model testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
   1095     if (checkEarlyTermination(testModelMul)) return;
   1096     const Model testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
   1097     if (checkEarlyTermination(testModelAdd)) return;
   1098 
   1099     // Save the testModelMul compilation to cache.
   1100     auto modelCacheMul = mModelCache;
   1101     for (auto& cache : modelCacheMul) {
   1102         cache[0].append("_mul");
   1103     }
   1104     {
   1105         hidl_vec<hidl_handle> modelCache, dataCache;
   1106         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
   1107         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1108         saveModelToCache(testModelMul, modelCache, dataCache);
   1109     }
   1110 
   1111     // Use a different token for testModelAdd.
   1112     mToken[0]++;
   1113 
   1114     // This test is probabilistic, so we run it multiple times.
   1115     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
   1116         // Save the testModelAdd compilation to cache.
   1117         {
   1118             hidl_vec<hidl_handle> modelCache, dataCache;
   1119             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
   1120             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1121 
   1122             // Spawn a thread to copy the cache content concurrently while saving to cache.
   1123             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
   1124             saveModelToCache(testModelAdd, modelCache, dataCache);
   1125             thread.join();
   1126         }
   1127 
   1128         // Retrieve preparedModel from cache.
   1129         {
   1130             sp<IPreparedModel> preparedModel = nullptr;
   1131             ErrorStatus status;
   1132             hidl_vec<hidl_handle> modelCache, dataCache;
   1133             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
   1134             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1135             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
   1136 
   1137             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
   1138             // the prepared model must be executed with the correct result and not crash.
   1139             if (status != ErrorStatus::NONE) {
   1140                 ASSERT_EQ(preparedModel, nullptr);
   1141             } else {
   1142                 ASSERT_NE(preparedModel, nullptr);
   1143                 generated_tests::EvaluatePreparedModel(
   1144                         preparedModel, [](int) { return false; },
   1145                         getLargeModelExamples(kLargeModelSize),
   1146                         testModelAdd.relaxComputationFloat32toFloat16,
   1147                         /*testDynamicOutputShape=*/false);
   1148             }
   1149         }
   1150     }
   1151 }
   1152 
   1153 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
   1154     if (!mIsCachingSupported) return;
   1155 
   1156     // Create test models and check if fully supported by the service.
   1157     const Model testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
   1158     if (checkEarlyTermination(testModelMul)) return;
   1159     const Model testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
   1160     if (checkEarlyTermination(testModelAdd)) return;
   1161 
   1162     // Save the testModelMul compilation to cache.
   1163     auto modelCacheMul = mModelCache;
   1164     for (auto& cache : modelCacheMul) {
   1165         cache[0].append("_mul");
   1166     }
   1167     {
   1168         hidl_vec<hidl_handle> modelCache, dataCache;
   1169         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
   1170         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1171         saveModelToCache(testModelMul, modelCache, dataCache);
   1172     }
   1173 
   1174     // Use a different token for testModelAdd.
   1175     mToken[0]++;
   1176 
   1177     // This test is probabilistic, so we run it multiple times.
   1178     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
   1179         // Save the testModelAdd compilation to cache.
   1180         {
   1181             hidl_vec<hidl_handle> modelCache, dataCache;
   1182             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
   1183             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1184             saveModelToCache(testModelAdd, modelCache, dataCache);
   1185         }
   1186 
   1187         // Retrieve preparedModel from cache.
   1188         {
   1189             sp<IPreparedModel> preparedModel = nullptr;
   1190             ErrorStatus status;
   1191             hidl_vec<hidl_handle> modelCache, dataCache;
   1192             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
   1193             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1194 
   1195             // Spawn a thread to copy the cache content concurrently while preparing from cache.
   1196             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
   1197             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
   1198             thread.join();
   1199 
   1200             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
   1201             // the prepared model must be executed with the correct result and not crash.
   1202             if (status != ErrorStatus::NONE) {
   1203                 ASSERT_EQ(preparedModel, nullptr);
   1204             } else {
   1205                 ASSERT_NE(preparedModel, nullptr);
   1206                 generated_tests::EvaluatePreparedModel(
   1207                         preparedModel, [](int) { return false; },
   1208                         getLargeModelExamples(kLargeModelSize),
   1209                         testModelAdd.relaxComputationFloat32toFloat16,
   1210                         /*testDynamicOutputShape=*/false);
   1211             }
   1212         }
   1213     }
   1214 }
   1215 
   1216 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
   1217     if (!mIsCachingSupported) return;
   1218 
   1219     // Create test models and check if fully supported by the service.
   1220     const Model testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
   1221     if (checkEarlyTermination(testModelMul)) return;
   1222     const Model testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
   1223     if (checkEarlyTermination(testModelAdd)) return;
   1224 
   1225     // Save the testModelMul compilation to cache.
   1226     auto modelCacheMul = mModelCache;
   1227     for (auto& cache : modelCacheMul) {
   1228         cache[0].append("_mul");
   1229     }
   1230     {
   1231         hidl_vec<hidl_handle> modelCache, dataCache;
   1232         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
   1233         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1234         saveModelToCache(testModelMul, modelCache, dataCache);
   1235     }
   1236 
   1237     // Use a different token for testModelAdd.
   1238     mToken[0]++;
   1239 
   1240     // Save the testModelAdd compilation to cache.
   1241     {
   1242         hidl_vec<hidl_handle> modelCache, dataCache;
   1243         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
   1244         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1245         saveModelToCache(testModelAdd, modelCache, dataCache);
   1246     }
   1247 
   1248     // Replace the model cache of testModelAdd with testModelMul.
   1249     copyCacheFiles(modelCacheMul, mModelCache);
   1250 
   1251     // Retrieve the preparedModel from cache, expect failure.
   1252     {
   1253         sp<IPreparedModel> preparedModel = nullptr;
   1254         ErrorStatus status;
   1255         hidl_vec<hidl_handle> modelCache, dataCache;
   1256         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
   1257         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1258         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
   1259         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
   1260         ASSERT_EQ(preparedModel, nullptr);
   1261     }
   1262 }
   1263 
   1264 static const auto kOperandTypeChoices =
   1265         ::testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
   1266 
   1267 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest, kOperandTypeChoices);
   1268 
   1269 class CompilationCachingSecurityTest
   1270     : public CompilationCachingTestBase,
   1271       public ::testing::WithParamInterface<std::tuple<OperandType, uint32_t>> {
   1272   protected:
   1273     CompilationCachingSecurityTest() : CompilationCachingTestBase(std::get<0>(GetParam())) {}
   1274 
   1275     void SetUp() {
   1276         CompilationCachingTestBase::SetUp();
   1277         generator.seed(kSeed);
   1278     }
   1279 
   1280     // Get a random integer within a closed range [lower, upper].
   1281     template <typename T>
   1282     T getRandomInt(T lower, T upper) {
   1283         std::uniform_int_distribution<T> dis(lower, upper);
   1284         return dis(generator);
   1285     }
   1286 
   1287     // Randomly flip one single bit of the cache entry.
   1288     void flipOneBitOfCache(const std::string& filename, bool* skip) {
   1289         FILE* pFile = fopen(filename.c_str(), "r+");
   1290         ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
   1291         long int fileSize = ftell(pFile);
   1292         if (fileSize == 0) {
   1293             fclose(pFile);
   1294             *skip = true;
   1295             return;
   1296         }
   1297         ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
   1298         int readByte = fgetc(pFile);
   1299         ASSERT_NE(readByte, EOF);
   1300         ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
   1301         ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
   1302         fclose(pFile);
   1303         *skip = false;
   1304     }
   1305 
   1306     // Randomly append bytes to the cache entry.
   1307     void appendBytesToCache(const std::string& filename, bool* skip) {
   1308         FILE* pFile = fopen(filename.c_str(), "a");
   1309         uint32_t appendLength = getRandomInt(1, 256);
   1310         for (uint32_t i = 0; i < appendLength; i++) {
   1311             ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
   1312         }
   1313         fclose(pFile);
   1314         *skip = false;
   1315     }
   1316 
   1317     enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
   1318 
   1319     // Test if the driver behaves as expected when given corrupted cache or token.
   1320     // The modifier will be invoked after save to cache but before prepare from cache.
   1321     // The modifier accepts one pointer argument "skip" as the returning value, indicating
   1322     // whether the test should be skipped or not.
   1323     void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
   1324         const Model testModel = createTestModel();
   1325         if (checkEarlyTermination(testModel)) return;
   1326 
   1327         // Save the compilation to cache.
   1328         {
   1329             hidl_vec<hidl_handle> modelCache, dataCache;
   1330             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
   1331             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1332             saveModelToCache(testModel, modelCache, dataCache);
   1333         }
   1334 
   1335         bool skip = false;
   1336         modifier(&skip);
   1337         if (skip) return;
   1338 
   1339         // Retrieve preparedModel from cache.
   1340         {
   1341             sp<IPreparedModel> preparedModel = nullptr;
   1342             ErrorStatus status;
   1343             hidl_vec<hidl_handle> modelCache, dataCache;
   1344             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
   1345             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
   1346             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
   1347 
   1348             switch (expected) {
   1349                 case ExpectedResult::GENERAL_FAILURE:
   1350                     ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
   1351                     ASSERT_EQ(preparedModel, nullptr);
   1352                     break;
   1353                 case ExpectedResult::NOT_CRASH:
   1354                     ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
   1355                     break;
   1356                 default:
   1357                     FAIL();
   1358             }
   1359         }
   1360     }
   1361 
   1362     const uint32_t kSeed = std::get<1>(GetParam());
   1363     std::mt19937 generator;
   1364 };
   1365 
   1366 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
   1367     if (!mIsCachingSupported) return;
   1368     for (uint32_t i = 0; i < mNumModelCache; i++) {
   1369         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
   1370                            [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
   1371     }
   1372 }
   1373 
   1374 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
   1375     if (!mIsCachingSupported) return;
   1376     for (uint32_t i = 0; i < mNumModelCache; i++) {
   1377         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
   1378                            [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
   1379     }
   1380 }
   1381 
   1382 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
   1383     if (!mIsCachingSupported) return;
   1384     for (uint32_t i = 0; i < mNumDataCache; i++) {
   1385         testCorruptedCache(ExpectedResult::NOT_CRASH,
   1386                            [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
   1387     }
   1388 }
   1389 
   1390 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
   1391     if (!mIsCachingSupported) return;
   1392     for (uint32_t i = 0; i < mNumDataCache; i++) {
   1393         testCorruptedCache(ExpectedResult::NOT_CRASH,
   1394                            [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
   1395     }
   1396 }
   1397 
   1398 TEST_P(CompilationCachingSecurityTest, WrongToken) {
   1399     if (!mIsCachingSupported) return;
   1400     testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
   1401         // Randomly flip one single bit in mToken.
   1402         uint32_t ind =
   1403                 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
   1404         mToken[ind] ^= (1U << getRandomInt(0, 7));
   1405         *skip = false;
   1406     });
   1407 }
   1408 
   1409 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingSecurityTest,
   1410                         ::testing::Combine(kOperandTypeChoices, ::testing::Range(0U, 10U)));
   1411 
   1412 }  // namespace functional
   1413 }  // namespace vts
   1414 }  // namespace V1_2
   1415 }  // namespace neuralnetworks
   1416 }  // namespace hardware
   1417 }  // namespace android
   1418