Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     20 
     21 #include "HalInterfaces.h"
     22 #include "OperationResolver.h"
     23 #include "OperationsUtils.h"
     24 #include "Tracing.h"
     25 
     26 namespace android {
     27 namespace nn {
     28 namespace reduce {
     29 
     30 constexpr uint32_t kNumInputs = 3;
     31 constexpr uint32_t kInputTensor = 0;
     32 constexpr uint32_t kInputAxes = 1;
     33 constexpr uint32_t kInputKeepDims = 2;
     34 
     35 constexpr uint32_t kNumOutputs = 1;
     36 constexpr uint32_t kOutputTensor = 0;
     37 
     38 // Values from
     39 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format#IEEE_754_half-precision_binary_floating-point_format:_binary16
     40 constexpr _Float16 kFloat16Max = 65504;
     41 constexpr _Float16 kFloat16Lowest = -kFloat16Max;
     42 
     43 namespace {
     44 
     45 template <typename T>
     46 inline bool compute(IOperationExecutionContext* context, T init, T func(T, T)) {
     47     const Shape inputShape = context->getInputShape(kInputTensor);
     48     const Shape axesShape = context->getInputShape(kInputAxes);
     49     const Shape outputShape = context->getOutputShape(kOutputTensor);
     50     const uint32_t inputRank = getNumberOfDimensions(inputShape);
     51     const uint32_t numAxes = getNumberOfElements(axesShape);
     52     std::vector<int> tempIndex(inputShape.dimensions.size());
     53     std::vector<int> tempAxes(numAxes);
     54     return tflite::reference_ops::ReduceGeneric<T>(
     55             context->getInputBuffer<T>(kInputTensor),
     56             reinterpret_cast<const int32_t*>(inputShape.dimensions.data()), inputRank,
     57             context->getOutputBuffer<T>(kOutputTensor),
     58             reinterpret_cast<const int32_t*>(outputShape.dimensions.data()),
     59             outputShape.dimensions.size(), context->getInputBuffer<int32_t>(kInputAxes), numAxes,
     60             context->getInputValue<bool8>(kInputKeepDims), tempIndex.data(), tempAxes.data(), init,
     61             func);
     62 }
     63 
     64 }  // namespace
     65 
     66 bool validateProdSum(const IOperationValidationContext* context) {
     67     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
     68     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
     69     OperandType inputType = context->getInputType(kInputTensor);
     70     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
     71                  inputType == OperandType::TENSOR_FLOAT32)
     72             << "Unsupported tensor type for REDUCE_PROD or REDUCE_SUM";
     73     NN_RET_CHECK(
     74             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
     75     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
     76     return validateHalVersion(context, HalVersion::V1_2);
     77 }
     78 
     79 bool validateMaxMin(const IOperationValidationContext* context) {
     80     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
     81     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
     82     OperandType inputType = context->getInputType(kInputTensor);
     83     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
     84                  inputType == OperandType::TENSOR_FLOAT32 ||
     85                  inputType == OperandType::TENSOR_QUANT8_ASYMM)
     86             << "Unsupported tensor type for REDUCE_MAX or REDUCE_MIN";
     87     NN_RET_CHECK(
     88             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
     89     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
     90     return validateHalVersion(context, HalVersion::V1_2);
     91 }
     92 
     93 bool validateLogical(const IOperationValidationContext* context) {
     94     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
     95     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
     96     OperandType inputType = context->getInputType(kInputTensor);
     97     NN_RET_CHECK(inputType == OperandType::TENSOR_BOOL8)
     98             << "Unsupported tensor type for REDUCE_ANY or REDUCE_ALL";
     99     NN_RET_CHECK(
    100             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
    101     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
    102     return validateHalVersion(context, HalVersion::V1_2);
    103 }
    104 
    105 bool prepare(IOperationExecutionContext* context) {
    106     Shape inputShape = context->getInputShape(kInputTensor);
    107     const uint32_t inputRank = getNumberOfDimensions(inputShape);
    108 
    109     std::vector<bool> shouldReduce(inputRank);
    110     const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes);
    111     Shape axesShape = context->getInputShape(kInputAxes);
    112     NN_RET_CHECK_EQ(getNumberOfDimensions(axesShape), 1u);
    113     const uint32_t numAxes = getNumberOfElements(axesShape);
    114     for (uint32_t i = 0; i < numAxes; ++i) {
    115         int32_t axis = axes[i];
    116         NN_RET_CHECK(handleNegativeAxis(inputRank, &axis));
    117         shouldReduce[axis] = true;
    118     }
    119 
    120     // Input and output must have the same quantization parameters, etc.
    121     Shape outputShape = inputShape;
    122     outputShape.dimensions.clear();
    123     bool keepDims = context->getInputValue<bool8>(kInputKeepDims);
    124     for (uint32_t axis = 0; axis < inputRank; ++axis) {
    125         if (shouldReduce[axis]) {
    126             if (keepDims) {
    127                 outputShape.dimensions.push_back(1);
    128             }
    129         } else {
    130             outputShape.dimensions.push_back(getSizeOfDimension(inputShape, axis));
    131         }
    132     }
    133 
    134     return context->setOutputShape(kOutputTensor, outputShape);
    135 }
    136 
    137 bool executeProd(IOperationExecutionContext* context) {
    138     switch (context->getInputType(kInputTensor)) {
    139         case OperandType::TENSOR_FLOAT16:
    140             return compute<_Float16>(context, 1, [](_Float16 a, _Float16 b) { return a * b; });
    141         case OperandType::TENSOR_FLOAT32:
    142             return compute<float>(context, 1, [](float a, float b) { return a * b; });
    143         default:
    144             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_PROD";
    145     }
    146 }
    147 
    148 bool executeSum(IOperationExecutionContext* context) {
    149     switch (context->getInputType(kInputTensor)) {
    150         case OperandType::TENSOR_FLOAT16:
    151             return compute<_Float16>(context, 0, [](_Float16 a, _Float16 b) { return a + b; });
    152         case OperandType::TENSOR_FLOAT32:
    153             return compute<float>(context, 0, [](float a, float b) { return a + b; });
    154         default:
    155             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_SUM";
    156     }
    157 }
    158 
    159 bool executeMax(IOperationExecutionContext* context) {
    160     switch (context->getInputType(kInputTensor)) {
    161         case OperandType::TENSOR_FLOAT16:
    162             return compute<_Float16>(context, kFloat16Lowest,
    163                                      [](_Float16 a, _Float16 b) { return std::max(a, b); });
    164         case OperandType::TENSOR_FLOAT32:
    165             return compute<float>(context, std::numeric_limits<float>::lowest(),
    166                                   [](float a, float b) { return std::max(a, b); });
    167         case OperandType::TENSOR_QUANT8_ASYMM:
    168             return compute<uint8_t>(context, std::numeric_limits<uint8_t>::lowest(),
    169                                     [](uint8_t a, uint8_t b) { return std::max(a, b); });
    170         default:
    171             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MAX";
    172     }
    173 }
    174 
    175 bool executeMin(IOperationExecutionContext* context) {
    176     switch (context->getInputType(kInputTensor)) {
    177         case OperandType::TENSOR_FLOAT16:
    178             return compute<_Float16>(context, kFloat16Max,
    179                                      [](_Float16 a, _Float16 b) { return std::min(a, b); });
    180         case OperandType::TENSOR_FLOAT32:
    181             return compute<float>(context, std::numeric_limits<float>::max(),
    182                                   [](float a, float b) { return std::min(a, b); });
    183         case OperandType::TENSOR_QUANT8_ASYMM:
    184             return compute<uint8_t>(context, std::numeric_limits<uint8_t>::max(),
    185                                     [](uint8_t a, uint8_t b) { return std::min(a, b); });
    186         default:
    187             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MIN";
    188     }
    189 }
    190 
    191 bool executeAny(IOperationExecutionContext* context) {
    192     switch (context->getInputType(kInputTensor)) {
    193         case OperandType::TENSOR_BOOL8:
    194             return compute<bool8>(context, false,
    195                                   [](bool8 a, bool8 b) { return static_cast<bool8>(a || b); });
    196         default:
    197             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ANY";
    198     }
    199 }
    200 
    201 bool executeAll(IOperationExecutionContext* context) {
    202     switch (context->getInputType(kInputTensor)) {
    203         case OperandType::TENSOR_BOOL8:
    204             return compute<bool8>(context, true,
    205                                   [](bool8 a, bool8 b) { return static_cast<bool8>(a && b); });
    206         default:
    207             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ALL";
    208     }
    209 }
    210 
    211 }  // namespace reduce
    212 
    213 NN_REGISTER_OPERATION(REDUCE_PROD, "REDUCE_PROD", reduce::validateProdSum, reduce::prepare,
    214                       reduce::executeProd);
    215 NN_REGISTER_OPERATION(REDUCE_SUM, "REDUCE_SUM", reduce::validateProdSum, reduce::prepare,
    216                       reduce::executeSum);
    217 NN_REGISTER_OPERATION(REDUCE_MAX, "REDUCE_MAX", reduce::validateMaxMin, reduce::prepare,
    218                       reduce::executeMax);
    219 NN_REGISTER_OPERATION(REDUCE_MIN, "REDUCE_MIN", reduce::validateMaxMin, reduce::prepare,
    220                       reduce::executeMin);
    221 NN_REGISTER_OPERATION(REDUCE_ANY, "REDUCE_ANY", reduce::validateLogical, reduce::prepare,
    222                       reduce::executeAny);
    223 NN_REGISTER_OPERATION(REDUCE_ALL, "REDUCE_ALL", reduce::validateLogical, reduce::prepare,
    224                       reduce::executeAll);
    225 
    226 }  // namespace nn
    227 }  // namespace android
    228