Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <string.h>
     16 #include <vector>
     17 #include "tensorflow/contrib/lite/builtin_op_data.h"
     18 #include "tensorflow/contrib/lite/context.h"
     19 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
     20 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
     21 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
     22 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
     23 #include "tensorflow/contrib/lite/kernels/op_macros.h"
     24 
     25 namespace tflite {
     26 namespace ops {
     27 namespace builtin {
     28 namespace space_to_batch_nd {
     29 
     30 // This file has two implementations of SpaceToBatchND.
     31 enum KernelType {
     32   kReference,
     33   kGenericOptimized,
     34 };
     35 
     36 struct SpaceToBatchNDContext {
     37   SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) {
     38     input = GetInput(context, node, 0);
     39     block_shape = GetInput(context, node, 1);
     40     paddings = GetInput(context, node, 2);
     41     output = GetOutput(context, node, 0);
     42   }
     43   TfLiteTensor* input;
     44   TfLiteTensor* block_shape;
     45   TfLiteTensor* paddings;
     46   TfLiteTensor* output;
     47 };
     48 
     49 // Currently, only 4D NHWC input/output op_context are supported.
     50 // The 4D array need to have exactly 2 spatial dimensions.
     51 // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
     52 const int kInputDimensionNum = 4;
     53 const int kBlockSizeDimensionNum = 1;
     54 const int kSpatialDimensionNum = 2;
     55 
     56 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
     57                                 SpaceToBatchNDContext* op_context) {
     58   TfLiteIntArray* input_size = op_context->input->dims;
     59   const int32* block_shape = GetTensorData<int32>(op_context->block_shape);
     60   const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
     61 
     62   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
     63                     kBlockSizeDimensionNum);
     64   TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
     65                     kSpatialDimensionNum);
     66   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings),
     67                     kSpatialDimensionNum);
     68 
     69   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
     70 
     71   // Ensures the input height and width (with padding) is a multiple of block
     72   // shape height and width.
     73   for (int dim = 0; dim < kSpatialDimensionNum; ++dim) {
     74     int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] +
     75                           paddings_data[dim * 2 + 1]);
     76     TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0);
     77     output_size->data[dim + 1] = final_dim_size / block_shape[dim];
     78   }
     79 
     80   const int output_batch_size =
     81       input_size->data[0] * block_shape[0] * block_shape[1];
     82   const int output_channel_size = input_size->data[3];
     83 
     84   output_size->data[0] = output_batch_size;
     85   output_size->data[3] = output_channel_size;
     86 
     87   return context->ResizeTensor(context, op_context->output, output_size);
     88 }
     89 
     90 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     91   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
     92   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
     93 
     94   SpaceToBatchNDContext op_context(context, node);
     95   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
     96                     kInputDimensionNum);
     97   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
     98 
     99   if (!IsConstantTensor(op_context.block_shape) ||
    100       !IsConstantTensor(op_context.paddings)) {
    101     SetTensorToDynamic(op_context.output);
    102     return kTfLiteOk;
    103   }
    104   return ResizeOutputTensor(context, &op_context);
    105 }
    106 
    107 template <KernelType kernel_type>
    108 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    109   SpaceToBatchNDContext op_context(context, node);
    110 
    111   // Resize the output tensor if the output tensor is dynamic.
    112   if (IsDynamicTensor(op_context.output)) {
    113     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
    114   }
    115 
    116 #define TF_LITE_SPACE_TO_BATCH_ND(type, scalar)                        \
    117   type::SpaceToBatchND(GetTensorData<scalar>(op_context.input),        \
    118                        GetTensorDims(op_context.input),                \
    119                        GetTensorData<int32_t>(op_context.block_shape), \
    120                        GetTensorDims(op_context.block_shape),          \
    121                        GetTensorData<int32_t>(op_context.paddings),    \
    122                        GetTensorDims(op_context.paddings),             \
    123                        GetTensorData<scalar>(op_context.output),       \
    124                        GetTensorDims(op_context.output))
    125   switch (op_context.input->type) {  // Already know in/out types are same.
    126     case kTfLiteFloat32:
    127       if (kernel_type == kReference) {
    128         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float);
    129       } else {
    130         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float);
    131       }
    132       break;
    133     case kTfLiteUInt8:
    134       if (kernel_type == kReference) {
    135         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t);
    136       } else {
    137         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t);
    138       }
    139       break;
    140     case kTfLiteInt32:
    141       if (kernel_type == kReference) {
    142         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t);
    143       } else {
    144         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t);
    145       }
    146       break;
    147     case kTfLiteInt64:
    148       if (kernel_type == kReference) {
    149         TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t);
    150       } else {
    151         TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t);
    152       }
    153       break;
    154     default:
    155       context->ReportError(context,
    156                            "Type is currently not supported by SpaceToBatch.");
    157       return kTfLiteError;
    158   }
    159 #undef TF_LITE_SPACE_TO_BATCH_ND
    160   return kTfLiteOk;
    161 }
    162 
    163 }  // namespace space_to_batch_nd
    164 
    165 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() {
    166   static TfLiteRegistration r = {
    167       nullptr, nullptr, space_to_batch_nd::Prepare,
    168       space_to_batch_nd::Eval<space_to_batch_nd::kReference>};
    169   return &r;
    170 }
    171 
    172 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() {
    173   static TfLiteRegistration r = {
    174       nullptr, nullptr, space_to_batch_nd::Prepare,
    175       space_to_batch_nd::Eval<space_to_batch_nd::kGenericOptimized>};
    176   return &r;
    177 }
    178 
    179 TfLiteRegistration* Register_SPACE_TO_BATCH_ND() {
    180   // return Register_SPACE_TO_BATCH_ND_REF();
    181   return Register_SPACE_TO_BATCH_ND_GENERIC_OPT();
    182 }
    183 
    184 }  // namespace builtin
    185 }  // namespace ops
    186 }  // namespace tflite
    187