Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cmath>
     17 #include "tensorflow/lite/c/c_api_internal.h"
     18 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     19 #include "tensorflow/lite/kernels/internal/tensor.h"
     20 #include "tensorflow/lite/kernels/kernel_util.h"
     21 
     22 namespace tflite {
     23 namespace ops {
     24 namespace builtin {
     25 namespace elementwise {
     26 namespace {
     27 
     28 bool IsNumericSupportedType(const TfLiteType type) {
     29   return type == kTfLiteFloat32;
     30 }
     31 
     32 bool IsLogicalSupportedType(const TfLiteType type) {
     33   return type == kTfLiteBool;
     34 }
     35 
     36 typedef bool (*IsSupportedType)(TfLiteType);
     37 template <IsSupportedType>
     38 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
     39   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
     40   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
     41   const TfLiteTensor* input = GetInput(context, node, 0);
     42   TfLiteTensor* output = GetOutput(context, node, 0);
     43   TF_LITE_ENSURE_EQ(context, input->type, output->type);
     44   if (!IsSupportedType(input->type)) {
     45     context->ReportError(context, "Current data type %d is not supported.",
     46                          input->type);
     47     return kTfLiteError;
     48   }
     49   return context->ResizeTensor(context, output,
     50                                TfLiteIntArrayCopy(input->dims));
     51 }
     52 
     53 template <typename T>
     54 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
     55                              T func(T), TfLiteType expected_type) {
     56   const TfLiteTensor* input = GetInput(context, node, 0);
     57   TfLiteTensor* output = GetOutput(context, node, 0);
     58   TF_LITE_ENSURE_EQ(context, input->type, expected_type);
     59   const int64_t num_elements = NumElements(input);
     60   const T* in_data = GetTensorData<T>(input);
     61   T* out_data = GetTensorData<T>(output);
     62   for (int64_t i = 0; i < num_elements; ++i) {
     63     out_data[i] = func(in_data[i]);
     64   }
     65   return kTfLiteOk;
     66 }
     67 
     68 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
     69                                 float float_func(float)) {
     70   return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
     71 }
     72 
     73 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
     74                                 bool bool_func(bool)) {
     75   return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
     76 }
     77 
     78 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
     79   return EvalNumeric(context, node, std::abs);
     80 }
     81 
     82 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
     83   return EvalNumeric(context, node, std::sin);
     84 }
     85 
     86 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
     87   return EvalNumeric(context, node, std::cos);
     88 }
     89 
     90 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
     91   return EvalNumeric(context, node, std::log);
     92 }
     93 
     94 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
     95   return EvalNumeric(context, node, std::sqrt);
     96 }
     97 
     98 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
     99   return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
    100 }
    101 
    102 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
    103   return EvalNumeric(context, node, [](float f) { return f * f; });
    104 }
    105 
    106 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
    107   return EvalLogical(context, node, [](bool v) { return !v; });
    108 }
    109 
    110 }  // namespace
    111 }  // namespace elementwise
    112 
    113 TfLiteRegistration* Register_ABS() {
    114   static TfLiteRegistration r = {
    115       /*init=*/nullptr, /*free=*/nullptr,
    116       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
    117       elementwise::AbsEval};
    118   return &r;
    119 }
    120 
    121 TfLiteRegistration* Register_SIN() {
    122   static TfLiteRegistration r = {
    123       /*init=*/nullptr, /*free=*/nullptr,
    124       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
    125       elementwise::SinEval};
    126   return &r;
    127 }
    128 
    129 TfLiteRegistration* Register_COS() {
    130   static TfLiteRegistration r = {
    131       /*init=*/nullptr, /*free=*/nullptr,
    132       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
    133       elementwise::CosEval};
    134   return &r;
    135 }
    136 
    137 TfLiteRegistration* Register_LOG() {
    138   static TfLiteRegistration r = {
    139       /*init=*/nullptr, /*free=*/nullptr,
    140       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
    141       elementwise::LogEval};
    142   return &r;
    143 }
    144 
    145 TfLiteRegistration* Register_SQRT() {
    146   static TfLiteRegistration r = {
    147       /*init=*/nullptr, /*free=*/nullptr,
    148       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
    149       elementwise::SqrtEval};
    150   return &r;
    151 }
    152 
    153 TfLiteRegistration* Register_RSQRT() {
    154   static TfLiteRegistration r = {
    155       /*init=*/nullptr, /*free=*/nullptr,
    156       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
    157       elementwise::RsqrtEval};
    158   return &r;
    159 }
    160 
    161 TfLiteRegistration* Register_SQUARE() {
    162   static TfLiteRegistration r = {
    163       /*init=*/nullptr, /*free=*/nullptr,
    164       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
    165       elementwise::SquareEval};
    166   return &r;
    167 }
    168 
    169 TfLiteRegistration* Register_LOGICAL_NOT() {
    170   static TfLiteRegistration r = {
    171       /*init=*/nullptr, /*free=*/nullptr,
    172       elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
    173       elementwise::LogicalNotEval};
    174   return &r;
    175 }
    176 
    177 }  // namespace builtin
    178 }  // namespace ops
    179 }  // namespace tflite
    180