Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <string.h>
     16 #include <cmath>
     17 #include <vector>
     18 #include "tensorflow/contrib/lite/builtin_op_data.h"
     19 #include "tensorflow/contrib/lite/context.h"
     20 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
     21 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
     22 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
     23 #include "tensorflow/contrib/lite/kernels/op_macros.h"
     24 
     25 namespace tflite {
     26 namespace ops {
     27 namespace builtin {
     28 namespace strided_slice {
     29 
     30 enum KernelType {
     31   kReference,
     32   // TODO(soroosh): add kGenericOptimized
     33 };
     34 
     35 constexpr int kInputTensor = 0;
     36 constexpr int kBeginTensor = 1;
     37 constexpr int kEndTensor = 2;
     38 constexpr int kStridesTensor = 3;
     39 constexpr int kOutputTensor = 0;
     40 
     41 struct StridedSliceContext {
     42   StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
     43     params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
     44     input = GetInput(context, node, kInputTensor);
     45     begin = GetInput(context, node, kBeginTensor);
     46     end = GetInput(context, node, kEndTensor);
     47     strides = GetInput(context, node, kStridesTensor);
     48     output = GetOutput(context, node, kOutputTensor);
     49     dims = NumDimensions(input);
     50   }
     51   TfLiteStridedSliceParams* params;
     52   TfLiteTensor* input;
     53   TfLiteTensor* begin;
     54   TfLiteTensor* end;
     55   TfLiteTensor* strides;
     56   TfLiteTensor* output;
     57   int dims;
     58 };
     59 
     60 // Reverse order of bits in the mask to match the expected order in kernel
     61 inline int ReverseMaskBits(int mask, int num_dimensions) {
     62   int out = 0;
     63   for (int dim = 0; dim < num_dimensions; dim++) {
     64     out <<= 1;
     65     out += (mask & 1);
     66     mask >>= 1;
     67   }
     68   return out;
     69 }
     70 
     71 // This Op only supports 1-4D cases and since we use the reference 4D
     72 // implementation, the 1-3D tensors are mapped to 4D.
     73 const int kMaxDim = 4;
     74 
     75 inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) {
     76   return (divisor + (dividend % divisor)) % divisor;
     77 }
     78 
     79 inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
     80   return pos_stride
     81              ? (index >= dim ? dim
     82                              : PositiveRemainder(
     83                                    std::min(std::max(index, -dim), dim), dim))
     84              : (index < -dim
     85                     ? -1
     86                     : PositiveRemainder(
     87                           std::min(std::max(index, -dim), dim - 1), dim));
     88 }
     89 
     90 inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) {
     91   const int dim = op_context->input->dims->data[idx];
     92   const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
     93   return op_context->params->begin_mask & (1 << idx)
     94              ? pos_stride ? 0 : dim - 1
     95              : ClampedIndex(GetTensorData<int32_t>(op_context->begin)[idx], dim,
     96                             pos_stride);
     97 }
     98 
     99 inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) {
    100   const int dim = op_context->input->dims->data[idx];
    101   const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
    102   return op_context->params->end_mask & (1 << idx)
    103              ? pos_stride ? dim : -1
    104              : ClampedIndex(GetTensorData<int32_t>(op_context->end)[idx], dim,
    105                             pos_stride);
    106 }
    107 
    108 // Processes the indexing tensors (begin, end and strides) to resize the
    109 // output tensor. This function is callable from both Prepare() and Eval() as
    110 // long as the caller ensures the indexing tensors are present.
    111 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
    112                                 StridedSliceContext* op_context) {
    113   std::vector<int> output_shape_vector;
    114 
    115   for (int idx = op_context->dims - 1; idx >= 0; --idx) {
    116     int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
    117     TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
    118 
    119     int32_t begin = GetBeginValueAtIndex(op_context, idx);
    120     int32_t end = GetEndValueAtIndex(op_context, idx);
    121 
    122     // This is valid for both positive and negative strides
    123     int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
    124     dim_shape = dim_shape < 0 ? 0 : dim_shape;
    125     if (!(op_context->params->shrink_axis_mask & (1 << idx))) {
    126       output_shape_vector.push_back(dim_shape);
    127     }
    128   }
    129 
    130   TfLiteIntArray* output_shape =
    131       TfLiteIntArrayCreate(output_shape_vector.size());
    132 
    133   std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(),
    134                     output_shape->data);
    135 
    136   TF_LITE_ENSURE_STATUS(
    137       context->ResizeTensor(context, op_context->output, output_shape));
    138 
    139   return kTfLiteOk;
    140 }
    141 
    142 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    143   TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
    144   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
    145 
    146   StridedSliceContext op_context(context, node);
    147 
    148   // Ensure validity of input tensor and its dimension
    149   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
    150   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
    151   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
    152   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
    153   // Only INT32 begin/end/strides are supported
    154   // TODO(soroosh) add support for INT64
    155   TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
    156   TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
    157   TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
    158   TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
    159                      "StridedSlice op only supports 1D-4D input arrays.");
    160 
    161   // TODO(soroosh): add the following missing functionalities
    162   TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
    163                      "ellipsis_mask is not implemented yet.");
    164   TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
    165                      "new_axis_mask is not implemented yet.");
    166 
    167   // Postpone allocation of output if any of the indexing tensors is not
    168   // constant
    169   if (!(IsConstantTensor(op_context.begin) &&
    170         IsConstantTensor(op_context.end) &&
    171         IsConstantTensor(op_context.strides))) {
    172     SetTensorToDynamic(op_context.output);
    173     return kTfLiteOk;
    174   }
    175   return ResizeOutputTensor(context, &op_context);
    176 }
    177 
    178 template <KernelType kernel_type>
    179 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    180   StridedSliceContext op_context(context, node);
    181 
    182   if (IsDynamicTensor(op_context.output)) {
    183     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
    184   }
    185 
    186   std::vector<int32_t> starts;
    187   std::vector<int32_t> stops;
    188   std::vector<int32_t> strides;
    189 
    190   for (int idx = op_context.dims - 1; idx >= 0; --idx) {
    191     starts.emplace_back(GetBeginValueAtIndex(&op_context, idx));
    192     stops.emplace_back(GetEndValueAtIndex(&op_context, idx));
    193     strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
    194   }
    195 
    196   for (int i = op_context.dims; i < kMaxDim; i++) {
    197     starts.emplace_back(0);
    198     stops.emplace_back(1);
    199     strides.emplace_back(1);
    200   }
    201 
    202   op_context.params->begin_mask =
    203       ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
    204   op_context.params->end_mask =
    205       ReverseMaskBits(op_context.params->end_mask, op_context.dims);
    206   op_context.params->shrink_axis_mask =
    207       ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
    208 
    209 #define TF_LITE_STRIDED_SLICE(kernel_type, data_type)                      \
    210   kernel_type::StridedSlice(                                               \
    211       GetTensorData<data_type>(op_context.input),                          \
    212       GetTensorDims(op_context.input), op_context.params->begin_mask,      \
    213       op_context.params->end_mask, op_context.params->shrink_axis_mask,    \
    214       starts, stops, strides, GetTensorData<data_type>(op_context.output), \
    215       GetTensorDims(op_context.output))
    216 
    217   switch (op_context.input->type) {
    218     case kTfLiteFloat32:
    219       if (kernel_type == kReference) {
    220         TF_LITE_STRIDED_SLICE(reference_ops, float);
    221       }
    222       break;
    223     case kTfLiteInt32:
    224       if (kernel_type == kReference) {
    225         TF_LITE_STRIDED_SLICE(reference_ops, int32_t);
    226       }
    227       break;
    228     case kTfLiteInt64:
    229       if (kernel_type == kReference) {
    230         TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
    231       }
    232       break;
    233     default:
    234       context->ReportError(context,
    235                            "Type is currently not supported "
    236                            "by StridedSlice.");
    237       return kTfLiteError;
    238   }
    239 #undef TF_LITE_STRIDED_SLICE
    240   return kTfLiteOk;
    241 }
    242 
    243 }  // namespace strided_slice
    244 
    245 TfLiteRegistration* Register_STRIDED_SLICE_REF() {
    246   static TfLiteRegistration r = {
    247       nullptr, nullptr, strided_slice::Prepare,
    248       strided_slice::Eval<strided_slice::kReference>};
    249   return &r;
    250 }
    251 
    252 // TODO(soroosh): add optimized
    253 TfLiteRegistration* Register_STRIDED_SLICE() {
    254   return Register_STRIDED_SLICE_REF();
    255 }
    256 
    257 }  // namespace builtin
    258 }  // namespace ops
    259 }  // namespace tflite
    260