Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "HalInterfaces.h"
     20 #include "IndexedShapeWrapper.h"
     21 #include "OperationResolver.h"
     22 #include "OperationsUtils.h"
     23 #include "Tracing.h"
     24 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
     25 
     26 namespace android {
     27 namespace nn {
     28 namespace prelu {
     29 
     30 constexpr char kOperationName[] = "PRELU";
     31 
     32 constexpr uint32_t kNumInputs = 2;
     33 constexpr uint32_t kInputTensor = 0;
     34 constexpr uint32_t kAlphaTensor = 1;
     35 
     36 constexpr uint32_t kNumOutputs = 1;
     37 constexpr uint32_t kOutputTensor = 0;
     38 
     39 template <typename T>
     40 inline bool eval(const std::function<T(const T&, const T&)>& func, const T* aData,
     41                  const Shape& aShape, const T* bData, const Shape& bShape, T* outputData,
     42                  const Shape& outputShape) {
     43     IndexedShapeWrapper aShapeIndexed(aShape);
     44     IndexedShapeWrapper bShapeIndexed(bShape);
     45     IndexedShapeWrapper outputShapeIndexed(outputShape);
     46     std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
     47     bool lastIndex = false;
     48     do {
     49         uint32_t outputFlatIndex;
     50         NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
     51         uint32_t aFlatIndex;
     52         NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
     53         uint32_t bFlatIndex;
     54         NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
     55 
     56         outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
     57 
     58         NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
     59     } while (!lastIndex);
     60     return true;
     61 }
     62 
     63 bool evalQuant8(const uint8_t* aData, const Shape& aShape, const uint8_t* bData,
     64                 const Shape& bShape, uint8_t* outputData, const Shape& outputShape) {
     65     const int32_t input_offset = -aShape.offset;
     66     const int32_t alpha_offset = -bShape.offset;
     67     const int32_t output_offset = outputShape.offset;
     68     const double input_product_scale = aShape.scale * bShape.scale;
     69     const double real_multiplier_pos = aShape.scale / outputShape.scale;
     70     const double real_multiplier_neg = input_product_scale / outputShape.scale;
     71     int32_t output_multiplier_pos, output_shift_pos;
     72     int32_t output_multiplier_neg, output_shift_neg;
     73     tflite::QuantizeMultiplier(real_multiplier_pos, &output_multiplier_pos, &output_shift_pos);
     74     tflite::QuantizeMultiplier(real_multiplier_neg, &output_multiplier_neg, &output_shift_neg);
     75     return eval<uint8_t>(
     76             [&](const uint8_t& val1, const uint8_t& val2) -> uint8_t {
     77                 const int32_t input = input_offset + static_cast<int32_t>(val1);
     78                 int32_t output_val;
     79                 if (input >= 0) {
     80                     output_val =
     81                             output_offset + tflite::MultiplyByQuantizedMultiplier(
     82                                                     input, output_multiplier_pos, output_shift_pos);
     83                 } else {
     84                     const int32_t alpha = alpha_offset + static_cast<int32_t>(val2);
     85                     output_val = output_offset +
     86                                  tflite::MultiplyByQuantizedMultiplier(
     87                                          input * alpha, output_multiplier_neg, output_shift_neg);
     88                 }
     89                 output_val = std::max(0, std::min(255, output_val));
     90                 return static_cast<uint8_t>(output_val);
     91             },
     92             aData, aShape, bData, bShape, outputData, outputShape);
     93 }
     94 
     95 bool validate(const IOperationValidationContext* context) {
     96     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
     97     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
     98     auto inputType = context->getInputType(kInputTensor);
     99     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
    100                  inputType == OperandType::TENSOR_FLOAT32 ||
    101                  inputType == OperandType::TENSOR_QUANT8_ASYMM)
    102             << "Unsupported tensor type for operation " << kOperationName;
    103     NN_RET_CHECK(validateInputTypes(context, {inputType, inputType}));
    104     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
    105     return validateHalVersion(context, HalVersion::V1_2);
    106 }
    107 
    108 bool prepare(IOperationExecutionContext* context) {
    109     Shape input = context->getInputShape(kInputTensor);
    110     Shape alpha = context->getInputShape(kAlphaTensor);
    111     NN_RET_CHECK(input.type == alpha.type);
    112     Shape output = context->getOutputShape(kOutputTensor);
    113     NN_RET_CHECK(calculateBroadcastedShape(input, alpha, &output));
    114     return context->setOutputShape(kOutputTensor, output);
    115 }
    116 
    117 bool execute(IOperationExecutionContext* context) {
    118     switch (context->getInputType(kInputTensor)) {
    119         case OperandType::TENSOR_FLOAT16:
    120             return eval<_Float16>(
    121                     [](const _Float16& val1, const _Float16& val2) -> _Float16 {
    122                         return val1 >= 0.0f ? val1 : val1 * val2;
    123                     },
    124                     context->getInputBuffer<_Float16>(kInputTensor),
    125                     context->getInputShape(kInputTensor),
    126                     context->getInputBuffer<_Float16>(kAlphaTensor),
    127                     context->getInputShape(kAlphaTensor),
    128                     context->getOutputBuffer<_Float16>(kOutputTensor),
    129                     context->getOutputShape(kOutputTensor));
    130         case OperandType::TENSOR_FLOAT32:
    131             return eval<float>(
    132                     [](const float& val1, const float& val2) -> float {
    133                         return val1 >= 0.0f ? val1 : val1 * val2;
    134                     },
    135                     context->getInputBuffer<float>(kInputTensor),
    136                     context->getInputShape(kInputTensor),
    137                     context->getInputBuffer<float>(kAlphaTensor),
    138                     context->getInputShape(kAlphaTensor),
    139                     context->getOutputBuffer<float>(kOutputTensor),
    140                     context->getOutputShape(kOutputTensor));
    141         case OperandType::TENSOR_QUANT8_ASYMM: {
    142             return evalQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
    143                               context->getInputShape(kInputTensor),
    144                               context->getInputBuffer<uint8_t>(kAlphaTensor),
    145                               context->getInputShape(kAlphaTensor),
    146                               context->getOutputBuffer<uint8_t>(kOutputTensor),
    147                               context->getOutputShape(kOutputTensor));
    148         }
    149         default:
    150             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
    151     }
    152 }
    153 
    154 }  // namespace prelu
    155 
    156 NN_REGISTER_OPERATION(PRELU, prelu::kOperationName, prelu::validate, prelu::prepare,
    157                       prelu::execute);
    158 
    159 }  // namespace nn
    160 }  // namespace android
    161