Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/strided_slice_op.h"
     17 
     18 #include <array>
     19 #include "tensorflow/core/kernels/bounds_check.h"
     20 #include "tensorflow/core/lib/core/status.h"
     21 
     22 namespace tensorflow {
     23 namespace {
     24 
     25 /// Constants
     26 constexpr int32 kShrinkAxis = -1, kNewAxis = -2;
     27 
     28 // Sparse slicing specification
     29 // if one does foo[3:5, ..., -3], this will have 3 length tensors
     30 struct StridedSliceSparseSpec {
     31   int64 dims;
     32   int32 num_add_axis_after_ellipsis;
     33   const Tensor* begin_tensor;
     34   const Tensor* end_tensor;
     35   const Tensor& strides_tensor;
     36   const int32 begin_mask, end_mask;
     37   int32 ellipsis_mask;
     38   const int32 new_axis_mask, shrink_axis_mask;
     39 };
     40 
     41 // Dense slicing specification
     42 // all ellipses and newaxis' are expanded out. So if
     43 // foo[3:5, ..., -3] where foo is 10 dimensional,
     44 // each inlinedVector will have 10 entries whereas the
     45 // sparse had 3 length tensors.
     46 struct StridedSliceDenseSpec {
     47   const int64 dims;
     48   int32 begin_mask;
     49   int32 end_mask;
     50   bool begin_valid;
     51   bool end_valid;
     52   gtl::InlinedVector<int64, 4>& begin;
     53   gtl::InlinedVector<int64, 4>& end;
     54   gtl::InlinedVector<int64, 4>& strides;
     55   // This vector helps construct the final shape of the slice.
     56   // The final tensor is reduced in rank whenever a single index e.g. foo[3]
     57   // is called for. The final tensor increases in rank with tf.newaxis
     58   // entries. If an index in this array is positive, the size of the dimension
     59   // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
     60   // it will be 1. A shrunk dimension is skipped.
     61   gtl::InlinedVector<int32, 4> final_shape_gather_indices;
     62   // The dense indexed shrink mask is which processing dimensions
     63   // should be shrunk. For example, if foo.shape = (10,10,10,10)
     64   // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
     65   // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
     66   int32 shrink_axis_mask;
     67 };
     68 
     69 }  // namespace
     70 
     71 template <class T>
     72 static Status TF_MUST_USE_RESULT BuildDenseSpec(
     73     const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
     74   // Build expanded begin, end, strides, begin_mask, end_mask
     75   // to remove any ellipsis
     76   dense->begin.resize(dense->dims);
     77   dense->end.resize(dense->dims);
     78   dense->strides.resize(dense->dims);
     79   // What indices to get the final shape from.
     80   dense->begin_mask = 0;
     81   dense->end_mask = 0;
     82   dense->shrink_axis_mask = 0;
     83   {
     84     int full_index = 0;
     85 
     86     const auto& strides_flat = sparse.strides_tensor.flat<T>();
     87     dense->begin_valid = sparse.begin_tensor != nullptr;
     88     dense->end_valid = sparse.end_tensor != nullptr;
     89 
     90     for (int i = 0; i < sparse.dims; i++) {
     91       if ((1 << i) & sparse.ellipsis_mask) {
     92         // Expand the ellipsis into the appropriate indices
     93         // NOTE: this only works because we guaranteed one ellipsis
     94         int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
     95                                         sparse.num_add_axis_after_ellipsis,
     96                                     dense->dims);
     97         for (; full_index < next_index; full_index++) {
     98           // new_axis' aren't real axis so you have to skip
     99           dense->begin[full_index] = dense->end[full_index] = 0;
    100           dense->strides[full_index] = 1;
    101           dense->begin_mask |= (1 << full_index);
    102           dense->end_mask |= (1 << full_index);
    103           dense->final_shape_gather_indices.push_back(full_index);
    104         }
    105       } else if ((1 << i) & sparse.new_axis_mask) {
    106         dense->final_shape_gather_indices.push_back(kNewAxis);
    107       } else {
    108         if (full_index == dense->begin.size()) {
    109           return errors::InvalidArgument("Index out of range using input dim ",
    110                                          full_index, "; input has only ",
    111                                          dense->dims, " dims");
    112         }
    113 
    114         // Gather slicing spec into appropriate index
    115         if (sparse.begin_tensor != nullptr) {
    116           const auto& begin_flat = sparse.begin_tensor->flat<T>();
    117           dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i));
    118         }
    119         if (sparse.end_tensor != nullptr) {
    120           const auto& end_flat = sparse.end_tensor->flat<T>();
    121           dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i));
    122         }
    123         dense->strides[full_index] =
    124             internal::SubtleMustCopy<T>(strides_flat(i));
    125         if (sparse.begin_mask & (1 << i)) {
    126           dense->begin_mask |= (1 << full_index);
    127         }
    128         if (sparse.end_mask & (1 << i)) {
    129           dense->end_mask |= (1 << full_index);
    130         }
    131         // If shrink, record where to get the dimensionality from (i.e.
    132         // new_axis creates a fake 1 size dimension. Also remember shrink
    133         // axis (now in dense form) so we can ignore dense->end below.
    134         if (sparse.shrink_axis_mask & (1 << i)) {
    135           dense->final_shape_gather_indices.push_back(kShrinkAxis);
    136           dense->shrink_axis_mask |= (1 << full_index);
    137         } else {
    138           dense->final_shape_gather_indices.push_back(full_index);
    139         }
    140         full_index++;
    141       }
    142     }
    143   }
    144   return Status::OK();
    145 }
    146 
    147 Status ValidateStridedSliceOp(
    148     const Tensor* begin_tensor, const Tensor* end_tensor,
    149     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
    150     int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
    151     int32 new_axis_mask, int32 shrink_axis_mask,
    152     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
    153     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
    154     gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
    155     gtl::InlinedVector<int64, 4>* strides) {
    156   const bool begin_is_wrong =
    157       begin_tensor != nullptr &&
    158       !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
    159         begin_tensor->NumElements() == strides_tensor.NumElements() &&
    160         begin_tensor->NumElements() < 32 /* using 32 bit masks */);
    161   const bool end_is_wrong =
    162       end_tensor != nullptr &&
    163       !(TensorShapeUtils::IsVector(end_tensor->shape()) &&
    164         end_tensor->NumElements() == strides_tensor.NumElements());
    165   if (begin_is_wrong || end_is_wrong ||
    166       !TensorShapeUtils::IsVector(strides_tensor.shape())) {
    167     if (begin_tensor != nullptr && end_tensor != nullptr) {
    168       return errors::InvalidArgument(
    169           "Expected begin, end, and strides to be 1D equal size tensors, ",
    170           "but got shapes ", begin_tensor->shape().DebugString(), ", ",
    171           end_tensor->shape().DebugString(), ", and ",
    172           strides_tensor.shape().DebugString(), " instead.");
    173     } else {
    174       return errors::InvalidArgument(
    175           "Expected begin, end, and strides to be 1D equal size tensors, ",
    176           "but got shape ", strides_tensor.shape().DebugString(),
    177           " for strides.");
    178     }
    179   }
    180   // Use bit compares to ensure ellipsis_mask is 0 or a power of 2
    181   // i.e. there exists only no more than one ellipsis
    182   if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
    183     return errors::InvalidArgument(
    184         "Multiple ellipses in slice spec not allowed");
    185   }
    186 
    187   // Step 1: Account for ellipsis and new axis
    188   //
    189   // Check for ellipses and count how many non-newaxis' there are after
    190   // TODO(aselle): Convert this to do a fast log2 followed by iteration
    191   //               counting ones in next guys
    192   bool ellipsis_seen = false;
    193 
    194   StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
    195                                         0,
    196                                         begin_tensor,
    197                                         end_tensor,
    198                                         strides_tensor,
    199                                         begin_mask_spec,
    200                                         end_mask_spec,
    201                                         ellipsis_mask,
    202                                         new_axis_mask,
    203                                         shrink_axis_mask};
    204 
    205   for (int32 i = 0; i < sparse_spec.dims; i++) {
    206     if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
    207       sparse_spec.num_add_axis_after_ellipsis++;
    208     }
    209     if ((1 << i) & ellipsis_mask) {
    210       ellipsis_seen = true;
    211     }
    212   }
    213   // If no ellipsis insert one at the end
    214   if (!ellipsis_seen) {
    215     sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
    216     sparse_spec.dims++;  // this effects loop iteration below
    217   }
    218 
    219   // Step 2: Make a sparse spec into a full index spec
    220   //
    221   // The sparse spec does not correspond to the number of dimensions
    222   // Make a dense spec that corresponds to the number of dimensions
    223   //
    224   // For example suppose foo[...,3:] on foo.shape=(2,2,3) then
    225   // we need to produce the missing begin_mask for the first two
    226   // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
    227   // we achieve begin_mask=6, end_mask=7
    228   StridedSliceDenseSpec dense_spec = {input_shape.dims(),
    229                                       0 /* begin_mask */,
    230                                       0 /* end_mask */,
    231                                       false /* begin_valid */,
    232                                       false /* end_valid */,
    233                                       *begin,
    234                                       *end,
    235                                       *strides};
    236 
    237   if (strides_tensor.dtype() == DT_INT32) {
    238     TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
    239   } else if (strides_tensor.dtype() == DT_INT64) {
    240     TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec));
    241   } else {
    242     LOG(FATAL) << "begin must be either int32 or int64";
    243   }
    244 
    245   // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
    246   //         and bounds check!
    247   *is_identity = true;
    248   *slice_dim0 = true;
    249   *is_simple_slice = true;
    250   processing_shape->Clear();
    251   for (int i = 0; i < input_shape.dims(); ++i) {
    252     int64& begin_i = (*begin)[i];
    253     int64& end_i = (*end)[i];
    254     int64& stride_i = (*strides)[i];
    255     int64 dim_i = input_shape.dim_size(i);
    256     if (stride_i == 0) {
    257       return errors::InvalidArgument("strides[", i, "] must be non-zero");
    258     }
    259     bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
    260     if (dim_i == -1) {
    261       processing_shape->AddDim(shrink_i ? 1 : -1);
    262       continue;
    263     }
    264 
    265     const std::array<int64, 2> masks = {
    266         {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
    267     const std::array<int64, 2> valid_range = {
    268         {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
    269 
    270     auto canonical = [stride_i, i, dim_i, masks, valid_range](int64 x, int c) {
    271       if (masks[c]) {
    272         return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
    273       } else {
    274         int64 x_fwd = x < 0 ? dim_i + x : x;  // make negative indices positive
    275         return x_fwd < valid_range[0]
    276                    ? valid_range[0]
    277                    : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
    278       }
    279     };
    280     if (shrink_i && stride_i <= 0) {
    281       return errors::InvalidArgument(
    282           "only stride 1 allowed on non-range indexing.");
    283     }
    284     (*is_simple_slice) &= stride_i == 1;
    285 
    286     const bool begin_and_end_masked =
    287         (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
    288     if (dense_spec.begin_valid && dense_spec.end_valid) {
    289       if (shrink_i) {
    290         // If we are shrinking, the end index is now possibly incorrect. In
    291         // particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
    292         // and canonical puts these to n-1 and 0, which implies a degenerate
    293         // interval. Fortunately, it is now safe to re-create end as begin+1.
    294         int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
    295         begin_i = x_fwd;
    296         end_i = begin_i + 1;
    297         if (x_fwd < 0 || x_fwd >= dim_i) {
    298           return errors::InvalidArgument(
    299               "slice index ", begin_i, " of dimension ", i, " out of bounds.");
    300         }
    301       } else {
    302         begin_i = canonical(begin_i, 0);
    303         end_i = canonical(end_i, 1);
    304       }
    305       // Update optimization values
    306       bool take_all_in_dimension =
    307           stride_i == 1 && begin_i == 0 && end_i == dim_i;
    308       (*is_identity) &= take_all_in_dimension;
    309       (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
    310     } else {
    311       (*is_identity) &= stride_i == 1 && begin_and_end_masked;
    312       (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
    313     }
    314     // Compute the processing shape (the intermediate Eigen will produce)
    315     int64 interval_length;
    316     bool known_interval = false;
    317     if (dense_spec.begin_valid && dense_spec.end_valid) {
    318       interval_length = end_i - begin_i;
    319       known_interval = true;
    320     } else if (shrink_i) {
    321       // The dimension is still known as 1 for the processing_shape, but will be
    322       // discarded for the final shape.
    323       interval_length = 1;
    324       known_interval = true;
    325     } else if (begin_and_end_masked) {
    326       // Even if we don't have values for begin or end, we do know that this
    327       // dimension covers the whole interval. If we have shape information for
    328       // this dimension, that tells us the interval length.
    329       if (dim_i > 0) {
    330         if (stride_i < 0) {
    331           interval_length = -dim_i;
    332         } else {
    333           interval_length = dim_i;
    334         }
    335         known_interval = true;
    336       }
    337     }
    338     if (known_interval) {
    339       int64 size_i;
    340       // Hold zero if the interval is degenerate, otherwise account for
    341       // remainder
    342       if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
    343         size_i = 0;
    344       } else {
    345         size_i = interval_length / stride_i +
    346                  (interval_length % stride_i != 0 ? 1 : 0);
    347       }
    348       processing_shape->AddDim(size_i);
    349     } else {
    350       processing_shape->AddDim(-1);
    351     }
    352   }
    353 
    354   // Step 4: Compute the final shape
    355   //
    356   // new_axis will increase dimension by 1 (with a one-size dimension)
    357   // slices like foo[3,...] will reduce dimension by 1.
    358   // This cannot be done earlier, because it depends on Step 3.
    359   final_shape->Clear();
    360   for (auto gather_index : dense_spec.final_shape_gather_indices) {
    361     if (gather_index >= 0) {
    362       final_shape->AddDim(processing_shape->dim_size(gather_index));
    363     } else if (gather_index == kNewAxis) {
    364       final_shape->AddDim(1);
    365     }
    366   }
    367   return Status::OK();
    368 }
    369 
    370 Status ValidateStridedSliceOp(
    371     const Tensor* begin_tensor, const Tensor* end_tensor,
    372     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
    373     int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
    374     int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
    375     TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
    376     bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
    377     gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
    378   // Validate with PartialTensorShape output
    379   PartialTensorShape partial_processing_shape, partial_final_shape;
    380   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
    381       begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
    382       end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
    383       &partial_processing_shape, &partial_final_shape, is_identity,
    384       is_simple_slice, slice_dim0, begin, end, strides));
    385 
    386   // Verify that the output shapes are fully known
    387   if (!partial_processing_shape.AsTensorShape(processing_shape) ||
    388       !partial_final_shape.AsTensorShape(final_shape)) {
    389     return errors::Internal("ValidateStridedSliceOp returned partial shapes ",
    390                             partial_processing_shape.DebugString(), " and ",
    391                             partial_final_shape.DebugString());
    392   }
    393   return Status::OK();
    394 }
    395 
    396 }  // namespace tensorflow
    397