1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/util/strided_slice_op.h" 17 18 #include <array> 19 #include "tensorflow/core/kernels/bounds_check.h" 20 #include "tensorflow/core/lib/core/status.h" 21 22 namespace tensorflow { 23 namespace { 24 25 /// Constants 26 constexpr int32 kShrinkAxis = -1, kNewAxis = -2; 27 28 // Sparse slicing specification 29 // if one does foo[3:5, ..., -3], this will have 3 length tensors 30 struct StridedSliceSparseSpec { 31 int64 dims; 32 int32 num_add_axis_after_ellipsis; 33 const Tensor* begin_tensor; 34 const Tensor* end_tensor; 35 const Tensor& strides_tensor; 36 const int32 begin_mask, end_mask; 37 int32 ellipsis_mask; 38 const int32 new_axis_mask, shrink_axis_mask; 39 }; 40 41 // Dense slicing specification 42 // all ellipses and newaxis' are expanded out. So if 43 // foo[3:5, ..., -3] where foo is 10 dimensional, 44 // each inlinedVector will have 10 entries whereas the 45 // sparse had 3 length tensors. 46 struct StridedSliceDenseSpec { 47 const int64 dims; 48 int32 begin_mask; 49 int32 end_mask; 50 bool begin_valid; 51 bool end_valid; 52 gtl::InlinedVector<int64, 4>& begin; 53 gtl::InlinedVector<int64, 4>& end; 54 gtl::InlinedVector<int64, 4>& strides; 55 // This vector helps construct the final shape of the slice. 56 // The final tensor is reduced in rank whenever a single index e.g. foo[3] 57 // is called for. The final tensor increases in rank with tf.newaxis 58 // entries. If an index in this array is positive, the size of the dimension 59 // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis, 60 // it will be 1. A shrunk dimension is skipped. 61 gtl::InlinedVector<int32, 4> final_shape_gather_indices; 62 // The dense indexed shrink mask is which processing dimensions 63 // should be shrunk. For example, if foo.shape = (10,10,10,10) 64 // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and 65 // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10). 66 int32 shrink_axis_mask; 67 }; 68 69 } // namespace 70 71 template <class T> 72 static Status TF_MUST_USE_RESULT BuildDenseSpec( 73 const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) { 74 // Build expanded begin, end, strides, begin_mask, end_mask 75 // to remove any ellipsis 76 dense->begin.resize(dense->dims); 77 dense->end.resize(dense->dims); 78 dense->strides.resize(dense->dims); 79 // What indices to get the final shape from. 80 dense->begin_mask = 0; 81 dense->end_mask = 0; 82 dense->shrink_axis_mask = 0; 83 { 84 int full_index = 0; 85 86 const auto& strides_flat = sparse.strides_tensor.flat<T>(); 87 dense->begin_valid = sparse.begin_tensor != nullptr; 88 dense->end_valid = sparse.end_tensor != nullptr; 89 90 for (int i = 0; i < sparse.dims; i++) { 91 if ((1 << i) & sparse.ellipsis_mask) { 92 // Expand the ellipsis into the appropriate indices 93 // NOTE: this only works because we guaranteed one ellipsis 94 int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 + 95 sparse.num_add_axis_after_ellipsis, 96 dense->dims); 97 for (; full_index < next_index; full_index++) { 98 // new_axis' aren't real axis so you have to skip 99 dense->begin[full_index] = dense->end[full_index] = 0; 100 dense->strides[full_index] = 1; 101 dense->begin_mask |= (1 << full_index); 102 dense->end_mask |= (1 << full_index); 103 dense->final_shape_gather_indices.push_back(full_index); 104 } 105 } else if ((1 << i) & sparse.new_axis_mask) { 106 dense->final_shape_gather_indices.push_back(kNewAxis); 107 } else { 108 if (full_index == dense->begin.size()) { 109 return errors::InvalidArgument("Index out of range using input dim ", 110 full_index, "; input has only ", 111 dense->dims, " dims"); 112 } 113 114 // Gather slicing spec into appropriate index 115 if (sparse.begin_tensor != nullptr) { 116 const auto& begin_flat = sparse.begin_tensor->flat<T>(); 117 dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i)); 118 } 119 if (sparse.end_tensor != nullptr) { 120 const auto& end_flat = sparse.end_tensor->flat<T>(); 121 dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i)); 122 } 123 dense->strides[full_index] = 124 internal::SubtleMustCopy<T>(strides_flat(i)); 125 if (sparse.begin_mask & (1 << i)) { 126 dense->begin_mask |= (1 << full_index); 127 } 128 if (sparse.end_mask & (1 << i)) { 129 dense->end_mask |= (1 << full_index); 130 } 131 // If shrink, record where to get the dimensionality from (i.e. 132 // new_axis creates a fake 1 size dimension. Also remember shrink 133 // axis (now in dense form) so we can ignore dense->end below. 134 if (sparse.shrink_axis_mask & (1 << i)) { 135 dense->final_shape_gather_indices.push_back(kShrinkAxis); 136 dense->shrink_axis_mask |= (1 << full_index); 137 } else { 138 dense->final_shape_gather_indices.push_back(full_index); 139 } 140 full_index++; 141 } 142 } 143 } 144 return Status::OK(); 145 } 146 147 Status ValidateStridedSliceOp( 148 const Tensor* begin_tensor, const Tensor* end_tensor, 149 const Tensor& strides_tensor, const PartialTensorShape& input_shape, 150 int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask, 151 int32 new_axis_mask, int32 shrink_axis_mask, 152 PartialTensorShape* processing_shape, PartialTensorShape* final_shape, 153 bool* is_identity, bool* is_simple_slice, bool* slice_dim0, 154 gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end, 155 gtl::InlinedVector<int64, 4>* strides) { 156 const bool begin_is_wrong = 157 begin_tensor != nullptr && 158 !(TensorShapeUtils::IsVector(begin_tensor->shape()) && 159 begin_tensor->NumElements() == strides_tensor.NumElements() && 160 begin_tensor->NumElements() < 32 /* using 32 bit masks */); 161 const bool end_is_wrong = 162 end_tensor != nullptr && 163 !(TensorShapeUtils::IsVector(end_tensor->shape()) && 164 end_tensor->NumElements() == strides_tensor.NumElements()); 165 if (begin_is_wrong || end_is_wrong || 166 !TensorShapeUtils::IsVector(strides_tensor.shape())) { 167 if (begin_tensor != nullptr && end_tensor != nullptr) { 168 return errors::InvalidArgument( 169 "Expected begin, end, and strides to be 1D equal size tensors, ", 170 "but got shapes ", begin_tensor->shape().DebugString(), ", ", 171 end_tensor->shape().DebugString(), ", and ", 172 strides_tensor.shape().DebugString(), " instead."); 173 } else { 174 return errors::InvalidArgument( 175 "Expected begin, end, and strides to be 1D equal size tensors, ", 176 "but got shape ", strides_tensor.shape().DebugString(), 177 " for strides."); 178 } 179 } 180 // Use bit compares to ensure ellipsis_mask is 0 or a power of 2 181 // i.e. there exists only no more than one ellipsis 182 if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) { 183 return errors::InvalidArgument( 184 "Multiple ellipses in slice spec not allowed"); 185 } 186 187 // Step 1: Account for ellipsis and new axis 188 // 189 // Check for ellipses and count how many non-newaxis' there are after 190 // TODO(aselle): Convert this to do a fast log2 followed by iteration 191 // counting ones in next guys 192 bool ellipsis_seen = false; 193 194 StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(), 195 0, 196 begin_tensor, 197 end_tensor, 198 strides_tensor, 199 begin_mask_spec, 200 end_mask_spec, 201 ellipsis_mask, 202 new_axis_mask, 203 shrink_axis_mask}; 204 205 for (int32 i = 0; i < sparse_spec.dims; i++) { 206 if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { 207 sparse_spec.num_add_axis_after_ellipsis++; 208 } 209 if ((1 << i) & ellipsis_mask) { 210 ellipsis_seen = true; 211 } 212 } 213 // If no ellipsis insert one at the end 214 if (!ellipsis_seen) { 215 sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); 216 sparse_spec.dims++; // this effects loop iteration below 217 } 218 219 // Step 2: Make a sparse spec into a full index spec 220 // 221 // The sparse spec does not correspond to the number of dimensions 222 // Make a dense spec that corresponds to the number of dimensions 223 // 224 // For example suppose foo[...,3:] on foo.shape=(2,2,3) then 225 // we need to produce the missing begin_mask for the first two 226 // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2 227 // we achieve begin_mask=6, end_mask=7 228 StridedSliceDenseSpec dense_spec = {input_shape.dims(), 229 0 /* begin_mask */, 230 0 /* end_mask */, 231 false /* begin_valid */, 232 false /* end_valid */, 233 *begin, 234 *end, 235 *strides}; 236 237 if (strides_tensor.dtype() == DT_INT32) { 238 TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec)); 239 } else if (strides_tensor.dtype() == DT_INT64) { 240 TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec)); 241 } else { 242 LOG(FATAL) << "begin must be either int32 or int64"; 243 } 244 245 // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit 246 // and bounds check! 247 *is_identity = true; 248 *slice_dim0 = true; 249 *is_simple_slice = true; 250 processing_shape->Clear(); 251 for (int i = 0; i < input_shape.dims(); ++i) { 252 int64& begin_i = (*begin)[i]; 253 int64& end_i = (*end)[i]; 254 int64& stride_i = (*strides)[i]; 255 int64 dim_i = input_shape.dim_size(i); 256 if (stride_i == 0) { 257 return errors::InvalidArgument("strides[", i, "] must be non-zero"); 258 } 259 bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i)); 260 if (dim_i == -1) { 261 processing_shape->AddDim(shrink_i ? 1 : -1); 262 continue; 263 } 264 265 const std::array<int64, 2> masks = { 266 {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}}; 267 const std::array<int64, 2> valid_range = { 268 {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}}; 269 270 auto canonical = [stride_i, i, dim_i, masks, valid_range](int64 x, int c) { 271 if (masks[c]) { 272 return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; 273 } else { 274 int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive 275 return x_fwd < valid_range[0] 276 ? valid_range[0] 277 : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; 278 } 279 }; 280 if (shrink_i && stride_i <= 0) { 281 return errors::InvalidArgument( 282 "only stride 1 allowed on non-range indexing."); 283 } 284 (*is_simple_slice) &= stride_i == 1; 285 286 const bool begin_and_end_masked = 287 (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i)); 288 if (dense_spec.begin_valid && dense_spec.end_valid) { 289 if (shrink_i) { 290 // If we are shrinking, the end index is now possibly incorrect. In 291 // particular foo[-1] produces sparse_begin = -1, sparse_end = 0. 292 // and canonical puts these to n-1 and 0, which implies a degenerate 293 // interval. Fortunately, it is now safe to re-create end as begin+1. 294 int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; 295 begin_i = x_fwd; 296 end_i = begin_i + 1; 297 if (x_fwd < 0 || x_fwd >= dim_i) { 298 return errors::InvalidArgument( 299 "slice index ", begin_i, " of dimension ", i, " out of bounds."); 300 } 301 } else { 302 begin_i = canonical(begin_i, 0); 303 end_i = canonical(end_i, 1); 304 } 305 // Update optimization values 306 bool take_all_in_dimension = 307 stride_i == 1 && begin_i == 0 && end_i == dim_i; 308 (*is_identity) &= take_all_in_dimension; 309 (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension; 310 } else { 311 (*is_identity) &= stride_i == 1 && begin_and_end_masked; 312 (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked; 313 } 314 // Compute the processing shape (the intermediate Eigen will produce) 315 int64 interval_length; 316 bool known_interval = false; 317 if (dense_spec.begin_valid && dense_spec.end_valid) { 318 interval_length = end_i - begin_i; 319 known_interval = true; 320 } else if (shrink_i) { 321 // The dimension is still known as 1 for the processing_shape, but will be 322 // discarded for the final shape. 323 interval_length = 1; 324 known_interval = true; 325 } else if (begin_and_end_masked) { 326 // Even if we don't have values for begin or end, we do know that this 327 // dimension covers the whole interval. If we have shape information for 328 // this dimension, that tells us the interval length. 329 if (dim_i > 0) { 330 if (stride_i < 0) { 331 interval_length = -dim_i; 332 } else { 333 interval_length = dim_i; 334 } 335 known_interval = true; 336 } 337 } 338 if (known_interval) { 339 int64 size_i; 340 // Hold zero if the interval is degenerate, otherwise account for 341 // remainder 342 if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) { 343 size_i = 0; 344 } else { 345 size_i = interval_length / stride_i + 346 (interval_length % stride_i != 0 ? 1 : 0); 347 } 348 processing_shape->AddDim(size_i); 349 } else { 350 processing_shape->AddDim(-1); 351 } 352 } 353 354 // Step 4: Compute the final shape 355 // 356 // new_axis will increase dimension by 1 (with a one-size dimension) 357 // slices like foo[3,...] will reduce dimension by 1. 358 // This cannot be done earlier, because it depends on Step 3. 359 final_shape->Clear(); 360 for (auto gather_index : dense_spec.final_shape_gather_indices) { 361 if (gather_index >= 0) { 362 final_shape->AddDim(processing_shape->dim_size(gather_index)); 363 } else if (gather_index == kNewAxis) { 364 final_shape->AddDim(1); 365 } 366 } 367 return Status::OK(); 368 } 369 370 Status ValidateStridedSliceOp( 371 const Tensor* begin_tensor, const Tensor* end_tensor, 372 const Tensor& strides_tensor, const PartialTensorShape& input_shape, 373 int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask, 374 int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape, 375 TensorShape* final_shape, bool* is_identity, bool* is_simple_slice, 376 bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin, 377 gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) { 378 // Validate with PartialTensorShape output 379 PartialTensorShape partial_processing_shape, partial_final_shape; 380 TF_RETURN_IF_ERROR(ValidateStridedSliceOp( 381 begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec, 382 end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask, 383 &partial_processing_shape, &partial_final_shape, is_identity, 384 is_simple_slice, slice_dim0, begin, end, strides)); 385 386 // Verify that the output shapes are fully known 387 if (!partial_processing_shape.AsTensorShape(processing_shape) || 388 !partial_final_shape.AsTensorShape(final_shape)) { 389 return errors::Internal("ValidateStridedSliceOp returned partial shapes ", 390 partial_processing_shape.DebugString(), " and ", 391 partial_final_shape.DebugString()); 392 } 393 return Status::OK(); 394 } 395 396 } // namespace tensorflow 397