Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <cassert>
     16 #include <cmath>
     17 #include <cstdio>
     18 #include <cstdlib>
     19 #include <iostream>
     20 #include <limits>
     21 
     22 #include "tensorflow/lite/c/builtin_op_data.h"
     23 #include "tensorflow/lite/c/c_api_internal.h"
     24 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
     25 #include "tensorflow/lite/kernels/internal/quantization_util.h"
     26 #include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h"
     27 #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
     28 #include "tensorflow/lite/kernels/internal/reference/integer_ops/softmax.h"
     29 #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h"
     30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     31 #include "tensorflow/lite/kernels/internal/tensor.h"
     32 #include "tensorflow/lite/kernels/kernel_util.h"
     33 #include "tensorflow/lite/kernels/op_macros.h"
     34 
     35 namespace tflite {
     36 namespace ops {
     37 namespace builtin {
     38 namespace activations {
     39 
     40 enum KernelType {
     41   kReference,
     42   kGenericOptimized,
     43 };
     44 
     45 struct OpData {
     46   int32_t input_multiplier = 0;
     47   int input_left_shift = 0;
     48   int32_t input_range_radius = 0;
     49   int diff_min = 0;
     50 };
     51 
     52 struct LogSoftmaxOpData : public OpData {
     53   int32_t reverse_scaling_divisor = 0;
     54   int32_t reverse_scaling_right_shift = 0;
     55 };
     56 
     57 struct PreluOpData : public OpData {
     58   int32_t output_multiplier = 0;
     59   int output_shift = 0;
     60 };
     61 
     62 namespace {
     63 TfLiteStatus CheckOutputQuantParams(TfLiteContext* context,
     64                                     const TfLiteTensor* input,
     65                                     const TfLiteTensor* output) {
     66   TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
     67   if (input->type == kTfLiteUInt8) {
     68     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
     69   } else {
     70     TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
     71   }
     72   return kTfLiteOk;
     73 }
     74 }  // namespace
     75 
     76 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
     77   // This is a builtin op, so we don't use the contents in 'buffer', if any.
     78   // Instead, we allocate a new object to carry information from Prepare() to
     79   // Eval().
     80   return new OpData;
     81 }
     82 
     83 void* LogSoftmaxInit(TfLiteContext* context, const char* buffer,
     84                      size_t length) {
     85   return new LogSoftmaxOpData;
     86 }
     87 
     88 void* PreluInit(TfLiteContext* context, const char* buffer, size_t length) {
     89   return new PreluOpData;
     90 }
     91 
     92 void Free(TfLiteContext* context, void* buffer) {
     93   delete reinterpret_cast<OpData*>(buffer);
     94 }
     95 
     96 void LogSoftmaxFree(TfLiteContext* context, void* buffer) {
     97   delete reinterpret_cast<LogSoftmaxOpData*>(buffer);
     98 }
     99 
    100 void PreluFree(TfLiteContext* context, void* buffer) {
    101   delete reinterpret_cast<PreluOpData*>(buffer);
    102 }
    103 
    104 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
    105   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
    106   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
    107   const TfLiteTensor* input = GetInput(context, node, 0);
    108   TfLiteTensor* output = GetOutput(context, node, 0);
    109   TF_LITE_ENSURE_EQ(context, input->type, output->type);
    110 
    111   return context->ResizeTensor(context, output,
    112                                TfLiteIntArrayCopy(input->dims));
    113 }
    114 
    115 TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
    116   OpData* data = reinterpret_cast<OpData*>(node->user_data);
    117 
    118   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
    119   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
    120   const TfLiteTensor* input = GetInput(context, node, 0);
    121   TfLiteTensor* output = GetOutput(context, node, 0);
    122   TF_LITE_ENSURE_EQ(context, input->type, output->type);
    123 
    124   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
    125     static constexpr int kInputIntegerBits = 4;
    126 
    127     const double input_real_multiplier =
    128         input->params.scale *
    129         static_cast<double>(1 << (31 - kInputIntegerBits));
    130 
    131     QuantizeMultiplierGreaterThanOne(input_real_multiplier,
    132                                      &data->input_multiplier,
    133                                      &data->input_left_shift);
    134     data->input_range_radius =
    135         CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
    136   } else if (input->type == kTfLiteInt16) {
    137     static constexpr int kInputIntegerBits = 3;
    138     static constexpr int kOutputFractionalBits = 15;
    139 
    140     // These operators are implemented in fixed-point arithmetic,
    141     // which intrinsically wants symmetric ranges (zero_point==0)
    142     // and power-of-two scales (power-of-two is abbreviated below as POT).
    143     // While more general support would be possible by means of rescaling,
    144     // that would add some overhead and some loss of accuracy and wouldn't
    145     // be used at the moment as current quantized LSTM applications are
    146     // happy with symmetric, power-of-two-scales quantization. So we just
    147     // implement that narrow case only for now.
    148 
    149     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
    150     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
    151 
    152     int input_scale_log2_rounded;
    153     TF_LITE_ENSURE(context,
    154                    CheckedLog2(input->params.scale, &input_scale_log2_rounded));
    155 
    156     int output_scale_log2_rounded;
    157     TF_LITE_ENSURE(
    158         context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
    159     TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
    160                       -kOutputFractionalBits);
    161 
    162     data->input_left_shift =
    163         (15 - kInputIntegerBits) + input_scale_log2_rounded;
    164     // Support for shifts is limited until we have a parameterized version of
    165     // SaturatingRoundingMultiplyByPOT().
    166     TF_LITE_ENSURE(context, data->input_left_shift >= 0);
    167     TF_LITE_ENSURE(context, data->input_left_shift <= 1);
    168   }
    169 
    170   return context->ResizeTensor(context, output,
    171                                TfLiteIntArrayCopy(input->dims));
    172 }
    173 
    174 TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
    175   OpData* data = reinterpret_cast<OpData*>(node->user_data);
    176 
    177   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
    178   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
    179   const TfLiteTensor* input = GetInput(context, node, 0);
    180   TfLiteTensor* output = GetOutput(context, node, 0);
    181   TF_LITE_ENSURE_EQ(context, input->type, output->type);
    182 
    183   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
    184     if (input->type == kTfLiteUInt8) {
    185       TF_LITE_ENSURE_EQ(context, output->params.zero_point,
    186                         std::numeric_limits<uint8_t>::min());
    187     }
    188     if (input->type == kTfLiteInt8) {
    189       TF_LITE_ENSURE_EQ(context, output->params.zero_point,
    190                         std::numeric_limits<int8_t>::min());
    191     }
    192     TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
    193 
    194     static constexpr int kInputIntegerBits = 4;
    195 
    196     const double input_real_multiplier =
    197         input->params.scale *
    198         static_cast<double>(1 << (31 - kInputIntegerBits));
    199 
    200     QuantizeMultiplierGreaterThanOne(input_real_multiplier,
    201                                      &data->input_multiplier,
    202                                      &data->input_left_shift);
    203     data->input_range_radius =
    204         CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
    205   } else if (input->type == kTfLiteInt16) {
    206     static constexpr int kInputIntegerBits = 3;
    207     static constexpr int kOutputFractionalBits = 15;
    208 
    209     // See comments in TanhPrepare about requiring zero_point==0
    210     // and a power-of-two ("POT") scale.
    211 
    212     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
    213     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
    214 
    215     int input_scale_log2_rounded;
    216     TF_LITE_ENSURE(context,
    217                    CheckedLog2(input->params.scale, &input_scale_log2_rounded));
    218 
    219     int output_scale_log2_rounded;
    220     TF_LITE_ENSURE(
    221         context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
    222     TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
    223                       -kOutputFractionalBits);
    224 
    225     data->input_left_shift =
    226         (15 - kInputIntegerBits) + input_scale_log2_rounded;
    227     // The int16 logistic implementation does not support shifting of the input.
    228     TF_LITE_ENSURE_EQ(context, data->input_left_shift, 0);
    229   }
    230 
    231   return context->ResizeTensor(context, output,
    232                                TfLiteIntArrayCopy(input->dims));
    233 }
    234 
    235 TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
    236   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
    237   OpData* data = reinterpret_cast<OpData*>(node->user_data);
    238 
    239   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
    240   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
    241   const TfLiteTensor* input = GetInput(context, node, 0);
    242   TfLiteTensor* output = GetOutput(context, node, 0);
    243   TF_LITE_ENSURE_EQ(context, input->type, output->type);
    244 
    245   const int num_dims = NumDimensions(input);
    246   TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
    247 
    248   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
    249     if (CheckOutputQuantParams(context, input, output) == kTfLiteError) {
    250       return kTfLiteError;
    251     }
    252 
    253     static const int kScaledDiffIntegerBits = 5;
    254     tflite::PreprocessSoftmaxScaling(
    255         params->beta, input->params.scale, kScaledDiffIntegerBits,
    256         &data->input_multiplier, &data->input_left_shift);
    257     data->diff_min = -1.0 * tflite::CalculateInputRadius(
    258                                 kScaledDiffIntegerBits, data->input_left_shift);
    259   }
    260 
    261   return context->ResizeTensor(context, output,
    262                                TfLiteIntArrayCopy(input->dims));
    263 }
    264 
    265 TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
    266   LogSoftmaxOpData* data = reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
    267 
    268   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
    269   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
    270   const TfLiteTensor* input = GetInput(context, node, 0);
    271   TfLiteTensor* output = GetOutput(context, node, 0);
    272   TF_LITE_ENSURE_EQ(context, input->type, output->type);
    273 
    274   if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
    275     if (input->type == kTfLiteUInt8) {
    276       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255);
    277     }
    278     if (input->type == kTfLiteInt8) {
    279       TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127);
    280     }
    281     TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256);
    282 
    283     static const double kBeta = 1.0;
    284     static const int kScaledDiffIntegerBits = 5;
    285     tflite::PreprocessLogSoftmaxScalingExp(
    286         kBeta, input->params.scale, kScaledDiffIntegerBits,
    287         &data->input_multiplier, &data->input_left_shift,
    288         &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift);
    289     data->reverse_scaling_right_shift *= -1;
    290     data->diff_min = -1.0 * tflite::CalculateInputRadius(
    291                                 kScaledDiffIntegerBits, data->input_left_shift);
    292   }
    293 
    294   return context->ResizeTensor(context, output,
    295                                TfLiteIntArrayCopy(input->dims));
    296 }
    297 
    298 TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
    299   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
    300   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
    301   const TfLiteTensor* input = GetInput(context, node, 0);
    302   TfLiteTensor* output = GetOutput(context, node, 0);
    303   const TfLiteTensor* alpha = GetInput(context, node, 1);
    304   PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
    305 
    306   TF_LITE_ENSURE_EQ(context, input->type, alpha->type);
    307   output->type = input->type;
    308 
    309   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
    310     double real_multiplier =
    311         input->params.scale * alpha->params.scale / output->params.scale;
    312     QuantizeMultiplierSmallerThanOneExp(
    313         real_multiplier, &data->output_multiplier, &data->output_shift);
    314   }
    315 
    316   // PRelu (parameteric Relu) shares the same alpha value on "shared axis".
    317   // This means it's always required to "broadcast" alpha values in PRelu.
    318   TfLiteIntArray* output_size = nullptr;
    319   TF_LITE_ENSURE_OK(
    320       context, CalculateShapeForBroadcast(context, input, alpha, &output_size));
    321 
    322   TF_LITE_ENSURE_OK(context,
    323                     context->ResizeTensor(context, output, output_size));
    324   // After broadcasting, the output shape should always be the same as the
    325   // input shape.
    326   TF_LITE_ENSURE(context, HaveSameShapes(input, output));
    327 
    328   return kTfLiteOk;
    329 }
    330 
    331 TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
    332   const TfLiteTensor* input = GetInput(context, node, 0);
    333   TfLiteTensor* output = GetOutput(context, node, 0);
    334   switch (input->type) {
    335     case kTfLiteFloat32: {
    336       size_t elements = input->bytes / sizeof(float);
    337       float* in = input->data.f;
    338       float* in_end = in + elements;
    339       float* out = output->data.f;
    340       for (; in < in_end; in++, out++) *out = std::max(0.f, *in);
    341       return kTfLiteOk;
    342     } break;
    343     default:
    344       context->ReportError(context, "Only float32 supported currently, got %s.",
    345                            TfLiteTypeGetName(input->type));
    346       return kTfLiteError;
    347   }
    348 }
    349 
    350 TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
    351   const TfLiteTensor* input = GetInput(context, node, 0);
    352   TfLiteTensor* output = GetOutput(context, node, 0);
    353   switch (input->type) {
    354     case kTfLiteFloat32: {
    355       size_t elements = input->bytes / sizeof(float);
    356       float* in = input->data.f;
    357       float* in_end = in + elements;
    358       float* out = output->data.f;
    359       for (; in < in_end; in++, out++) {
    360         *out = std::min(std::max(-1.f, *in), 1.f);
    361       }
    362       return kTfLiteOk;
    363     } break;
    364     default:
    365       context->ReportError(context, "Only float32 supported currently, got %s.",
    366                            TfLiteTypeGetName(input->type));
    367       return kTfLiteError;
    368   }
    369 }
    370 
    371 namespace {
    372 template <typename T>
    373 void QuantizedRelu6(const TfLiteTensor* input, TfLiteTensor* output) {
    374   ActivationParams params;
    375   params.activation_type = FusedActivationFunctionType::kRelu6;
    376   params.quantized_activation_min =
    377       std::max(static_cast<int32_t>(std::numeric_limits<T>::min()),
    378                output->params.zero_point +
    379                    static_cast<int32>(roundf(0.f / output->params.scale)));
    380   params.quantized_activation_max =
    381       std::min(static_cast<int32_t>(std::numeric_limits<T>::max()),
    382                output->params.zero_point +
    383                    static_cast<int32>(roundf(6.f / output->params.scale)));
    384   optimized_ops::ReluX(params, GetTensorShape(input), GetTensorData<T>(input),
    385                        GetTensorShape(output), GetTensorData<T>(output));
    386 }
    387 }  // namespace
    388 
    389 TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
    390   const TfLiteTensor* input = GetInput(context, node, 0);
    391   TfLiteTensor* output = GetOutput(context, node, 0);
    392   switch (input->type) {
    393     case kTfLiteFloat32: {
    394       size_t elements = input->bytes / sizeof(float);
    395       float* in = input->data.f;
    396       float* in_end = in + elements;
    397       float* out = output->data.f;
    398       for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f);
    399       return kTfLiteOk;
    400     } break;
    401     case kTfLiteUInt8:
    402       QuantizedRelu6<uint8_t>(input, output);
    403       return kTfLiteOk;
    404     case kTfLiteInt8: {
    405       QuantizedRelu6<int8_t>(input, output);
    406       return kTfLiteOk;
    407     } break;
    408     default:
    409       context->ReportError(
    410           context, "Only float32, uint8 and int8 supported currently, got %s.",
    411           TfLiteTypeGetName(input->type));
    412       return kTfLiteError;
    413   }
    414 }
    415 
    416 template <KernelType kernel_type>
    417 TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
    418   OpData* data = reinterpret_cast<OpData*>(node->user_data);
    419   const TfLiteTensor* input = GetInput(context, node, 0);
    420   TfLiteTensor* output = GetOutput(context, node, 0);
    421   switch (input->type) {
    422     case kTfLiteFloat32: {
    423       if (kernel_type == kGenericOptimized) {
    424         optimized_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input),
    425                             GetTensorShape(output),
    426                             GetTensorData<float>(output));
    427       } else {
    428         reference_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input),
    429                             GetTensorShape(output),
    430                             GetTensorData<float>(output));
    431       }
    432       return kTfLiteOk;
    433     } break;
    434     case kTfLiteInt16: {
    435       TanhParams params;
    436       params.input_left_shift = data->input_left_shift;
    437       if (kernel_type == kGenericOptimized) {
    438         optimized_ops::Tanh(
    439             params, GetTensorShape(input), GetTensorData<int16_t>(input),
    440             GetTensorShape(output), GetTensorData<int16_t>(output));
    441       } else {
    442         reference_ops::Tanh(
    443             params, GetTensorShape(input), GetTensorData<int16_t>(input),
    444             GetTensorShape(output), GetTensorData<int16_t>(output));
    445       }
    446       return kTfLiteOk;
    447     } break;
    448     case kTfLiteUInt8: {
    449       TanhParams params;
    450       params.input_zero_point = input->params.zero_point;
    451       params.input_range_radius = data->input_range_radius;
    452       params.input_multiplier = data->input_multiplier;
    453       params.input_left_shift = data->input_left_shift;
    454       if (kernel_type == kGenericOptimized) {
    455         optimized_ops::Tanh(
    456             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
    457             GetTensorShape(output), GetTensorData<uint8_t>(output));
    458       } else {
    459         reference_ops::Tanh(
    460             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
    461             GetTensorShape(output), GetTensorData<uint8_t>(output));
    462       }
    463       return kTfLiteOk;
    464     } break;
    465     case kTfLiteInt8: {
    466       const auto input_shape = GetTensorShape(input);
    467       const auto output_shape = GetTensorShape(output);
    468       const int size = MatchingFlatSize(input_shape, output_shape);
    469       reference_integer_ops::Tanh(
    470           input->params.zero_point, data->input_range_radius,
    471           data->input_multiplier, data->input_left_shift, size,
    472           GetTensorData<int8_t>(input), GetTensorData<int8_t>(output));
    473       return kTfLiteOk;
    474     } break;
    475     default:
    476       context->ReportError(context, "Only float32 supported currently, got %s.",
    477                            TfLiteTypeGetName(input->type));
    478       return kTfLiteError;
    479   }
    480 }
    481 
    482 // Sigmoid is also know as "Logistic".
    483 template <KernelType kernel_type>
    484 TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
    485   OpData* data = reinterpret_cast<OpData*>(node->user_data);
    486 
    487   const TfLiteTensor* input = GetInput(context, node, 0);
    488   TfLiteTensor* output = GetOutput(context, node, 0);
    489   switch (input->type) {
    490     case kTfLiteFloat32: {
    491       if (kernel_type == kGenericOptimized) {
    492         optimized_ops::Logistic(
    493             GetTensorShape(input), GetTensorData<float>(input),
    494             GetTensorShape(output), GetTensorData<float>(output));
    495       } else {
    496         reference_ops::Logistic(
    497             GetTensorShape(input), GetTensorData<float>(input),
    498             GetTensorShape(output), GetTensorData<float>(output));
    499       }
    500       break;
    501     }
    502     case kTfLiteInt16: {
    503       LogisticParams params;
    504       if (kernel_type == kGenericOptimized) {
    505         optimized_ops::Logistic(
    506             params, GetTensorShape(input), GetTensorData<int16_t>(input),
    507             GetTensorShape(output), GetTensorData<int16_t>(output));
    508       } else {
    509         reference_ops::Logistic(
    510             params, GetTensorShape(input), GetTensorData<int16_t>(input),
    511             GetTensorShape(output), GetTensorData<int16_t>(output));
    512       }
    513       break;
    514     }
    515     case kTfLiteUInt8: {
    516       LogisticParams params;
    517       params.input_zero_point = input->params.zero_point;
    518       params.input_range_radius = data->input_range_radius;
    519       params.input_multiplier = data->input_multiplier;
    520       params.input_left_shift = data->input_left_shift;
    521       if (kernel_type == kGenericOptimized) {
    522         optimized_ops::Logistic(
    523             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
    524             GetTensorShape(output), GetTensorData<uint8_t>(output));
    525       } else {
    526         reference_ops::Logistic(
    527             params, GetTensorShape(input), GetTensorData<uint8_t>(input),
    528             GetTensorShape(output), GetTensorData<uint8_t>(output));
    529       }
    530       break;
    531     }
    532     case kTfLiteInt8: {
    533       const int input_size =
    534           MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
    535       reference_integer_ops::Logistic(
    536           input->params.zero_point, data->input_range_radius,
    537           data->input_multiplier, data->input_left_shift, input_size,
    538           GetTensorData<int8_t>(input), GetTensorData<int8_t>(output));
    539       break;
    540     }
    541     default:
    542       context->ReportError(context, "Only float32 supported currently, got %s.",
    543                            TfLiteTypeGetName(input->type));
    544   }
    545   return kTfLiteOk;
    546 }
    547 
    548 // Performs softmax along the input of size (input_size * batch_size).
    549 void Softmax(const float* in, const int input_size, const int batch_size,
    550              const float beta, float* out) {
    551   TF_LITE_ASSERT(input_size > 0);
    552 
    553   // For each batch
    554   for (int b = 0; b < batch_size; b++) {
    555     // Find the max coeff.
    556     float max_coeff = in[0];
    557     for (int i = 1; i < input_size; i++) {
    558       if (in[i] > max_coeff) max_coeff = in[i];
    559     }
    560 
    561     // Compute the normalized sum of exps.
    562     float exp_sum = 0.0;
    563     for (int i = 0; i < input_size; i++) {
    564       out[i] = std::exp((in[i] - max_coeff) * beta);
    565       exp_sum += out[i];
    566     }
    567 
    568     // Divide by the sum of exps.
    569     float reciprocal_sum_exp = 1.f / exp_sum;
    570     for (int i = 0; i < input_size; i++) {
    571       out[i] *= reciprocal_sum_exp;
    572     }
    573 
    574     // Advance in and out pointers for the next batch.
    575     in += input_size;
    576     out += input_size;
    577   }
    578 }
    579 
    580 // Takes a 1D tensor and performs softmax along it.
    581 void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output,
    582                     TfLiteSoftmaxParams* params) {
    583   const int input_size = input->dims->data[0];
    584   Softmax(input->data.f, input_size, 1, params->beta, output->data.f);
    585 }
    586 
    587 // Takes a 2D tensor and perform softmax along the last dimension.
    588 void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
    589                     TfLiteSoftmaxParams* params) {
    590   const int batch_size = input->dims->data[0];
    591   const int input_size = input->dims->data[1];
    592   Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f);
    593 }
    594 
    595 // Takes a 3D tensor and perform softmax along the last dimension.
    596 void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
    597                     TfLiteSoftmaxParams* params) {
    598   const int batch_size = input->dims->data[0];
    599   const int intermediate_size = input->dims->data[1];
    600   const int input_size = input->dims->data[2];
    601   SoftmaxParams op_params;
    602   op_params.beta = params->beta;
    603   optimized_ops::Softmax(
    604       op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
    605       GetTensorData<float>(input),
    606       GetTensorShape({batch_size, intermediate_size, 1, input_size}),
    607       GetTensorData<float>(output));
    608 }
    609 
    610 void Softmax1DQuantizedUint8(const TfLiteTensor* input, TfLiteTensor* output,
    611                              TfLiteSoftmaxParams* params, OpData* data) {
    612   // TODO(ahentz): this is arguably a dirty trick. Since the implementation
    613   // always traverses the last dimension of a 4D tensor, we will pretend our 1D
    614   // tensor is 4D in a special way. We will convert a (Y) shape into a (1,
    615   // 1, 1, Y) shape.
    616   const int input_size = input->dims->data[0];
    617   SoftmaxParams op_params;
    618   op_params.input_multiplier = data->input_multiplier;
    619   op_params.input_left_shift = data->input_left_shift;
    620   op_params.diff_min = data->diff_min;
    621   optimized_ops::Softmax(op_params, GetTensorShape({1, 1, 1, input_size}),
    622                          GetTensorData<uint8_t>(input),
    623                          GetTensorShape({1, 1, 1, input_size}),
    624                          GetTensorData<uint8_t>(output));
    625 }
    626 void Softmax2DQuantizedUint8(const TfLiteTensor* input, TfLiteTensor* output,
    627                              TfLiteSoftmaxParams* params, OpData* data) {
    628   // TODO(ahentz): this is arguably a dirty trick. Since the implementation
    629   // always traverses the last dimension of a 4D tensor, we will pretend our 2D
    630   // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X,
    631   // 1, 1, Y) shape.
    632   const int batch_size = input->dims->data[0];
    633   const int input_size = input->dims->data[1];
    634   SoftmaxParams op_params;
    635   op_params.input_multiplier = data->input_multiplier;
    636   op_params.input_left_shift = data->input_left_shift;
    637   op_params.diff_min = data->diff_min;
    638   optimized_ops::Softmax(op_params,
    639                          GetTensorShape({batch_size, 1, 1, input_size}),
    640                          GetTensorData<uint8_t>(input),
    641                          GetTensorShape({batch_size, 1, 1, input_size}),
    642                          GetTensorData<uint8_t>(output));
    643 }
    644 
    645 void Softmax3DQuantizedUint8(const TfLiteTensor* input, TfLiteTensor* output,
    646                              TfLiteSoftmaxParams* params, OpData* data) {
    647   const int batch_size = input->dims->data[0];
    648   const int intermediate_size = input->dims->data[1];
    649   const int input_size = input->dims->data[2];
    650   SoftmaxParams op_params;
    651   op_params.input_multiplier = data->input_multiplier;
    652   op_params.input_left_shift = data->input_left_shift;
    653   op_params.diff_min = data->diff_min;
    654   optimized_ops::Softmax(
    655       op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
    656       GetTensorData<uint8_t>(input),
    657       GetTensorShape({batch_size, intermediate_size, 1, input_size}),
    658       GetTensorData<uint8_t>(output));
    659 }
    660 
    661 // Takes a 4D tensor and perform softmax along the forth dimension.
    662 void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
    663                     TfLiteSoftmaxParams* params) {
    664   SoftmaxParams op_params;
    665   op_params.beta = params->beta;
    666   optimized_ops::Softmax(op_params, GetTensorShape(input),
    667                          GetTensorData<float>(input), GetTensorShape(output),
    668                          GetTensorData<float>(output));
    669 }
    670 
    671 void Softmax4DQuantizedUint8(const TfLiteTensor* input, TfLiteTensor* output,
    672                              TfLiteSoftmaxParams* params, OpData* data) {
    673   SoftmaxParams op_params;
    674   op_params.input_multiplier = data->input_multiplier;
    675   op_params.input_left_shift = data->input_left_shift;
    676   op_params.diff_min = data->diff_min;
    677   optimized_ops::Softmax(op_params, GetTensorShape(input),
    678                          GetTensorData<uint8_t>(input), GetTensorShape(output),
    679                          GetTensorData<uint8_t>(output));
    680 }
    681 
    682 // TODO(jianlijianli): Try merging Softmax<n>DQuantizedInt8 with
    683 // Softmax<n>DQuantized, which needs a larger refactor.
    684 void Softmax1DQuantizedInt8(const TfLiteTensor* input, TfLiteTensor* output,
    685                             TfLiteSoftmaxParams* params, OpData* data) {
    686   const int input_size = input->dims->data[0];
    687   SoftmaxParams op_params;
    688   op_params.input_multiplier = data->input_multiplier;
    689   op_params.input_left_shift = data->input_left_shift;
    690   op_params.diff_min = data->diff_min;
    691   reference_integer_ops::Softmax(
    692       op_params, GetTensorShape({1, 1, 1, input_size}),
    693       GetTensorData<int8_t>(input), GetTensorShape({1, 1, 1, input_size}),
    694       GetTensorData<int8_t>(output));
    695 }
    696 
    697 void Softmax2DQuantizedInt8(const TfLiteTensor* input, TfLiteTensor* output,
    698                             TfLiteSoftmaxParams* params, OpData* data) {
    699   const int batch_size = input->dims->data[0];
    700   const int input_size = input->dims->data[1];
    701   SoftmaxParams op_params;
    702   op_params.input_multiplier = data->input_multiplier;
    703   op_params.input_left_shift = data->input_left_shift;
    704   op_params.diff_min = data->diff_min;
    705   reference_integer_ops::Softmax(op_params,
    706                                  GetTensorShape({batch_size, 1, 1, input_size}),
    707                                  GetTensorData<int8_t>(input),
    708                                  GetTensorShape({batch_size, 1, 1, input_size}),
    709                                  GetTensorData<int8_t>(output));
    710 }
    711 
    712 void Softmax3DQuantizedInt8(const TfLiteTensor* input, TfLiteTensor* output,
    713                             TfLiteSoftmaxParams* params, OpData* data) {
    714   const int batch_size = input->dims->data[0];
    715   const int intermediate_size = input->dims->data[1];
    716   const int input_size = input->dims->data[2];
    717   SoftmaxParams op_params;
    718   op_params.input_multiplier = data->input_multiplier;
    719   op_params.input_left_shift = data->input_left_shift;
    720   op_params.diff_min = data->diff_min;
    721   reference_integer_ops::Softmax(
    722       op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
    723       GetTensorData<int8_t>(input),
    724       GetTensorShape({batch_size, intermediate_size, 1, input_size}),
    725       GetTensorData<int8_t>(output));
    726 }
    727 
    728 void Softmax4DQuantizedInt8(const TfLiteTensor* input, TfLiteTensor* output,
    729                             TfLiteSoftmaxParams* params, OpData* data) {
    730   SoftmaxParams op_params;
    731   op_params.input_multiplier = data->input_multiplier;
    732   op_params.input_left_shift = data->input_left_shift;
    733   op_params.diff_min = data->diff_min;
    734   reference_integer_ops::Softmax(
    735       op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
    736       GetTensorShape(output), GetTensorData<int8_t>(output));
    737 }
    738 
    739 TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
    740   auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
    741   OpData* data = reinterpret_cast<OpData*>(node->user_data);
    742 
    743   const TfLiteTensor* input = GetInput(context, node, 0);
    744   TfLiteTensor* output = GetOutput(context, node, 0);
    745 
    746   // TODO(ahentz): consider an implementation that works for many (all?)
    747   // dimensions.
    748   switch (input->type) {
    749     case kTfLiteFloat32: {
    750       if (NumDimensions(input) == 1) {
    751         Softmax1DFloat(input, output, params);
    752         return kTfLiteOk;
    753       }
    754       if (NumDimensions(input) == 2) {
    755         Softmax2DFloat(input, output, params);
    756         return kTfLiteOk;
    757       }
    758       if (NumDimensions(input) == 3) {
    759         Softmax3DFloat(input, output, params);
    760         return kTfLiteOk;
    761       }
    762       if (NumDimensions(input) == 4) {
    763         Softmax4DFloat(input, output, params);
    764         return kTfLiteOk;
    765       }
    766       context->ReportError(
    767           context, "Only 1D, 2D and 4D tensors supported currently, got %dD.",
    768           NumDimensions(input));
    769       return kTfLiteError;
    770     }
    771     case kTfLiteUInt8: {
    772       if (NumDimensions(input) == 1) {
    773         Softmax1DQuantizedUint8(input, output, params, data);
    774         return kTfLiteOk;
    775       }
    776       if (NumDimensions(input) == 2) {
    777         Softmax2DQuantizedUint8(input, output, params, data);
    778         return kTfLiteOk;
    779       }
    780       if (NumDimensions(input) == 3) {
    781         Softmax3DQuantizedUint8(input, output, params, data);
    782         return kTfLiteOk;
    783       }
    784       if (NumDimensions(input) == 4) {
    785         Softmax4DQuantizedUint8(input, output, params, data);
    786         return kTfLiteOk;
    787       }
    788       context->ReportError(
    789           context, "Only 2D and 4D tensors supported currently, got %dD.",
    790           NumDimensions(input));
    791       return kTfLiteError;
    792     }
    793     case kTfLiteInt8: {
    794       if (NumDimensions(input) == 1) {
    795         Softmax1DQuantizedInt8(input, output, params, data);
    796         return kTfLiteOk;
    797       }
    798       if (NumDimensions(input) == 2) {
    799         Softmax2DQuantizedInt8(input, output, params, data);
    800         return kTfLiteOk;
    801       }
    802       if (NumDimensions(input) == 3) {
    803         Softmax3DQuantizedInt8(input, output, params, data);
    804         return kTfLiteOk;
    805       }
    806       if (NumDimensions(input) == 4) {
    807         Softmax4DQuantizedInt8(input, output, params, data);
    808         return kTfLiteOk;
    809       }
    810       context->ReportError(
    811           context,
    812           "Only 4D tensors supported currently for Int8 kernels, got %dD.",
    813           NumDimensions(input));
    814       return kTfLiteError;
    815     }
    816 
    817     default:
    818       context->ReportError(
    819           context, "Only float32 and uint8_t supported currently, got %s.",
    820           TfLiteTypeGetName(input->type));
    821       return kTfLiteError;
    822   }
    823 }
    824 
    825 template <KernelType kernel_type>
    826 TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
    827   const LogSoftmaxOpData* data =
    828       reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
    829   const TfLiteTensor* input = GetInput(context, node, 0);
    830   TfLiteTensor* output = GetOutput(context, node, 0);
    831   switch (input->type) {
    832     case kTfLiteFloat32: {
    833       SoftmaxParams op_params;
    834       if (kernel_type == kGenericOptimized) {
    835         optimized_ops::LogSoftmax(
    836             op_params, GetTensorShape(input), GetTensorData<float>(input),
    837             GetTensorShape(output), GetTensorData<float>(output));
    838       } else {
    839         reference_ops::LogSoftmax(
    840             op_params, GetTensorShape(input), GetTensorData<float>(input),
    841             GetTensorShape(output), GetTensorData<float>(output));
    842       }
    843       return kTfLiteOk;
    844     }
    845     case kTfLiteUInt8: {
    846       SoftmaxParams op_params;
    847       op_params.input_multiplier = data->input_multiplier;
    848       op_params.input_left_shift = data->input_left_shift;
    849       op_params.reverse_scaling_divisor = data->reverse_scaling_divisor;
    850       op_params.reverse_scaling_right_shift = data->reverse_scaling_right_shift;
    851       op_params.diff_min = data->diff_min;
    852       if (kernel_type == kGenericOptimized) {
    853         optimized_ops::LogSoftmax(
    854             op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
    855             GetTensorShape(output), GetTensorData<uint8_t>(output));
    856       } else {
    857         reference_ops::LogSoftmax(
    858             op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
    859             GetTensorShape(output), GetTensorData<uint8_t>(output));
    860       }
    861       return kTfLiteOk;
    862     }
    863     case kTfLiteInt8: {
    864       const auto input_shape = GetTensorShape(input);
    865       const auto output_shape = GetTensorShape(output);
    866       const int trailing_dim = input_shape.DimensionsCount() - 1;
    867       const int outer_size =
    868           MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
    869       const int depth =
    870           MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
    871       reference_integer_ops::LogSoftmax(
    872           data->input_multiplier, data->input_left_shift,
    873           data->reverse_scaling_divisor, data->reverse_scaling_right_shift,
    874           data->diff_min, outer_size, depth, GetTensorData<int8_t>(input),
    875           GetTensorData<int8_t>(output));
    876       return kTfLiteOk;
    877     }
    878     default:
    879       context->ReportError(context, "Only float32 supported currently., got %s",
    880                            TfLiteTypeGetName(input->type));
    881       return kTfLiteError;
    882   }
    883 }
    884 
    885 template <typename T>
    886 T ApplyPrelu(T input, T alpha) {
    887   return input >= 0.0 ? input : input * alpha;
    888 }
    889 
    890 TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
    891   const TfLiteTensor* input = GetInput(context, node, 0);
    892   const TfLiteTensor* alpha = GetInput(context, node, 1);
    893   TfLiteTensor* output = GetOutput(context, node, 0);
    894   const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
    895   switch (input->type) {
    896     case kTfLiteFloat32: {
    897       reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
    898           GetTensorShape(input), GetTensorData<float>(input),
    899           GetTensorShape(alpha), GetTensorData<float>(alpha),
    900           GetTensorShape(output), GetTensorData<float>(output),
    901           ApplyPrelu<float>);
    902       return kTfLiteOk;
    903     } break;
    904     case kTfLiteUInt8: {
    905       PreluParams op_params;
    906       op_params.input_offset = -input->params.zero_point;
    907       op_params.alpha_offset = -alpha->params.zero_point;
    908       op_params.output_offset = output->params.zero_point;
    909       op_params.output_multiplier = data->output_multiplier;
    910       op_params.output_shift = data->output_shift;
    911       reference_ops::BroadcastPrelu4DSlow(
    912           op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
    913           GetTensorShape(alpha), GetTensorData<uint8_t>(alpha),
    914           GetTensorShape(output), GetTensorData<uint8_t>(output));
    915       return kTfLiteOk;
    916     } break;
    917     default:
    918       context->ReportError(context,
    919                            "Only float32, uint8 supported currently, got %d.",
    920                            TfLiteTypeGetName(input->type));
    921       return kTfLiteError;
    922   }
    923 }
    924 
    925 TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
    926   const TfLiteTensor* input = GetInput(context, node, 0);
    927   TfLiteTensor* output = GetOutput(context, node, 0);
    928   const auto* params =
    929       reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
    930 
    931   LeakyReluParams op_params;
    932   op_params.alpha = params->alpha;
    933   switch (input->type) {
    934     case kTfLiteFloat32: {
    935       optimized_ops::LeakyRelu(
    936           op_params, GetTensorShape(input), GetTensorData<float>(input),
    937           GetTensorShape(output), GetTensorData<float>(output));
    938       return kTfLiteOk;
    939     } break;
    940     default:
    941       context->ReportError(context, "Only float32 supported currently, got %s.",
    942                            TfLiteTypeGetName(input->type));
    943       return kTfLiteError;
    944   }
    945 }
    946 
    947 TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
    948   const TfLiteTensor* input = GetInput(context, node, 0);
    949   TfLiteTensor* output = GetOutput(context, node, 0);
    950   switch (input->type) {
    951     case kTfLiteFloat32: {
    952       optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
    953                          GetTensorShape(output), GetTensorData<float>(output));
    954       return kTfLiteOk;
    955     } break;
    956     default:
    957       context->ReportError(context, "Only float32 supported currently, got %s.",
    958                            TfLiteTypeGetName(input->type));
    959       return kTfLiteError;
    960   }
    961 }
    962 
    963 }  // namespace activations
    964 
    965 TfLiteRegistration* Register_ELU() {
    966   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
    967                                  activations::GenericPrepare,
    968                                  activations::EluEval};
    969   return &r;
    970 }
    971 
    972 TfLiteRegistration* Register_RELU() {
    973   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
    974                                  activations::GenericPrepare,
    975                                  activations::ReluEval};
    976   return &r;
    977 }
    978 
    979 TfLiteRegistration* Register_RELU_N1_TO_1() {
    980   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
    981                                  activations::GenericPrepare,
    982                                  activations::Relu1Eval};
    983   return &r;
    984 }
    985 
    986 TfLiteRegistration* Register_RELU6() {
    987   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
    988                                  activations::GenericPrepare,
    989                                  activations::Relu6Eval};
    990   return &r;
    991 }
    992 
    993 TfLiteRegistration* Register_TANH_REF() {
    994   static TfLiteRegistration r = {
    995       activations::Init, activations::Free, activations::TanhPrepare,
    996       activations::TanhEval<activations::kReference>};
    997   return &r;
    998 }
    999 
   1000 TfLiteRegistration* Register_TANH() {
   1001   static TfLiteRegistration r = {
   1002       activations::Init, activations::Free, activations::TanhPrepare,
   1003       activations::TanhEval<activations::kGenericOptimized>};
   1004   return &r;
   1005 }
   1006 
   1007 TfLiteRegistration* Register_LOGISTIC_REF() {
   1008   static TfLiteRegistration r = {
   1009       activations::Init, activations::Free, activations::SigmoidPrepare,
   1010       activations::SigmoidEval<activations::kReference>};
   1011   return &r;
   1012 }
   1013 
   1014 TfLiteRegistration* Register_LOGISTIC() {
   1015   static TfLiteRegistration r = {
   1016       activations::Init, activations::Free, activations::SigmoidPrepare,
   1017       activations::SigmoidEval<activations::kGenericOptimized>};
   1018   return &r;
   1019 }
   1020 
   1021 TfLiteRegistration* Register_SOFTMAX() {
   1022   static TfLiteRegistration r = {activations::Init, activations::Free,
   1023                                  activations::SoftmaxPrepare,
   1024                                  activations::SoftmaxEval};
   1025   return &r;
   1026 }
   1027 
   1028 TfLiteRegistration* Register_LOG_SOFTMAX_REF() {
   1029   static TfLiteRegistration r = {
   1030       activations::LogSoftmaxInit, activations::LogSoftmaxFree,
   1031       activations::LogSoftmaxPrepare,
   1032       activations::LogSoftmaxEval<activations::kReference>};
   1033   return &r;
   1034 }
   1035 
   1036 TfLiteRegistration* Register_LOG_SOFTMAX() {
   1037   static TfLiteRegistration r = {
   1038       activations::LogSoftmaxInit, activations::LogSoftmaxFree,
   1039       activations::LogSoftmaxPrepare,
   1040       activations::LogSoftmaxEval<activations::kGenericOptimized>};
   1041   return &r;
   1042 }
   1043 
   1044 TfLiteRegistration* Register_PRELU() {
   1045   static TfLiteRegistration r = {activations::PreluInit, activations::PreluFree,
   1046                                  activations::PreluPrepare,
   1047                                  activations::PreluEval};
   1048   return &r;
   1049 }
   1050 
   1051 TfLiteRegistration* Register_LEAKY_RELU() {
   1052   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
   1053                                  activations::GenericPrepare,
   1054                                  activations::LeakyReluEval};
   1055   return &r;
   1056 }
   1057 
   1058 }  // namespace builtin
   1059 }  // namespace ops
   1060 }  // namespace tflite
   1061