1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" 20 21 #include "HalInterfaces.h" 22 #include "OperationResolver.h" 23 #include "OperationsUtils.h" 24 #include "Tracing.h" 25 26 namespace android { 27 namespace nn { 28 namespace reduce { 29 30 constexpr uint32_t kNumInputs = 3; 31 constexpr uint32_t kInputTensor = 0; 32 constexpr uint32_t kInputAxes = 1; 33 constexpr uint32_t kInputKeepDims = 2; 34 35 constexpr uint32_t kNumOutputs = 1; 36 constexpr uint32_t kOutputTensor = 0; 37 38 // Values from 39 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format#IEEE_754_half-precision_binary_floating-point_format:_binary16 40 constexpr _Float16 kFloat16Max = 65504; 41 constexpr _Float16 kFloat16Lowest = -kFloat16Max; 42 43 namespace { 44 45 template <typename T> 46 inline bool compute(IOperationExecutionContext* context, T init, T func(T, T)) { 47 const Shape inputShape = context->getInputShape(kInputTensor); 48 const Shape axesShape = context->getInputShape(kInputAxes); 49 const Shape outputShape = context->getOutputShape(kOutputTensor); 50 const uint32_t inputRank = getNumberOfDimensions(inputShape); 51 const uint32_t numAxes = getNumberOfElements(axesShape); 52 std::vector<int> tempIndex(inputShape.dimensions.size()); 53 std::vector<int> tempAxes(numAxes); 54 return tflite::reference_ops::ReduceGeneric<T>( 55 context->getInputBuffer<T>(kInputTensor), 56 reinterpret_cast<const int32_t*>(inputShape.dimensions.data()), inputRank, 57 context->getOutputBuffer<T>(kOutputTensor), 58 reinterpret_cast<const int32_t*>(outputShape.dimensions.data()), 59 outputShape.dimensions.size(), context->getInputBuffer<int32_t>(kInputAxes), numAxes, 60 context->getInputValue<bool8>(kInputKeepDims), tempIndex.data(), tempAxes.data(), init, 61 func); 62 } 63 64 } // namespace 65 66 bool validateProdSum(const IOperationValidationContext* context) { 67 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 68 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 69 OperandType inputType = context->getInputType(kInputTensor); 70 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 || 71 inputType == OperandType::TENSOR_FLOAT32) 72 << "Unsupported tensor type for REDUCE_PROD or REDUCE_SUM"; 73 NN_RET_CHECK( 74 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL})); 75 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 76 return validateHalVersion(context, HalVersion::V1_2); 77 } 78 79 bool validateMaxMin(const IOperationValidationContext* context) { 80 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 81 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 82 OperandType inputType = context->getInputType(kInputTensor); 83 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 || 84 inputType == OperandType::TENSOR_FLOAT32 || 85 inputType == OperandType::TENSOR_QUANT8_ASYMM) 86 << "Unsupported tensor type for REDUCE_MAX or REDUCE_MIN"; 87 NN_RET_CHECK( 88 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL})); 89 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 90 return validateHalVersion(context, HalVersion::V1_2); 91 } 92 93 bool validateLogical(const IOperationValidationContext* context) { 94 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 95 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 96 OperandType inputType = context->getInputType(kInputTensor); 97 NN_RET_CHECK(inputType == OperandType::TENSOR_BOOL8) 98 << "Unsupported tensor type for REDUCE_ANY or REDUCE_ALL"; 99 NN_RET_CHECK( 100 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL})); 101 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 102 return validateHalVersion(context, HalVersion::V1_2); 103 } 104 105 bool prepare(IOperationExecutionContext* context) { 106 Shape inputShape = context->getInputShape(kInputTensor); 107 const uint32_t inputRank = getNumberOfDimensions(inputShape); 108 109 std::vector<bool> shouldReduce(inputRank); 110 const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes); 111 Shape axesShape = context->getInputShape(kInputAxes); 112 NN_RET_CHECK_EQ(getNumberOfDimensions(axesShape), 1u); 113 const uint32_t numAxes = getNumberOfElements(axesShape); 114 for (uint32_t i = 0; i < numAxes; ++i) { 115 int32_t axis = axes[i]; 116 NN_RET_CHECK(handleNegativeAxis(inputRank, &axis)); 117 shouldReduce[axis] = true; 118 } 119 120 // Input and output must have the same quantization parameters, etc. 121 Shape outputShape = inputShape; 122 outputShape.dimensions.clear(); 123 bool keepDims = context->getInputValue<bool8>(kInputKeepDims); 124 for (uint32_t axis = 0; axis < inputRank; ++axis) { 125 if (shouldReduce[axis]) { 126 if (keepDims) { 127 outputShape.dimensions.push_back(1); 128 } 129 } else { 130 outputShape.dimensions.push_back(getSizeOfDimension(inputShape, axis)); 131 } 132 } 133 134 return context->setOutputShape(kOutputTensor, outputShape); 135 } 136 137 bool executeProd(IOperationExecutionContext* context) { 138 switch (context->getInputType(kInputTensor)) { 139 case OperandType::TENSOR_FLOAT16: 140 return compute<_Float16>(context, 1, [](_Float16 a, _Float16 b) { return a * b; }); 141 case OperandType::TENSOR_FLOAT32: 142 return compute<float>(context, 1, [](float a, float b) { return a * b; }); 143 default: 144 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_PROD"; 145 } 146 } 147 148 bool executeSum(IOperationExecutionContext* context) { 149 switch (context->getInputType(kInputTensor)) { 150 case OperandType::TENSOR_FLOAT16: 151 return compute<_Float16>(context, 0, [](_Float16 a, _Float16 b) { return a + b; }); 152 case OperandType::TENSOR_FLOAT32: 153 return compute<float>(context, 0, [](float a, float b) { return a + b; }); 154 default: 155 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_SUM"; 156 } 157 } 158 159 bool executeMax(IOperationExecutionContext* context) { 160 switch (context->getInputType(kInputTensor)) { 161 case OperandType::TENSOR_FLOAT16: 162 return compute<_Float16>(context, kFloat16Lowest, 163 [](_Float16 a, _Float16 b) { return std::max(a, b); }); 164 case OperandType::TENSOR_FLOAT32: 165 return compute<float>(context, std::numeric_limits<float>::lowest(), 166 [](float a, float b) { return std::max(a, b); }); 167 case OperandType::TENSOR_QUANT8_ASYMM: 168 return compute<uint8_t>(context, std::numeric_limits<uint8_t>::lowest(), 169 [](uint8_t a, uint8_t b) { return std::max(a, b); }); 170 default: 171 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MAX"; 172 } 173 } 174 175 bool executeMin(IOperationExecutionContext* context) { 176 switch (context->getInputType(kInputTensor)) { 177 case OperandType::TENSOR_FLOAT16: 178 return compute<_Float16>(context, kFloat16Max, 179 [](_Float16 a, _Float16 b) { return std::min(a, b); }); 180 case OperandType::TENSOR_FLOAT32: 181 return compute<float>(context, std::numeric_limits<float>::max(), 182 [](float a, float b) { return std::min(a, b); }); 183 case OperandType::TENSOR_QUANT8_ASYMM: 184 return compute<uint8_t>(context, std::numeric_limits<uint8_t>::max(), 185 [](uint8_t a, uint8_t b) { return std::min(a, b); }); 186 default: 187 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MIN"; 188 } 189 } 190 191 bool executeAny(IOperationExecutionContext* context) { 192 switch (context->getInputType(kInputTensor)) { 193 case OperandType::TENSOR_BOOL8: 194 return compute<bool8>(context, false, 195 [](bool8 a, bool8 b) { return static_cast<bool8>(a || b); }); 196 default: 197 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ANY"; 198 } 199 } 200 201 bool executeAll(IOperationExecutionContext* context) { 202 switch (context->getInputType(kInputTensor)) { 203 case OperandType::TENSOR_BOOL8: 204 return compute<bool8>(context, true, 205 [](bool8 a, bool8 b) { return static_cast<bool8>(a && b); }); 206 default: 207 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ALL"; 208 } 209 } 210 211 } // namespace reduce 212 213 NN_REGISTER_OPERATION(REDUCE_PROD, "REDUCE_PROD", reduce::validateProdSum, reduce::prepare, 214 reduce::executeProd); 215 NN_REGISTER_OPERATION(REDUCE_SUM, "REDUCE_SUM", reduce::validateProdSum, reduce::prepare, 216 reduce::executeSum); 217 NN_REGISTER_OPERATION(REDUCE_MAX, "REDUCE_MAX", reduce::validateMaxMin, reduce::prepare, 218 reduce::executeMax); 219 NN_REGISTER_OPERATION(REDUCE_MIN, "REDUCE_MIN", reduce::validateMaxMin, reduce::prepare, 220 reduce::executeMin); 221 NN_REGISTER_OPERATION(REDUCE_ANY, "REDUCE_ANY", reduce::validateLogical, reduce::prepare, 222 reduce::executeAny); 223 NN_REGISTER_OPERATION(REDUCE_ALL, "REDUCE_ALL", reduce::validateLogical, reduce::prepare, 224 reduce::executeAll); 225 226 } // namespace nn 227 } // namespace android 228