Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
     18 #define ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
     19 
     20 #include "HalInterfaces.h"
     21 
     22 #include <android-base/macros.h>
     23 #include <fmq/MessageQueue.h>
     24 #include <hidl/MQDescriptor.h>
     25 
     26 #include <atomic>
     27 #include <map>
     28 #include <memory>
     29 #include <mutex>
     30 #include <stack>
     31 #include <tuple>
     32 
     33 namespace android::nn {
     34 
     35 /**
     36  * Number of elements in the FMQ.
     37  */
     38 constexpr const size_t kExecutionBurstChannelLength = 1024;
     39 
     40 /**
     41  * Function to serialize a request.
     42  *
     43  * Prefer calling RequestChannelSender::send.
     44  *
     45  * @param request Request object without the pool information.
     46  * @param measure Whether to collect timing information for the execution.
     47  * @param memoryIds Slot identifiers corresponding to memory resources for the
     48  *     request.
     49  * @return Serialized FMQ request data.
     50  */
     51 std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure,
     52                                        const std::vector<int32_t>& slots);
     53 
     54 /**
     55  * Deserialize the FMQ result data.
     56  *
     57  * The three resulting fields are the status of the execution, the dynamic
     58  * shapes of the output tensors, and the timing information of the execution.
     59  *
     60  * @param data Serialized FMQ result data.
     61  * @return Result object if successfully deserialized, std::nullopt otherwise.
     62  */
     63 std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
     64         const std::vector<FmqResultDatum>& data);
     65 
     66 /**
     67  * ResultChannelReceiver is responsible for waiting on the channel until the
     68  * packet is available, extracting the packet from the channel, and
     69  * deserializing the packet.
     70  *
     71  * Because the receiver can wait on a packet that may never come (e.g., because
     72  * the sending side of the packet has been closed), this object can be
     73  * invalidating, unblocking the receiver.
     74  */
     75 class ResultChannelReceiver {
     76     using FmqResultDescriptor = ::android::hardware::MQDescriptorSync<FmqResultDatum>;
     77     using FmqResultChannel =
     78             hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>;
     79 
     80    public:
     81     /**
     82      * Create the receiving end of a result channel.
     83      *
     84      * Prefer this call over the constructor.
     85      *
     86      * @param channelLength Number of elements in the FMQ.
     87      * @param blocking 'true' if FMQ should use futex, 'false' if it should
     88      *     spin-wait.
     89      * @return A pair of ResultChannelReceiver and the FMQ descriptor on
     90      *     successful creation, both nullptr otherwise.
     91      */
     92     static std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*> create(
     93             size_t channelLength, bool blocking);
     94 
     95     /**
     96      * Get the result from the channel.
     97      *
     98      * This method will block until either:
     99      * 1) The packet has been retrieved, or
    100      * 2) The receiver has been invalidated
    101      *
    102      * @return Result object if successfully received, std::nullopt if error or
    103      *     if the receiver object was invalidated.
    104      */
    105     std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> getBlocking();
    106 
    107     /**
    108      * Method to mark the channel as invalid, unblocking any current or future
    109      * calls to ResultChannelReceiver::getBlocking.
    110      */
    111     void invalidate();
    112 
    113     // prefer calling ResultChannelReceiver::getBlocking
    114     std::optional<std::vector<FmqResultDatum>> getPacketBlocking();
    115 
    116     ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking);
    117 
    118    private:
    119     const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
    120     std::atomic<bool> mValid{true};
    121     const bool mBlocking;
    122 };
    123 
    124 /**
    125  * RequestChannelSender is responsible for serializing the result packet of
    126  * information, sending it on the result channel, and signaling that the data is
    127  * available.
    128  */
    129 class RequestChannelSender {
    130     using FmqRequestDescriptor = ::android::hardware::MQDescriptorSync<FmqRequestDatum>;
    131     using FmqRequestChannel =
    132             hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>;
    133 
    134    public:
    135     /**
    136      * Create the sending end of a request channel.
    137      *
    138      * Prefer this call over the constructor.
    139      *
    140      * @param channelLength Number of elements in the FMQ.
    141      * @param blocking 'true' if FMQ should use futex, 'false' if it should
    142      *     spin-wait.
    143      * @return A pair of ResultChannelReceiver and the FMQ descriptor on
    144      *     successful creation, both nullptr otherwise.
    145      */
    146     static std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*> create(
    147             size_t channelLength, bool blocking);
    148 
    149     /**
    150      * Send the request to the channel.
    151      *
    152      * @param request Request object without the pool information.
    153      * @param measure Whether to collect timing information for the execution.
    154      * @param memoryIds Slot identifiers corresponding to memory resources for
    155      *     the request.
    156      * @return 'true' on successful send, 'false' otherwise.
    157      */
    158     bool send(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots);
    159 
    160     /**
    161      * Method to mark the channel as invalid, causing all future calls to
    162      * RequestChannelSender::send to immediately return false without attempting
    163      * to send a message across the FMQ.
    164      */
    165     void invalidate();
    166 
    167     // prefer calling RequestChannelSender::send
    168     bool sendPacket(const std::vector<FmqRequestDatum>& packet);
    169 
    170     RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking);
    171 
    172    private:
    173     const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
    174     std::atomic<bool> mValid{true};
    175     const bool mBlocking;
    176 };
    177 
    178 /**
    179  * The ExecutionBurstController class manages both the serialization and
    180  * deserialization of data across FMQ, making it appear to the runtime as a
    181  * regular synchronous inference. Additionally, this class manages the burst's
    182  * memory cache.
    183  */
    184 class ExecutionBurstController {
    185     DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController);
    186 
    187    public:
    188     /**
    189      * NN runtime burst callback object and memory cache.
    190      *
    191      * ExecutionBurstCallback associates a hidl_memory object with a slot number
    192      * to be passed across FMQ. The ExecutionBurstServer can use this callback
    193      * to retrieve this hidl_memory corresponding to the slot via HIDL.
    194      *
    195      * Whenever a hidl_memory object is copied, it will duplicate the underlying
    196      * file descriptor. Because the NN runtime currently copies the hidl_memory
    197      * on each execution, it is difficult to associate hidl_memory objects with
    198      * previously cached hidl_memory objects. For this reason, callers of this
    199      * class must pair each hidl_memory object with an associated key. For
    200      * efficiency, if two hidl_memory objects represent the same underlying
    201      * buffer, they must use the same key.
    202      */
    203     class ExecutionBurstCallback : public IBurstCallback {
    204         DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);
    205 
    206        public:
    207         ExecutionBurstCallback() = default;
    208 
    209         Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
    210 
    211         /**
    212          * This function performs one of two different actions:
    213          * 1) If a key corresponding to a memory resource is unrecognized by the
    214          *    ExecutionBurstCallback object, the ExecutionBurstCallback object
    215          *    will allocate a slot, bind the memory to the slot, and return the
    216          *    slot identifier.
    217          * 2) If a key corresponding to a memory resource is recognized by the
    218          *    ExecutionBurstCallback object, the ExecutionBurstCallback object
    219          *    will return the existing slot identifier.
    220          *
    221          * @param memories Memory resources used in an inference.
    222          * @param keys Unique identifiers where each element corresponds to a
    223          *     memory resource element in "memories".
    224          * @return Unique slot identifiers where each returned slot element
    225          *     corresponds to a memory resource element in "memories".
    226          */
    227         std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories,
    228                                       const std::vector<intptr_t>& keys);
    229 
    230         /*
    231          * This function performs two different actions:
    232          * 1) Removes an entry from the cache (if present), including the local
    233          *    storage of the hidl_memory object. Note that this call does not
    234          *    free any corresponding hidl_memory object in ExecutionBurstServer,
    235          *    which is separately freed via IBurstContext::freeMemory.
    236          * 2) Return whether a cache entry was removed and which slot was removed if
    237          *    found. If the key did not to correspond to any entry in the cache, a
    238          *    slot number of 0 is returned. The slot number and whether the entry
    239          *    existed is useful so the same slot can be freed in the
    240          *    ExecutionBurstServer's cache via IBurstContext::freeMemory.
    241          */
    242         std::pair<bool, int32_t> freeMemory(intptr_t key);
    243 
    244        private:
    245         int32_t getSlotLocked(const hidl_memory& memory, intptr_t key);
    246         int32_t allocateSlotLocked();
    247 
    248         std::mutex mMutex;
    249         std::stack<int32_t, std::vector<int32_t>> mFreeSlots;
    250         std::map<intptr_t, int32_t> mMemoryIdToSlot;
    251         std::vector<hidl_memory> mMemoryCache;
    252     };
    253 
    254     /**
    255      * Creates a burst controller on a prepared model.
    256      *
    257      * Prefer this over ExecutionBurstController's constructor.
    258      *
    259      * @param preparedModel Model prepared for execution to execute on.
    260      * @param blocking 'true' if the FMQ should use a futex to perform blocking
    261      *     until data is available in a less responsive, but more energy
    262      *     efficient manner. 'false' if the FMQ should use spin-looping to
    263      *     wait until data is available in a more responsive, but less energy
    264      *     efficient manner.
    265      * @return ExecutionBurstController Execution burst controller object.
    266      */
    267     static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel,
    268                                                             bool blocking);
    269 
    270     // prefer calling ExecutionBurstController::create
    271     ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender,
    272                              const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
    273                              const sp<IBurstContext>& burstContext,
    274                              const sp<ExecutionBurstCallback>& callback,
    275                              const sp<hardware::hidl_death_recipient>& deathHandler = nullptr);
    276 
    277     // explicit destructor to unregister the death recipient
    278     ~ExecutionBurstController();
    279 
    280     /**
    281      * Execute a request on a model.
    282      *
    283      * @param request Arguments to be executed on a model.
    284      * @param measure Whether to collect timing measurements, either YES or NO
    285      * @param memoryIds Identifiers corresponding to each memory object in the
    286      *     request's pools.
    287      * @return A tuple of:
    288      *     - status of the execution
    289      *     - dynamic output shapes from the execution
    290      *     - any execution time measurements of the execution
    291      */
    292     std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute(
    293             const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
    294 
    295     // TODO: combine "compute" and "tryCompute" back into a single function.
    296     // "tryCompute" was created later to return the "fallback" boolean. This
    297     // could not be done directly in "compute" because the VTS test cases (which
    298     // test burst using "compute") had already been locked down and could not be
    299     // changed.
    300     /**
    301      * Execute a request on a model.
    302      *
    303      * @param request Arguments to be executed on a model.
    304      * @param measure Whether to collect timing measurements, either YES or NO
    305      * @param memoryIds Identifiers corresponding to each memory object in the
    306      *     request's pools.
    307      * @return A tuple of:
    308      *     - status of the execution
    309      *     - dynamic output shapes from the execution
    310      *     - any execution time measurements of the execution
    311      *     - whether or not a failed burst execution should be re-run using a
    312      *       different path (e.g., IPreparedModel::executeSynchronously)
    313      */
    314     std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> tryCompute(
    315             const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
    316 
    317     /**
    318      * Propagate a user's freeing of memory to the service.
    319      *
    320      * @param key Key corresponding to the memory object.
    321      */
    322     void freeMemory(intptr_t key);
    323 
    324    private:
    325     std::mutex mMutex;
    326     const std::shared_ptr<RequestChannelSender> mRequestChannelSender;
    327     const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver;
    328     const sp<IBurstContext> mBurstContext;
    329     const sp<ExecutionBurstCallback> mMemoryCache;
    330     const sp<hardware::hidl_death_recipient> mDeathHandler;
    331 };
    332 
    333 }  // namespace android::nn
    334 
    335 #endif  // ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
    336