Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
     18 #define ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
     19 
     20 #include "HalInterfaces.h"
     21 
     22 #include <android-base/macros.h>
     23 #include <fmq/MessageQueue.h>
     24 #include <hidl/MQDescriptor.h>
     25 
     26 #include <atomic>
     27 #include <memory>
     28 #include <optional>
     29 #include <thread>
     30 #include <vector>
     31 
     32 namespace android::nn {
     33 
     34 using ::android::hardware::MQDescriptorSync;
     35 using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
     36 using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
     37 
     38 /**
     39  * Function to serialize results.
     40  *
     41  * Prefer calling ResultChannelSender::send.
     42  *
     43  * @param errorStatus Status of the execution.
     44  * @param outputShapes Dynamic shapes of the output tensors.
     45  * @param timing Timing information of the execution.
     46  * @return Serialized FMQ result data.
     47  */
     48 std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
     49                                       const std::vector<OutputShape>& outputShapes, Timing timing);
     50 
     51 /**
     52  * Deserialize the FMQ request data.
     53  *
     54  * The three resulting fields are the Request object (where Request::pools is
     55  * empty), slot identifiers (which are stand-ins for Request::pools), and
     56  * whether timing information must be collected for the run.
     57  *
     58  * @param data Serialized FMQ request data.
     59  * @return Request object if successfully deserialized, std::nullopt otherwise.
     60  */
     61 std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
     62         const std::vector<FmqRequestDatum>& data);
     63 
     64 /**
     65  * RequestChannelReceiver is responsible for waiting on the channel until the
     66  * packet is available, extracting the packet from the channel, and
     67  * deserializing the packet.
     68  *
     69  * Because the receiver can wait on a packet that may never come (e.g., because
     70  * the sending side of the packet has been closed), this object can be
     71  * invalidating, unblocking the receiver.
     72  */
     73 class RequestChannelReceiver {
     74     using FmqRequestChannel =
     75             hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>;
     76 
     77    public:
     78     /**
     79      * Create the receiving end of a request channel.
     80      *
     81      * Prefer this call over the constructor.
     82      *
     83      * @param requestChannel Descriptor for the request channel.
     84      * @return RequestChannelReceiver on successful creation, nullptr otherwise.
     85      */
     86     static std::unique_ptr<RequestChannelReceiver> create(
     87             const FmqRequestDescriptor& requestChannel);
     88 
     89     /**
     90      * Get the request from the channel.
     91      *
     92      * This method will block until either:
     93      * 1) The packet has been retrieved, or
     94      * 2) The receiver has been invalidated
     95      *
     96      * @return Request object if successfully received, std::nullopt if error or
     97      *     if the receiver object was invalidated.
     98      */
     99     std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> getBlocking();
    100 
    101     /**
    102      * Method to mark the channel as invalid, unblocking any current or future
    103      * calls to RequestChannelReceiver::getBlocking.
    104      */
    105     void invalidate();
    106 
    107     RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking);
    108 
    109    private:
    110     std::optional<std::vector<FmqRequestDatum>> getPacketBlocking();
    111 
    112     const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
    113     std::atomic<bool> mTeardown{false};
    114     const bool mBlocking;
    115 };
    116 
    117 /**
    118  * ResultChannelSender is responsible for serializing the result packet of
    119  * information, sending it on the result channel, and signaling that the data is
    120  * available.
    121  */
    122 class ResultChannelSender {
    123     using FmqResultChannel =
    124             hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>;
    125 
    126    public:
    127     /**
    128      * Create the sending end of a result channel.
    129      *
    130      * Prefer this call over the constructor.
    131      *
    132      * @param resultChannel Descriptor for the result channel.
    133      * @return ResultChannelSender on successful creation, nullptr otherwise.
    134      */
    135     static std::unique_ptr<ResultChannelSender> create(const FmqResultDescriptor& resultChannel);
    136 
    137     /**
    138      * Send the result to the channel.
    139      *
    140      * @param errorStatus Status of the execution.
    141      * @param outputShapes Dynamic shapes of the output tensors.
    142      * @param timing Timing information of the execution.
    143      * @return 'true' on successful send, 'false' otherwise.
    144      */
    145     bool send(ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing);
    146 
    147     // prefer calling ResultChannelSender::send
    148     bool sendPacket(const std::vector<FmqResultDatum>& packet);
    149 
    150     ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking);
    151 
    152    private:
    153     const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
    154     const bool mBlocking;
    155 };
    156 
    157 /**
    158  * The ExecutionBurstServer class is responsible for waiting for and
    159  * deserializing a request object from a FMQ, performing the inference, and
    160  * serializing the result back across another FMQ.
    161  */
    162 class ExecutionBurstServer : public IBurstContext {
    163     DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer);
    164 
    165    public:
    166     /**
    167      * IBurstExecutorWithCache is a callback object passed to
    168      * ExecutionBurstServer's factory function that is used to perform an
    169      * execution. Because some memory resources are needed across multiple
    170      * executions, this object also contains a local cache that can directly be
    171      * used in the execution.
    172      *
    173      * ExecutionBurstServer will never access its IBurstExecutorWithCache object
    174      * with concurrent calls.
    175      */
    176     class IBurstExecutorWithCache {
    177         DISALLOW_COPY_AND_ASSIGN(IBurstExecutorWithCache);
    178 
    179        public:
    180         IBurstExecutorWithCache() = default;
    181         virtual ~IBurstExecutorWithCache() = default;
    182 
    183         /**
    184          * Checks if a cache entry specified by a slot is present in the cache.
    185          *
    186          * @param slot Identifier of the cache entry.
    187          * @return 'true' if the cache entry is present in the cache, 'false'
    188          *     otherwise.
    189          */
    190         virtual bool isCacheEntryPresent(int32_t slot) const = 0;
    191 
    192         /**
    193          * Adds an entry specified by a slot to the cache.
    194          *
    195          * The caller of this function must ensure that the cache entry that is
    196          * being added is not already present in the cache. This can be checked
    197          * via isCacheEntryPresent.
    198          *
    199          * @param memory Memory resource to be cached.
    200          * @param slot Slot identifier corresponding to the memory resource.
    201          */
    202         virtual void addCacheEntry(const hidl_memory& memory, int32_t slot) = 0;
    203 
    204         /**
    205          * Removes an entry specified by a slot from the cache.
    206          *
    207          * If the cache entry corresponding to the slot number does not exist,
    208          * the call does nothing.
    209          *
    210          * @param slot Slot identifier corresponding to the memory resource.
    211          */
    212         virtual void removeCacheEntry(int32_t slot) = 0;
    213 
    214         /**
    215          * Perform an execution.
    216          *
    217          * @param request Request object with inputs and outputs specified.
    218          *     Request::pools is empty, and DataLocation::poolIndex instead
    219          *     refers to the 'slots' argument as if it were Request::pools.
    220          * @param slots Slots corresponding to the cached memory entries to be
    221          *     used.
    222          * @param measure Whether timing information is requested for the
    223          *     execution.
    224          * @return Result of the execution, including the status of the
    225          *     execution, dynamic output shapes, and any timing information.
    226          */
    227         virtual std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
    228                 const Request& request, const std::vector<int32_t>& slots,
    229                 MeasureTiming measure) = 0;
    230     };
    231 
    232     /**
    233      * Create automated context to manage FMQ-based executions.
    234      *
    235      * This function is intended to be used by a service to automatically:
    236      * 1) Receive data from a provided FMQ
    237      * 2) Execute a model with the given information
    238      * 3) Send the result to the created FMQ
    239      *
    240      * @param callback Callback used to retrieve memories corresponding to
    241      *     unrecognized slots.
    242      * @param requestChannel Input FMQ channel through which the client passes the
    243      *     request to the service.
    244      * @param resultChannel Output FMQ channel from which the client can retrieve
    245      *     the result of the execution.
    246      * @param executorWithCache Object which maintains a local cache of the
    247      *     memory pools and executes using the cached memory pools.
    248      * @result IBurstContext Handle to the burst context.
    249      */
    250     static sp<ExecutionBurstServer> create(
    251             const sp<IBurstCallback>& callback, const FmqRequestDescriptor& requestChannel,
    252             const FmqResultDescriptor& resultChannel,
    253             std::shared_ptr<IBurstExecutorWithCache> executorWithCache);
    254 
    255     /**
    256      * Create automated context to manage FMQ-based executions.
    257      *
    258      * This function is intended to be used by a service to automatically:
    259      * 1) Receive data from a provided FMQ
    260      * 2) Execute a model with the given information
    261      * 3) Send the result to the created FMQ
    262      *
    263      * @param callback Callback used to retrieve memories corresponding to
    264      *     unrecognized slots.
    265      * @param requestChannel Input FMQ channel through which the client passes the
    266      *     request to the service.
    267      * @param resultChannel Output FMQ channel from which the client can retrieve
    268      *     the result of the execution.
    269      * @param preparedModel PreparedModel that the burst object was created from.
    270      *     IPreparedModel::executeSynchronously will be used to perform the
    271      *     execution.
    272      * @result IBurstContext Handle to the burst context.
    273      */
    274     static sp<ExecutionBurstServer> create(const sp<IBurstCallback>& callback,
    275                                            const FmqRequestDescriptor& requestChannel,
    276                                            const FmqResultDescriptor& resultChannel,
    277                                            IPreparedModel* preparedModel);
    278 
    279     ExecutionBurstServer(const sp<IBurstCallback>& callback,
    280                          std::unique_ptr<RequestChannelReceiver> requestChannel,
    281                          std::unique_ptr<ResultChannelSender> resultChannel,
    282                          std::shared_ptr<IBurstExecutorWithCache> cachedExecutor);
    283     ~ExecutionBurstServer();
    284 
    285     // Used by the NN runtime to preemptively remove any stored memory.
    286     Return<void> freeMemory(int32_t slot) override;
    287 
    288    private:
    289     // Ensures all cache entries contained in mExecutorWithCache are present in
    290     // the cache. If they are not present, they are retrieved (via
    291     // IBurstCallback::getMemories) and added to mExecutorWithCache.
    292     //
    293     // This method is locked via mMutex when it is called.
    294     void ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots);
    295 
    296     // Work loop that will continue processing execution requests until the
    297     // ExecutionBurstServer object is freed.
    298     void task();
    299 
    300     std::thread mWorker;
    301     std::mutex mMutex;
    302     std::atomic<bool> mTeardown{false};
    303     const sp<IBurstCallback> mCallback;
    304     const std::unique_ptr<RequestChannelReceiver> mRequestChannelReceiver;
    305     const std::unique_ptr<ResultChannelSender> mResultChannelSender;
    306     const std::shared_ptr<IBurstExecutorWithCache> mExecutorWithCache;
    307 };
    308 
    309 }  // namespace android::nn
    310 
    311 #endif  // ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
    312