1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <string.h> 16 #include <cmath> 17 #include <vector> 18 #include "tensorflow/contrib/lite/builtin_op_data.h" 19 #include "tensorflow/contrib/lite/context.h" 20 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" 21 #include "tensorflow/contrib/lite/kernels/internal/tensor.h" 22 #include "tensorflow/contrib/lite/kernels/kernel_util.h" 23 #include "tensorflow/contrib/lite/kernels/op_macros.h" 24 25 namespace tflite { 26 namespace ops { 27 namespace builtin { 28 namespace strided_slice { 29 30 enum KernelType { 31 kReference, 32 // TODO(soroosh): add kGenericOptimized 33 }; 34 35 constexpr int kInputTensor = 0; 36 constexpr int kBeginTensor = 1; 37 constexpr int kEndTensor = 2; 38 constexpr int kStridesTensor = 3; 39 constexpr int kOutputTensor = 0; 40 41 struct StridedSliceContext { 42 StridedSliceContext(TfLiteContext* context, TfLiteNode* node) { 43 params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data); 44 input = GetInput(context, node, kInputTensor); 45 begin = GetInput(context, node, kBeginTensor); 46 end = GetInput(context, node, kEndTensor); 47 strides = GetInput(context, node, kStridesTensor); 48 output = GetOutput(context, node, kOutputTensor); 49 dims = NumDimensions(input); 50 } 51 TfLiteStridedSliceParams* params; 52 TfLiteTensor* input; 53 TfLiteTensor* begin; 54 TfLiteTensor* end; 55 TfLiteTensor* strides; 56 TfLiteTensor* output; 57 int dims; 58 }; 59 60 // Reverse order of bits in the mask to match the expected order in kernel 61 inline int ReverseMaskBits(int mask, int num_dimensions) { 62 int out = 0; 63 for (int dim = 0; dim < num_dimensions; dim++) { 64 out <<= 1; 65 out += (mask & 1); 66 mask >>= 1; 67 } 68 return out; 69 } 70 71 // This Op only supports 1-4D cases and since we use the reference 4D 72 // implementation, the 1-3D tensors are mapped to 4D. 73 const int kMaxDim = 4; 74 75 inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) { 76 return (divisor + (dividend % divisor)) % divisor; 77 } 78 79 inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { 80 return pos_stride 81 ? (index >= dim ? dim 82 : PositiveRemainder( 83 std::min(std::max(index, -dim), dim), dim)) 84 : (index < -dim 85 ? -1 86 : PositiveRemainder( 87 std::min(std::max(index, -dim), dim - 1), dim)); 88 } 89 90 inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) { 91 const int dim = op_context->input->dims->data[idx]; 92 const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0; 93 return op_context->params->begin_mask & (1 << idx) 94 ? pos_stride ? 0 : dim - 1 95 : ClampedIndex(GetTensorData<int32_t>(op_context->begin)[idx], dim, 96 pos_stride); 97 } 98 99 inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) { 100 const int dim = op_context->input->dims->data[idx]; 101 const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0; 102 return op_context->params->end_mask & (1 << idx) 103 ? pos_stride ? dim : -1 104 : ClampedIndex(GetTensorData<int32_t>(op_context->end)[idx], dim, 105 pos_stride); 106 } 107 108 // Processes the indexing tensors (begin, end and strides) to resize the 109 // output tensor. This function is callable from both Prepare() and Eval() as 110 // long as the caller ensures the indexing tensors are present. 111 TfLiteStatus ResizeOutputTensor(TfLiteContext* context, 112 StridedSliceContext* op_context) { 113 std::vector<int> output_shape_vector; 114 115 for (int idx = op_context->dims - 1; idx >= 0; --idx) { 116 int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx]; 117 TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); 118 119 int32_t begin = GetBeginValueAtIndex(op_context, idx); 120 int32_t end = GetEndValueAtIndex(op_context, idx); 121 122 // This is valid for both positive and negative strides 123 int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride)); 124 dim_shape = dim_shape < 0 ? 0 : dim_shape; 125 if (!(op_context->params->shrink_axis_mask & (1 << idx))) { 126 output_shape_vector.push_back(dim_shape); 127 } 128 } 129 130 TfLiteIntArray* output_shape = 131 TfLiteIntArrayCreate(output_shape_vector.size()); 132 133 std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(), 134 output_shape->data); 135 136 TF_LITE_ENSURE_STATUS( 137 context->ResizeTensor(context, op_context->output, output_shape)); 138 139 return kTfLiteOk; 140 } 141 142 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 143 TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); 144 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 145 146 StridedSliceContext op_context(context, node); 147 148 // Ensure validity of input tensor and its dimension 149 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); 150 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); 151 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); 152 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); 153 // Only INT32 begin/end/strides are supported 154 // TODO(soroosh) add support for INT64 155 TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); 156 TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); 157 TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); 158 TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, 159 "StridedSlice op only supports 1D-4D input arrays."); 160 161 // TODO(soroosh): add the following missing functionalities 162 TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, 163 "ellipsis_mask is not implemented yet."); 164 TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, 165 "new_axis_mask is not implemented yet."); 166 167 // Postpone allocation of output if any of the indexing tensors is not 168 // constant 169 if (!(IsConstantTensor(op_context.begin) && 170 IsConstantTensor(op_context.end) && 171 IsConstantTensor(op_context.strides))) { 172 SetTensorToDynamic(op_context.output); 173 return kTfLiteOk; 174 } 175 return ResizeOutputTensor(context, &op_context); 176 } 177 178 template <KernelType kernel_type> 179 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 180 StridedSliceContext op_context(context, node); 181 182 if (IsDynamicTensor(op_context.output)) { 183 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); 184 } 185 186 std::vector<int32_t> starts; 187 std::vector<int32_t> stops; 188 std::vector<int32_t> strides; 189 190 for (int idx = op_context.dims - 1; idx >= 0; --idx) { 191 starts.emplace_back(GetBeginValueAtIndex(&op_context, idx)); 192 stops.emplace_back(GetEndValueAtIndex(&op_context, idx)); 193 strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]); 194 } 195 196 for (int i = op_context.dims; i < kMaxDim; i++) { 197 starts.emplace_back(0); 198 stops.emplace_back(1); 199 strides.emplace_back(1); 200 } 201 202 op_context.params->begin_mask = 203 ReverseMaskBits(op_context.params->begin_mask, op_context.dims); 204 op_context.params->end_mask = 205 ReverseMaskBits(op_context.params->end_mask, op_context.dims); 206 op_context.params->shrink_axis_mask = 207 ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); 208 209 #define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ 210 kernel_type::StridedSlice( \ 211 GetTensorData<data_type>(op_context.input), \ 212 GetTensorDims(op_context.input), op_context.params->begin_mask, \ 213 op_context.params->end_mask, op_context.params->shrink_axis_mask, \ 214 starts, stops, strides, GetTensorData<data_type>(op_context.output), \ 215 GetTensorDims(op_context.output)) 216 217 switch (op_context.input->type) { 218 case kTfLiteFloat32: 219 if (kernel_type == kReference) { 220 TF_LITE_STRIDED_SLICE(reference_ops, float); 221 } 222 break; 223 case kTfLiteInt32: 224 if (kernel_type == kReference) { 225 TF_LITE_STRIDED_SLICE(reference_ops, int32_t); 226 } 227 break; 228 case kTfLiteInt64: 229 if (kernel_type == kReference) { 230 TF_LITE_STRIDED_SLICE(reference_ops, int64_t); 231 } 232 break; 233 default: 234 context->ReportError(context, 235 "Type is currently not supported " 236 "by StridedSlice."); 237 return kTfLiteError; 238 } 239 #undef TF_LITE_STRIDED_SLICE 240 return kTfLiteOk; 241 } 242 243 } // namespace strided_slice 244 245 TfLiteRegistration* Register_STRIDED_SLICE_REF() { 246 static TfLiteRegistration r = { 247 nullptr, nullptr, strided_slice::Prepare, 248 strided_slice::Eval<strided_slice::kReference>}; 249 return &r; 250 } 251 252 // TODO(soroosh): add optimized 253 TfLiteRegistration* Register_STRIDED_SLICE() { 254 return Register_STRIDED_SLICE_REF(); 255 } 256 257 } // namespace builtin 258 } // namespace ops 259 } // namespace tflite 260