Home | History | Annotate | Download | only in fibonacci_extension
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "FibonacciDriver"
     18 
     19 #include "FibonacciDriver.h"
     20 
     21 #include "HalInterfaces.h"
     22 #include "NeuralNetworksExtensions.h"
     23 #include "OperationResolver.h"
     24 #include "OperationsUtils.h"
     25 #include "Utils.h"
     26 #include "ValidateHal.h"
     27 
     28 #include "FibonacciExtension.h"
     29 
     30 namespace android {
     31 namespace nn {
     32 namespace sample_driver {
     33 namespace {
     34 
     35 const uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
     36 const uint32_t kTypeWithinExtensionMask = (1 << kLowBitsType) - 1;
     37 
     38 namespace fibonacci_op {
     39 
     40 constexpr char kOperationName[] = "TEST_VENDOR_FIBONACCI";
     41 
     42 constexpr uint32_t kNumInputs = 1;
     43 constexpr uint32_t kInputN = 0;
     44 
     45 constexpr uint32_t kNumOutputs = 1;
     46 constexpr uint32_t kOutputTensor = 0;
     47 
     48 bool getFibonacciExtensionPrefix(const Model& model, uint16_t* prefix) {
     49     NN_RET_CHECK_EQ(model.extensionNameToPrefix.size(), 1u);  // Assumes no other extensions in use.
     50     NN_RET_CHECK_EQ(model.extensionNameToPrefix[0].name, TEST_VENDOR_FIBONACCI_EXTENSION_NAME);
     51     *prefix = model.extensionNameToPrefix[0].prefix;
     52     return true;
     53 }
     54 
     55 bool isFibonacciOperation(const Operation& operation, const Model& model) {
     56     int32_t operationType = static_cast<int32_t>(operation.type);
     57     uint16_t prefix;
     58     NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
     59     NN_RET_CHECK_EQ(operationType, (prefix << kLowBitsType) | TEST_VENDOR_FIBONACCI);
     60     return true;
     61 }
     62 
     63 bool validate(const Operation& operation, const Model& model) {
     64     NN_RET_CHECK(isFibonacciOperation(operation, model));
     65     NN_RET_CHECK_EQ(operation.inputs.size(), kNumInputs);
     66     NN_RET_CHECK_EQ(operation.outputs.size(), kNumOutputs);
     67     int32_t inputType = static_cast<int32_t>(model.operands[operation.inputs[0]].type);
     68     int32_t outputType = static_cast<int32_t>(model.operands[operation.outputs[0]].type);
     69     uint16_t prefix;
     70     NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix));
     71     NN_RET_CHECK(inputType == ((prefix << kLowBitsType) | TEST_VENDOR_INT64) ||
     72                  inputType == ANEURALNETWORKS_TENSOR_FLOAT32);
     73     NN_RET_CHECK(outputType == ((prefix << kLowBitsType) | TEST_VENDOR_TENSOR_QUANT64_ASYMM) ||
     74                  outputType == ANEURALNETWORKS_TENSOR_FLOAT32);
     75     return true;
     76 }
     77 
     78 bool prepare(IOperationExecutionContext* context) {
     79     int64_t n;
     80     if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
     81         n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
     82     } else {
     83         n = context->getInputValue<int64_t>(kInputN);
     84     }
     85     NN_RET_CHECK_GE(n, 1);
     86     Shape output = context->getOutputShape(kOutputTensor);
     87     output.dimensions = {static_cast<uint32_t>(n)};
     88     return context->setOutputShape(kOutputTensor, output);
     89 }
     90 
     91 template <typename ScaleT, typename ZeroPointT, typename OutputT>
     92 bool compute(int32_t n, ScaleT outputScale, ZeroPointT outputZeroPoint, OutputT* output) {
     93     // Compute the Fibonacci numbers.
     94     if (n >= 1) {
     95         output[0] = 1;
     96     }
     97     if (n >= 2) {
     98         output[1] = 1;
     99     }
    100     if (n >= 3) {
    101         for (int32_t i = 2; i < n; ++i) {
    102             output[i] = output[i - 1] + output[i - 2];
    103         }
    104     }
    105 
    106     // Quantize output.
    107     for (int32_t i = 0; i < n; ++i) {
    108         output[i] = output[i] / outputScale + outputZeroPoint;
    109     }
    110 
    111     return true;
    112 }
    113 
    114 bool execute(IOperationExecutionContext* context) {
    115     int64_t n;
    116     if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) {
    117         n = static_cast<int64_t>(context->getInputValue<float>(kInputN));
    118     } else {
    119         n = context->getInputValue<int64_t>(kInputN);
    120     }
    121     if (context->getOutputType(kOutputTensor) == OperandType::TENSOR_FLOAT32) {
    122         float* output = context->getOutputBuffer<float>(kOutputTensor);
    123         return compute(n, /*scale=*/1.0, /*zeroPoint=*/0, output);
    124     } else {
    125         uint64_t* output = context->getOutputBuffer<uint64_t>(kOutputTensor);
    126         Shape outputShape = context->getOutputShape(kOutputTensor);
    127         auto outputQuant = reinterpret_cast<const TestVendorQuant64AsymmParams*>(
    128                 outputShape.extraParams.extension().data());
    129         return compute(n, outputQuant->scale, outputQuant->zeroPoint, output);
    130     }
    131 }
    132 
    133 }  // namespace fibonacci_op
    134 }  // namespace
    135 
    136 const OperationRegistration* FibonacciOperationResolver::findOperation(
    137         OperationType operationType) const {
    138     // .validate is omitted because it's not used by the extension driver.
    139     static OperationRegistration operationRegistration(operationType, fibonacci_op::kOperationName,
    140                                                        nullptr, fibonacci_op::prepare,
    141                                                        fibonacci_op::execute, {});
    142     uint16_t prefix = static_cast<int32_t>(operationType) >> kLowBitsType;
    143     uint16_t typeWithinExtension = static_cast<int32_t>(operationType) & kTypeWithinExtensionMask;
    144     // Assumes no other extensions in use.
    145     return prefix != 0 && typeWithinExtension == TEST_VENDOR_FIBONACCI ? &operationRegistration
    146                                                                        : nullptr;
    147 }
    148 
    149 Return<void> FibonacciDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
    150     cb(ErrorStatus::NONE,
    151        {
    152                {
    153                        .name = TEST_VENDOR_FIBONACCI_EXTENSION_NAME,
    154                        .operandTypes =
    155                                {
    156                                        {
    157                                                .type = TEST_VENDOR_INT64,
    158                                                .isTensor = false,
    159                                                .byteSize = 8,
    160                                        },
    161                                        {
    162                                                .type = TEST_VENDOR_TENSOR_QUANT64_ASYMM,
    163                                                .isTensor = true,
    164                                                .byteSize = 8,
    165                                        },
    166                                },
    167                },
    168        });
    169     return Void();
    170 }
    171 
    172 Return<void> FibonacciDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
    173     android::nn::initVLogMask();
    174     VLOG(DRIVER) << "getCapabilities()";
    175     static const PerformanceInfo kPerf = {.execTime = 1.0f, .powerUsage = 1.0f};
    176     Capabilities capabilities = {.relaxedFloat32toFloat16PerformanceScalar = kPerf,
    177                                  .relaxedFloat32toFloat16PerformanceTensor = kPerf,
    178                                  .operandPerformance = nonExtensionOperandPerformance(kPerf)};
    179     cb(ErrorStatus::NONE, capabilities);
    180     return Void();
    181 }
    182 
    183 Return<void> FibonacciDriver::getSupportedOperations_1_2(const V1_2::Model& model,
    184                                                          getSupportedOperations_1_2_cb cb) {
    185     VLOG(DRIVER) << "getSupportedOperations()";
    186     if (!validateModel(model)) {
    187         cb(ErrorStatus::INVALID_ARGUMENT, {});
    188         return Void();
    189     }
    190     const size_t count = model.operations.size();
    191     std::vector<bool> supported(count);
    192     for (size_t i = 0; i < count; ++i) {
    193         const Operation& operation = model.operations[i];
    194         if (fibonacci_op::isFibonacciOperation(operation, model)) {
    195             if (!fibonacci_op::validate(operation, model)) {
    196                 cb(ErrorStatus::INVALID_ARGUMENT, {});
    197                 return Void();
    198             }
    199             supported[i] = true;
    200         }
    201     }
    202     cb(ErrorStatus::NONE, supported);
    203     return Void();
    204 }
    205 
    206 }  // namespace sample_driver
    207 }  // namespace nn
    208 }  // namespace android
    209