Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
     18 #define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
     19 
     20 #include "NeuralNetworksExtensions.h"
     21 #include "NeuralNetworksWrapper.h"
     22 
     23 #include <variant>
     24 
     25 namespace android {
     26 namespace nn {
     27 namespace extension_wrapper {
     28 
     29 using wrapper::SymmPerChannelQuantParams;
     30 using wrapper::Type;
     31 
     32 struct ExtensionOperandParams {
     33     std::vector<uint8_t> data;
     34 
     35     ExtensionOperandParams(std::vector<uint8_t> data) : data(std::move(data)) {}
     36 
     37     template <typename T>
     38     ExtensionOperandParams(const T& data)
     39         : ExtensionOperandParams(
     40                   std::vector(reinterpret_cast<const uint8_t*>(&data),
     41                               reinterpret_cast<const uint8_t*>(&data) + sizeof(data))) {
     42         static_assert(std::is_trivially_copyable<T>::value, "data must be trivially copyable");
     43     }
     44 };
     45 
     46 struct OperandType {
     47     using ExtraParams =
     48             std::variant<std::monostate, SymmPerChannelQuantParams, ExtensionOperandParams>;
     49 
     50     ANeuralNetworksOperandType operandType;
     51     std::vector<uint32_t> dimensions;
     52     ExtraParams extraParams;
     53 
     54     OperandType(const OperandType& other)
     55         : operandType(other.operandType),
     56           dimensions(other.dimensions),
     57           extraParams(other.extraParams) {
     58         operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
     59     }
     60 
     61     OperandType& operator=(const OperandType& other) {
     62         if (this != &other) {
     63             operandType = other.operandType;
     64             dimensions = other.dimensions;
     65             extraParams = other.extraParams;
     66             operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
     67         }
     68         return *this;
     69     }
     70 
     71     OperandType(Type type, std::vector<uint32_t> d, float scale = 0.0f, int32_t zeroPoint = 0,
     72                 ExtraParams&& extraParams = std::monostate())
     73         : dimensions(std::move(d)), extraParams(std::move(extraParams)) {
     74         operandType = {
     75                 .type = static_cast<int32_t>(type),
     76                 .dimensionCount = static_cast<uint32_t>(dimensions.size()),
     77                 .dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr,
     78                 .scale = scale,
     79                 .zeroPoint = zeroPoint,
     80         };
     81     }
     82 
     83     OperandType(Type type, std::vector<uint32_t> dimensions, float scale, int32_t zeroPoint,
     84                 SymmPerChannelQuantParams&& channelQuant)
     85         : OperandType(type, dimensions, scale, zeroPoint, ExtraParams(std::move(channelQuant))) {}
     86 
     87     OperandType(Type type, std::vector<uint32_t> dimensions, ExtraParams&& extraParams)
     88         : OperandType(type, dimensions, 0.0f, 0, std::move(extraParams)) {}
     89 };
     90 
     91 class Model : public wrapper::Model {
     92    public:
     93     using wrapper::Model::Model;  // Inherit constructors.
     94 
     95     int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) {
     96         int32_t result;
     97         if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension,
     98                                                          &result) != ANEURALNETWORKS_NO_ERROR) {
     99             mValid = false;
    100         }
    101         return result;
    102     }
    103 
    104     ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName,
    105                                                            uint16_t typeWithinExtension) {
    106         ANeuralNetworksOperationType result;
    107         if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName,
    108                                                            typeWithinExtension,
    109                                                            &result) != ANEURALNETWORKS_NO_ERROR) {
    110             mValid = false;
    111         }
    112         return result;
    113     }
    114 
    115     uint32_t addOperand(const OperandType* type) {
    116         if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
    117             ANEURALNETWORKS_NO_ERROR) {
    118             mValid = false;
    119         }
    120         if (std::holds_alternative<SymmPerChannelQuantParams>(type->extraParams)) {
    121             const auto& channelQuant = std::get<SymmPerChannelQuantParams>(type->extraParams);
    122             if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(
    123                         mModel, mNextOperandId, &channelQuant.params) != ANEURALNETWORKS_NO_ERROR) {
    124                 mValid = false;
    125             }
    126         } else if (std::holds_alternative<ExtensionOperandParams>(type->extraParams)) {
    127             const auto& extension = std::get<ExtensionOperandParams>(type->extraParams);
    128             if (ANeuralNetworksModel_setOperandExtensionData(
    129                         mModel, mNextOperandId, extension.data.data(), extension.data.size()) !=
    130                 ANEURALNETWORKS_NO_ERROR) {
    131                 mValid = false;
    132             }
    133         }
    134         return mNextOperandId++;
    135     }
    136 };
    137 
    138 }  // namespace extension_wrapper
    139 
    140 namespace wrapper {
    141 
    142 using ExtensionModel = extension_wrapper::Model;
    143 using ExtensionOperandType = extension_wrapper::OperandType;
    144 using ExtensionOperandParams = extension_wrapper::ExtensionOperandParams;
    145 
    146 }  // namespace wrapper
    147 }  // namespace nn
    148 }  // namespace android
    149 
    150 #endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
    151