Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Provides C++ classes to more easily use the Neural Networks API.
     18 
     19 #ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
     20 #define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
     21 
     22 #include "NeuralNetworks.h"
     23 
     24 #include <math.h>
     25 #include <vector>
     26 
     27 namespace android {
     28 namespace nn {
     29 namespace wrapper {
     30 
     31 enum class Type {
     32     FLOAT32 = ANEURALNETWORKS_FLOAT32,
     33     INT32 = ANEURALNETWORKS_INT32,
     34     UINT32 = ANEURALNETWORKS_UINT32,
     35     TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
     36     TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32,
     37     TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
     38 };
     39 
     40 enum class ExecutePreference {
     41     PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
     42     PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
     43     PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
     44 };
     45 
     46 enum class Result {
     47     NO_ERROR = ANEURALNETWORKS_NO_ERROR,
     48     OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
     49     INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
     50     UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
     51     BAD_DATA = ANEURALNETWORKS_BAD_DATA,
     52     OP_FAILED = ANEURALNETWORKS_OP_FAILED,
     53     UNMAPPABLE = ANEURALNETWORKS_UNMAPPABLE,
     54     BAD_STATE = ANEURALNETWORKS_BAD_STATE,
     55 };
     56 
     57 struct OperandType {
     58     ANeuralNetworksOperandType operandType;
     59     std::vector<uint32_t> dimensions;
     60 
     61     OperandType(Type type, std::vector<uint32_t> d, float scale = 0.0f, int32_t zeroPoint = 0)
     62             : dimensions(std::move(d)) {
     63         operandType = {
     64             .type = static_cast<int32_t>(type),
     65             .dimensionCount = static_cast<uint32_t>(dimensions.size()),
     66             .dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr,
     67             .scale = scale,
     68             .zeroPoint = zeroPoint,
     69         };
     70     }
     71 };
     72 
     73 class Memory {
     74 public:
     75     Memory(size_t size, int protect, int fd, size_t offset) {
     76         mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
     77                  ANEURALNETWORKS_NO_ERROR;
     78     }
     79 
     80     ~Memory() { ANeuralNetworksMemory_free(mMemory); }
     81 
     82     // Disallow copy semantics to ensure the runtime object can only be freed
     83     // once. Copy semantics could be enabled if some sort of reference counting
     84     // or deep-copy system for runtime objects is added later.
     85     Memory(const Memory&) = delete;
     86     Memory& operator=(const Memory&) = delete;
     87 
     88     // Move semantics to remove access to the runtime object from the wrapper
     89     // object that is being moved. This ensures the runtime object will be
     90     // freed only once.
     91     Memory(Memory&& other) { *this = std::move(other); }
     92     Memory& operator=(Memory&& other) {
     93         if (this != &other) {
     94             ANeuralNetworksMemory_free(mMemory);
     95             mMemory = other.mMemory;
     96             mValid = other.mValid;
     97             other.mMemory = nullptr;
     98             other.mValid = false;
     99         }
    100         return *this;
    101     }
    102 
    103     ANeuralNetworksMemory* get() const { return mMemory; }
    104     bool isValid() const { return mValid; }
    105 
    106 private:
    107     ANeuralNetworksMemory* mMemory = nullptr;
    108     bool mValid = true;
    109 };
    110 
    111 class Model {
    112 public:
    113     Model() {
    114         // TODO handle the value returned by this call
    115         ANeuralNetworksModel_create(&mModel);
    116     }
    117     ~Model() { ANeuralNetworksModel_free(mModel); }
    118 
    119     // Disallow copy semantics to ensure the runtime object can only be freed
    120     // once. Copy semantics could be enabled if some sort of reference counting
    121     // or deep-copy system for runtime objects is added later.
    122     Model(const Model&) = delete;
    123     Model& operator=(const Model&) = delete;
    124 
    125     // Move semantics to remove access to the runtime object from the wrapper
    126     // object that is being moved. This ensures the runtime object will be
    127     // freed only once.
    128     Model(Model&& other) { *this = std::move(other); }
    129     Model& operator=(Model&& other) {
    130         if (this != &other) {
    131             ANeuralNetworksModel_free(mModel);
    132             mModel = other.mModel;
    133             mNextOperandId = other.mNextOperandId;
    134             mValid = other.mValid;
    135             other.mModel = nullptr;
    136             other.mNextOperandId = 0;
    137             other.mValid = false;
    138         }
    139         return *this;
    140     }
    141 
    142     Result finish() {
    143         if (mValid) {
    144             auto result = static_cast<Result>(ANeuralNetworksModel_finish(mModel));
    145             if (result != Result::NO_ERROR) {
    146                 mValid = false;
    147             }
    148             return result;
    149         } else {
    150             return Result::BAD_STATE;
    151         }
    152     }
    153 
    154     uint32_t addOperand(const OperandType* type) {
    155         if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
    156             ANEURALNETWORKS_NO_ERROR) {
    157             mValid = false;
    158         }
    159         return mNextOperandId++;
    160     }
    161 
    162     void setOperandValue(uint32_t index, const void* buffer, size_t length) {
    163         if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
    164             ANEURALNETWORKS_NO_ERROR) {
    165             mValid = false;
    166         }
    167     }
    168 
    169     void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
    170                                    size_t length) {
    171         if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
    172                                                            length) != ANEURALNETWORKS_NO_ERROR) {
    173             mValid = false;
    174         }
    175     }
    176 
    177     void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
    178                       const std::vector<uint32_t>& outputs) {
    179         if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()),
    180                                               inputs.data(), static_cast<uint32_t>(outputs.size()),
    181                                               outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
    182             mValid = false;
    183         }
    184     }
    185     void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs,
    186                                   const std::vector<uint32_t>& outputs) {
    187         if (ANeuralNetworksModel_identifyInputsAndOutputs(
    188                         mModel, static_cast<uint32_t>(inputs.size()), inputs.data(),
    189                         static_cast<uint32_t>(outputs.size()),
    190                         outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
    191             mValid = false;
    192         }
    193     }
    194 
    195     void relaxComputationFloat32toFloat16(bool isRelax) {
    196         if (ANeuralNetworksModel_relaxComputationFloat32toFloat16(mModel, isRelax) ==
    197                 ANEURALNETWORKS_NO_ERROR) {
    198             mRelaxed = isRelax;
    199         }
    200     }
    201 
    202     ANeuralNetworksModel* getHandle() const { return mModel; }
    203     bool isValid() const { return mValid; }
    204     bool isRelaxed() const { return mRelaxed; }
    205 
    206 private:
    207     ANeuralNetworksModel* mModel = nullptr;
    208     // We keep track of the operand ID as a convenience to the caller.
    209     uint32_t mNextOperandId = 0;
    210     bool mValid = true;
    211     bool mRelaxed = false;
    212 };
    213 
    214 class Event {
    215 public:
    216     Event() {}
    217     ~Event() { ANeuralNetworksEvent_free(mEvent); }
    218 
    219     // Disallow copy semantics to ensure the runtime object can only be freed
    220     // once. Copy semantics could be enabled if some sort of reference counting
    221     // or deep-copy system for runtime objects is added later.
    222     Event(const Event&) = delete;
    223     Event& operator=(const Event&) = delete;
    224 
    225     // Move semantics to remove access to the runtime object from the wrapper
    226     // object that is being moved. This ensures the runtime object will be
    227     // freed only once.
    228     Event(Event&& other) { *this = std::move(other); }
    229     Event& operator=(Event&& other) {
    230         if (this != &other) {
    231             ANeuralNetworksEvent_free(mEvent);
    232             mEvent = other.mEvent;
    233             other.mEvent = nullptr;
    234         }
    235         return *this;
    236     }
    237 
    238     Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); }
    239 
    240     // Only for use by Execution
    241     void set(ANeuralNetworksEvent* newEvent) {
    242         ANeuralNetworksEvent_free(mEvent);
    243         mEvent = newEvent;
    244     }
    245 
    246 private:
    247     ANeuralNetworksEvent* mEvent = nullptr;
    248 };
    249 
    250 class Compilation {
    251 public:
    252     Compilation(const Model* model) {
    253         int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
    254         if (result != 0) {
    255             // TODO Handle the error
    256         }
    257     }
    258 
    259     ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
    260 
    261     // Disallow copy semantics to ensure the runtime object can only be freed
    262     // once. Copy semantics could be enabled if some sort of reference counting
    263     // or deep-copy system for runtime objects is added later.
    264     Compilation(const Compilation&) = delete;
    265     Compilation& operator=(const Compilation&) = delete;
    266 
    267     // Move semantics to remove access to the runtime object from the wrapper
    268     // object that is being moved. This ensures the runtime object will be
    269     // freed only once.
    270     Compilation(Compilation&& other) { *this = std::move(other); }
    271     Compilation& operator=(Compilation&& other) {
    272         if (this != &other) {
    273             ANeuralNetworksCompilation_free(mCompilation);
    274             mCompilation = other.mCompilation;
    275             other.mCompilation = nullptr;
    276         }
    277         return *this;
    278     }
    279 
    280     Result setPreference(ExecutePreference preference) {
    281         return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
    282                     mCompilation, static_cast<int32_t>(preference)));
    283     }
    284 
    285     Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); }
    286 
    287     ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
    288 
    289 private:
    290     ANeuralNetworksCompilation* mCompilation = nullptr;
    291 };
    292 
    293 class Execution {
    294 public:
    295     Execution(const Compilation* compilation) {
    296         int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
    297         if (result != 0) {
    298             // TODO Handle the error
    299         }
    300     }
    301 
    302     ~Execution() { ANeuralNetworksExecution_free(mExecution); }
    303 
    304     // Disallow copy semantics to ensure the runtime object can only be freed
    305     // once. Copy semantics could be enabled if some sort of reference counting
    306     // or deep-copy system for runtime objects is added later.
    307     Execution(const Execution&) = delete;
    308     Execution& operator=(const Execution&) = delete;
    309 
    310     // Move semantics to remove access to the runtime object from the wrapper
    311     // object that is being moved. This ensures the runtime object will be
    312     // freed only once.
    313     Execution(Execution&& other) { *this = std::move(other); }
    314     Execution& operator=(Execution&& other) {
    315         if (this != &other) {
    316             ANeuralNetworksExecution_free(mExecution);
    317             mExecution = other.mExecution;
    318             other.mExecution = nullptr;
    319         }
    320         return *this;
    321     }
    322 
    323     Result setInput(uint32_t index, const void* buffer, size_t length,
    324                     const ANeuralNetworksOperandType* type = nullptr) {
    325         return static_cast<Result>(
    326                     ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
    327     }
    328 
    329     Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
    330                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
    331         return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory(
    332                     mExecution, index, type, memory->get(), offset, length));
    333     }
    334 
    335     Result setOutput(uint32_t index, void* buffer, size_t length,
    336                      const ANeuralNetworksOperandType* type = nullptr) {
    337         return static_cast<Result>(
    338                     ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length));
    339     }
    340 
    341     Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
    342                                uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
    343         return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory(
    344                     mExecution, index, type, memory->get(), offset, length));
    345     }
    346 
    347     Result startCompute(Event* event) {
    348         ANeuralNetworksEvent* ev = nullptr;
    349         Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev));
    350         event->set(ev);
    351         return result;
    352     }
    353 
    354     Result compute() {
    355         ANeuralNetworksEvent* event = nullptr;
    356         Result result =
    357                     static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event));
    358         if (result != Result::NO_ERROR) {
    359             return result;
    360         }
    361         // TODO how to manage the lifetime of events when multiple waiters is not
    362         // clear.
    363         result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
    364         ANeuralNetworksEvent_free(event);
    365         return result;
    366     }
    367 
    368 private:
    369     ANeuralNetworksExecution* mExecution = nullptr;
    370 };
    371 
    372 }  // namespace wrapper
    373 }  // namespace nn
    374 }  // namespace android
    375 
    376 #endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
    377