Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/lite/c/c_api_internal.h"
     16 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     17 #include "tensorflow/lite/kernels/internal/tensor.h"
     18 #include "tensorflow/lite/kernels/kernel_util.h"
     19 #include "tensorflow/lite/kernels/op_macros.h"
     20 
     21 namespace tflite {
     22 namespace ops {
     23 namespace builtin {
     24 namespace pow {
     25 namespace {
     26 
     27 // Input/output tensor index.
     28 constexpr int kInputTensor1 = 0;
     29 constexpr int kInputTensor2 = 1;
     30 constexpr int kOutputTensor = 0;
     31 
     32 // Op data for pow op.
     33 struct OpData {
     34   bool requires_broadcast;
     35 };
     36 
     37 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
     38   auto* data = new OpData;
     39   data->requires_broadcast = false;
     40   return data;
     41 }
     42 
     43 void Free(TfLiteContext* context, void* buffer) {
     44   delete reinterpret_cast<OpData*>(buffer);
     45 }
     46 
     47 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     48   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
     49   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
     50 
     51   OpData* data = reinterpret_cast<OpData*>(node->user_data);
     52 
     53   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
     54   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
     55   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
     56 
     57   TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
     58 
     59   const TfLiteType type = input1->type;
     60   if (type != kTfLiteInt32 && type != kTfLiteFloat32) {
     61     context->ReportError(context, "Unsupported data type %d.", type);
     62     return kTfLiteError;
     63   }
     64   output->type = type;
     65 
     66   data->requires_broadcast = !HaveSameShapes(input1, input2);
     67 
     68   TfLiteIntArray* output_size = nullptr;
     69   if (data->requires_broadcast) {
     70     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
     71                                    context, input1, input2, &output_size));
     72   } else {
     73     output_size = TfLiteIntArrayCopy(input1->dims);
     74   }
     75 
     76   return context->ResizeTensor(context, output, output_size);
     77 }
     78 
     79 template <typename T>
     80 void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
     81              TfLiteTensor* output, bool requires_broadcast) {
     82   if (requires_broadcast) {
     83     reference_ops::BroadcastPow4DSlow(
     84         GetTensorShape(input1), GetTensorData<T>(input1),
     85         GetTensorShape(input2), GetTensorData<T>(input2),
     86         GetTensorShape(output), GetTensorData<T>(output));
     87   } else {
     88     reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1),
     89                        GetTensorShape(input2), GetTensorData<T>(input2),
     90                        GetTensorShape(output), GetTensorData<T>(output));
     91   }
     92 }
     93 
     94 TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
     95   const int64_t num_elements = NumElements(input);
     96   const int32_t* data = GetTensorData<int32_t>(input);
     97   for (int i = 0; i < num_elements; ++i) {
     98     if (data[i] < 0) {
     99       context->ReportError(context,
    100                            "POW does not support negative value for int32.");
    101       return kTfLiteError;
    102     }
    103   }
    104   return kTfLiteOk;
    105 }
    106 
    107 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    108   OpData* data = reinterpret_cast<OpData*>(node->user_data);
    109 
    110   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
    111   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
    112   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    113 
    114   switch (output->type) {
    115     case kTfLiteInt32: {
    116       // TensorFlow does not support negative for int32.
    117       TF_LITE_ENSURE_OK(context, CheckValue(context, input2));
    118       PowImpl<int32_t>(input1, input2, output, data->requires_broadcast);
    119       break;
    120     }
    121     case kTfLiteFloat32: {
    122       PowImpl<float>(input1, input2, output, data->requires_broadcast);
    123       break;
    124     }
    125     default: {
    126       context->ReportError(context, "Unsupported data type: %d", output->type);
    127       return kTfLiteError;
    128     }
    129   }
    130   return kTfLiteOk;
    131 }
    132 
    133 }  // namespace
    134 }  // namespace pow
    135 
    136 TfLiteRegistration* Register_POW() {
    137   static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval};
    138   return &r;
    139 }
    140 
    141 }  // namespace builtin
    142 }  // namespace ops
    143 }  // namespace tflite
    144