1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <string.h> 16 #include <vector> 17 #include "tensorflow/lite/c/builtin_op_data.h" 18 #include "tensorflow/lite/c/c_api_internal.h" 19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" 20 #include "tensorflow/lite/kernels/internal/tensor.h" 21 #include "tensorflow/lite/kernels/kernel_util.h" 22 #include "tensorflow/lite/kernels/op_macros.h" 23 namespace tflite { 24 namespace ops { 25 namespace builtin { 26 namespace tile { 27 28 constexpr int kInputTensor = 0; 29 constexpr int kInputMultipliers = 1; 30 constexpr int kOutputTensor = 0; 31 32 namespace { 33 template <typename T> 34 TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape, 35 const TfLiteTensor* multipliers, 36 int num_dimensions) { 37 const T* multipliers_v = GetTensorData<T>(multipliers); 38 39 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); 40 for (int i = 0; i < num_dimensions; ++i) { 41 output_shape->data[i] = shape.data[i] * multipliers_v[i]; 42 } 43 return output_shape; 44 } 45 46 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { 47 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 48 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 49 const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); 50 51 const int num_dimensions = NumDimensions(input); 52 const int num_multipliers = NumElements(multipliers); 53 TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers); 54 switch (multipliers->type) { 55 case kTfLiteInt32: 56 return context->ResizeTensor( 57 context, output, 58 MultiplyShapeDims<int32_t>(*input->dims, multipliers, 59 num_dimensions)); 60 case kTfLiteInt64: 61 return context->ResizeTensor( 62 context, output, 63 MultiplyShapeDims<int64_t>(*input->dims, multipliers, 64 num_dimensions)); 65 default: 66 context->ReportError( 67 context, "Multipliers of type '%s' are not supported by tile.", 68 TfLiteTypeGetName(multipliers->type)); 69 return kTfLiteError; 70 } 71 } 72 73 template <typename T> 74 void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier, 75 T* out_data) { 76 for (int i = 0; i < multiplier; ++i) { 77 const T* in_end = in_data + in_size; 78 T* new_out_data = std::copy(in_data, in_end, out_data); 79 in_data = out_data; 80 out_data = new_out_data; 81 } 82 } 83 84 template <typename T, typename M> 85 std::pair<int, int> TileOneDimension(const TfLiteIntArray& in_dimensions, 86 const T* in_data, const M* multipliers, 87 T* out_data, int dimension) { 88 const int dimension_size = in_dimensions.data[dimension]; 89 if (dimension == in_dimensions.size - 1) { 90 CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], 91 out_data); 92 return std::make_pair( 93 dimension_size, 94 dimension_size * static_cast<int>(multipliers[dimension])); 95 } 96 int total_stride_size = 0, total_tiled_stride_size = 0; 97 const T* copy_from_data = in_data; 98 T* copy_to_data = out_data; 99 for (int i = 0; i < dimension_size; ++i) { 100 int stride_size = 0, tiled_stride_size = 0; 101 std::tie(stride_size, tiled_stride_size) = 102 TileOneDimension(in_dimensions, copy_from_data, multipliers, 103 copy_to_data, dimension + 1); 104 copy_from_data += stride_size; 105 copy_to_data += tiled_stride_size; 106 total_stride_size += stride_size; 107 total_tiled_stride_size += tiled_stride_size; 108 } 109 CopyMultipleTimes(out_data, total_tiled_stride_size, 110 multipliers[dimension] - 1, 111 out_data + total_tiled_stride_size); 112 return std::make_pair(total_stride_size, 113 total_tiled_stride_size * multipliers[dimension]); 114 } 115 116 template <typename T> 117 void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data, 118 const TfLiteTensor* multipliers, TfLiteTensor* out_data) { 119 // Doing recursively tiling from top to down dimension. 120 switch (multipliers->type) { 121 case kTfLiteInt32: 122 TileOneDimension(in_dimensions, GetTensorData<T>(in_data), 123 GetTensorData<int32_t>(multipliers), 124 GetTensorData<T>(out_data), 0); 125 break; 126 case kTfLiteInt64: 127 TileOneDimension(in_dimensions, GetTensorData<T>(in_data), 128 GetTensorData<int64_t>(multipliers), 129 GetTensorData<T>(out_data), 0); 130 break; 131 default: 132 break; 133 } 134 } 135 } // namespace 136 137 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 138 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); 139 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 140 141 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 142 143 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 144 TF_LITE_ENSURE_EQ(context, input->type, output->type); 145 146 const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); 147 // Only int32 and int64 multipliers type is supported. 148 if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) { 149 context->ReportError(context, 150 "Multipliers of type '%s' are not supported by tile.", 151 TfLiteTypeGetName(multipliers->type)); 152 return kTfLiteError; 153 } 154 155 if (IsConstantTensor(multipliers)) { 156 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); 157 } else { 158 SetTensorToDynamic(output); 159 } 160 return kTfLiteOk; 161 } 162 163 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 164 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 165 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 166 const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); 167 168 if (IsDynamicTensor(output)) { 169 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); 170 } 171 172 switch (output->type) { 173 case kTfLiteFloat32: 174 Tile<float>(*(input->dims), input, multipliers, output); 175 break; 176 case kTfLiteUInt8: 177 Tile<uint8_t>(*(input->dims), input, multipliers, output); 178 break; 179 case kTfLiteInt32: 180 Tile<int32_t>(*(input->dims), input, multipliers, output); 181 break; 182 case kTfLiteInt64: 183 Tile<int64_t>(*(input->dims), input, multipliers, output); 184 break; 185 case kTfLiteBool: 186 Tile<bool>(*(input->dims), input, multipliers, output); 187 break; 188 default: 189 context->ReportError(context, "Type '%s' is not supported by tile.", 190 TfLiteTypeGetName(output->type)); 191 return kTfLiteError; 192 } 193 return kTfLiteOk; 194 } 195 196 } // namespace tile 197 TfLiteRegistration* Register_TILE() { 198 static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval}; 199 return &r; 200 } 201 } // namespace builtin 202 } // namespace ops 203 } // namespace tflite 204