1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "neuralnetworks_hidl_hal_test" 18 19 #include "VtsHalNeuralnetworks.h" 20 21 #include "Callbacks.h" 22 #include "ExecutionBurstController.h" 23 #include "ExecutionBurstServer.h" 24 #include "TestHarness.h" 25 #include "Utils.h" 26 27 #include <android-base/logging.h> 28 #include <cstring> 29 30 namespace android { 31 namespace hardware { 32 namespace neuralnetworks { 33 namespace V1_2 { 34 namespace vts { 35 namespace functional { 36 37 using ::android::nn::ExecutionBurstController; 38 using ::android::nn::RequestChannelSender; 39 using ::android::nn::ResultChannelReceiver; 40 using ExecutionBurstCallback = ::android::nn::ExecutionBurstController::ExecutionBurstCallback; 41 42 // This constant value represents the length of an FMQ that is large enough to 43 // return a result from a burst execution for all of the generated test cases. 44 constexpr size_t kExecutionBurstChannelLength = 1024; 45 46 // This constant value represents a length of an FMQ that is not large enough 47 // to return a result from a burst execution for some of the generated test 48 // cases. 49 constexpr size_t kExecutionBurstChannelSmallLength = 8; 50 51 ///////////////////////// UTILITY FUNCTIONS ///////////////////////// 52 53 static bool badTiming(Timing timing) { 54 return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX; 55 } 56 57 static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback, 58 std::unique_ptr<RequestChannelSender>* sender, 59 std::unique_ptr<ResultChannelReceiver>* receiver, 60 sp<IBurstContext>* context, 61 size_t resultChannelLength = kExecutionBurstChannelLength) { 62 ASSERT_NE(nullptr, preparedModel.get()); 63 ASSERT_NE(nullptr, sender); 64 ASSERT_NE(nullptr, receiver); 65 ASSERT_NE(nullptr, context); 66 67 // create FMQ objects 68 auto [fmqRequestChannel, fmqRequestDescriptor] = 69 RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true); 70 auto [fmqResultChannel, fmqResultDescriptor] = 71 ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true); 72 ASSERT_NE(nullptr, fmqRequestChannel.get()); 73 ASSERT_NE(nullptr, fmqResultChannel.get()); 74 ASSERT_NE(nullptr, fmqRequestDescriptor); 75 ASSERT_NE(nullptr, fmqResultDescriptor); 76 77 // configure burst 78 ErrorStatus errorStatus; 79 sp<IBurstContext> burstContext; 80 const Return<void> ret = preparedModel->configureExecutionBurst( 81 callback, *fmqRequestDescriptor, *fmqResultDescriptor, 82 [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) { 83 errorStatus = status; 84 burstContext = context; 85 }); 86 ASSERT_TRUE(ret.isOk()); 87 ASSERT_EQ(ErrorStatus::NONE, errorStatus); 88 ASSERT_NE(nullptr, burstContext.get()); 89 90 // return values 91 *sender = std::move(fmqRequestChannel); 92 *receiver = std::move(fmqResultChannel); 93 *context = burstContext; 94 } 95 96 static void createBurstWithResultChannelLength( 97 const sp<IPreparedModel>& preparedModel, size_t resultChannelLength, 98 std::shared_ptr<ExecutionBurstController>* controller) { 99 ASSERT_NE(nullptr, preparedModel.get()); 100 ASSERT_NE(nullptr, controller); 101 102 // create FMQ objects 103 std::unique_ptr<RequestChannelSender> sender; 104 std::unique_ptr<ResultChannelReceiver> receiver; 105 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback(); 106 sp<IBurstContext> context; 107 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context, 108 resultChannelLength)); 109 ASSERT_NE(nullptr, sender.get()); 110 ASSERT_NE(nullptr, receiver.get()); 111 ASSERT_NE(nullptr, context.get()); 112 113 // return values 114 *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver), 115 context, callback); 116 } 117 118 // Primary validation function. This function will take a valid serialized 119 // request, apply a mutation to it to invalidate the serialized request, then 120 // pass it to interface calls that use the serialized request. Note that the 121 // serialized request here is passed by value, and any mutation to the 122 // serialized request does not leave this function. 123 static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver, 124 const std::string& message, std::vector<FmqRequestDatum> serialized, 125 const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) { 126 mutation(&serialized); 127 128 // skip if packet is too large to send 129 if (serialized.size() > kExecutionBurstChannelLength) { 130 return; 131 } 132 133 SCOPED_TRACE(message); 134 135 // send invalid packet 136 ASSERT_TRUE(sender->sendPacket(serialized)); 137 138 // receive error 139 auto results = receiver->getBlocking(); 140 ASSERT_TRUE(results.has_value()); 141 const auto [status, outputShapes, timing] = std::move(*results); 142 EXPECT_NE(ErrorStatus::NONE, status); 143 EXPECT_EQ(0u, outputShapes.size()); 144 EXPECT_TRUE(badTiming(timing)); 145 } 146 147 // For validation, valid packet entries are mutated to invalid packet entries, 148 // or invalid packet entries are inserted into valid packets. This function 149 // creates pre-set invalid packet entries for convenience. 150 static std::vector<FmqRequestDatum> createBadRequestPacketEntries() { 151 const FmqRequestDatum::PacketInformation packetInformation = { 152 /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10, 153 /*.numberOfPools=*/10}; 154 const FmqRequestDatum::OperandInformation operandInformation = { 155 /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10}; 156 const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max(); 157 std::vector<FmqRequestDatum> bad(7); 158 bad[0].packetInformation(packetInformation); 159 bad[1].inputOperandInformation(operandInformation); 160 bad[2].inputOperandDimensionValue(0); 161 bad[3].outputOperandInformation(operandInformation); 162 bad[4].outputOperandDimensionValue(0); 163 bad[5].poolIdentifier(invalidPoolIdentifier); 164 bad[6].measureTiming(MeasureTiming::YES); 165 return bad; 166 } 167 168 // For validation, valid packet entries are mutated to invalid packet entries, 169 // or invalid packet entries are inserted into valid packets. This function 170 // retrieves pre-set invalid packet entries for convenience. This function 171 // caches these data so they can be reused on subsequent validation checks. 172 static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() { 173 static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries(); 174 return bad; 175 } 176 177 ///////////////////////// REMOVE DATUM //////////////////////////////////// 178 179 static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, 180 const std::vector<FmqRequestDatum>& serialized) { 181 for (size_t index = 0; index < serialized.size(); ++index) { 182 const std::string message = "removeDatum: removed datum at index " + std::to_string(index); 183 validate(sender, receiver, message, serialized, 184 [index](std::vector<FmqRequestDatum>* serialized) { 185 serialized->erase(serialized->begin() + index); 186 }); 187 } 188 } 189 190 ///////////////////////// ADD DATUM //////////////////////////////////// 191 192 static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, 193 const std::vector<FmqRequestDatum>& serialized) { 194 const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries(); 195 for (size_t index = 0; index <= serialized.size(); ++index) { 196 for (size_t type = 0; type < extra.size(); ++type) { 197 const std::string message = "addDatum: added datum type " + std::to_string(type) + 198 " at index " + std::to_string(index); 199 validate(sender, receiver, message, serialized, 200 [index, type, &extra](std::vector<FmqRequestDatum>* serialized) { 201 serialized->insert(serialized->begin() + index, extra[type]); 202 }); 203 } 204 } 205 } 206 207 ///////////////////////// MUTATE DATUM //////////////////////////////////// 208 209 static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) { 210 using Discriminator = FmqRequestDatum::hidl_discriminator; 211 212 const bool differentValues = (lhs != rhs); 213 const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator()); 214 const auto discriminator = rhs.getDiscriminator(); 215 const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue || 216 discriminator == Discriminator::outputOperandDimensionValue); 217 218 return differentValues && !(sameDiscriminator && isDimensionValue); 219 } 220 221 static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver, 222 const std::vector<FmqRequestDatum>& serialized) { 223 const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries(); 224 for (size_t index = 0; index < serialized.size(); ++index) { 225 for (size_t type = 0; type < change.size(); ++type) { 226 if (interestingCase(serialized[index], change[type])) { 227 const std::string message = "mutateDatum: changed datum at index " + 228 std::to_string(index) + " to datum type " + 229 std::to_string(type); 230 validate(sender, receiver, message, serialized, 231 [index, type, &change](std::vector<FmqRequestDatum>* serialized) { 232 (*serialized)[index] = change[type]; 233 }); 234 } 235 } 236 } 237 } 238 239 ///////////////////////// BURST VALIATION TESTS //////////////////////////////////// 240 241 static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel, 242 const std::vector<Request>& requests) { 243 // create burst 244 std::unique_ptr<RequestChannelSender> sender; 245 std::unique_ptr<ResultChannelReceiver> receiver; 246 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback(); 247 sp<IBurstContext> context; 248 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context)); 249 ASSERT_NE(nullptr, sender.get()); 250 ASSERT_NE(nullptr, receiver.get()); 251 ASSERT_NE(nullptr, context.get()); 252 253 // validate each request 254 for (const Request& request : requests) { 255 // load memory into callback slots 256 std::vector<intptr_t> keys; 257 keys.reserve(request.pools.size()); 258 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys), 259 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); }); 260 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys); 261 262 // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for 263 // subsequent slot validation testing) 264 ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) { 265 return slot != std::numeric_limits<int32_t>::max(); 266 })); 267 268 // serialize the request 269 const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots); 270 271 // validations 272 removeDatumTest(sender.get(), receiver.get(), serialized); 273 addDatumTest(sender.get(), receiver.get(), serialized); 274 mutateDatumTest(sender.get(), receiver.get(), serialized); 275 } 276 } 277 278 // This test validates that when the Result message size exceeds length of the 279 // result FMQ, the service instance gracefully fails and returns an error. 280 static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel, 281 const std::vector<Request>& requests) { 282 // create regular burst 283 std::shared_ptr<ExecutionBurstController> controllerRegular; 284 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength( 285 preparedModel, kExecutionBurstChannelLength, &controllerRegular)); 286 ASSERT_NE(nullptr, controllerRegular.get()); 287 288 // create burst with small output channel 289 std::shared_ptr<ExecutionBurstController> controllerSmall; 290 ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength( 291 preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall)); 292 ASSERT_NE(nullptr, controllerSmall.get()); 293 294 // validate each request 295 for (const Request& request : requests) { 296 // load memory into callback slots 297 std::vector<intptr_t> keys(request.pools.size()); 298 for (size_t i = 0; i < keys.size(); ++i) { 299 keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]); 300 } 301 302 // collect serialized result by running regular burst 303 const auto [statusRegular, outputShapesRegular, timingRegular] = 304 controllerRegular->compute(request, MeasureTiming::NO, keys); 305 306 // skip test if regular burst output isn't useful for testing a failure 307 // caused by having too small of a length for the result FMQ 308 const std::vector<FmqResultDatum> serialized = 309 ::android::nn::serialize(statusRegular, outputShapesRegular, timingRegular); 310 if (statusRegular != ErrorStatus::NONE || 311 serialized.size() <= kExecutionBurstChannelSmallLength) { 312 continue; 313 } 314 315 // by this point, execution should fail because the result channel isn't 316 // large enough to return the serialized result 317 const auto [statusSmall, outputShapesSmall, timingSmall] = 318 controllerSmall->compute(request, MeasureTiming::NO, keys); 319 EXPECT_NE(ErrorStatus::NONE, statusSmall); 320 EXPECT_EQ(0u, outputShapesSmall.size()); 321 EXPECT_TRUE(badTiming(timingSmall)); 322 } 323 } 324 325 static bool isSanitized(const FmqResultDatum& datum) { 326 using Discriminator = FmqResultDatum::hidl_discriminator; 327 328 // check to ensure the padding values in the returned 329 // FmqResultDatum::OperandInformation are initialized to 0 330 if (datum.getDiscriminator() == Discriminator::operandInformation) { 331 static_assert( 332 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0, 333 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient"); 334 static_assert( 335 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1, 336 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient"); 337 static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4, 338 "unexpected value for offset of " 339 "FmqResultDatum::OperandInformation::numberOfDimensions"); 340 static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4, 341 "unexpected value for size of " 342 "FmqResultDatum::OperandInformation::numberOfDimensions"); 343 static_assert(sizeof(FmqResultDatum::OperandInformation) == 8, 344 "unexpected value for size of " 345 "FmqResultDatum::OperandInformation"); 346 347 constexpr size_t paddingOffset = 348 offsetof(FmqResultDatum::OperandInformation, isSufficient) + 349 sizeof(FmqResultDatum::OperandInformation::isSufficient); 350 constexpr size_t paddingSize = 351 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset; 352 353 FmqResultDatum::OperandInformation initialized{}; 354 std::memset(&initialized, 0, sizeof(initialized)); 355 356 const char* initializedPaddingStart = 357 reinterpret_cast<const char*>(&initialized) + paddingOffset; 358 const char* datumPaddingStart = 359 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset; 360 361 return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0; 362 } 363 364 // there are no other padding initialization checks required, so return true 365 // for any sum-type that isn't FmqResultDatum::OperandInformation 366 return true; 367 } 368 369 static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel, 370 const std::vector<Request>& requests) { 371 // create burst 372 std::unique_ptr<RequestChannelSender> sender; 373 std::unique_ptr<ResultChannelReceiver> receiver; 374 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback(); 375 sp<IBurstContext> context; 376 ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context)); 377 ASSERT_NE(nullptr, sender.get()); 378 ASSERT_NE(nullptr, receiver.get()); 379 ASSERT_NE(nullptr, context.get()); 380 381 // validate each request 382 for (const Request& request : requests) { 383 // load memory into callback slots 384 std::vector<intptr_t> keys; 385 keys.reserve(request.pools.size()); 386 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys), 387 [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); }); 388 const std::vector<int32_t> slots = callback->getSlots(request.pools, keys); 389 390 // send valid request 391 ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots)); 392 393 // receive valid result 394 auto serialized = receiver->getPacketBlocking(); 395 ASSERT_TRUE(serialized.has_value()); 396 397 // sanitize result 398 ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized)) 399 << "The result serialized data is not properly sanitized"; 400 } 401 } 402 403 ///////////////////////////// ENTRY POINT ////////////////////////////////// 404 405 void ValidationTest::validateBurst(const sp<IPreparedModel>& preparedModel, 406 const std::vector<Request>& requests) { 407 ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, requests)); 408 ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, requests)); 409 ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, requests)); 410 } 411 412 } // namespace functional 413 } // namespace vts 414 } // namespace V1_2 415 } // namespace neuralnetworks 416 } // namespace hardware 417 } // namespace android 418