Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <string.h>
     16 #include <vector>
     17 #include "tensorflow/lite/c/builtin_op_data.h"
     18 #include "tensorflow/lite/c/c_api_internal.h"
     19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     20 #include "tensorflow/lite/kernels/internal/tensor.h"
     21 #include "tensorflow/lite/kernels/kernel_util.h"
     22 #include "tensorflow/lite/kernels/op_macros.h"
     23 namespace tflite {
     24 namespace ops {
     25 namespace builtin {
     26 namespace tile {
     27 
     28 constexpr int kInputTensor = 0;
     29 constexpr int kInputMultipliers = 1;
     30 constexpr int kOutputTensor = 0;
     31 
     32 namespace {
     33 template <typename T>
     34 TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape,
     35                                   const TfLiteTensor* multipliers,
     36                                   int num_dimensions) {
     37   const T* multipliers_v = GetTensorData<T>(multipliers);
     38 
     39   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
     40   for (int i = 0; i < num_dimensions; ++i) {
     41     output_shape->data[i] = shape.data[i] * multipliers_v[i];
     42   }
     43   return output_shape;
     44 }
     45 
     46 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
     47   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
     48   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
     49   const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
     50 
     51   const int num_dimensions = NumDimensions(input);
     52   const int num_multipliers = NumElements(multipliers);
     53   TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers);
     54   switch (multipliers->type) {
     55     case kTfLiteInt32:
     56       return context->ResizeTensor(
     57           context, output,
     58           MultiplyShapeDims<int32_t>(*input->dims, multipliers,
     59                                      num_dimensions));
     60     case kTfLiteInt64:
     61       return context->ResizeTensor(
     62           context, output,
     63           MultiplyShapeDims<int64_t>(*input->dims, multipliers,
     64                                      num_dimensions));
     65     default:
     66       context->ReportError(
     67           context, "Multipliers of type '%s' are not supported by tile.",
     68           TfLiteTypeGetName(multipliers->type));
     69       return kTfLiteError;
     70   }
     71 }
     72 
     73 template <typename T>
     74 void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier,
     75                        T* out_data) {
     76   for (int i = 0; i < multiplier; ++i) {
     77     const T* in_end = in_data + in_size;
     78     T* new_out_data = std::copy(in_data, in_end, out_data);
     79     in_data = out_data;
     80     out_data = new_out_data;
     81   }
     82 }
     83 
     84 template <typename T, typename M>
     85 std::pair<int, int> TileOneDimension(const TfLiteIntArray& in_dimensions,
     86                                      const T* in_data, const M* multipliers,
     87                                      T* out_data, int dimension) {
     88   const int dimension_size = in_dimensions.data[dimension];
     89   if (dimension == in_dimensions.size - 1) {
     90     CopyMultipleTimes(in_data, dimension_size, multipliers[dimension],
     91                       out_data);
     92     return std::make_pair(
     93         dimension_size,
     94         dimension_size * static_cast<int>(multipliers[dimension]));
     95   }
     96   int total_stride_size = 0, total_tiled_stride_size = 0;
     97   const T* copy_from_data = in_data;
     98   T* copy_to_data = out_data;
     99   for (int i = 0; i < dimension_size; ++i) {
    100     int stride_size = 0, tiled_stride_size = 0;
    101     std::tie(stride_size, tiled_stride_size) =
    102         TileOneDimension(in_dimensions, copy_from_data, multipliers,
    103                          copy_to_data, dimension + 1);
    104     copy_from_data += stride_size;
    105     copy_to_data += tiled_stride_size;
    106     total_stride_size += stride_size;
    107     total_tiled_stride_size += tiled_stride_size;
    108   }
    109   CopyMultipleTimes(out_data, total_tiled_stride_size,
    110                     multipliers[dimension] - 1,
    111                     out_data + total_tiled_stride_size);
    112   return std::make_pair(total_stride_size,
    113                         total_tiled_stride_size * multipliers[dimension]);
    114 }
    115 
    116 template <typename T>
    117 void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data,
    118           const TfLiteTensor* multipliers, TfLiteTensor* out_data) {
    119   // Doing recursively tiling from top to down dimension.
    120   switch (multipliers->type) {
    121     case kTfLiteInt32:
    122       TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
    123                        GetTensorData<int32_t>(multipliers),
    124                        GetTensorData<T>(out_data), 0);
    125       break;
    126     case kTfLiteInt64:
    127       TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
    128                        GetTensorData<int64_t>(multipliers),
    129                        GetTensorData<T>(out_data), 0);
    130       break;
    131     default:
    132       break;
    133   }
    134 }
    135 }  // namespace
    136 
    137 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    138   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
    139   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
    140 
    141   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    142 
    143   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    144   TF_LITE_ENSURE_EQ(context, input->type, output->type);
    145 
    146   const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
    147   // Only int32 and int64 multipliers type is supported.
    148   if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) {
    149     context->ReportError(context,
    150                          "Multipliers of type '%s' are not supported by tile.",
    151                          TfLiteTypeGetName(multipliers->type));
    152     return kTfLiteError;
    153   }
    154 
    155   if (IsConstantTensor(multipliers)) {
    156     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
    157   } else {
    158     SetTensorToDynamic(output);
    159   }
    160   return kTfLiteOk;
    161 }
    162 
    163 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    164   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    165   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    166   const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
    167 
    168   if (IsDynamicTensor(output)) {
    169     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
    170   }
    171 
    172   switch (output->type) {
    173     case kTfLiteFloat32:
    174       Tile<float>(*(input->dims), input, multipliers, output);
    175       break;
    176     case kTfLiteUInt8:
    177       Tile<uint8_t>(*(input->dims), input, multipliers, output);
    178       break;
    179     case kTfLiteInt32:
    180       Tile<int32_t>(*(input->dims), input, multipliers, output);
    181       break;
    182     case kTfLiteInt64:
    183       Tile<int64_t>(*(input->dims), input, multipliers, output);
    184       break;
    185     case kTfLiteBool:
    186       Tile<bool>(*(input->dims), input, multipliers, output);
    187       break;
    188     default:
    189       context->ReportError(context, "Type '%s' is not supported by tile.",
    190                            TfLiteTypeGetName(output->type));
    191       return kTfLiteError;
    192   }
    193   return kTfLiteOk;
    194 }
    195 
    196 }  // namespace tile
    197 TfLiteRegistration* Register_TILE() {
    198   static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval};
    199   return &r;
    200 }
    201 }  // namespace builtin
    202 }  // namespace ops
    203 }  // namespace tflite
    204