Home | History | Annotate | Download | only in sample
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "SampleDriver"
     18 
     19 #include "SampleDriver.h"
     20 
     21 #include "CpuExecutor.h"
     22 #include "ExecutionBurstServer.h"
     23 #include "HalInterfaces.h"
     24 #include "Tracing.h"
     25 #include "ValidateHal.h"
     26 
     27 #include <android-base/logging.h>
     28 #include <hidl/LegacySupport.h>
     29 #include <chrono>
     30 #include <optional>
     31 #include <thread>
     32 
     33 namespace android {
     34 namespace nn {
     35 namespace sample_driver {
     36 
     37 namespace {
     38 
     39 using time_point = std::chrono::steady_clock::time_point;
     40 
     41 auto now() {
     42     return std::chrono::steady_clock::now();
     43 };
     44 
     45 auto microsecondsDuration(decltype(now()) end, decltype(now()) start) {
     46     return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
     47 };
     48 
     49 }  // namespace
     50 
     51 static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
     52 
     53 Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
     54     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
     55                  "SampleDriver::getCapabilities");
     56     return getCapabilities_1_2([&](ErrorStatus error, const V1_2::Capabilities& capabilities) {
     57         // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
     58         cb(error, convertToV1_0(capabilities));
     59     });
     60 }
     61 
     62 Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
     63     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
     64                  "SampleDriver::getCapabilities_1_1");
     65     return getCapabilities_1_2([&](ErrorStatus error, const V1_2::Capabilities& capabilities) {
     66         // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
     67         cb(error, convertToV1_1(capabilities));
     68     });
     69 }
     70 
     71 Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
     72     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
     73                  "SampleDriver::getVersionString");
     74     cb(ErrorStatus::NONE, "JUST_AN_EXAMPLE");
     75     return Void();
     76 }
     77 
     78 Return<void> SampleDriver::getType(getType_cb cb) {
     79     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
     80     cb(ErrorStatus::NONE, V1_2::DeviceType::CPU);
     81     return Void();
     82 }
     83 
     84 Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
     85     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
     86                  "SampleDriver::getSupportedExtensions");
     87     cb(ErrorStatus::NONE, {/* No extensions. */});
     88     return Void();
     89 }
     90 
     91 Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
     92                                                   getSupportedOperations_cb cb) {
     93     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
     94                  "SampleDriver::getSupportedOperations");
     95     if (!validateModel(model)) {
     96         VLOG(DRIVER) << "getSupportedOperations";
     97         std::vector<bool> supported;
     98         cb(ErrorStatus::INVALID_ARGUMENT, supported);
     99         return Void();
    100     }
    101     return getSupportedOperations_1_2(convertToV1_2(model), cb);
    102 }
    103 
    104 Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
    105                                                       getSupportedOperations_1_1_cb cb) {
    106     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
    107                  "SampleDriver::getSupportedOperations_1_1");
    108     if (!validateModel(model)) {
    109         VLOG(DRIVER) << "getSupportedOperations_1_1";
    110         std::vector<bool> supported;
    111         cb(ErrorStatus::INVALID_ARGUMENT, supported);
    112         return Void();
    113     }
    114     return getSupportedOperations_1_2(convertToV1_2(model), cb);
    115 }
    116 
    117 Return<void> SampleDriver::getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) {
    118     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
    119                  "SampleDriver::getNumberOfCacheFilesNeeded");
    120     // Set both numbers to be 0 for cache not supported.
    121     cb(ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
    122     return Void();
    123 }
    124 
    125 static void notify(const sp<V1_0::IPreparedModelCallback>& callback, const ErrorStatus& status,
    126                    const sp<SamplePreparedModel>& preparedModel) {
    127     callback->notify(status, preparedModel);
    128 }
    129 
    130 static void notify(const sp<V1_2::IPreparedModelCallback>& callback, const ErrorStatus& status,
    131                    const sp<SamplePreparedModel>& preparedModel) {
    132     callback->notify_1_2(status, preparedModel);
    133 }
    134 
    135 template <typename T_Model, typename T_IPreparedModelCallback>
    136 Return<ErrorStatus> prepareModelBase(const T_Model& model, const SampleDriver* driver,
    137                                      ExecutionPreference preference,
    138                                      const sp<T_IPreparedModelCallback>& callback) {
    139     if (callback.get() == nullptr) {
    140         LOG(ERROR) << "invalid callback passed to prepareModelBase";
    141         return ErrorStatus::INVALID_ARGUMENT;
    142     }
    143     if (VLOG_IS_ON(DRIVER)) {
    144         VLOG(DRIVER) << "prepareModelBase";
    145         logModelToInfo(model);
    146     }
    147     if (!validateModel(model) || !validateExecutionPreference(preference)) {
    148         notify(callback, ErrorStatus::INVALID_ARGUMENT, nullptr);
    149         return ErrorStatus::INVALID_ARGUMENT;
    150     }
    151 
    152     // TODO: make asynchronous later
    153     sp<SamplePreparedModel> preparedModel = new SamplePreparedModel(convertToV1_2(model), driver);
    154     if (!preparedModel->initialize()) {
    155         notify(callback, ErrorStatus::INVALID_ARGUMENT, nullptr);
    156         return ErrorStatus::INVALID_ARGUMENT;
    157     }
    158     notify(callback, ErrorStatus::NONE, preparedModel);
    159     return ErrorStatus::NONE;
    160 }
    161 
    162 Return<ErrorStatus> SampleDriver::prepareModel(const V1_0::Model& model,
    163                                                const sp<V1_0::IPreparedModelCallback>& callback) {
    164     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
    165     return prepareModelBase(model, this, ExecutionPreference::FAST_SINGLE_ANSWER, callback);
    166 }
    167 
    168 Return<ErrorStatus> SampleDriver::prepareModel_1_1(
    169         const V1_1::Model& model, ExecutionPreference preference,
    170         const sp<V1_0::IPreparedModelCallback>& callback) {
    171     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
    172     return prepareModelBase(model, this, preference, callback);
    173 }
    174 
    175 Return<ErrorStatus> SampleDriver::prepareModel_1_2(
    176         const V1_2::Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>&,
    177         const hidl_vec<hidl_handle>&, const HidlToken&,
    178         const sp<V1_2::IPreparedModelCallback>& callback) {
    179     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
    180     return prepareModelBase(model, this, preference, callback);
    181 }
    182 
    183 Return<ErrorStatus> SampleDriver::prepareModelFromCache(
    184         const hidl_vec<hidl_handle>&, const hidl_vec<hidl_handle>&, const HidlToken&,
    185         const sp<V1_2::IPreparedModelCallback>& callback) {
    186     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
    187                  "SampleDriver::prepareModelFromCache");
    188     callback->notify_1_2(ErrorStatus::GENERAL_FAILURE, nullptr);
    189     return ErrorStatus::GENERAL_FAILURE;
    190 }
    191 
    192 Return<DeviceStatus> SampleDriver::getStatus() {
    193     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED,
    194                  "SampleDriver::getStatus");
    195     VLOG(DRIVER) << "getStatus()";
    196     return DeviceStatus::AVAILABLE;
    197 }
    198 
    199 int SampleDriver::run() {
    200     android::hardware::configureRpcThreadpool(4, true);
    201     if (registerAsService(mName) != android::OK) {
    202         LOG(ERROR) << "Could not register service";
    203         return 1;
    204     }
    205     android::hardware::joinRpcThreadpool();
    206     LOG(ERROR) << "Service exited!";
    207     return 1;
    208 }
    209 
    210 bool SamplePreparedModel::initialize() {
    211     return setRunTimePoolInfosFromHidlMemories(&mPoolInfos, mModel.pools);
    212 }
    213 
    214 static Return<void> notify(const sp<V1_0::IExecutionCallback>& callback, const ErrorStatus& status,
    215                            const hidl_vec<OutputShape>&, Timing) {
    216     return callback->notify(status);
    217 }
    218 
    219 static Return<void> notify(const sp<V1_2::IExecutionCallback>& callback, const ErrorStatus& status,
    220                            const hidl_vec<OutputShape>& outputShapes, Timing timing) {
    221     return callback->notify_1_2(status, outputShapes, timing);
    222 }
    223 
    224 // TODO(xusongw): Let callback notify actual output shape once dynamic output shape
    225 //                is supported in CpuExecutor.
    226 template <typename T_IExecutionCallback>
    227 void asyncExecute(const Request& request, MeasureTiming measure, time_point driverStart,
    228                   const Model& model, const SampleDriver& driver,
    229                   const std::vector<RunTimePoolInfo>& poolInfos,
    230                   const sp<T_IExecutionCallback>& callback) {
    231     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
    232                  "SampleDriver::asyncExecute");
    233     std::vector<RunTimePoolInfo> requestPoolInfos;
    234     if (!setRunTimePoolInfosFromHidlMemories(&requestPoolInfos, request.pools)) {
    235         notify(callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
    236         return;
    237     }
    238 
    239     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
    240                         "SampleDriver::asyncExecute");
    241     CpuExecutor executor = driver.getExecutor();
    242     time_point driverEnd, deviceStart, deviceEnd;
    243     if (measure == MeasureTiming::YES) deviceStart = now();
    244     int n = executor.run(model, request, poolInfos, requestPoolInfos);
    245     if (measure == MeasureTiming::YES) deviceEnd = now();
    246     VLOG(DRIVER) << "executor.run returned " << n;
    247     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
    248     hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
    249     Return<void> returned;
    250     if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
    251         driverEnd = now();
    252         Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
    253                          .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
    254         VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
    255         returned = notify(callback, executionStatus, outputShapes, timing);
    256     } else {
    257         returned = notify(callback, executionStatus, outputShapes, kNoTiming);
    258     }
    259     if (!returned.isOk()) {
    260         LOG(ERROR) << " hidl callback failed to return properly: " << returned.description();
    261     }
    262 }
    263 
    264 template <typename T_IExecutionCallback>
    265 Return<ErrorStatus> executeBase(const Request& request, MeasureTiming measure, const Model& model,
    266                                 const SampleDriver& driver,
    267                                 const std::vector<RunTimePoolInfo>& poolInfos,
    268                                 const sp<T_IExecutionCallback>& callback) {
    269     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
    270     VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
    271 
    272     time_point driverStart;
    273     if (measure == MeasureTiming::YES) driverStart = now();
    274 
    275     if (callback.get() == nullptr) {
    276         LOG(ERROR) << "invalid callback passed to executeBase";
    277         return ErrorStatus::INVALID_ARGUMENT;
    278     }
    279     if (!validateRequest(request, model)) {
    280         notify(callback, ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
    281         return ErrorStatus::INVALID_ARGUMENT;
    282     }
    283 
    284     // This thread is intentionally detached because the sample driver service
    285     // is expected to live forever.
    286     std::thread([&model, &driver, &poolInfos, request, measure, driverStart, callback] {
    287         asyncExecute(request, measure, driverStart, model, driver, poolInfos, callback);
    288     })
    289             .detach();
    290 
    291     return ErrorStatus::NONE;
    292 }
    293 
    294 Return<ErrorStatus> SamplePreparedModel::execute(const Request& request,
    295                                                  const sp<V1_0::IExecutionCallback>& callback) {
    296     return executeBase(request, MeasureTiming::NO, mModel, *mDriver, mPoolInfos, callback);
    297 }
    298 
    299 Return<ErrorStatus> SamplePreparedModel::execute_1_2(const Request& request, MeasureTiming measure,
    300                                                      const sp<V1_2::IExecutionCallback>& callback) {
    301     return executeBase(request, measure, mModel, *mDriver, mPoolInfos, callback);
    302 }
    303 
    304 Return<void> SamplePreparedModel::executeSynchronously(const Request& request,
    305                                                        MeasureTiming measure,
    306                                                        executeSynchronously_cb cb) {
    307     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
    308                  "SampleDriver::executeSynchronously");
    309     VLOG(DRIVER) << "executeSynchronously(" << SHOW_IF_DEBUG(toString(request)) << ")";
    310 
    311     time_point driverStart, driverEnd, deviceStart, deviceEnd;
    312     if (measure == MeasureTiming::YES) driverStart = now();
    313 
    314     if (!validateRequest(request, mModel)) {
    315         cb(ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
    316         return Void();
    317     }
    318 
    319     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
    320                         "SampleDriver::executeSynchronously");
    321     std::vector<RunTimePoolInfo> requestPoolInfos;
    322     if (!setRunTimePoolInfosFromHidlMemories(&requestPoolInfos, request.pools)) {
    323         cb(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
    324         return Void();
    325     }
    326 
    327     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
    328                         "SampleDriver::executeSynchronously");
    329     CpuExecutor executor = mDriver->getExecutor();
    330     if (measure == MeasureTiming::YES) deviceStart = now();
    331     int n = executor.run(mModel, request, mPoolInfos, requestPoolInfos);
    332     if (measure == MeasureTiming::YES) deviceEnd = now();
    333     VLOG(DRIVER) << "executor.run returned " << n;
    334     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
    335     hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
    336     if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
    337         driverEnd = now();
    338         Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
    339                          .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
    340         VLOG(DRIVER) << "executeSynchronously timing = " << toString(timing);
    341         cb(executionStatus, outputShapes, timing);
    342     } else {
    343         cb(executionStatus, outputShapes, kNoTiming);
    344     }
    345     return Void();
    346 }
    347 
    348 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
    349 // the mapping until either (1) the memory is freed in the runtime, or (2) the
    350 // burst object is destroyed. This allows for subsequent executions operating on
    351 // pools that have been used before to reuse the mapping instead of mapping and
    352 // unmapping the memory on each execution.
    353 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
    354    public:
    355     BurstExecutorWithCache(const Model& model, const SampleDriver* driver,
    356                            const std::vector<RunTimePoolInfo>& poolInfos)
    357         : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
    358 
    359     bool isCacheEntryPresent(int32_t slot) const override {
    360         const auto it = mMemoryCache.find(slot);
    361         return (it != mMemoryCache.end()) && it->second.has_value();
    362     }
    363 
    364     void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
    365         mMemoryCache[slot] = RunTimePoolInfo::createFromHidlMemory(memory);
    366     }
    367 
    368     void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
    369 
    370     std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
    371             const Request& request, const std::vector<int32_t>& slots,
    372             MeasureTiming measure) override {
    373         NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
    374                      "BurstExecutorWithCache::execute");
    375 
    376         time_point driverStart, driverEnd, deviceStart, deviceEnd;
    377         if (measure == MeasureTiming::YES) driverStart = now();
    378 
    379         // ensure all relevant pools are valid
    380         if (!std::all_of(slots.begin(), slots.end(),
    381                          [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
    382             return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
    383         }
    384 
    385         // finish the request object (for validation)
    386         hidl_vec<hidl_memory> pools(slots.size());
    387         std::transform(slots.begin(), slots.end(), pools.begin(),
    388                        [this](int32_t slot) { return mMemoryCache[slot]->getHidlMemory(); });
    389         Request fullRequest = request;
    390         fullRequest.pools = std::move(pools);
    391 
    392         // validate request object against the model
    393         if (!validateRequest(fullRequest, mModel)) {
    394             return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
    395         }
    396 
    397         // select relevant entries from cache
    398         std::vector<RunTimePoolInfo> requestPoolInfos;
    399         requestPoolInfos.reserve(slots.size());
    400         std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
    401                        [this](int32_t slot) { return *mMemoryCache[slot]; });
    402 
    403         // execution
    404         CpuExecutor executor = mDriver->getExecutor();
    405         if (measure == MeasureTiming::YES) deviceStart = now();
    406         int n = executor.run(mModel, request, mModelPoolInfos, requestPoolInfos);
    407         if (measure == MeasureTiming::YES) deviceEnd = now();
    408         VLOG(DRIVER) << "executor.run returned " << n;
    409         ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
    410         hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
    411         if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
    412             driverEnd = now();
    413             Timing timing = {
    414                     .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
    415                     .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
    416             VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
    417             return std::make_tuple(executionStatus, outputShapes, timing);
    418         } else {
    419             return std::make_tuple(executionStatus, outputShapes, kNoTiming);
    420         }
    421     }
    422 
    423    private:
    424     const Model mModel;
    425     const SampleDriver* const mDriver;
    426     const std::vector<RunTimePoolInfo> mModelPoolInfos;
    427     std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache;  // cached requestPoolInfos
    428 };
    429 
    430 Return<void> SamplePreparedModel::configureExecutionBurst(
    431         const sp<V1_2::IBurstCallback>& callback,
    432         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
    433         const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
    434         configureExecutionBurst_cb cb) {
    435     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
    436                  "SampleDriver::configureExecutionBurst");
    437 
    438     // Alternatively, the burst could be configured via:
    439     // const sp<V1_2::IBurstContext> burst =
    440     //         ExecutionBurstServer::create(callback, requestChannel,
    441     //                                      resultChannel, this);
    442     //
    443     // However, this alternative representation does not include a memory map
    444     // caching optimization, and adds overhead.
    445     const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
    446             std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
    447     const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
    448             callback, requestChannel, resultChannel, executorWithCache);
    449 
    450     if (burst == nullptr) {
    451         cb(ErrorStatus::GENERAL_FAILURE, {});
    452     } else {
    453         cb(ErrorStatus::NONE, burst);
    454     }
    455 
    456     return Void();
    457 }
    458 
    459 } // namespace sample_driver
    460 } // namespace nn
    461 } // namespace android
    462