1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "OperationsUtils.h" 18 #define LOG_TAG "Operations" 19 20 #include "HalInterfaces.h" 21 #include "IndexedShapeWrapper.h" 22 #include "OperationResolver.h" 23 #include "Tracing.h" 24 25 #include <cmath> 26 27 namespace android { 28 namespace nn { 29 namespace quantize { 30 31 constexpr uint32_t kNumInputs = 1; 32 constexpr uint32_t kInputTensor = 0; 33 34 constexpr uint32_t kNumOutputs = 1; 35 constexpr uint32_t kOutputTensor = 0; 36 37 namespace { 38 39 bool quantizeFloat32ToQuant8(const float* inputData, uint8_t* outputData, 40 const Shape& outputShape) { 41 NNTRACE_COMP("quantizeFloat32ToQuant8"); 42 uint32_t size = getNumberOfElements(outputShape); 43 for (uint32_t i = 0; i < size; ++i) { 44 outputData[i] = static_cast<uint8_t>(std::max<float>( 45 0, std::min<float>(255, outputShape.offset + 46 std::round(inputData[i] / outputShape.scale)))); 47 } 48 return true; 49 } 50 51 bool quantizeFloat16ToQuant8(const _Float16* inputData, uint8_t* outputData, 52 const Shape& outputShape) { 53 NNTRACE_COMP("quantizeFloat16ToQuant8"); 54 uint32_t size = getNumberOfElements(outputShape); 55 for (uint32_t i = 0; i < size; ++i) { 56 outputData[i] = static_cast<uint8_t>(std::max<float>( 57 0, std::min<float>(255, outputShape.offset + 58 std::round(inputData[i] / outputShape.scale)))); 59 } 60 return true; 61 } 62 63 } // namespace 64 65 bool validate(const IOperationValidationContext* context) { 66 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 67 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 68 69 const OperandType inputType = context->getInputType(kInputTensor); 70 const OperandType outputType = context->getOutputType(kOutputTensor); 71 72 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 || 73 inputType == OperandType::TENSOR_FLOAT32) 74 << "Unsupported input operand type for QUANTIZE op: " << toString(inputType); 75 NN_RET_CHECK(outputType == OperandType::TENSOR_QUANT8_ASYMM) 76 << "Unsupported output operand type for QUANTIZE op: " << toString(outputType); 77 return validateHalVersion(context, HalVersion::V1_2); 78 } 79 80 bool prepare(IOperationExecutionContext* context) { 81 const Shape& input = context->getInputShape(kInputTensor); 82 Shape output = context->getOutputShape(kOutputTensor); 83 output.dimensions = input.dimensions; 84 return context->setOutputShape(kOutputTensor, output); 85 } 86 87 bool execute(IOperationExecutionContext* context) { 88 // Bypass execution in the case of zero-sized input. 89 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 90 91 const OperandType inputType = context->getInputType(kInputTensor); 92 if (inputType == OperandType::TENSOR_FLOAT32) { 93 return quantizeFloat32ToQuant8(context->getInputBuffer<float>(kInputTensor), 94 context->getOutputBuffer<uint8_t>(kOutputTensor), 95 context->getOutputShape(kOutputTensor)); 96 } else if (inputType == OperandType::TENSOR_FLOAT16) { 97 return quantizeFloat16ToQuant8(context->getInputBuffer<_Float16>(kInputTensor), 98 context->getOutputBuffer<uint8_t>(kOutputTensor), 99 context->getOutputShape(kOutputTensor)); 100 } 101 NN_RET_CHECK_FAIL() << "Unsupported tensor types combination for QUANTIZE op. (input type: " 102 << toString(inputType) 103 << " output type: " << toString(context->getOutputType(kOutputTensor)) 104 << ")"; 105 } 106 107 } // namespace quantize 108 109 NN_REGISTER_OPERATION(QUANTIZE, "QUANTIZE", quantize::validate, quantize::prepare, 110 quantize::execute, .allowZeroSizedInput = true); 111 112 } // namespace nn 113 } // namespace android 114