Home | History | Annotate | Download | only in sample
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_SAMPLE_DRIVER_SAMPLE_DRIVER_H
     18 #define ANDROID_ML_NN_SAMPLE_DRIVER_SAMPLE_DRIVER_H
     19 
     20 #include "CpuExecutor.h"
     21 #include "HalInterfaces.h"
     22 #include "NeuralNetworks.h"
     23 
     24 #include <string>
     25 
     26 namespace android {
     27 namespace nn {
     28 namespace sample_driver {
     29 
     30 using ::android::hardware::MQDescriptorSync;
     31 using HidlToken = hidl_array<uint8_t, ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN>;
     32 
     33 // Base class used to create sample drivers for the NN HAL.  This class
     34 // provides some implementation of the more common functions.
     35 //
     36 // Since these drivers simulate hardware, they must run the computations
     37 // on the CPU.  An actual driver would not do that.
     38 class SampleDriver : public IDevice {
     39    public:
     40     SampleDriver(const char* name,
     41                  const IOperationResolver* operationResolver = BuiltinOperationResolver::get())
     42         : mName(name), mOperationResolver(operationResolver) {
     43         android::nn::initVLogMask();
     44     }
     45     ~SampleDriver() override {}
     46     Return<void> getCapabilities(getCapabilities_cb cb) override;
     47     Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override;
     48     Return<void> getVersionString(getVersionString_cb cb) override;
     49     Return<void> getType(getType_cb cb) override;
     50     Return<void> getSupportedExtensions(getSupportedExtensions_cb) override;
     51     Return<void> getSupportedOperations(const V1_0::Model& model,
     52                                         getSupportedOperations_cb cb) override;
     53     Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
     54                                             getSupportedOperations_1_1_cb cb) override;
     55     Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override;
     56     Return<ErrorStatus> prepareModel(const V1_0::Model& model,
     57                                      const sp<V1_0::IPreparedModelCallback>& callback) override;
     58     Return<ErrorStatus> prepareModel_1_1(const V1_1::Model& model, ExecutionPreference preference,
     59                                          const sp<V1_0::IPreparedModelCallback>& callback) override;
     60     Return<ErrorStatus> prepareModel_1_2(const V1_2::Model& model, ExecutionPreference preference,
     61                                          const hidl_vec<hidl_handle>& modelCache,
     62                                          const hidl_vec<hidl_handle>& dataCache,
     63                                          const HidlToken& token,
     64                                          const sp<V1_2::IPreparedModelCallback>& callback) override;
     65     Return<ErrorStatus> prepareModelFromCache(
     66             const hidl_vec<hidl_handle>& modelCache, const hidl_vec<hidl_handle>& dataCache,
     67             const HidlToken& token, const sp<V1_2::IPreparedModelCallback>& callback) override;
     68     Return<DeviceStatus> getStatus() override;
     69 
     70     // Starts and runs the driver service.  Typically called from main().
     71     // This will return only once the service shuts down.
     72     int run();
     73 
     74     CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); }
     75 
     76    protected:
     77     std::string mName;
     78     const IOperationResolver* mOperationResolver;
     79 };
     80 
     81 class SamplePreparedModel : public IPreparedModel {
     82    public:
     83     SamplePreparedModel(const Model& model, const SampleDriver* driver)
     84         : mModel(model), mDriver(driver) {}
     85     ~SamplePreparedModel() override {}
     86     bool initialize();
     87     Return<ErrorStatus> execute(const Request& request,
     88                                 const sp<V1_0::IExecutionCallback>& callback) override;
     89     Return<ErrorStatus> execute_1_2(const Request& request, MeasureTiming measure,
     90                                     const sp<V1_2::IExecutionCallback>& callback) override;
     91     Return<void> executeSynchronously(const Request& request, MeasureTiming measure,
     92                                       executeSynchronously_cb cb) override;
     93     Return<void> configureExecutionBurst(
     94             const sp<V1_2::IBurstCallback>& callback,
     95             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
     96             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
     97             configureExecutionBurst_cb cb) override;
     98 
     99    private:
    100     Model mModel;
    101     const SampleDriver* mDriver;
    102     std::vector<RunTimePoolInfo> mPoolInfos;
    103 };
    104 
    105 }  // namespace sample_driver
    106 }  // namespace nn
    107 }  // namespace android
    108 
    109 #endif  // ANDROID_ML_NN_SAMPLE_DRIVER_SAMPLE_DRIVER_H
    110