1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "HalInterfaces.h" 20 #include "IndexedShapeWrapper.h" 21 #include "OperationResolver.h" 22 #include "OperationsUtils.h" 23 #include "Tracing.h" 24 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h" 25 26 namespace android { 27 namespace nn { 28 namespace prelu { 29 30 constexpr char kOperationName[] = "PRELU"; 31 32 constexpr uint32_t kNumInputs = 2; 33 constexpr uint32_t kInputTensor = 0; 34 constexpr uint32_t kAlphaTensor = 1; 35 36 constexpr uint32_t kNumOutputs = 1; 37 constexpr uint32_t kOutputTensor = 0; 38 39 template <typename T> 40 inline bool eval(const std::function<T(const T&, const T&)>& func, const T* aData, 41 const Shape& aShape, const T* bData, const Shape& bShape, T* outputData, 42 const Shape& outputShape) { 43 IndexedShapeWrapper aShapeIndexed(aShape); 44 IndexedShapeWrapper bShapeIndexed(bShape); 45 IndexedShapeWrapper outputShapeIndexed(outputShape); 46 std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0); 47 bool lastIndex = false; 48 do { 49 uint32_t outputFlatIndex; 50 NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex)); 51 uint32_t aFlatIndex; 52 NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex)); 53 uint32_t bFlatIndex; 54 NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex)); 55 56 outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]); 57 58 NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex)); 59 } while (!lastIndex); 60 return true; 61 } 62 63 bool evalQuant8(const uint8_t* aData, const Shape& aShape, const uint8_t* bData, 64 const Shape& bShape, uint8_t* outputData, const Shape& outputShape) { 65 const int32_t input_offset = -aShape.offset; 66 const int32_t alpha_offset = -bShape.offset; 67 const int32_t output_offset = outputShape.offset; 68 const double input_product_scale = aShape.scale * bShape.scale; 69 const double real_multiplier_pos = aShape.scale / outputShape.scale; 70 const double real_multiplier_neg = input_product_scale / outputShape.scale; 71 int32_t output_multiplier_pos, output_shift_pos; 72 int32_t output_multiplier_neg, output_shift_neg; 73 tflite::QuantizeMultiplier(real_multiplier_pos, &output_multiplier_pos, &output_shift_pos); 74 tflite::QuantizeMultiplier(real_multiplier_neg, &output_multiplier_neg, &output_shift_neg); 75 return eval<uint8_t>( 76 [&](const uint8_t& val1, const uint8_t& val2) -> uint8_t { 77 const int32_t input = input_offset + static_cast<int32_t>(val1); 78 int32_t output_val; 79 if (input >= 0) { 80 output_val = 81 output_offset + tflite::MultiplyByQuantizedMultiplier( 82 input, output_multiplier_pos, output_shift_pos); 83 } else { 84 const int32_t alpha = alpha_offset + static_cast<int32_t>(val2); 85 output_val = output_offset + 86 tflite::MultiplyByQuantizedMultiplier( 87 input * alpha, output_multiplier_neg, output_shift_neg); 88 } 89 output_val = std::max(0, std::min(255, output_val)); 90 return static_cast<uint8_t>(output_val); 91 }, 92 aData, aShape, bData, bShape, outputData, outputShape); 93 } 94 95 bool validate(const IOperationValidationContext* context) { 96 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 97 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 98 auto inputType = context->getInputType(kInputTensor); 99 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 || 100 inputType == OperandType::TENSOR_FLOAT32 || 101 inputType == OperandType::TENSOR_QUANT8_ASYMM) 102 << "Unsupported tensor type for operation " << kOperationName; 103 NN_RET_CHECK(validateInputTypes(context, {inputType, inputType})); 104 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 105 return validateHalVersion(context, HalVersion::V1_2); 106 } 107 108 bool prepare(IOperationExecutionContext* context) { 109 Shape input = context->getInputShape(kInputTensor); 110 Shape alpha = context->getInputShape(kAlphaTensor); 111 NN_RET_CHECK(input.type == alpha.type); 112 Shape output = context->getOutputShape(kOutputTensor); 113 NN_RET_CHECK(calculateBroadcastedShape(input, alpha, &output)); 114 return context->setOutputShape(kOutputTensor, output); 115 } 116 117 bool execute(IOperationExecutionContext* context) { 118 switch (context->getInputType(kInputTensor)) { 119 case OperandType::TENSOR_FLOAT16: 120 return eval<_Float16>( 121 [](const _Float16& val1, const _Float16& val2) -> _Float16 { 122 return val1 >= 0.0f ? val1 : val1 * val2; 123 }, 124 context->getInputBuffer<_Float16>(kInputTensor), 125 context->getInputShape(kInputTensor), 126 context->getInputBuffer<_Float16>(kAlphaTensor), 127 context->getInputShape(kAlphaTensor), 128 context->getOutputBuffer<_Float16>(kOutputTensor), 129 context->getOutputShape(kOutputTensor)); 130 case OperandType::TENSOR_FLOAT32: 131 return eval<float>( 132 [](const float& val1, const float& val2) -> float { 133 return val1 >= 0.0f ? val1 : val1 * val2; 134 }, 135 context->getInputBuffer<float>(kInputTensor), 136 context->getInputShape(kInputTensor), 137 context->getInputBuffer<float>(kAlphaTensor), 138 context->getInputShape(kAlphaTensor), 139 context->getOutputBuffer<float>(kOutputTensor), 140 context->getOutputShape(kOutputTensor)); 141 case OperandType::TENSOR_QUANT8_ASYMM: { 142 return evalQuant8(context->getInputBuffer<uint8_t>(kInputTensor), 143 context->getInputShape(kInputTensor), 144 context->getInputBuffer<uint8_t>(kAlphaTensor), 145 context->getInputShape(kAlphaTensor), 146 context->getOutputBuffer<uint8_t>(kOutputTensor), 147 context->getOutputShape(kOutputTensor)); 148 } 149 default: 150 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; 151 } 152 } 153 154 } // namespace prelu 155 156 NN_REGISTER_OPERATION(PRELU, prelu::kOperationName, prelu::validate, prelu::prepare, 157 prelu::execute); 158 159 } // namespace nn 160 } // namespace android 161