1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "OperationsUtils.h" 18 #define LOG_TAG "Operations" 19 20 #include "HalInterfaces.h" 21 #include "IndexedShapeWrapper.h" 22 #include "OperationResolver.h" 23 24 namespace android { 25 namespace nn { 26 namespace dequantize { 27 28 constexpr uint32_t kNumInputs = 1; 29 constexpr uint32_t kInputTensor = 0; 30 31 constexpr uint32_t kNumOutputs = 1; 32 constexpr uint32_t kOutputTensor = 0; 33 34 namespace { 35 36 template <typename InputType, typename OutputType> 37 bool compute(const InputType* inputData, const Shape& inputShape, OutputType* outputData) { 38 const int numElements = getNumberOfElements(inputShape); 39 const int32_t zeroPoint = inputShape.offset; 40 const float scale = inputShape.scale; 41 for (int i = 0; i < numElements; ++i) { 42 const int32_t value = inputData[i]; 43 outputData[i] = static_cast<OutputType>(scale * (value - zeroPoint)); 44 } 45 return true; 46 } 47 48 template <typename OutputType> 49 bool computePerChannel(const int8_t* inputData, const Shape& inputShape, OutputType* outputData) { 50 // First we calculate a stride which is the number of elements we need to 51 // skip to change an index along a dimension with different quantization 52 // scales. 53 const int channelDim = inputShape.extraParams.channelQuant().channelDim; 54 int stride = 1; 55 for (int i = getNumberOfDimensions(inputShape) - 1; i > channelDim; --i) { 56 stride *= getSizeOfDimension(inputShape, i); 57 } 58 59 const int numElements = getNumberOfElements(inputShape); 60 const int32_t zeroPoint = inputShape.offset; 61 62 for (int i = 0; i < numElements; ++i) { 63 // To get current index along the quantized dimension we calculate how 64 // many even |strides| we looped through and take this number modulo the 65 // size of the dimension (so that we don't have an overflow if the 66 // channelDim is not 0). 67 const int scaleIndex = (i / stride) % getSizeOfDimension(inputShape, channelDim); 68 const float scale = inputShape.extraParams.channelQuant().scales[scaleIndex]; 69 const int32_t value = inputData[i]; 70 outputData[i] = static_cast<OutputType>(scale * (value - zeroPoint)); 71 } 72 return true; 73 } 74 75 } // namespace 76 77 bool validate(const IOperationValidationContext* context) { 78 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 79 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 80 81 const OperandType inputType = context->getInputType(kInputTensor); 82 const OperandType outputType = context->getOutputType(kOutputTensor); 83 84 if (inputType == OperandType::TENSOR_QUANT8_ASYMM && 85 outputType == OperandType::TENSOR_FLOAT32) { 86 return validateHalVersion(context, HalVersion::V1_0); 87 } 88 89 NN_RET_CHECK(inputType == OperandType::TENSOR_QUANT8_ASYMM || 90 inputType == OperandType::TENSOR_QUANT8_SYMM || 91 inputType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) 92 << "Unsupported input operand type for DEQUANTIZE op: " << toString(inputType); 93 NN_RET_CHECK(outputType == OperandType::TENSOR_FLOAT16 || 94 outputType == OperandType::TENSOR_FLOAT32) 95 << "Unsupported output operand type for DEQUANTIZE op: " << toString(outputType); 96 return validateHalVersion(context, HalVersion::V1_2); 97 } 98 99 bool prepare(IOperationExecutionContext* context) { 100 const Shape& input = context->getInputShape(kInputTensor); 101 Shape output = context->getOutputShape(kOutputTensor); 102 output.dimensions = input.dimensions; 103 return context->setOutputShape(kOutputTensor, output); 104 } 105 106 bool execute(IOperationExecutionContext* context) { 107 // Bypass execution in the case of zero-sized input. 108 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 109 110 const OperandType inputType = context->getInputType(kInputTensor); 111 const OperandType outputType = context->getOutputType(kOutputTensor); 112 113 const Shape& inputShape = context->getInputShape(kInputTensor); 114 if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { 115 const uint8_t* inputBuffer = context->getInputBuffer<uint8_t>(kInputTensor); 116 if (outputType == OperandType::TENSOR_FLOAT16) { 117 return compute(inputBuffer, inputShape, 118 context->getOutputBuffer<_Float16>(kOutputTensor)); 119 } else if (outputType == OperandType::TENSOR_FLOAT32) { 120 return compute(inputBuffer, inputShape, context->getOutputBuffer<float>(kOutputTensor)); 121 } 122 } else if (inputType == OperandType::TENSOR_QUANT8_SYMM) { 123 const int8_t* inputBuffer = context->getInputBuffer<int8_t>(kInputTensor); 124 if (outputType == OperandType::TENSOR_FLOAT16) { 125 return compute(inputBuffer, inputShape, 126 context->getOutputBuffer<_Float16>(kOutputTensor)); 127 } else if (outputType == OperandType::TENSOR_FLOAT32) { 128 return compute(inputBuffer, inputShape, context->getOutputBuffer<float>(kOutputTensor)); 129 } 130 } else if (inputType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) { 131 const int8_t* inputBuffer = context->getInputBuffer<int8_t>(kInputTensor); 132 if (outputType == OperandType::TENSOR_FLOAT16) { 133 return computePerChannel(inputBuffer, inputShape, 134 context->getOutputBuffer<_Float16>(kOutputTensor)); 135 } else if (outputType == OperandType::TENSOR_FLOAT32) { 136 return computePerChannel(inputBuffer, inputShape, 137 context->getOutputBuffer<float>(kOutputTensor)); 138 } 139 } 140 NN_RET_CHECK_FAIL() << "Unsupported tensor types combination for dequantize op. (input type: " 141 << toString(inputType) << " output type: " << toString(outputType) << ")"; 142 } 143 144 } // namespace dequantize 145 146 NN_REGISTER_OPERATION(DEQUANTIZE, "DEQUANTIZE", dequantize::validate, dequantize::prepare, 147 dequantize::execute, .allowZeroSizedInput = true); 148 149 } // namespace nn 150 } // namespace android 151