Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/lite/c/builtin_op_data.h"
     16 #include "tensorflow/lite/c/c_api_internal.h"
     17 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
     18 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     19 #include "tensorflow/lite/kernels/internal/tensor.h"
     20 #include "tensorflow/lite/kernels/kernel_util.h"
     21 #include "tensorflow/lite/kernels/op_macros.h"
     22 
     23 namespace tflite {
     24 namespace ops {
     25 namespace builtin {
     26 namespace resize_bilinear {
     27 
     28 // This file has three implementation of RESIZE_BILINEAR.
     29 enum KernelType {
     30   kReference,
     31   kGenericOptimized,  // Neon-free
     32   kNeonOptimized,
     33 };
     34 
     35 constexpr int kInputTensor = 0;
     36 constexpr int kSizeTensor = 1;
     37 constexpr int kOutputTensor = 0;
     38 
     39 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
     40                                 const TfLiteTensor* input,
     41                                 const TfLiteTensor* size,
     42                                 TfLiteTensor* output) {
     43   TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
     44   output_size->data[0] = input->dims->data[0];
     45   const int32* size_data = GetTensorData<int32>(size);
     46   output_size->data[1] = size_data[0];
     47   output_size->data[2] = size_data[1];
     48   output_size->data[3] = input->dims->data[3];
     49   return context->ResizeTensor(context, output, output_size);
     50 }
     51 
     52 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     53   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
     54   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
     55 
     56   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
     57   const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
     58   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
     59 
     60   // TODO(ahentz): Our current implementations rely on the inputs being 4D.
     61   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
     62   TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
     63 
     64   TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
     65   // ResizeBilinear creates a float tensor even when the input is made of
     66   // integers.
     67   output->type = input->type;
     68 
     69   if (!IsConstantTensor(size)) {
     70     SetTensorToDynamic(output);
     71     return kTfLiteOk;
     72   }
     73   return ResizeOutputTensor(context, input, size, output);
     74 }
     75 
     76 template <KernelType kernel_type>
     77 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
     78   auto* params =
     79       reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
     80 
     81   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
     82   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
     83   const TfLiteTensor* size = GetInput(context, node, kSizeTensor);
     84 
     85   if (IsDynamicTensor(output)) {
     86     TF_LITE_ENSURE_OK(context,
     87                       ResizeOutputTensor(context, input, size, output));
     88   }
     89 
     90   if (output->type == kTfLiteFloat32) {
     91 #define TF_LITE_RESIZE_BILINEAR(type, datatype)                              \
     92   tflite::ResizeBilinearParams op_params;                                    \
     93   op_params.align_corners = params->align_corners;                           \
     94   type::ResizeBilinear(op_params, GetTensorShape(input),                     \
     95                        GetTensorData<datatype>(input), GetTensorShape(size), \
     96                        GetTensorData<int32>(size), GetTensorShape(output),   \
     97                        GetTensorData<datatype>(output))
     98 
     99     if (kernel_type == kReference) {
    100       TF_LITE_RESIZE_BILINEAR(reference_ops, float);
    101     }
    102     if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
    103       TF_LITE_RESIZE_BILINEAR(optimized_ops, float);
    104     }
    105   } else if (output->type == kTfLiteUInt8) {
    106     if (kernel_type == kReference) {
    107       TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t);
    108     }
    109     if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
    110       TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t);
    111     }
    112   } else if (output->type == kTfLiteInt8) {
    113     TF_LITE_RESIZE_BILINEAR(reference_ops, int8_t);
    114 #undef TF_LITE_RESIZE_BILINEAR
    115   } else {
    116     context->ReportError(context, "Output type is %d, requires float.",
    117                          output->type);
    118     return kTfLiteError;
    119   }
    120 
    121   return kTfLiteOk;
    122 }
    123 
    124 }  // namespace resize_bilinear
    125 
    126 TfLiteRegistration* Register_RESIZE_BILINEAR_REF() {
    127   static TfLiteRegistration r = {
    128       nullptr, nullptr, resize_bilinear::Prepare,
    129       resize_bilinear::Eval<resize_bilinear::kReference>};
    130   return &r;
    131 }
    132 
    133 TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() {
    134   static TfLiteRegistration r = {
    135       nullptr, nullptr, resize_bilinear::Prepare,
    136       resize_bilinear::Eval<resize_bilinear::kGenericOptimized>};
    137   return &r;
    138 }
    139 
    140 TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() {
    141   static TfLiteRegistration r = {
    142       nullptr, nullptr, resize_bilinear::Prepare,
    143       resize_bilinear::Eval<resize_bilinear::kNeonOptimized>};
    144   return &r;
    145 }
    146 
    147 TfLiteRegistration* Register_RESIZE_BILINEAR() {
    148 #ifdef USE_NEON
    149   return Register_RESIZE_BILINEAR_NEON_OPT();
    150 #else
    151   return Register_RESIZE_BILINEAR_GENERIC_OPT();
    152 #endif
    153 }
    154 
    155 }  // namespace builtin
    156 }  // namespace ops
    157 }  // namespace tflite
    158