1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "ExecutionBurstController" 18 19 #include "ExecutionBurstController.h" 20 21 #include <android-base/logging.h> 22 #include <cstring> 23 #include <limits> 24 #include <string> 25 #include "Tracing.h" 26 27 namespace android::nn { 28 namespace { 29 30 using ::android::hardware::MQDescriptorSync; 31 using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>; 32 using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>; 33 34 constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(), 35 std::numeric_limits<uint64_t>::max()}; 36 37 class BurstContextDeathHandler : public hardware::hidl_death_recipient { 38 public: 39 using Callback = std::function<void()>; 40 41 BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) { 42 CHECK(onDeathCallback != nullptr); 43 } 44 45 void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override { 46 LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!"; 47 mOnDeathCallback(); 48 } 49 50 private: 51 const Callback mOnDeathCallback; 52 }; 53 54 } // anonymous namespace 55 56 // serialize a request into a packet 57 std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure, 58 const std::vector<int32_t>& slots) { 59 // count how many elements need to be sent for a request 60 size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size(); 61 for (const auto& input : request.inputs) { 62 count += input.dimensions.size(); 63 } 64 for (const auto& output : request.outputs) { 65 count += output.dimensions.size(); 66 } 67 68 // create buffer to temporarily store elements 69 std::vector<FmqRequestDatum> data; 70 data.reserve(count); 71 72 // package packetInfo 73 { 74 FmqRequestDatum datum; 75 datum.packetInformation( 76 {/*.packetSize=*/static_cast<uint32_t>(count), 77 /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()), 78 /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()), 79 /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())}); 80 data.push_back(datum); 81 } 82 83 // package input data 84 for (const auto& input : request.inputs) { 85 // package operand information 86 FmqRequestDatum datum; 87 datum.inputOperandInformation( 88 {/*.hasNoValue=*/input.hasNoValue, 89 /*.location=*/input.location, 90 /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())}); 91 data.push_back(datum); 92 93 // package operand dimensions 94 for (uint32_t dimension : input.dimensions) { 95 FmqRequestDatum datum; 96 datum.inputOperandDimensionValue(dimension); 97 data.push_back(datum); 98 } 99 } 100 101 // package output data 102 for (const auto& output : request.outputs) { 103 // package operand information 104 FmqRequestDatum datum; 105 datum.outputOperandInformation( 106 {/*.hasNoValue=*/output.hasNoValue, 107 /*.location=*/output.location, 108 /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())}); 109 data.push_back(datum); 110 111 // package operand dimensions 112 for (uint32_t dimension : output.dimensions) { 113 FmqRequestDatum datum; 114 datum.outputOperandDimensionValue(dimension); 115 data.push_back(datum); 116 } 117 } 118 119 // package pool identifier 120 for (int32_t slot : slots) { 121 FmqRequestDatum datum; 122 datum.poolIdentifier(slot); 123 data.push_back(datum); 124 } 125 126 // package measureTiming 127 { 128 FmqRequestDatum datum; 129 datum.measureTiming(measure); 130 data.push_back(datum); 131 } 132 133 // return packet 134 return data; 135 } 136 137 // deserialize a packet into the result 138 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize( 139 const std::vector<FmqResultDatum>& data) { 140 using discriminator = FmqResultDatum::hidl_discriminator; 141 142 std::vector<OutputShape> outputShapes; 143 size_t index = 0; 144 145 // validate packet information 146 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) { 147 LOG(ERROR) << "FMQ Result packet ill-formed"; 148 return std::nullopt; 149 } 150 151 // unpackage packet information 152 const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation(); 153 index++; 154 const uint32_t packetSize = packetInfo.packetSize; 155 const ErrorStatus errorStatus = packetInfo.errorStatus; 156 const uint32_t numberOfOperands = packetInfo.numberOfOperands; 157 158 // verify packet size 159 if (data.size() != packetSize) { 160 LOG(ERROR) << "FMQ Result packet ill-formed"; 161 return std::nullopt; 162 } 163 164 // unpackage operands 165 for (size_t operand = 0; operand < numberOfOperands; ++operand) { 166 // validate operand information 167 if (data[index].getDiscriminator() != discriminator::operandInformation) { 168 LOG(ERROR) << "FMQ Result packet ill-formed"; 169 return std::nullopt; 170 } 171 172 // unpackage operand information 173 const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation(); 174 index++; 175 const bool isSufficient = operandInfo.isSufficient; 176 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions; 177 178 // unpackage operand dimensions 179 std::vector<uint32_t> dimensions; 180 dimensions.reserve(numberOfDimensions); 181 for (size_t i = 0; i < numberOfDimensions; ++i) { 182 // validate dimension 183 if (data[index].getDiscriminator() != discriminator::operandDimensionValue) { 184 LOG(ERROR) << "FMQ Result packet ill-formed"; 185 return std::nullopt; 186 } 187 188 // unpackage dimension 189 const uint32_t dimension = data[index].operandDimensionValue(); 190 index++; 191 192 // store result 193 dimensions.push_back(dimension); 194 } 195 196 // store result 197 outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient}); 198 } 199 200 // validate execution timing 201 if (data[index].getDiscriminator() != discriminator::executionTiming) { 202 LOG(ERROR) << "FMQ Result packet ill-formed"; 203 return std::nullopt; 204 } 205 206 // unpackage execution timing 207 const Timing timing = data[index].executionTiming(); 208 index++; 209 210 // validate packet information 211 if (index != packetSize) { 212 LOG(ERROR) << "FMQ Result packet ill-formed"; 213 return std::nullopt; 214 } 215 216 // return result 217 return std::make_tuple(errorStatus, std::move(outputShapes), timing); 218 } 219 220 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*> 221 ResultChannelReceiver::create(size_t channelLength, bool blocking) { 222 std::unique_ptr<FmqResultChannel> fmqResultChannel = 223 std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/blocking); 224 if (!fmqResultChannel->isValid()) { 225 LOG(ERROR) << "Unable to create ResultChannelReceiver"; 226 return {nullptr, nullptr}; 227 } 228 const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc(); 229 return std::make_pair( 230 std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), blocking), 231 descriptor); 232 } 233 234 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel, 235 bool blocking) 236 : mFmqResultChannel(std::move(fmqResultChannel)), mBlocking(blocking) {} 237 238 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> 239 ResultChannelReceiver::getBlocking() { 240 const auto packet = getPacketBlocking(); 241 if (!packet) { 242 return std::nullopt; 243 } 244 245 return deserialize(*packet); 246 } 247 248 void ResultChannelReceiver::invalidate() { 249 mValid = false; 250 251 // force unblock 252 // ExecutionBurstController waits on a result packet after sending a 253 // request. If the driver containing ExecutionBurstServer crashes, the 254 // controller will still be waiting on the futex (assuming mBlocking is 255 // true). This force unblock wakes up any thread waiting on the futex. 256 if (mBlocking) { 257 // TODO: look for a different/better way to signal/notify the futex to 258 // wake up any thread waiting on it 259 FmqResultDatum datum; 260 datum.packetInformation({/*.packetSize=*/0, /*.errorStatus=*/ErrorStatus::GENERAL_FAILURE, 261 /*.numberOfOperands=*/0}); 262 mFmqResultChannel->writeBlocking(&datum, 1); 263 } 264 } 265 266 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() { 267 using discriminator = FmqResultDatum::hidl_discriminator; 268 269 if (!mValid) { 270 return std::nullopt; 271 } 272 273 // wait for result packet and read first element of result packet 274 FmqResultDatum datum; 275 bool success = true; 276 if (mBlocking) { 277 success = mFmqResultChannel->readBlocking(&datum, 1); 278 } else { 279 while ((success = mValid.load(std::memory_order_relaxed)) && 280 !mFmqResultChannel->read(&datum, 1)) { 281 } 282 } 283 284 // retrieve remaining elements 285 // NOTE: all of the data is already available at this point, so there's no 286 // need to do a blocking wait to wait for more data. This is known because 287 // in FMQ, all writes are published (made available) atomically. Currently, 288 // the producer always publishes the entire packet in one function call, so 289 // if the first element of the packet is available, the remaining elements 290 // are also available. 291 const size_t count = mFmqResultChannel->availableToRead(); 292 std::vector<FmqResultDatum> packet(count + 1); 293 std::memcpy(&packet.front(), &datum, sizeof(datum)); 294 success &= mFmqResultChannel->read(packet.data() + 1, count); 295 296 if (!mValid) { 297 return std::nullopt; 298 } 299 300 // ensure packet was successfully received 301 if (!success) { 302 LOG(ERROR) << "Error receiving packet"; 303 return std::nullopt; 304 } 305 306 return std::make_optional(std::move(packet)); 307 } 308 309 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*> 310 RequestChannelSender::create(size_t channelLength, bool blocking) { 311 std::unique_ptr<FmqRequestChannel> fmqRequestChannel = 312 std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/blocking); 313 if (!fmqRequestChannel->isValid()) { 314 LOG(ERROR) << "Unable to create RequestChannelSender"; 315 return {nullptr, nullptr}; 316 } 317 const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc(); 318 return std::make_pair( 319 std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel), blocking), 320 descriptor); 321 } 322 323 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, 324 bool blocking) 325 : mFmqRequestChannel(std::move(fmqRequestChannel)), mBlocking(blocking) {} 326 327 bool RequestChannelSender::send(const Request& request, MeasureTiming measure, 328 const std::vector<int32_t>& slots) { 329 const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots); 330 return sendPacket(serialized); 331 } 332 333 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) { 334 if (!mValid) { 335 return false; 336 } 337 338 if (packet.size() > mFmqRequestChannel->availableToWrite()) { 339 LOG(ERROR) 340 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ"; 341 return false; 342 } 343 344 if (mBlocking) { 345 return mFmqRequestChannel->writeBlocking(packet.data(), packet.size()); 346 } else { 347 return mFmqRequestChannel->write(packet.data(), packet.size()); 348 } 349 } 350 351 void RequestChannelSender::invalidate() { 352 mValid = false; 353 } 354 355 Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories( 356 const hidl_vec<int32_t>& slots, getMemories_cb cb) { 357 std::lock_guard<std::mutex> guard(mMutex); 358 359 // get all memories 360 hidl_vec<hidl_memory> memories(slots.size()); 361 std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) { 362 return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{}; 363 }); 364 365 // ensure all memories are valid 366 if (!std::all_of(memories.begin(), memories.end(), 367 [](const hidl_memory& memory) { return memory.valid(); })) { 368 cb(ErrorStatus::INVALID_ARGUMENT, {}); 369 return Void(); 370 } 371 372 // return successful 373 cb(ErrorStatus::NONE, std::move(memories)); 374 return Void(); 375 } 376 377 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots( 378 const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) { 379 std::lock_guard<std::mutex> guard(mMutex); 380 381 // retrieve (or bind) all slots corresponding to memories 382 std::vector<int32_t> slots; 383 slots.reserve(memories.size()); 384 for (size_t i = 0; i < memories.size(); ++i) { 385 slots.push_back(getSlotLocked(memories[i], keys[i])); 386 } 387 return slots; 388 } 389 390 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory( 391 intptr_t key) { 392 std::lock_guard<std::mutex> guard(mMutex); 393 394 auto iter = mMemoryIdToSlot.find(key); 395 if (iter == mMemoryIdToSlot.end()) { 396 return {false, 0}; 397 } 398 const int32_t slot = iter->second; 399 mMemoryIdToSlot.erase(key); 400 mMemoryCache[slot] = {}; 401 mFreeSlots.push(slot); 402 return {true, slot}; 403 } 404 405 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory, 406 intptr_t key) { 407 auto iter = mMemoryIdToSlot.find(key); 408 if (iter == mMemoryIdToSlot.end()) { 409 const int32_t slot = allocateSlotLocked(); 410 mMemoryIdToSlot[key] = slot; 411 mMemoryCache[slot] = memory; 412 return slot; 413 } else { 414 const int32_t slot = iter->second; 415 return slot; 416 } 417 } 418 419 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() { 420 constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max(); 421 422 // if there is a free slot, use it 423 if (mFreeSlots.size() > 0) { 424 const int32_t slot = mFreeSlots.top(); 425 mFreeSlots.pop(); 426 return slot; 427 } 428 429 // otherwise use a slot for the first time 430 CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!"; 431 const int32_t slot = static_cast<int32_t>(mMemoryCache.size()); 432 mMemoryCache.emplace_back(); 433 434 return slot; 435 } 436 437 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create( 438 const sp<IPreparedModel>& preparedModel, bool blocking) { 439 // check inputs 440 if (preparedModel == nullptr) { 441 LOG(ERROR) << "ExecutionBurstController::create passed a nullptr"; 442 return nullptr; 443 } 444 445 // create callback object 446 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback(); 447 448 // create FMQ objects 449 auto [requestChannelSenderTemp, requestChannelDescriptor] = 450 RequestChannelSender::create(kExecutionBurstChannelLength, blocking); 451 auto [resultChannelReceiverTemp, resultChannelDescriptor] = 452 ResultChannelReceiver::create(kExecutionBurstChannelLength, blocking); 453 std::shared_ptr<RequestChannelSender> requestChannelSender = 454 std::move(requestChannelSenderTemp); 455 std::shared_ptr<ResultChannelReceiver> resultChannelReceiver = 456 std::move(resultChannelReceiverTemp); 457 458 // check FMQ objects 459 if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor || 460 !resultChannelDescriptor) { 461 LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue"; 462 return nullptr; 463 } 464 465 // configure burst 466 ErrorStatus errorStatus; 467 sp<IBurstContext> burstContext; 468 const Return<void> ret = preparedModel->configureExecutionBurst( 469 callback, *requestChannelDescriptor, *resultChannelDescriptor, 470 [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) { 471 errorStatus = status; 472 burstContext = context; 473 }); 474 475 // check burst 476 if (!ret.isOk()) { 477 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description " 478 << ret.description(); 479 return nullptr; 480 } 481 if (errorStatus != ErrorStatus::NONE) { 482 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status " 483 << toString(errorStatus); 484 return nullptr; 485 } 486 if (burstContext == nullptr) { 487 LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst"; 488 return nullptr; 489 } 490 491 // create death handler object 492 BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender, 493 resultChannelReceiver] { 494 requestChannelSender->invalidate(); 495 resultChannelReceiver->invalidate(); 496 }; 497 const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback); 498 499 // linkToDeath registers a callback that will be invoked on service death to 500 // proactively handle service crashes. If the linkToDeath call fails, 501 // asynchronous calls are susceptible to hangs if the service crashes before 502 // providing the response. 503 const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0); 504 if (!deathHandlerRet.isOk() || deathHandlerRet != true) { 505 LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient " 506 "for the IBurstContext object."; 507 return nullptr; 508 } 509 510 // make and return controller 511 return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver, 512 burstContext, callback, deathHandler); 513 } 514 515 ExecutionBurstController::ExecutionBurstController( 516 const std::shared_ptr<RequestChannelSender>& requestChannelSender, 517 const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver, 518 const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback, 519 const sp<hardware::hidl_death_recipient>& deathHandler) 520 : mRequestChannelSender(requestChannelSender), 521 mResultChannelReceiver(resultChannelReceiver), 522 mBurstContext(burstContext), 523 mMemoryCache(callback), 524 mDeathHandler(deathHandler) {} 525 526 ExecutionBurstController::~ExecutionBurstController() { 527 // It is safe to ignore any errors resulting from this unlinkToDeath call 528 // because the ExecutionBurstController object is already being destroyed 529 // and its underlying IBurstContext object is no longer being used by the NN 530 // runtime. 531 if (mDeathHandler) { 532 mBurstContext->unlinkToDeath(mDeathHandler).isOk(); 533 } 534 } 535 536 std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute( 537 const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) { 538 auto [status, outputShapes, timing, fallback] = tryCompute(request, measure, memoryIds); 539 (void)fallback; // ignore fallback field 540 return {status, std::move(outputShapes), timing}; 541 } 542 543 std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> 544 ExecutionBurstController::tryCompute(const Request& request, MeasureTiming measure, 545 const std::vector<intptr_t>& memoryIds) { 546 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute"); 547 548 std::lock_guard<std::mutex> guard(mMutex); 549 550 // send request packet 551 const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds); 552 const bool success = mRequestChannelSender->send(request, measure, slots); 553 if (!success) { 554 LOG(ERROR) << "Error sending FMQ packet"; 555 // only use fallback execution path if the packet could not be sent 556 return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/true}; 557 } 558 559 // get result packet 560 const auto result = mResultChannelReceiver->getBlocking(); 561 if (!result) { 562 LOG(ERROR) << "Error retrieving FMQ packet"; 563 // only use fallback execution path if the packet could not be sent 564 return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/false}; 565 } 566 567 // unpack results and return (only use fallback execution path if the 568 // packet could not be sent) 569 auto [status, outputShapes, timing] = std::move(*result); 570 return {status, std::move(outputShapes), timing, /*fallback=*/false}; 571 } 572 573 void ExecutionBurstController::freeMemory(intptr_t key) { 574 std::lock_guard<std::mutex> guard(mMutex); 575 576 bool valid; 577 int32_t slot; 578 std::tie(valid, slot) = mMemoryCache->freeMemory(key); 579 if (valid) { 580 mBurstContext->freeMemory(slot).isOk(); 581 } 582 } 583 584 } // namespace android::nn 585