Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <functional>
     16 #include <type_traits>
     17 #include "tensorflow/lite/c/c_api_internal.h"
     18 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     19 #include "tensorflow/lite/kernels/internal/tensor.h"
     20 #include "tensorflow/lite/kernels/kernel_util.h"
     21 #include "tensorflow/lite/kernels/op_macros.h"
     22 
     23 // TODO(b/117523611): We should factor out a binary_op and put binary ops there.
     24 namespace tflite {
     25 namespace ops {
     26 namespace builtin {
     27 namespace floor_mod {
     28 namespace {
     29 
     30 // Input/output tensor index.
     31 constexpr int kInputTensor1 = 0;
     32 constexpr int kInputTensor2 = 1;
     33 constexpr int kOutputTensor = 0;
     34 
     35 // Op data for floor_mod op.
     36 struct OpData {
     37   bool requires_broadcast;
     38 };
     39 
     40 struct FloatMod {
     41   float operator()(const float lhs, const float rhs) const {
     42     return std::fmod(lhs, rhs);
     43   }
     44 };
     45 
     46 // TODO(b/117912007): Move the implementation to reference_ops.h
     47 // TODO(b/117912880): Support quantization.
     48 template <typename T>
     49 T FloorMod(T input1, T input2) {
     50   using ModFunc = typename std::conditional<std::is_integral<T>::value,
     51                                             std::modulus<T>, FloatMod>::type;
     52 
     53   ModFunc mod_func;
     54   T trunc_mod = mod_func(input1, input2);
     55   return (input1 < T(0)) == (input2 < T(0))
     56              ? trunc_mod
     57              : mod_func(trunc_mod + input2, input2);
     58 }
     59 
     60 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
     61   auto* data = new OpData;
     62   data->requires_broadcast = false;
     63   return data;
     64 }
     65 
     66 void Free(TfLiteContext* context, void* buffer) {
     67   delete reinterpret_cast<OpData*>(buffer);
     68 }
     69 
     70 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     71   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
     72   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
     73 
     74   // Reinterprete the opaque data provided by user.
     75   OpData* data = reinterpret_cast<OpData*>(node->user_data);
     76 
     77   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
     78   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
     79   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
     80 
     81   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
     82 
     83   const TfLiteType type = input1->type;
     84   if (type != kTfLiteInt32 && type != kTfLiteFloat32 && type != kTfLiteInt64) {
     85     context->ReportError(context, "Type '%s' is not supported by floor_mod.",
     86                          TfLiteTypeGetName(type));
     87     return kTfLiteError;
     88   }
     89   output->type = type;
     90 
     91   data->requires_broadcast = !HaveSameShapes(input1, input2);
     92 
     93   TfLiteIntArray* output_size = nullptr;
     94   if (data->requires_broadcast) {
     95     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
     96                                    context, input1, input2, &output_size));
     97   } else {
     98     output_size = TfLiteIntArrayCopy(input1->dims);
     99   }
    100 
    101   return context->ResizeTensor(context, output, output_size);
    102 }
    103 
    104 template <typename T>
    105 TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
    106                       const TfLiteTensor* input1, const TfLiteTensor* input2,
    107                       TfLiteTensor* output) {
    108   const T* denominator_data = GetTensorData<T>(input2);
    109 
    110   if (input2->type == kTfLiteInt32 || input2->type == kTfLiteInt64) {
    111     // Validate the denominator only for integer.
    112     const int num_elements = NumElements(input2);
    113     for (int i = 0; i < num_elements; ++i) {
    114       if (denominator_data[i] == 0) {
    115         context->ReportError(context, "Division by 0");
    116         return kTfLiteError;
    117       }
    118     }
    119   }
    120   if (requires_broadcast) {
    121     reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
    122         GetTensorShape(input1), GetTensorData<T>(input1),
    123         GetTensorShape(input2), denominator_data, GetTensorShape(output),
    124         GetTensorData<T>(output), FloorMod<T>);
    125   } else {
    126     reference_ops::BinaryFunction<T, T, T>(
    127         GetTensorShape(input1), GetTensorData<T>(input1),
    128         GetTensorShape(input2), GetTensorData<T>(input2),
    129         GetTensorShape(output), GetTensorData<T>(output), FloorMod<T>);
    130   }
    131 
    132   return kTfLiteOk;
    133 }
    134 
    135 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    136   OpData* data = reinterpret_cast<OpData*>(node->user_data);
    137 
    138   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
    139   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
    140   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    141 
    142   switch (input1->type) {
    143     case kTfLiteInt32: {
    144       return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
    145                                input2, output);
    146     }
    147     case kTfLiteInt64: {
    148       return EvalImpl<int64_t>(context, data->requires_broadcast, input1,
    149                                input2, output);
    150     }
    151     case kTfLiteFloat32: {
    152       return EvalImpl<float>(context, data->requires_broadcast, input1, input2,
    153                              output);
    154     }
    155     default: {
    156       context->ReportError(context, "Type '%s' is not supported by floor_mod.",
    157                            TfLiteTypeGetName(input1->type));
    158       return kTfLiteError;
    159     }
    160   }
    161 }
    162 
    163 }  // namespace
    164 }  // namespace floor_mod
    165 
    166 TfLiteRegistration* Register_FLOOR_MOD() {
    167   // Init, Free, Prepare, Eval are satisfying the Interface required by
    168   // TfLiteRegistration.
    169   static TfLiteRegistration r = {floor_mod::Init, floor_mod::Free,
    170                                  floor_mod::Prepare, floor_mod::Eval};
    171   return &r;
    172 }
    173 
    174 }  // namespace builtin
    175 }  // namespace ops
    176 }  // namespace tflite
    177