Home | History | Annotate | Download | only in functional
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
     18 
     19 #include "VtsHalNeuralnetworks.h"
     20 
     21 #include "Callbacks.h"
     22 #include "ExecutionBurstController.h"
     23 #include "ExecutionBurstServer.h"
     24 #include "TestHarness.h"
     25 #include "Utils.h"
     26 
     27 #include <android-base/logging.h>
     28 #include <cstring>
     29 
     30 namespace android {
     31 namespace hardware {
     32 namespace neuralnetworks {
     33 namespace V1_2 {
     34 namespace vts {
     35 namespace functional {
     36 
     37 using ::android::nn::ExecutionBurstController;
     38 using ::android::nn::RequestChannelSender;
     39 using ::android::nn::ResultChannelReceiver;
     40 using ExecutionBurstCallback = ::android::nn::ExecutionBurstController::ExecutionBurstCallback;
     41 
     42 // This constant value represents the length of an FMQ that is large enough to
     43 // return a result from a burst execution for all of the generated test cases.
     44 constexpr size_t kExecutionBurstChannelLength = 1024;
     45 
     46 // This constant value represents a length of an FMQ that is not large enough
     47 // to return a result from a burst execution for some of the generated test
     48 // cases.
     49 constexpr size_t kExecutionBurstChannelSmallLength = 8;
     50 
     51 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
     52 
     53 static bool badTiming(Timing timing) {
     54     return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
     55 }
     56 
     57 static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
     58                         std::unique_ptr<RequestChannelSender>* sender,
     59                         std::unique_ptr<ResultChannelReceiver>* receiver,
     60                         sp<IBurstContext>* context,
     61                         size_t resultChannelLength = kExecutionBurstChannelLength) {
     62     ASSERT_NE(nullptr, preparedModel.get());
     63     ASSERT_NE(nullptr, sender);
     64     ASSERT_NE(nullptr, receiver);
     65     ASSERT_NE(nullptr, context);
     66 
     67     // create FMQ objects
     68     auto [fmqRequestChannel, fmqRequestDescriptor] =
     69             RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true);
     70     auto [fmqResultChannel, fmqResultDescriptor] =
     71             ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true);
     72     ASSERT_NE(nullptr, fmqRequestChannel.get());
     73     ASSERT_NE(nullptr, fmqResultChannel.get());
     74     ASSERT_NE(nullptr, fmqRequestDescriptor);
     75     ASSERT_NE(nullptr, fmqResultDescriptor);
     76 
     77     // configure burst
     78     ErrorStatus errorStatus;
     79     sp<IBurstContext> burstContext;
     80     const Return<void> ret = preparedModel->configureExecutionBurst(
     81             callback, *fmqRequestDescriptor, *fmqResultDescriptor,
     82             [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
     83                 errorStatus = status;
     84                 burstContext = context;
     85             });
     86     ASSERT_TRUE(ret.isOk());
     87     ASSERT_EQ(ErrorStatus::NONE, errorStatus);
     88     ASSERT_NE(nullptr, burstContext.get());
     89 
     90     // return values
     91     *sender = std::move(fmqRequestChannel);
     92     *receiver = std::move(fmqResultChannel);
     93     *context = burstContext;
     94 }
     95 
     96 static void createBurstWithResultChannelLength(
     97         const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
     98         std::shared_ptr<ExecutionBurstController>* controller) {
     99     ASSERT_NE(nullptr, preparedModel.get());
    100     ASSERT_NE(nullptr, controller);
    101 
    102     // create FMQ objects
    103     std::unique_ptr<RequestChannelSender> sender;
    104     std::unique_ptr<ResultChannelReceiver> receiver;
    105     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
    106     sp<IBurstContext> context;
    107     ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
    108                                         resultChannelLength));
    109     ASSERT_NE(nullptr, sender.get());
    110     ASSERT_NE(nullptr, receiver.get());
    111     ASSERT_NE(nullptr, context.get());
    112 
    113     // return values
    114     *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
    115                                                              context, callback);
    116 }
    117 
    118 // Primary validation function. This function will take a valid serialized
    119 // request, apply a mutation to it to invalidate the serialized request, then
    120 // pass it to interface calls that use the serialized request. Note that the
    121 // serialized request here is passed by value, and any mutation to the
    122 // serialized request does not leave this function.
    123 static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
    124                      const std::string& message, std::vector<FmqRequestDatum> serialized,
    125                      const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
    126     mutation(&serialized);
    127 
    128     // skip if packet is too large to send
    129     if (serialized.size() > kExecutionBurstChannelLength) {
    130         return;
    131     }
    132 
    133     SCOPED_TRACE(message);
    134 
    135     // send invalid packet
    136     ASSERT_TRUE(sender->sendPacket(serialized));
    137 
    138     // receive error
    139     auto results = receiver->getBlocking();
    140     ASSERT_TRUE(results.has_value());
    141     const auto [status, outputShapes, timing] = std::move(*results);
    142     EXPECT_NE(ErrorStatus::NONE, status);
    143     EXPECT_EQ(0u, outputShapes.size());
    144     EXPECT_TRUE(badTiming(timing));
    145 }
    146 
    147 // For validation, valid packet entries are mutated to invalid packet entries,
    148 // or invalid packet entries are inserted into valid packets. This function
    149 // creates pre-set invalid packet entries for convenience.
    150 static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
    151     const FmqRequestDatum::PacketInformation packetInformation = {
    152             /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
    153             /*.numberOfPools=*/10};
    154     const FmqRequestDatum::OperandInformation operandInformation = {
    155             /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
    156     const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
    157     std::vector<FmqRequestDatum> bad(7);
    158     bad[0].packetInformation(packetInformation);
    159     bad[1].inputOperandInformation(operandInformation);
    160     bad[2].inputOperandDimensionValue(0);
    161     bad[3].outputOperandInformation(operandInformation);
    162     bad[4].outputOperandDimensionValue(0);
    163     bad[5].poolIdentifier(invalidPoolIdentifier);
    164     bad[6].measureTiming(MeasureTiming::YES);
    165     return bad;
    166 }
    167 
    168 // For validation, valid packet entries are mutated to invalid packet entries,
    169 // or invalid packet entries are inserted into valid packets. This function
    170 // retrieves pre-set invalid packet entries for convenience. This function
    171 // caches these data so they can be reused on subsequent validation checks.
    172 static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
    173     static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
    174     return bad;
    175 }
    176 
    177 ///////////////////////// REMOVE DATUM ////////////////////////////////////
    178 
    179 static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
    180                             const std::vector<FmqRequestDatum>& serialized) {
    181     for (size_t index = 0; index < serialized.size(); ++index) {
    182         const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
    183         validate(sender, receiver, message, serialized,
    184                  [index](std::vector<FmqRequestDatum>* serialized) {
    185                      serialized->erase(serialized->begin() + index);
    186                  });
    187     }
    188 }
    189 
    190 ///////////////////////// ADD DATUM ////////////////////////////////////
    191 
    192 static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
    193                          const std::vector<FmqRequestDatum>& serialized) {
    194     const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
    195     for (size_t index = 0; index <= serialized.size(); ++index) {
    196         for (size_t type = 0; type < extra.size(); ++type) {
    197             const std::string message = "addDatum: added datum type " + std::to_string(type) +
    198                                         " at index " + std::to_string(index);
    199             validate(sender, receiver, message, serialized,
    200                      [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
    201                          serialized->insert(serialized->begin() + index, extra[type]);
    202                      });
    203         }
    204     }
    205 }
    206 
    207 ///////////////////////// MUTATE DATUM ////////////////////////////////////
    208 
    209 static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
    210     using Discriminator = FmqRequestDatum::hidl_discriminator;
    211 
    212     const bool differentValues = (lhs != rhs);
    213     const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
    214     const auto discriminator = rhs.getDiscriminator();
    215     const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
    216                                    discriminator == Discriminator::outputOperandDimensionValue);
    217 
    218     return differentValues && !(sameDiscriminator && isDimensionValue);
    219 }
    220 
    221 static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
    222                             const std::vector<FmqRequestDatum>& serialized) {
    223     const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
    224     for (size_t index = 0; index < serialized.size(); ++index) {
    225         for (size_t type = 0; type < change.size(); ++type) {
    226             if (interestingCase(serialized[index], change[type])) {
    227                 const std::string message = "mutateDatum: changed datum at index " +
    228                                             std::to_string(index) + " to datum type " +
    229                                             std::to_string(type);
    230                 validate(sender, receiver, message, serialized,
    231                          [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
    232                              (*serialized)[index] = change[type];
    233                          });
    234             }
    235         }
    236     }
    237 }
    238 
    239 ///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
    240 
    241 static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
    242                                        const std::vector<Request>& requests) {
    243     // create burst
    244     std::unique_ptr<RequestChannelSender> sender;
    245     std::unique_ptr<ResultChannelReceiver> receiver;
    246     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
    247     sp<IBurstContext> context;
    248     ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
    249     ASSERT_NE(nullptr, sender.get());
    250     ASSERT_NE(nullptr, receiver.get());
    251     ASSERT_NE(nullptr, context.get());
    252 
    253     // validate each request
    254     for (const Request& request : requests) {
    255         // load memory into callback slots
    256         std::vector<intptr_t> keys;
    257         keys.reserve(request.pools.size());
    258         std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
    259                        [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
    260         const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
    261 
    262         // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
    263         // subsequent slot validation testing)
    264         ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
    265             return slot != std::numeric_limits<int32_t>::max();
    266         }));
    267 
    268         // serialize the request
    269         const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots);
    270 
    271         // validations
    272         removeDatumTest(sender.get(), receiver.get(), serialized);
    273         addDatumTest(sender.get(), receiver.get(), serialized);
    274         mutateDatumTest(sender.get(), receiver.get(), serialized);
    275     }
    276 }
    277 
    278 // This test validates that when the Result message size exceeds length of the
    279 // result FMQ, the service instance gracefully fails and returns an error.
    280 static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
    281                                    const std::vector<Request>& requests) {
    282     // create regular burst
    283     std::shared_ptr<ExecutionBurstController> controllerRegular;
    284     ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
    285             preparedModel, kExecutionBurstChannelLength, &controllerRegular));
    286     ASSERT_NE(nullptr, controllerRegular.get());
    287 
    288     // create burst with small output channel
    289     std::shared_ptr<ExecutionBurstController> controllerSmall;
    290     ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
    291             preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
    292     ASSERT_NE(nullptr, controllerSmall.get());
    293 
    294     // validate each request
    295     for (const Request& request : requests) {
    296         // load memory into callback slots
    297         std::vector<intptr_t> keys(request.pools.size());
    298         for (size_t i = 0; i < keys.size(); ++i) {
    299             keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
    300         }
    301 
    302         // collect serialized result by running regular burst
    303         const auto [statusRegular, outputShapesRegular, timingRegular] =
    304                 controllerRegular->compute(request, MeasureTiming::NO, keys);
    305 
    306         // skip test if regular burst output isn't useful for testing a failure
    307         // caused by having too small of a length for the result FMQ
    308         const std::vector<FmqResultDatum> serialized =
    309                 ::android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
    310         if (statusRegular != ErrorStatus::NONE ||
    311             serialized.size() <= kExecutionBurstChannelSmallLength) {
    312             continue;
    313         }
    314 
    315         // by this point, execution should fail because the result channel isn't
    316         // large enough to return the serialized result
    317         const auto [statusSmall, outputShapesSmall, timingSmall] =
    318                 controllerSmall->compute(request, MeasureTiming::NO, keys);
    319         EXPECT_NE(ErrorStatus::NONE, statusSmall);
    320         EXPECT_EQ(0u, outputShapesSmall.size());
    321         EXPECT_TRUE(badTiming(timingSmall));
    322     }
    323 }
    324 
    325 static bool isSanitized(const FmqResultDatum& datum) {
    326     using Discriminator = FmqResultDatum::hidl_discriminator;
    327 
    328     // check to ensure the padding values in the returned
    329     // FmqResultDatum::OperandInformation are initialized to 0
    330     if (datum.getDiscriminator() == Discriminator::operandInformation) {
    331         static_assert(
    332                 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
    333                 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
    334         static_assert(
    335                 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
    336                 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
    337         static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
    338                       "unexpected value for offset of "
    339                       "FmqResultDatum::OperandInformation::numberOfDimensions");
    340         static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
    341                       "unexpected value for size of "
    342                       "FmqResultDatum::OperandInformation::numberOfDimensions");
    343         static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
    344                       "unexpected value for size of "
    345                       "FmqResultDatum::OperandInformation");
    346 
    347         constexpr size_t paddingOffset =
    348                 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
    349                 sizeof(FmqResultDatum::OperandInformation::isSufficient);
    350         constexpr size_t paddingSize =
    351                 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
    352 
    353         FmqResultDatum::OperandInformation initialized{};
    354         std::memset(&initialized, 0, sizeof(initialized));
    355 
    356         const char* initializedPaddingStart =
    357                 reinterpret_cast<const char*>(&initialized) + paddingOffset;
    358         const char* datumPaddingStart =
    359                 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
    360 
    361         return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
    362     }
    363 
    364     // there are no other padding initialization checks required, so return true
    365     // for any sum-type that isn't FmqResultDatum::OperandInformation
    366     return true;
    367 }
    368 
    369 static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
    370                                    const std::vector<Request>& requests) {
    371     // create burst
    372     std::unique_ptr<RequestChannelSender> sender;
    373     std::unique_ptr<ResultChannelReceiver> receiver;
    374     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
    375     sp<IBurstContext> context;
    376     ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
    377     ASSERT_NE(nullptr, sender.get());
    378     ASSERT_NE(nullptr, receiver.get());
    379     ASSERT_NE(nullptr, context.get());
    380 
    381     // validate each request
    382     for (const Request& request : requests) {
    383         // load memory into callback slots
    384         std::vector<intptr_t> keys;
    385         keys.reserve(request.pools.size());
    386         std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
    387                        [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
    388         const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
    389 
    390         // send valid request
    391         ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
    392 
    393         // receive valid result
    394         auto serialized = receiver->getPacketBlocking();
    395         ASSERT_TRUE(serialized.has_value());
    396 
    397         // sanitize result
    398         ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
    399                 << "The result serialized data is not properly sanitized";
    400     }
    401 }
    402 
    403 ///////////////////////////// ENTRY POINT //////////////////////////////////
    404 
    405 void ValidationTest::validateBurst(const sp<IPreparedModel>& preparedModel,
    406                                    const std::vector<Request>& requests) {
    407     ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, requests));
    408     ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, requests));
    409     ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, requests));
    410 }
    411 
    412 }  // namespace functional
    413 }  // namespace vts
    414 }  // namespace V1_2
    415 }  // namespace neuralnetworks
    416 }  // namespace hardware
    417 }  // namespace android
    418