1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <string.h> 16 #include <vector> 17 #include "tensorflow/contrib/lite/builtin_op_data.h" 18 #include "tensorflow/contrib/lite/context.h" 19 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" 20 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" 21 #include "tensorflow/contrib/lite/kernels/internal/tensor.h" 22 #include "tensorflow/contrib/lite/kernels/kernel_util.h" 23 #include "tensorflow/contrib/lite/kernels/op_macros.h" 24 25 namespace tflite { 26 namespace ops { 27 namespace builtin { 28 namespace space_to_batch_nd { 29 30 // This file has two implementations of SpaceToBatchND. 31 enum KernelType { 32 kReference, 33 kGenericOptimized, 34 }; 35 36 struct SpaceToBatchNDContext { 37 SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) { 38 input = GetInput(context, node, 0); 39 block_shape = GetInput(context, node, 1); 40 paddings = GetInput(context, node, 2); 41 output = GetOutput(context, node, 0); 42 } 43 TfLiteTensor* input; 44 TfLiteTensor* block_shape; 45 TfLiteTensor* paddings; 46 TfLiteTensor* output; 47 }; 48 49 // Currently, only 4D NHWC input/output op_context are supported. 50 // The 4D array need to have exactly 2 spatial dimensions. 51 // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND. 52 const int kInputDimensionNum = 4; 53 const int kBlockSizeDimensionNum = 1; 54 const int kSpatialDimensionNum = 2; 55 56 TfLiteStatus ResizeOutputTensor(TfLiteContext* context, 57 SpaceToBatchNDContext* op_context) { 58 TfLiteIntArray* input_size = op_context->input->dims; 59 const int32* block_shape = GetTensorData<int32>(op_context->block_shape); 60 const int32* paddings_data = GetTensorData<int32>(op_context->paddings); 61 62 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 63 kBlockSizeDimensionNum); 64 TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0], 65 kSpatialDimensionNum); 66 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings), 67 kSpatialDimensionNum); 68 69 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); 70 71 // Ensures the input height and width (with padding) is a multiple of block 72 // shape height and width. 73 for (int dim = 0; dim < kSpatialDimensionNum; ++dim) { 74 int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] + 75 paddings_data[dim * 2 + 1]); 76 TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0); 77 output_size->data[dim + 1] = final_dim_size / block_shape[dim]; 78 } 79 80 const int output_batch_size = 81 input_size->data[0] * block_shape[0] * block_shape[1]; 82 const int output_channel_size = input_size->data[3]; 83 84 output_size->data[0] = output_batch_size; 85 output_size->data[3] = output_channel_size; 86 87 return context->ResizeTensor(context, op_context->output, output_size); 88 } 89 90 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 91 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); 92 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 93 94 SpaceToBatchNDContext op_context(context, node); 95 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), 96 kInputDimensionNum); 97 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); 98 99 if (!IsConstantTensor(op_context.block_shape) || 100 !IsConstantTensor(op_context.paddings)) { 101 SetTensorToDynamic(op_context.output); 102 return kTfLiteOk; 103 } 104 return ResizeOutputTensor(context, &op_context); 105 } 106 107 template <KernelType kernel_type> 108 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 109 SpaceToBatchNDContext op_context(context, node); 110 111 // Resize the output tensor if the output tensor is dynamic. 112 if (IsDynamicTensor(op_context.output)) { 113 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); 114 } 115 116 #define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ 117 type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \ 118 GetTensorDims(op_context.input), \ 119 GetTensorData<int32_t>(op_context.block_shape), \ 120 GetTensorDims(op_context.block_shape), \ 121 GetTensorData<int32_t>(op_context.paddings), \ 122 GetTensorDims(op_context.paddings), \ 123 GetTensorData<scalar>(op_context.output), \ 124 GetTensorDims(op_context.output)) 125 switch (op_context.input->type) { // Already know in/out types are same. 126 case kTfLiteFloat32: 127 if (kernel_type == kReference) { 128 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float); 129 } else { 130 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float); 131 } 132 break; 133 case kTfLiteUInt8: 134 if (kernel_type == kReference) { 135 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t); 136 } else { 137 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t); 138 } 139 break; 140 case kTfLiteInt32: 141 if (kernel_type == kReference) { 142 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t); 143 } else { 144 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t); 145 } 146 break; 147 case kTfLiteInt64: 148 if (kernel_type == kReference) { 149 TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t); 150 } else { 151 TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t); 152 } 153 break; 154 default: 155 context->ReportError(context, 156 "Type is currently not supported by SpaceToBatch."); 157 return kTfLiteError; 158 } 159 #undef TF_LITE_SPACE_TO_BATCH_ND 160 return kTfLiteOk; 161 } 162 163 } // namespace space_to_batch_nd 164 165 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() { 166 static TfLiteRegistration r = { 167 nullptr, nullptr, space_to_batch_nd::Prepare, 168 space_to_batch_nd::Eval<space_to_batch_nd::kReference>}; 169 return &r; 170 } 171 172 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() { 173 static TfLiteRegistration r = { 174 nullptr, nullptr, space_to_batch_nd::Prepare, 175 space_to_batch_nd::Eval<space_to_batch_nd::kGenericOptimized>}; 176 return &r; 177 } 178 179 TfLiteRegistration* Register_SPACE_TO_BATCH_ND() { 180 // return Register_SPACE_TO_BATCH_ND_REF(); 181 return Register_SPACE_TO_BATCH_ND_GENERIC_OPT(); 182 } 183 184 } // namespace builtin 185 } // namespace ops 186 } // namespace tflite 187