Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "CpuOperationUtils.h"
     18 #include "OperationResolver.h"
     19 
     20 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
     21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
     22 
     23 #include "Tracing.h"
     24 
     25 namespace android {
     26 namespace nn {
     27 
     28 namespace softmax {
     29 
     30 constexpr char kOperationName[] = "SOFTMAX";
     31 
     32 constexpr uint32_t kNumInputs = 3;
     33 constexpr uint32_t kInputTensor = 0;
     34 constexpr uint32_t kBetaScalar = 1;
     35 constexpr uint32_t kAxisScalar = 2;
     36 
     37 constexpr uint32_t kNumOutputs = 1;
     38 constexpr uint32_t kOutputTensor = 0;
     39 
     40 namespace {
     41 
     42 inline bool softmaxSlowFloat32(const float* inputData, const Shape& inputShape, const float beta,
     43                                int32_t axis, float* outputData, const Shape& outputShape) {
     44     NNTRACE_TRANS("softmaxFloatSlow32");
     45     const uint32_t outerSize = getNumberOfElements(inputShape, 0, axis);
     46     const uint32_t axisSize = getSizeOfDimension(inputShape, axis);
     47     const uint32_t innerSize =
     48             getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape));
     49     for (uint32_t outer = 0; outer < outerSize; ++outer) {
     50         const float* inputBeg = inputData + outer * axisSize * innerSize;
     51         const float* inputEnd = inputBeg + axisSize * innerSize;
     52         float* outputBeg = outputData + outer * axisSize * innerSize;
     53         for (uint32_t inner = 0; inner < innerSize; ++inner, ++inputBeg, ++inputEnd, ++outputBeg) {
     54             // Find max
     55             float maxValue = -FLT_MAX;
     56             for (const float* p = inputBeg; p < inputEnd; p += innerSize) {
     57                 maxValue = std::max(maxValue, *p);
     58             }
     59             // Compute sum
     60             float sum = 0.0f;
     61             for (const float* p = inputBeg; p < inputEnd; p += innerSize) {
     62                 sum += std::exp((*p - maxValue) * beta);
     63             }
     64             // Compute result
     65             float* pOut = outputBeg;
     66             for (const float* p = inputBeg; p < inputEnd; p += innerSize, pOut += innerSize) {
     67                 *pOut = std::exp((*p - maxValue) * beta) / sum;
     68             }
     69         }
     70     }
     71     return true;
     72 }
     73 
     74 bool softmaxFloat32(const float* inputData, const Shape& inputShape, const float beta, int32_t axis,
     75                     float* outputData, const Shape& outputShape) {
     76     int32_t ndim = getNumberOfDimensions(inputShape);
     77     NN_CHECK(handleNegativeAxis(inputShape, &axis));
     78     // TFLite optimized implementation only supports computation along the last axis
     79     if (axis == ndim - 1) {
     80         NNTRACE_COMP("optimized_ops::Softmax::float");
     81         tflite::SoftmaxParams param = {.beta = beta};
     82         tflite::optimized_ops::Softmax(param, convertShapeToTflshape(inputShape), inputData,
     83                                        convertShapeToTflshape(outputShape), outputData);
     84         return true;
     85     } else {
     86         return softmaxSlowFloat32(inputData, inputShape, beta, axis, outputData, outputShape);
     87     }
     88 }
     89 
     90 bool softmaxFloat16(const _Float16* inputData, const Shape& inputShape, const float beta,
     91                     int32_t axis, _Float16* outputData, const Shape& outputShape) {
     92     NNTRACE_TRANS("softmaxFloat16");
     93     std::vector<float> inputData_float32(getNumberOfElements(inputShape));
     94     convertFloat16ToFloat32(inputData, &inputData_float32);
     95     std::vector<float> outputData_float32(getNumberOfElements(outputShape));
     96 
     97     softmaxFloat32(inputData_float32.data(), inputShape, beta, axis, outputData_float32.data(),
     98                    outputShape);
     99     convertFloat32ToFloat16(outputData_float32, outputData);
    100 
    101     return true;
    102 }
    103 
    104 bool softmaxQuant8Impl(const uint8_t* inputData, const Shape& inputShape, const float beta,
    105                        int32_t axis, int32_t inputMultiplier, int32_t inputLeftShift, float diffMin,
    106                        uint8_t* outputData, const Shape& outputShape) {
    107     NNTRACE_TRANS("softmaxQuant8");
    108     // The representation chosen for the input to the exp() function is Q5.26.
    109     // We need to leave extra space since values that we skip might be as large as
    110     // -32 before multiplying by input_beta_multiplier, and therefore as large as
    111     // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
    112     // accumulation, but exp(-16) definitely is.
    113     static const int32_t kScaledDiffIntegerBits = 5;
    114     static const int kAccumulationIntegerBits = 12;
    115     using FixedPointScaledDiff = gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
    116     using FixedPointAccum = gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
    117     using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
    118 
    119     const uint32_t outerSize = getNumberOfElements(inputShape, 0, axis);
    120     const uint32_t axisSize = getSizeOfDimension(inputShape, axis);
    121     const uint32_t innerSize =
    122             getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape));
    123     for (uint32_t outer = 0; outer < outerSize; ++outer) {
    124         const uint8_t* inputBeg = inputData + outer * axisSize * innerSize;
    125         const uint8_t* inputEnd = inputBeg + axisSize * innerSize;
    126         uint8_t* outputBeg = outputData + outer * axisSize * innerSize;
    127         for (uint32_t inner = 0; inner < innerSize; ++inner, ++inputBeg, ++inputEnd, ++outputBeg) {
    128             // Find max
    129             uint8_t maxValue = 0;
    130             for (const uint8_t* p = inputBeg; p < inputEnd; p += innerSize) {
    131                 maxValue = std::max(maxValue, *p);
    132             }
    133 
    134             // Compute sum
    135             FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
    136             for (const uint8_t* p = inputBeg; p < inputEnd; p += innerSize) {
    137                 int32_t input_diff = static_cast<int32_t>(*p) - maxValue;
    138                 if (input_diff >= diffMin) {
    139                     const int32_t input_diff_rescaled =
    140                             tflite::MultiplyByQuantizedMultiplierGreaterThanOne(
    141                                     input_diff, inputMultiplier, inputLeftShift);
    142                     const auto scaled_diff_f8 = FixedPointScaledDiff::FromRaw(input_diff_rescaled);
    143                     sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
    144                                                         exp_on_negative_values(scaled_diff_f8));
    145                 }
    146             }
    147 
    148             uint32_t fixed_sum_of_exps = static_cast<uint32_t>(sum_of_exps.raw());
    149             int32_t headroom_plus_one = tflite::CountLeadingZeros(fixed_sum_of_exps);
    150             // This is the number of bits to the left of the binary point above 1.0.
    151             // Consider fixed_sum_of_exps=1.25.  In that case shifted_scale=0.8 and
    152             // no later adjustment will be needed.
    153             int32_t num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
    154             int32_t shifted_sum_minus_one = static_cast<int32_t>(
    155                     (fixed_sum_of_exps << headroom_plus_one) - (static_cast<uint32_t>(1) << 31));
    156 
    157             FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
    158                     FixedPoint0::FromRaw(shifted_sum_minus_one));
    159 
    160             // Compute result
    161             uint8_t* pOut = outputBeg;
    162             for (const uint8_t* p = inputBeg; p < inputEnd; p += innerSize, pOut += innerSize) {
    163                 int32_t input_diff = static_cast<int32_t>(*p) - maxValue;
    164                 if (input_diff >= diffMin) {
    165                     const int32_t input_diff_rescaled =
    166                             tflite::MultiplyByQuantizedMultiplierGreaterThanOne(
    167                                     input_diff, inputMultiplier, inputLeftShift);
    168                     const auto scaled_diff_f8 = FixedPointScaledDiff::FromRaw(input_diff_rescaled);
    169 
    170                     FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
    171                     int32_t unsat_output = gemmlowp::RoundingDivideByPOT(
    172                             (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
    173 
    174                     *pOut = static_cast<uint8_t>(
    175                             std::max(std::min(unsat_output, static_cast<int32_t>(255)), 0));
    176 
    177                 } else {
    178                     *pOut = 0;
    179                 }
    180             }
    181         }
    182     }
    183     return true;
    184 }
    185 
    186 bool softmaxQuant8(const uint8_t* inputData, const Shape& inputShape, const float beta,
    187                    int32_t axis, uint8_t* outputData, const Shape& outputShape) {
    188     int32_t ndim = getNumberOfDimensions(inputShape);
    189     NN_CHECK(handleNegativeAxis(inputShape, &axis));
    190 
    191     if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
    192         LOG(ERROR) << "incorrect scale / offset for output";
    193         return false;
    194     }
    195 
    196     static const int32_t kScaledDiffIntegerBits = 5;
    197     const double input_beta_real_multiplier =
    198             std::min(1.0 * beta * inputShape.scale * (1 << (31 - kScaledDiffIntegerBits)),
    199                      (1LL << 31) - 1.0);
    200 
    201     int32_t inputMultiplier = 0, inputLeftShift = 0;
    202     if (!QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier, &inputMultiplier,
    203                                           &inputLeftShift)) {
    204         return false;
    205     }
    206     int32_t diffMin = -CalculateInputRadius(kScaledDiffIntegerBits, inputLeftShift);
    207 
    208     // TFLite optimized implementation only supports computation along the last axis
    209     if (axis == ndim - 1) {
    210         NNTRACE_COMP("optimized_ops::Softmax::uint8");
    211         tflite::SoftmaxParams param = {.beta = beta,
    212                                        .input_multiplier = inputMultiplier,
    213                                        .input_left_shift = inputLeftShift,
    214                                        .diff_min = diffMin};
    215         tflite::optimized_ops::Softmax(param, convertShapeToTflshape(inputShape), inputData,
    216                                        convertShapeToTflshape(outputShape), outputData);
    217         return true;
    218     } else {
    219         return softmaxQuant8Impl(inputData, inputShape, beta, axis, inputMultiplier, inputLeftShift,
    220                                  diffMin, outputData, outputShape);
    221     }
    222 }
    223 
    224 }  // namespace
    225 
    226 bool validate(const IOperationValidationContext* context) {
    227     NN_RET_CHECK(context->getNumInputs() == kNumInputs ||
    228                  context->getNumInputs() == kNumInputs - 1);
    229     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    230     auto inputType = context->getInputType(kInputTensor);
    231     std::vector<OperandType> inExpectedTypes;
    232     if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    233         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
    234         inExpectedTypes = {inputType, OperandType::FLOAT32};
    235     } else if (inputType == OperandType::TENSOR_FLOAT16) {
    236         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    237         inExpectedTypes = {inputType, OperandType::FLOAT16};
    238     } else {
    239         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
    240     }
    241     if (context->getNumInputs() == kNumInputs) {
    242         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    243         inExpectedTypes.push_back(OperandType::INT32);
    244     } else {
    245         const size_t ndim = context->getInputShape(kInputTensor).dimensions.size();
    246         if (ndim != 2 && ndim != 4 && ndim != 0) {
    247             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    248         }
    249     }
    250     return validateInputTypes(context, inExpectedTypes) &&
    251            validateOutputTypes(context, {inputType});
    252 }
    253 
    254 bool prepare(IOperationExecutionContext* context) {
    255     Shape input = context->getInputShape(kInputTensor);
    256     float beta = (input.type == OperandType::TENSOR_FLOAT16)
    257                          ? context->getInputValue<_Float16>(kBetaScalar)
    258                          : context->getInputValue<float>(kBetaScalar);
    259     NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
    260     NN_RET_CHECK_GT(beta, 0.0f);
    261     Shape output = context->getOutputShape(kOutputTensor);
    262     output.dimensions = input.dimensions;
    263     return context->setOutputShape(kOutputTensor, output);
    264 }
    265 
    266 bool execute(IOperationExecutionContext* context) {
    267     // Bypass execution in the case of zero-sized input.
    268     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    269     int32_t axis = (context->getNumInputs() == kNumInputs)
    270                            ? context->getInputValue<int32_t>(kAxisScalar)
    271                            : -1;
    272     switch (context->getInputType(kInputTensor)) {
    273         case OperandType::TENSOR_FLOAT16:
    274             return softmaxFloat16(context->getInputBuffer<_Float16>(kInputTensor),
    275                                   context->getInputShape(kInputTensor),
    276                                   context->getInputValue<_Float16>(kBetaScalar), axis,
    277                                   context->getOutputBuffer<_Float16>(kOutputTensor),
    278                                   context->getOutputShape(kOutputTensor));
    279         case OperandType::TENSOR_FLOAT32:
    280             return softmaxFloat32(context->getInputBuffer<float>(kInputTensor),
    281                                   context->getInputShape(kInputTensor),
    282                                   context->getInputValue<float>(kBetaScalar), axis,
    283                                   context->getOutputBuffer<float>(kOutputTensor),
    284                                   context->getOutputShape(kOutputTensor));
    285         case OperandType::TENSOR_QUANT8_ASYMM:
    286             return softmaxQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
    287                                  context->getInputShape(kInputTensor),
    288                                  context->getInputValue<float>(kBetaScalar), axis,
    289                                  context->getOutputBuffer<uint8_t>(kOutputTensor),
    290                                  context->getOutputShape(kOutputTensor));
    291         default:
    292             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
    293     }
    294 }
    295 
    296 }  // namespace softmax
    297 
    298 NN_REGISTER_OPERATION(SOFTMAX, "SOFTMAX", softmax::validate, softmax::prepare, softmax::execute,
    299                       .allowZeroSizedInput = true);
    300 
    301 }  // namespace nn
    302 }  // namespace android
    303