Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "ExecutionBurstController"
     18 
     19 #include "ExecutionBurstController.h"
     20 
     21 #include <android-base/logging.h>
     22 #include <cstring>
     23 #include <limits>
     24 #include <string>
     25 #include "Tracing.h"
     26 
     27 namespace android::nn {
     28 namespace {
     29 
     30 using ::android::hardware::MQDescriptorSync;
     31 using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
     32 using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
     33 
     34 constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
     35                               std::numeric_limits<uint64_t>::max()};
     36 
     37 class BurstContextDeathHandler : public hardware::hidl_death_recipient {
     38    public:
     39     using Callback = std::function<void()>;
     40 
     41     BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
     42         CHECK(onDeathCallback != nullptr);
     43     }
     44 
     45     void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
     46         LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
     47         mOnDeathCallback();
     48     }
     49 
     50    private:
     51     const Callback mOnDeathCallback;
     52 };
     53 
     54 }  // anonymous namespace
     55 
     56 // serialize a request into a packet
     57 std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure,
     58                                        const std::vector<int32_t>& slots) {
     59     // count how many elements need to be sent for a request
     60     size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
     61     for (const auto& input : request.inputs) {
     62         count += input.dimensions.size();
     63     }
     64     for (const auto& output : request.outputs) {
     65         count += output.dimensions.size();
     66     }
     67 
     68     // create buffer to temporarily store elements
     69     std::vector<FmqRequestDatum> data;
     70     data.reserve(count);
     71 
     72     // package packetInfo
     73     {
     74         FmqRequestDatum datum;
     75         datum.packetInformation(
     76                 {/*.packetSize=*/static_cast<uint32_t>(count),
     77                  /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
     78                  /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
     79                  /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
     80         data.push_back(datum);
     81     }
     82 
     83     // package input data
     84     for (const auto& input : request.inputs) {
     85         // package operand information
     86         FmqRequestDatum datum;
     87         datum.inputOperandInformation(
     88                 {/*.hasNoValue=*/input.hasNoValue,
     89                  /*.location=*/input.location,
     90                  /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
     91         data.push_back(datum);
     92 
     93         // package operand dimensions
     94         for (uint32_t dimension : input.dimensions) {
     95             FmqRequestDatum datum;
     96             datum.inputOperandDimensionValue(dimension);
     97             data.push_back(datum);
     98         }
     99     }
    100 
    101     // package output data
    102     for (const auto& output : request.outputs) {
    103         // package operand information
    104         FmqRequestDatum datum;
    105         datum.outputOperandInformation(
    106                 {/*.hasNoValue=*/output.hasNoValue,
    107                  /*.location=*/output.location,
    108                  /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
    109         data.push_back(datum);
    110 
    111         // package operand dimensions
    112         for (uint32_t dimension : output.dimensions) {
    113             FmqRequestDatum datum;
    114             datum.outputOperandDimensionValue(dimension);
    115             data.push_back(datum);
    116         }
    117     }
    118 
    119     // package pool identifier
    120     for (int32_t slot : slots) {
    121         FmqRequestDatum datum;
    122         datum.poolIdentifier(slot);
    123         data.push_back(datum);
    124     }
    125 
    126     // package measureTiming
    127     {
    128         FmqRequestDatum datum;
    129         datum.measureTiming(measure);
    130         data.push_back(datum);
    131     }
    132 
    133     // return packet
    134     return data;
    135 }
    136 
    137 // deserialize a packet into the result
    138 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
    139         const std::vector<FmqResultDatum>& data) {
    140     using discriminator = FmqResultDatum::hidl_discriminator;
    141 
    142     std::vector<OutputShape> outputShapes;
    143     size_t index = 0;
    144 
    145     // validate packet information
    146     if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
    147         LOG(ERROR) << "FMQ Result packet ill-formed";
    148         return std::nullopt;
    149     }
    150 
    151     // unpackage packet information
    152     const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
    153     index++;
    154     const uint32_t packetSize = packetInfo.packetSize;
    155     const ErrorStatus errorStatus = packetInfo.errorStatus;
    156     const uint32_t numberOfOperands = packetInfo.numberOfOperands;
    157 
    158     // verify packet size
    159     if (data.size() != packetSize) {
    160         LOG(ERROR) << "FMQ Result packet ill-formed";
    161         return std::nullopt;
    162     }
    163 
    164     // unpackage operands
    165     for (size_t operand = 0; operand < numberOfOperands; ++operand) {
    166         // validate operand information
    167         if (data[index].getDiscriminator() != discriminator::operandInformation) {
    168             LOG(ERROR) << "FMQ Result packet ill-formed";
    169             return std::nullopt;
    170         }
    171 
    172         // unpackage operand information
    173         const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
    174         index++;
    175         const bool isSufficient = operandInfo.isSufficient;
    176         const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
    177 
    178         // unpackage operand dimensions
    179         std::vector<uint32_t> dimensions;
    180         dimensions.reserve(numberOfDimensions);
    181         for (size_t i = 0; i < numberOfDimensions; ++i) {
    182             // validate dimension
    183             if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
    184                 LOG(ERROR) << "FMQ Result packet ill-formed";
    185                 return std::nullopt;
    186             }
    187 
    188             // unpackage dimension
    189             const uint32_t dimension = data[index].operandDimensionValue();
    190             index++;
    191 
    192             // store result
    193             dimensions.push_back(dimension);
    194         }
    195 
    196         // store result
    197         outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
    198     }
    199 
    200     // validate execution timing
    201     if (data[index].getDiscriminator() != discriminator::executionTiming) {
    202         LOG(ERROR) << "FMQ Result packet ill-formed";
    203         return std::nullopt;
    204     }
    205 
    206     // unpackage execution timing
    207     const Timing timing = data[index].executionTiming();
    208     index++;
    209 
    210     // validate packet information
    211     if (index != packetSize) {
    212         LOG(ERROR) << "FMQ Result packet ill-formed";
    213         return std::nullopt;
    214     }
    215 
    216     // return result
    217     return std::make_tuple(errorStatus, std::move(outputShapes), timing);
    218 }
    219 
    220 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
    221 ResultChannelReceiver::create(size_t channelLength, bool blocking) {
    222     std::unique_ptr<FmqResultChannel> fmqResultChannel =
    223             std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/blocking);
    224     if (!fmqResultChannel->isValid()) {
    225         LOG(ERROR) << "Unable to create ResultChannelReceiver";
    226         return {nullptr, nullptr};
    227     }
    228     const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
    229     return std::make_pair(
    230             std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), blocking),
    231             descriptor);
    232 }
    233 
    234 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
    235                                              bool blocking)
    236     : mFmqResultChannel(std::move(fmqResultChannel)), mBlocking(blocking) {}
    237 
    238 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>>
    239 ResultChannelReceiver::getBlocking() {
    240     const auto packet = getPacketBlocking();
    241     if (!packet) {
    242         return std::nullopt;
    243     }
    244 
    245     return deserialize(*packet);
    246 }
    247 
    248 void ResultChannelReceiver::invalidate() {
    249     mValid = false;
    250 
    251     // force unblock
    252     // ExecutionBurstController waits on a result packet after sending a
    253     // request. If the driver containing ExecutionBurstServer crashes, the
    254     // controller will still be waiting on the futex (assuming mBlocking is
    255     // true). This force unblock wakes up any thread waiting on the futex.
    256     if (mBlocking) {
    257         // TODO: look for a different/better way to signal/notify the futex to
    258         // wake up any thread waiting on it
    259         FmqResultDatum datum;
    260         datum.packetInformation({/*.packetSize=*/0, /*.errorStatus=*/ErrorStatus::GENERAL_FAILURE,
    261                                  /*.numberOfOperands=*/0});
    262         mFmqResultChannel->writeBlocking(&datum, 1);
    263     }
    264 }
    265 
    266 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
    267     using discriminator = FmqResultDatum::hidl_discriminator;
    268 
    269     if (!mValid) {
    270         return std::nullopt;
    271     }
    272 
    273     // wait for result packet and read first element of result packet
    274     FmqResultDatum datum;
    275     bool success = true;
    276     if (mBlocking) {
    277         success = mFmqResultChannel->readBlocking(&datum, 1);
    278     } else {
    279         while ((success = mValid.load(std::memory_order_relaxed)) &&
    280                !mFmqResultChannel->read(&datum, 1)) {
    281         }
    282     }
    283 
    284     // retrieve remaining elements
    285     // NOTE: all of the data is already available at this point, so there's no
    286     // need to do a blocking wait to wait for more data. This is known because
    287     // in FMQ, all writes are published (made available) atomically. Currently,
    288     // the producer always publishes the entire packet in one function call, so
    289     // if the first element of the packet is available, the remaining elements
    290     // are also available.
    291     const size_t count = mFmqResultChannel->availableToRead();
    292     std::vector<FmqResultDatum> packet(count + 1);
    293     std::memcpy(&packet.front(), &datum, sizeof(datum));
    294     success &= mFmqResultChannel->read(packet.data() + 1, count);
    295 
    296     if (!mValid) {
    297         return std::nullopt;
    298     }
    299 
    300     // ensure packet was successfully received
    301     if (!success) {
    302         LOG(ERROR) << "Error receiving packet";
    303         return std::nullopt;
    304     }
    305 
    306     return std::make_optional(std::move(packet));
    307 }
    308 
    309 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
    310 RequestChannelSender::create(size_t channelLength, bool blocking) {
    311     std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
    312             std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/blocking);
    313     if (!fmqRequestChannel->isValid()) {
    314         LOG(ERROR) << "Unable to create RequestChannelSender";
    315         return {nullptr, nullptr};
    316     }
    317     const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
    318     return std::make_pair(
    319             std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel), blocking),
    320             descriptor);
    321 }
    322 
    323 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
    324                                            bool blocking)
    325     : mFmqRequestChannel(std::move(fmqRequestChannel)), mBlocking(blocking) {}
    326 
    327 bool RequestChannelSender::send(const Request& request, MeasureTiming measure,
    328                                 const std::vector<int32_t>& slots) {
    329     const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
    330     return sendPacket(serialized);
    331 }
    332 
    333 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
    334     if (!mValid) {
    335         return false;
    336     }
    337 
    338     if (packet.size() > mFmqRequestChannel->availableToWrite()) {
    339         LOG(ERROR)
    340                 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
    341         return false;
    342     }
    343 
    344     if (mBlocking) {
    345         return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
    346     } else {
    347         return mFmqRequestChannel->write(packet.data(), packet.size());
    348     }
    349 }
    350 
    351 void RequestChannelSender::invalidate() {
    352     mValid = false;
    353 }
    354 
    355 Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
    356         const hidl_vec<int32_t>& slots, getMemories_cb cb) {
    357     std::lock_guard<std::mutex> guard(mMutex);
    358 
    359     // get all memories
    360     hidl_vec<hidl_memory> memories(slots.size());
    361     std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
    362         return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
    363     });
    364 
    365     // ensure all memories are valid
    366     if (!std::all_of(memories.begin(), memories.end(),
    367                      [](const hidl_memory& memory) { return memory.valid(); })) {
    368         cb(ErrorStatus::INVALID_ARGUMENT, {});
    369         return Void();
    370     }
    371 
    372     // return successful
    373     cb(ErrorStatus::NONE, std::move(memories));
    374     return Void();
    375 }
    376 
    377 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
    378         const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) {
    379     std::lock_guard<std::mutex> guard(mMutex);
    380 
    381     // retrieve (or bind) all slots corresponding to memories
    382     std::vector<int32_t> slots;
    383     slots.reserve(memories.size());
    384     for (size_t i = 0; i < memories.size(); ++i) {
    385         slots.push_back(getSlotLocked(memories[i], keys[i]));
    386     }
    387     return slots;
    388 }
    389 
    390 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
    391         intptr_t key) {
    392     std::lock_guard<std::mutex> guard(mMutex);
    393 
    394     auto iter = mMemoryIdToSlot.find(key);
    395     if (iter == mMemoryIdToSlot.end()) {
    396         return {false, 0};
    397     }
    398     const int32_t slot = iter->second;
    399     mMemoryIdToSlot.erase(key);
    400     mMemoryCache[slot] = {};
    401     mFreeSlots.push(slot);
    402     return {true, slot};
    403 }
    404 
    405 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
    406                                                                         intptr_t key) {
    407     auto iter = mMemoryIdToSlot.find(key);
    408     if (iter == mMemoryIdToSlot.end()) {
    409         const int32_t slot = allocateSlotLocked();
    410         mMemoryIdToSlot[key] = slot;
    411         mMemoryCache[slot] = memory;
    412         return slot;
    413     } else {
    414         const int32_t slot = iter->second;
    415         return slot;
    416     }
    417 }
    418 
    419 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
    420     constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
    421 
    422     // if there is a free slot, use it
    423     if (mFreeSlots.size() > 0) {
    424         const int32_t slot = mFreeSlots.top();
    425         mFreeSlots.pop();
    426         return slot;
    427     }
    428 
    429     // otherwise use a slot for the first time
    430     CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
    431     const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
    432     mMemoryCache.emplace_back();
    433 
    434     return slot;
    435 }
    436 
    437 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
    438         const sp<IPreparedModel>& preparedModel, bool blocking) {
    439     // check inputs
    440     if (preparedModel == nullptr) {
    441         LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
    442         return nullptr;
    443     }
    444 
    445     // create callback object
    446     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
    447 
    448     // create FMQ objects
    449     auto [requestChannelSenderTemp, requestChannelDescriptor] =
    450             RequestChannelSender::create(kExecutionBurstChannelLength, blocking);
    451     auto [resultChannelReceiverTemp, resultChannelDescriptor] =
    452             ResultChannelReceiver::create(kExecutionBurstChannelLength, blocking);
    453     std::shared_ptr<RequestChannelSender> requestChannelSender =
    454             std::move(requestChannelSenderTemp);
    455     std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
    456             std::move(resultChannelReceiverTemp);
    457 
    458     // check FMQ objects
    459     if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
    460         !resultChannelDescriptor) {
    461         LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
    462         return nullptr;
    463     }
    464 
    465     // configure burst
    466     ErrorStatus errorStatus;
    467     sp<IBurstContext> burstContext;
    468     const Return<void> ret = preparedModel->configureExecutionBurst(
    469             callback, *requestChannelDescriptor, *resultChannelDescriptor,
    470             [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
    471                 errorStatus = status;
    472                 burstContext = context;
    473             });
    474 
    475     // check burst
    476     if (!ret.isOk()) {
    477         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
    478                    << ret.description();
    479         return nullptr;
    480     }
    481     if (errorStatus != ErrorStatus::NONE) {
    482         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
    483                    << toString(errorStatus);
    484         return nullptr;
    485     }
    486     if (burstContext == nullptr) {
    487         LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
    488         return nullptr;
    489     }
    490 
    491     // create death handler object
    492     BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
    493                                                           resultChannelReceiver] {
    494         requestChannelSender->invalidate();
    495         resultChannelReceiver->invalidate();
    496     };
    497     const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
    498 
    499     // linkToDeath registers a callback that will be invoked on service death to
    500     // proactively handle service crashes. If the linkToDeath call fails,
    501     // asynchronous calls are susceptible to hangs if the service crashes before
    502     // providing the response.
    503     const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
    504     if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
    505         LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
    506                       "for the IBurstContext object.";
    507         return nullptr;
    508     }
    509 
    510     // make and return controller
    511     return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
    512                                                       burstContext, callback, deathHandler);
    513 }
    514 
    515 ExecutionBurstController::ExecutionBurstController(
    516         const std::shared_ptr<RequestChannelSender>& requestChannelSender,
    517         const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
    518         const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
    519         const sp<hardware::hidl_death_recipient>& deathHandler)
    520     : mRequestChannelSender(requestChannelSender),
    521       mResultChannelReceiver(resultChannelReceiver),
    522       mBurstContext(burstContext),
    523       mMemoryCache(callback),
    524       mDeathHandler(deathHandler) {}
    525 
    526 ExecutionBurstController::~ExecutionBurstController() {
    527     // It is safe to ignore any errors resulting from this unlinkToDeath call
    528     // because the ExecutionBurstController object is already being destroyed
    529     // and its underlying IBurstContext object is no longer being used by the NN
    530     // runtime.
    531     if (mDeathHandler) {
    532         mBurstContext->unlinkToDeath(mDeathHandler).isOk();
    533     }
    534 }
    535 
    536 std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute(
    537         const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
    538     auto [status, outputShapes, timing, fallback] = tryCompute(request, measure, memoryIds);
    539     (void)fallback;  // ignore fallback field
    540     return {status, std::move(outputShapes), timing};
    541 }
    542 
    543 std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool>
    544 ExecutionBurstController::tryCompute(const Request& request, MeasureTiming measure,
    545                                      const std::vector<intptr_t>& memoryIds) {
    546     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
    547 
    548     std::lock_guard<std::mutex> guard(mMutex);
    549 
    550     // send request packet
    551     const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
    552     const bool success = mRequestChannelSender->send(request, measure, slots);
    553     if (!success) {
    554         LOG(ERROR) << "Error sending FMQ packet";
    555         // only use fallback execution path if the packet could not be sent
    556         return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/true};
    557     }
    558 
    559     // get result packet
    560     const auto result = mResultChannelReceiver->getBlocking();
    561     if (!result) {
    562         LOG(ERROR) << "Error retrieving FMQ packet";
    563         // only use fallback execution path if the packet could not be sent
    564         return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/false};
    565     }
    566 
    567     // unpack results and return (only use fallback execution path if the
    568     // packet could not be sent)
    569     auto [status, outputShapes, timing] = std::move(*result);
    570     return {status, std::move(outputShapes), timing, /*fallback=*/false};
    571 }
    572 
    573 void ExecutionBurstController::freeMemory(intptr_t key) {
    574     std::lock_guard<std::mutex> guard(mMutex);
    575 
    576     bool valid;
    577     int32_t slot;
    578     std::tie(valid, slot) = mMemoryCache->freeMemory(key);
    579     if (valid) {
    580         mBurstContext->freeMemory(slot).isOk();
    581     }
    582 }
    583 
    584 }  // namespace android::nn
    585