Home | History | Annotate | Download | only in functional
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
     18 
     19 #include "VtsHalNeuralnetworks.h"
     20 
     21 #include "Callbacks.h"
     22 #include "ExecutionBurstController.h"
     23 #include "TestHarness.h"
     24 #include "Utils.h"
     25 
     26 #include <android-base/logging.h>
     27 #include <android/hidl/memory/1.0/IMemory.h>
     28 #include <hidlmemory/mapping.h>
     29 
     30 namespace android {
     31 namespace hardware {
     32 namespace neuralnetworks {
     33 namespace V1_2 {
     34 namespace vts {
     35 namespace functional {
     36 
     37 using ::android::hardware::neuralnetworks::V1_2::implementation::ExecutionCallback;
     38 using ::android::hidl::memory::V1_0::IMemory;
     39 using test_helper::for_all;
     40 using test_helper::MixedTyped;
     41 using test_helper::MixedTypedExample;
     42 
     43 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
     44 
     45 static bool badTiming(Timing timing) {
     46     return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
     47 }
     48 
     49 // Primary validation function. This function will take a valid request, apply a
     50 // mutation to it to invalidate the request, then pass it to interface calls
     51 // that use the request. Note that the request here is passed by value, and any
     52 // mutation to the request does not leave this function.
     53 static void validate(const sp<IPreparedModel>& preparedModel, const std::string& message,
     54                      Request request, const std::function<void(Request*)>& mutation) {
     55     mutation(&request);
     56 
     57     // We'd like to test both with timing requested and without timing
     58     // requested. Rather than running each test both ways, we'll decide whether
     59     // to request timing by hashing the message. We do not use std::hash because
     60     // it is not guaranteed stable across executions.
     61     char hash = 0;
     62     for (auto c : message) {
     63         hash ^= c;
     64     };
     65     MeasureTiming measure = (hash & 1) ? MeasureTiming::YES : MeasureTiming::NO;
     66 
     67     // asynchronous
     68     {
     69         SCOPED_TRACE(message + " [execute_1_2]");
     70 
     71         sp<ExecutionCallback> executionCallback = new ExecutionCallback();
     72         ASSERT_NE(nullptr, executionCallback.get());
     73         Return<ErrorStatus> executeLaunchStatus =
     74                 preparedModel->execute_1_2(request, measure, executionCallback);
     75         ASSERT_TRUE(executeLaunchStatus.isOk());
     76         ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(executeLaunchStatus));
     77 
     78         executionCallback->wait();
     79         ErrorStatus executionReturnStatus = executionCallback->getStatus();
     80         const auto& outputShapes = executionCallback->getOutputShapes();
     81         Timing timing = executionCallback->getTiming();
     82         ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, executionReturnStatus);
     83         ASSERT_EQ(outputShapes.size(), 0);
     84         ASSERT_TRUE(badTiming(timing));
     85     }
     86 
     87     // synchronous
     88     {
     89         SCOPED_TRACE(message + " [executeSynchronously]");
     90 
     91         Return<void> executeStatus = preparedModel->executeSynchronously(
     92                 request, measure,
     93                 [](ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
     94                    const Timing& timing) {
     95                     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, error);
     96                     EXPECT_EQ(outputShapes.size(), 0);
     97                     EXPECT_TRUE(badTiming(timing));
     98                 });
     99         ASSERT_TRUE(executeStatus.isOk());
    100     }
    101 
    102     // burst
    103     {
    104         SCOPED_TRACE(message + " [burst]");
    105 
    106         // create burst
    107         std::shared_ptr<::android::nn::ExecutionBurstController> burst =
    108                 ::android::nn::ExecutionBurstController::create(preparedModel, /*blocking=*/true);
    109         ASSERT_NE(nullptr, burst.get());
    110 
    111         // create memory keys
    112         std::vector<intptr_t> keys(request.pools.size());
    113         for (size_t i = 0; i < keys.size(); ++i) {
    114             keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
    115         }
    116 
    117         // execute and verify
    118         ErrorStatus error;
    119         std::vector<OutputShape> outputShapes;
    120         Timing timing;
    121         std::tie(error, outputShapes, timing) = burst->compute(request, measure, keys);
    122         EXPECT_EQ(ErrorStatus::INVALID_ARGUMENT, error);
    123         EXPECT_EQ(outputShapes.size(), 0);
    124         EXPECT_TRUE(badTiming(timing));
    125 
    126         // additional burst testing
    127         if (request.pools.size() > 0) {
    128             // valid free
    129             burst->freeMemory(keys.front());
    130 
    131             // negative test: invalid free of unknown (blank) memory
    132             burst->freeMemory(intptr_t{});
    133 
    134             // negative test: double free of memory
    135             burst->freeMemory(keys.front());
    136         }
    137     }
    138 }
    139 
    140 // Delete element from hidl_vec. hidl_vec doesn't support a "remove" operation,
    141 // so this is efficiently accomplished by moving the element to the end and
    142 // resizing the hidl_vec to one less.
    143 template <typename Type>
    144 static void hidl_vec_removeAt(hidl_vec<Type>* vec, uint32_t index) {
    145     if (vec) {
    146         std::rotate(vec->begin() + index, vec->begin() + index + 1, vec->end());
    147         vec->resize(vec->size() - 1);
    148     }
    149 }
    150 
    151 template <typename Type>
    152 static uint32_t hidl_vec_push_back(hidl_vec<Type>* vec, const Type& value) {
    153     // assume vec is valid
    154     const uint32_t index = vec->size();
    155     vec->resize(index + 1);
    156     (*vec)[index] = value;
    157     return index;
    158 }
    159 
    160 ///////////////////////// REMOVE INPUT ////////////////////////////////////
    161 
    162 static void removeInputTest(const sp<IPreparedModel>& preparedModel, const Request& request) {
    163     for (size_t input = 0; input < request.inputs.size(); ++input) {
    164         const std::string message = "removeInput: removed input " + std::to_string(input);
    165         validate(preparedModel, message, request,
    166                  [input](Request* request) { hidl_vec_removeAt(&request->inputs, input); });
    167     }
    168 }
    169 
    170 ///////////////////////// REMOVE OUTPUT ////////////////////////////////////
    171 
    172 static void removeOutputTest(const sp<IPreparedModel>& preparedModel, const Request& request) {
    173     for (size_t output = 0; output < request.outputs.size(); ++output) {
    174         const std::string message = "removeOutput: removed Output " + std::to_string(output);
    175         validate(preparedModel, message, request,
    176                  [output](Request* request) { hidl_vec_removeAt(&request->outputs, output); });
    177     }
    178 }
    179 
    180 ///////////////////////////// ENTRY POINT //////////////////////////////////
    181 
    182 std::vector<Request> createRequests(const std::vector<MixedTypedExample>& examples) {
    183     const uint32_t INPUT = 0;
    184     const uint32_t OUTPUT = 1;
    185 
    186     std::vector<Request> requests;
    187 
    188     for (auto& example : examples) {
    189         const MixedTyped& inputs = example.operands.first;
    190         const MixedTyped& outputs = example.operands.second;
    191 
    192         std::vector<RequestArgument> inputs_info, outputs_info;
    193         uint32_t inputSize = 0, outputSize = 0;
    194 
    195         // This function only partially specifies the metadata (vector of RequestArguments).
    196         // The contents are copied over below.
    197         for_all(inputs, [&inputs_info, &inputSize](int index, auto, auto s) {
    198             if (inputs_info.size() <= static_cast<size_t>(index)) inputs_info.resize(index + 1);
    199             RequestArgument arg = {
    200                 .location = {.poolIndex = INPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
    201                 .dimensions = {},
    202             };
    203             RequestArgument arg_empty = {
    204                 .hasNoValue = true,
    205             };
    206             inputs_info[index] = s ? arg : arg_empty;
    207             inputSize += s;
    208         });
    209         // Compute offset for inputs 1 and so on
    210         {
    211             size_t offset = 0;
    212             for (auto& i : inputs_info) {
    213                 if (!i.hasNoValue) i.location.offset = offset;
    214                 offset += i.location.length;
    215             }
    216         }
    217 
    218         // Go through all outputs, initialize RequestArgument descriptors
    219         for_all(outputs, [&outputs_info, &outputSize](int index, auto, auto s) {
    220             if (outputs_info.size() <= static_cast<size_t>(index)) outputs_info.resize(index + 1);
    221             RequestArgument arg = {
    222                 .location = {.poolIndex = OUTPUT, .offset = 0, .length = static_cast<uint32_t>(s)},
    223                 .dimensions = {},
    224             };
    225             outputs_info[index] = arg;
    226             outputSize += s;
    227         });
    228         // Compute offset for outputs 1 and so on
    229         {
    230             size_t offset = 0;
    231             for (auto& i : outputs_info) {
    232                 i.location.offset = offset;
    233                 offset += i.location.length;
    234             }
    235         }
    236         std::vector<hidl_memory> pools = {nn::allocateSharedMemory(inputSize),
    237                                           nn::allocateSharedMemory(outputSize)};
    238         if (pools[INPUT].size() == 0 || pools[OUTPUT].size() == 0) {
    239             return {};
    240         }
    241 
    242         // map pool
    243         sp<IMemory> inputMemory = mapMemory(pools[INPUT]);
    244         if (inputMemory == nullptr) {
    245             return {};
    246         }
    247         char* inputPtr = reinterpret_cast<char*>(static_cast<void*>(inputMemory->getPointer()));
    248         if (inputPtr == nullptr) {
    249             return {};
    250         }
    251 
    252         // initialize pool
    253         inputMemory->update();
    254         for_all(inputs, [&inputs_info, inputPtr](int index, auto p, auto s) {
    255             char* begin = (char*)p;
    256             char* end = begin + s;
    257             // TODO: handle more than one input
    258             std::copy(begin, end, inputPtr + inputs_info[index].location.offset);
    259         });
    260         inputMemory->commit();
    261 
    262         requests.push_back({.inputs = inputs_info, .outputs = outputs_info, .pools = pools});
    263     }
    264 
    265     return requests;
    266 }
    267 
    268 void ValidationTest::validateRequests(const sp<IPreparedModel>& preparedModel,
    269                                       const std::vector<Request>& requests) {
    270     // validate each request
    271     for (const Request& request : requests) {
    272         removeInputTest(preparedModel, request);
    273         removeOutputTest(preparedModel, request);
    274     }
    275 }
    276 
    277 }  // namespace functional
    278 }  // namespace vts
    279 }  // namespace V1_2
    280 }  // namespace neuralnetworks
    281 }  // namespace hardware
    282 }  // namespace android
    283