Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "OperationsUtils.h"
     18 #define LOG_TAG "Operations"
     19 
     20 #include "HalInterfaces.h"
     21 #include "IndexedShapeWrapper.h"
     22 #include "OperationResolver.h"
     23 
     24 namespace android {
     25 namespace nn {
     26 namespace dequantize {
     27 
     28 constexpr uint32_t kNumInputs = 1;
     29 constexpr uint32_t kInputTensor = 0;
     30 
     31 constexpr uint32_t kNumOutputs = 1;
     32 constexpr uint32_t kOutputTensor = 0;
     33 
     34 namespace {
     35 
     36 template <typename InputType, typename OutputType>
     37 bool compute(const InputType* inputData, const Shape& inputShape, OutputType* outputData) {
     38     const int numElements = getNumberOfElements(inputShape);
     39     const int32_t zeroPoint = inputShape.offset;
     40     const float scale = inputShape.scale;
     41     for (int i = 0; i < numElements; ++i) {
     42         const int32_t value = inputData[i];
     43         outputData[i] = static_cast<OutputType>(scale * (value - zeroPoint));
     44     }
     45     return true;
     46 }
     47 
     48 template <typename OutputType>
     49 bool computePerChannel(const int8_t* inputData, const Shape& inputShape, OutputType* outputData) {
     50     // First we calculate a stride which is the number of elements we need to
     51     // skip to change an index along a dimension with different quantization
     52     // scales.
     53     const int channelDim = inputShape.extraParams.channelQuant().channelDim;
     54     int stride = 1;
     55     for (int i = getNumberOfDimensions(inputShape) - 1; i > channelDim; --i) {
     56         stride *= getSizeOfDimension(inputShape, i);
     57     }
     58 
     59     const int numElements = getNumberOfElements(inputShape);
     60     const int32_t zeroPoint = inputShape.offset;
     61 
     62     for (int i = 0; i < numElements; ++i) {
     63         // To get current index along the quantized dimension we calculate how
     64         // many even |strides| we looped through and take this number modulo the
     65         // size of the dimension (so that we don't have an overflow if the
     66         // channelDim is not 0).
     67         const int scaleIndex = (i / stride) % getSizeOfDimension(inputShape, channelDim);
     68         const float scale = inputShape.extraParams.channelQuant().scales[scaleIndex];
     69         const int32_t value = inputData[i];
     70         outputData[i] = static_cast<OutputType>(scale * (value - zeroPoint));
     71     }
     72     return true;
     73 }
     74 
     75 }  // namespace
     76 
     77 bool validate(const IOperationValidationContext* context) {
     78     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
     79     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
     80 
     81     const OperandType inputType = context->getInputType(kInputTensor);
     82     const OperandType outputType = context->getOutputType(kOutputTensor);
     83 
     84     if (inputType == OperandType::TENSOR_QUANT8_ASYMM &&
     85         outputType == OperandType::TENSOR_FLOAT32) {
     86         return validateHalVersion(context, HalVersion::V1_0);
     87     }
     88 
     89     NN_RET_CHECK(inputType == OperandType::TENSOR_QUANT8_ASYMM ||
     90                  inputType == OperandType::TENSOR_QUANT8_SYMM ||
     91                  inputType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)
     92             << "Unsupported input operand type for DEQUANTIZE op: " << toString(inputType);
     93     NN_RET_CHECK(outputType == OperandType::TENSOR_FLOAT16 ||
     94                  outputType == OperandType::TENSOR_FLOAT32)
     95             << "Unsupported output operand type for DEQUANTIZE op: " << toString(outputType);
     96     return validateHalVersion(context, HalVersion::V1_2);
     97 }
     98 
     99 bool prepare(IOperationExecutionContext* context) {
    100     const Shape& input = context->getInputShape(kInputTensor);
    101     Shape output = context->getOutputShape(kOutputTensor);
    102     output.dimensions = input.dimensions;
    103     return context->setOutputShape(kOutputTensor, output);
    104 }
    105 
    106 bool execute(IOperationExecutionContext* context) {
    107     // Bypass execution in the case of zero-sized input.
    108     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    109 
    110     const OperandType inputType = context->getInputType(kInputTensor);
    111     const OperandType outputType = context->getOutputType(kOutputTensor);
    112 
    113     const Shape& inputShape = context->getInputShape(kInputTensor);
    114     if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    115         const uint8_t* inputBuffer = context->getInputBuffer<uint8_t>(kInputTensor);
    116         if (outputType == OperandType::TENSOR_FLOAT16) {
    117             return compute(inputBuffer, inputShape,
    118                            context->getOutputBuffer<_Float16>(kOutputTensor));
    119         } else if (outputType == OperandType::TENSOR_FLOAT32) {
    120             return compute(inputBuffer, inputShape, context->getOutputBuffer<float>(kOutputTensor));
    121         }
    122     } else if (inputType == OperandType::TENSOR_QUANT8_SYMM) {
    123         const int8_t* inputBuffer = context->getInputBuffer<int8_t>(kInputTensor);
    124         if (outputType == OperandType::TENSOR_FLOAT16) {
    125             return compute(inputBuffer, inputShape,
    126                            context->getOutputBuffer<_Float16>(kOutputTensor));
    127         } else if (outputType == OperandType::TENSOR_FLOAT32) {
    128             return compute(inputBuffer, inputShape, context->getOutputBuffer<float>(kOutputTensor));
    129         }
    130     } else if (inputType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
    131         const int8_t* inputBuffer = context->getInputBuffer<int8_t>(kInputTensor);
    132         if (outputType == OperandType::TENSOR_FLOAT16) {
    133             return computePerChannel(inputBuffer, inputShape,
    134                                      context->getOutputBuffer<_Float16>(kOutputTensor));
    135         } else if (outputType == OperandType::TENSOR_FLOAT32) {
    136             return computePerChannel(inputBuffer, inputShape,
    137                                      context->getOutputBuffer<float>(kOutputTensor));
    138         }
    139     }
    140     NN_RET_CHECK_FAIL() << "Unsupported tensor types combination for dequantize op. (input type: "
    141                         << toString(inputType) << " output type: " << toString(outputType) << ")";
    142 }
    143 
    144 }  // namespace dequantize
    145 
    146 NN_REGISTER_OPERATION(DEQUANTIZE, "DEQUANTIZE", dequantize::validate, dequantize::prepare,
    147                       dequantize::execute, .allowZeroSizedInput = true);
    148 
    149 }  // namespace nn
    150 }  // namespace android
    151