Home | History | Annotate | Download | only in sparse
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
     17 #define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
     18 
     19 #include <limits>
     20 #include <numeric>
     21 #include <vector>
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_types.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/framework/types.pb.h"
     28 #include "tensorflow/core/kernels/bounds_check.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/lib/strings/str_util.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/types.h"
     33 #include "tensorflow/core/util/sparse/dim_comparator.h"
     34 #include "tensorflow/core/util/sparse/group_iterator.h"
     35 
     36 namespace tensorflow {
     37 namespace sparse {
     38 
     39 class SparseTensor {
     40  public:
     41   typedef typename gtl::ArraySlice<int64> VarDimArray;
     42   typedef typename gtl::InlinedVector<int64, 8> ShapeArray;
     43 
     44   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
     45       : SparseTensor(ix, vals, TensorShapeToVector(shape),
     46                      UndefinedOrder(TensorShapeToVector(shape))) {}
     47 
     48   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
     49       : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
     50 
     51   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
     52                const VarDimArray order)
     53       : SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
     54 
     55   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
     56                const VarDimArray order)
     57       : ix_(ix),
     58         vals_(vals),
     59         shape_(shape.begin(), shape.end()),
     60         order_(order.begin(), order.end()),
     61         dims_(GetDimsFromIx(ix)) {
     62     CHECK_EQ(ix.dtype(), DT_INT64)
     63         << "indices must be type int64 but got: " << ix.dtype();
     64     CHECK(TensorShapeUtils::IsVector(vals.shape()))
     65         << "vals must be a vec, but got: " << vals.shape().DebugString();
     66     CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
     67         << "indices and values rows (indexing dimension) must match.";
     68     CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
     69     CHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
     70   }
     71 
     72   SparseTensor(const SparseTensor& other)
     73       : SparseTensor(other.ix_, other.vals_, other.shape_, other.order_) {}
     74 
     75   SparseTensor(SparseTensor&& other)
     76       : SparseTensor(std::move(other.ix_), std::move(other.vals_),
     77                      std::move(other.shape_), std::move(other.order_)) {}
     78 
     79   SparseTensor& operator=(const SparseTensor& other) {
     80     ix_ = other.ix_;
     81     vals_ = other.vals_;
     82     shape_ = other.shape_;
     83     order_ = other.order_;
     84     return *this;
     85   }
     86 
     87   std::size_t num_entries() const { return ix_.dim_size(0); }
     88 
     89   int dims() const { return shape_.size(); }
     90 
     91   const Tensor& indices() const { return ix_; }
     92 
     93   const Tensor& values() const { return vals_; }
     94 
     95   DataType dtype() const { return vals_.dtype(); }
     96 
     97   Status IndicesValid() const {
     98     const auto ix_t = ix_.matrix<int64>();
     99     for (int64 ord : order_) {
    100       if (ord < 0) {
    101         return errors::FailedPrecondition(
    102             "Order was not provided.  Provide an order at "
    103             "construction time or run ReorderInPlace");
    104       }
    105     }
    106 
    107     for (std::size_t n = 0; n < num_entries(); ++n) {
    108       TF_RETURN_IF_ERROR(IndexValid(ix_t, n));
    109     }
    110 
    111     return Status::OK();
    112   }
    113 
    114   VarDimArray shape() const { return shape_; }
    115 
    116   VarDimArray order() const { return order_; }
    117 
    118   // Resorts the indices and values according to the dimensions in order.
    119   template <typename T>
    120   void Reorder(const VarDimArray& order);
    121 
    122   // Returns a group iterable that can be used for clumping indices
    123   // and values according to the group indices of interest.
    124   //
    125   // Precondition: order()[0..group_ix.size()] == group_ix.
    126   //
    127   // See the README.md in this directory for more usage information.
    128   GroupIterable group(const VarDimArray& group_ix) const {
    129     CHECK_LE(group_ix.size(), dims_);
    130     for (std::size_t di = 0; di < group_ix.size(); ++di) {
    131       CHECK_GE(group_ix[di], 0) << "Group dimension out of range";
    132       CHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
    133       CHECK_EQ(group_ix[di], order_[di])
    134           << "Group dimension does not match sorted order";
    135     }
    136     return GroupIterable(ix_, vals_, dims_, group_ix);
    137   }
    138 
    139   // Stores the sparse indices into the dense tensor out.
    140   // Preconditions:
    141   //   out->shape().dims() == shape().dims()
    142   //   out->shape().dim_size(d) >= shape(d) for all d
    143   //
    144   // Returns true on success.  False on failure (mismatched dimensions
    145   // or out-of-bounds indices).
    146   //
    147   // If initialize==True, ToDense first overwrites all coefficients in out to 0.
    148   //
    149   template <typename T>
    150   bool ToDense(Tensor* out, bool initialize = true);
    151 
    152   // Concat() will concatenate all the tensors according to their first order
    153   // dimension.  All tensors must have identical shape except for
    154   // the first order dimension.  All tensors orders' first dimension
    155   // must match.
    156   //
    157   // If all of the tensors have identical ordering, then the output
    158   // will have this ordering.  Otherwise the output is set as not
    159   // having any order and a Reorder<T>() should be called on it before
    160   // performing any subsequent operations.
    161   template <typename T>
    162   static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
    163 
    164   // Split() will split the input SparseTensor into a list of num_split
    165   // SparseTensor given a splitting dimension. If the input dimension range
    166   // isn't an integer multiple of split_dim, we add one extra dimension for
    167   // each slice.
    168   template <typename T>
    169   static std::vector<SparseTensor> Split(const SparseTensor& tensor,
    170                                          const int split_dim,
    171                                          const int num_split);
    172 
    173   // Slice() will slice the input SparseTensor into a SparseTensor based on
    174   // specified start and size. Both start and size are 1-D array with each
    175   // element of the array representing one dimension. The start is the start
    176   // index at each dimension and the size is the size at each dimension.
    177   template <typename T>
    178   static SparseTensor Slice(const SparseTensor& tensor,
    179                             const gtl::ArraySlice<int64>& start,
    180                             const gtl::ArraySlice<int64>& size);
    181 
    182   // Picks out the dimensions according to `dim_indices`.
    183   std::vector<int64> PickDims(gtl::ArraySlice<int64> dim_indices) const {
    184     std::vector<int64> res(dim_indices.size());
    185     for (size_t i = 0; i < dim_indices.size(); ++i) {
    186       res[i] = shape_[dim_indices[i]];
    187     }
    188     return res;
    189   }
    190 
    191  private:
    192   static int GetDimsFromIx(const Tensor& ix) {
    193     CHECK(TensorShapeUtils::IsMatrix(ix.shape()))
    194         << "indices must be a matrix, but got: " << ix.shape().DebugString();
    195     return ix.dim_size(1);
    196   }
    197 
    198   static inline ShapeArray UndefinedOrder(const VarDimArray shape) {
    199     return ShapeArray(shape.size(), -1);
    200   }
    201 
    202   static inline ShapeArray TensorShapeToVector(const TensorShape& shape) {
    203     ShapeArray vec(shape.dims());
    204     for (int i = 0; i < shape.dims(); ++i) vec[i] = shape.dim_size(i);
    205     return vec;
    206   }
    207 
    208   // Helper for IndicesValid()
    209   inline Status IndexValid(const TTypes<int64>::ConstMatrix& ix_t,
    210                            int n) const {
    211     bool valid = true;
    212     bool different = false;
    213     bool increasing = true;
    214     if (n == 0) {
    215       for (int di = 0; di < dims_; ++di) {
    216         if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false;
    217       }
    218       different = true;
    219     } else {
    220       for (int di = 0; di < dims_; ++di) {
    221         if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false;
    222         int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]);
    223         if (diff > 0) different = true;
    224         if (!different && diff < 0) increasing = false;
    225       }
    226     }
    227     if (TF_PREDICT_FALSE(!valid || !increasing || !different)) {
    228       string index = strings::StrCat("indices[", n, "] = [");
    229       for (int di = 0; di < dims_; ++di) {
    230         strings::StrAppend(&index, ix_t(n, di), di < dims_ - 1 ? "," : "]");
    231       }
    232       if (!valid) {
    233         return errors::InvalidArgument(index,
    234                                        " is out of bounds: need 0 <= index < [",
    235                                        str_util::Join(shape_, ","), "]");
    236       }
    237       if (!increasing) {
    238         return errors::InvalidArgument(index, " is out of order");
    239       }
    240       if (!different) {
    241         return errors::InvalidArgument(index, " is repeated");
    242       }
    243     }
    244     return Status::OK();
    245   }
    246 
    247   // Helper for ToDense<T>()
    248   template <typename T>
    249   bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
    250 
    251   // Helper for Split() that returns the slice index.
    252   static inline int GetSliceIndex(const int dim, const int split_size,
    253                                   const int residual) {
    254     CHECK_GT(split_size, 0);
    255     CHECK_GE(dim, 0);
    256     if (residual == 0) return dim / split_size;
    257     const int offset = residual * (split_size + 1);
    258     if (dim < offset) {
    259       return dim / (split_size + 1);
    260     } else {
    261       return residual + ((dim - offset) / split_size);
    262     }
    263   }
    264 
    265   // Helper for Split() that returns the dimension in the slice.
    266   static inline int GetDimensionInSlice(const int dim, const int split_size,
    267                                         const int residual) {
    268     CHECK_GT(split_size, 0);
    269     CHECK_GE(dim, 0);
    270     if (residual == 0) return dim % split_size;
    271     const int offset = residual * (split_size + 1);
    272     if (dim < offset) {
    273       return dim % (split_size + 1);
    274     } else {
    275       return (dim - offset) % split_size;
    276     }
    277   }
    278 
    279   // Helper for Split() that returns the shape given a slice index.
    280   static inline int GetSliceShape(const int slice_index, const int split_size,
    281                                   const int residual) {
    282     CHECK_GT(split_size, 0);
    283     CHECK_GE(slice_index, 0);
    284     if (residual == 0) return split_size;
    285     if (slice_index < residual) {
    286       return split_size + 1;
    287     } else {
    288       return split_size;
    289     }
    290   }
    291 
    292   Tensor ix_;
    293   Tensor vals_;
    294   ShapeArray shape_;
    295   ShapeArray order_;
    296   const int dims_;
    297 };
    298 
    299 // This operation updates the indices and values Tensor rows, so it is
    300 // an in-place algorithm.  It requires O(N log N) time and O(N)
    301 // temporary space.
    302 template <typename T>
    303 void SparseTensor::Reorder(const VarDimArray& order) {
    304   CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
    305       << "Reorder requested with the wrong datatype";
    306   CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
    307   auto ix_t = ix_.matrix<int64>();
    308   auto vals_t = vals_.vec<T>();
    309 
    310   std::vector<int64> reorder(num_entries());
    311   std::iota(reorder.begin(), reorder.end(), 0);
    312 
    313   // Sort to get order of indices
    314   switch (order.size()) {
    315 #define CASE_SORT(ORDER_SIZE)                                    \
    316   case ORDER_SIZE: {                                             \
    317     FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape()); \
    318     std::sort(reorder.begin(), reorder.end(), sorter);           \
    319     break;                                                       \
    320   }
    321     CASE_SORT(0);
    322     CASE_SORT(1);
    323     CASE_SORT(2);
    324     CASE_SORT(3);
    325     CASE_SORT(4);
    326     CASE_SORT(5);
    327 #undef CASE_SORT
    328     default: {
    329       DimComparator sorter(ix_t, order, shape());
    330       std::sort(reorder.begin(), reorder.end(), sorter);
    331     }
    332   }
    333 
    334   // We have a forward reordering, but what we'll need is a
    335   // permutation (the inverse).  This can be calculated with O(1)
    336   // additional
    337   // and O(n) time (INVPERM) but we just do the simple thing here.
    338   std::vector<size_t> permutation(reorder.size());
    339   for (std::size_t n = 0; n < reorder.size(); ++n) {
    340     permutation[reorder[n]] = n;
    341   }
    342 
    343   // Update indices & values by converting the permutations to
    344   // a product of transpositions.  Iterate over the cycles in the
    345   // permutation, and convert each of those into a product of
    346   // transpositions (swaps):
    347   //   https://en.wikipedia.org/wiki/Cyclic_permutation
    348   // This is N swaps, 2*N comparisons.
    349   for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
    350     while (n != permutation[n]) {
    351       std::size_t r = permutation[n];
    352       std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
    353       std::swap(vals_t(n), vals_t(r));
    354       std::swap(permutation[n], permutation[r]);
    355     }
    356   }
    357 
    358   order_ = ShapeArray(order.begin(), order.end());
    359 }
    360 
    361 template <typename T>
    362 bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
    363   CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
    364       << "ToDense requested with the wrong datatype";
    365 
    366   CHECK_EQ(out->shape().dims(), dims_)
    367       << "Incompatible dimensions between SparseTensor and output";
    368 
    369   CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
    370       << "Output must be type: " << DataTypeToEnum<T>::v()
    371       << " but got: " << out->dtype();
    372 
    373   // Make sure the dense output is the same rank and has room
    374   // to hold the SparseTensor.
    375   const auto& out_shape = out->shape();
    376   if (shape_.size() != out_shape.dims()) return false;
    377   for (int d = 0; d < shape_.size(); ++d) {
    378     if (shape_[d] > out_shape.dim_size(d)) return false;
    379   }
    380 
    381   if (initialize) {
    382     auto out_t = out->flat<T>();
    383     out_t.setConstant(T());
    384   }
    385 
    386   return true;
    387 }
    388 
    389 template <typename T>
    390 bool SparseTensor::ToDense(Tensor* out, bool initialize) {
    391   if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
    392 
    393   auto out_t = out->flat<T>();
    394   auto ix_t = ix_.matrix<int64>();
    395   auto vals_t = vals_.vec<T>();
    396 
    397   std::vector<int64> strides(dims_);
    398   const auto& out_shape = out->shape();
    399   if (dims_ > 0) {
    400     strides[dims_ - 1] = 1;
    401   }
    402   for (int d = dims_ - 2; d >= 0; --d) {
    403     strides[d] = strides[d + 1] * out_shape.dim_size(d + 1);
    404   }
    405 
    406   for (int n = 0; n < vals_t.dimension(0); ++n) {
    407     bool invalid_dims = false;
    408     int64 ix = 0;
    409     for (int d = 0; d < dims_; ++d) {
    410       const int64 ix_n_d = internal::SubtleMustCopy(ix_t(n, d));
    411       if (!FastBoundsCheck(ix_n_d, out_shape.dim_size(d))) {
    412         invalid_dims = true;
    413       }
    414       ix += strides[d] * ix_n_d;
    415     }
    416     if (invalid_dims) return false;
    417     out_t(ix) = vals_t(n);
    418   }
    419   return true;
    420 }
    421 
    422 template <typename T>
    423 SparseTensor SparseTensor::Concat(
    424     const gtl::ArraySlice<SparseTensor>& tensors) {
    425   CHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
    426   const int dims = tensors[0].dims_;
    427   CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
    428   auto order_0 = tensors[0].order();
    429   const int primary_dim = order_0[0];
    430   ShapeArray final_order(order_0.begin(), order_0.end());
    431   ShapeArray final_shape(tensors[0].shape().begin(), tensors[0].shape().end());
    432   final_shape[primary_dim] = 0;  // We'll build this up as we go along.
    433   int num_entries = 0;
    434 
    435   bool fully_ordered = true;
    436   for (const SparseTensor& st : tensors) {
    437     CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
    438     CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
    439         << "Concat requested with the wrong data type";
    440     CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
    441     CHECK_EQ(st.order()[0], primary_dim)
    442         << "All SparseTensors' order[0] must match.  This is the concat dim.";
    443     if (st.order() != final_order) fully_ordered = false;
    444     const VarDimArray& st_shape = st.shape();
    445     for (int d = 0; d < dims - 1; ++d) {
    446       const int cdim = (d < primary_dim) ? d : d + 1;
    447       CHECK_EQ(final_shape[cdim], st_shape[cdim])
    448           << "All SparseTensors' shapes must match except on the concat dim.  "
    449           << "Concat dim: " << primary_dim
    450           << ", mismatched shape at dim: " << cdim
    451           << ".  Expecting shape like: [" << str_util::Join(final_shape, ",")
    452           << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]";
    453     }
    454 
    455     // Update dimension of final shape
    456     final_shape[primary_dim] =
    457         (final_shape[primary_dim] + st_shape[primary_dim]);
    458 
    459     num_entries += st.num_entries();  // Update number of entries
    460   }
    461 
    462   // If nonconsistent ordering among inputs, set final order to -1s.
    463   if (!fully_ordered) {
    464     final_order = UndefinedOrder(final_shape);
    465   }
    466 
    467   Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
    468   Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
    469 
    470   TTypes<int64>::Matrix ix_t = output_ix.matrix<int64>();
    471   typename TTypes<T>::Vec vals_t = output_vals.vec<T>();
    472 
    473   Eigen::DenseIndex offset = 0;
    474   int64 shape_offset = 0;
    475   for (const SparseTensor& st : tensors) {
    476     const int st_num_entries = st.num_entries();
    477 
    478     // Fill in indices & values.
    479     std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
    480 
    481     const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
    482     auto* ix_out = &ix_t(offset, 0);
    483     for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
    484       *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
    485     }
    486 
    487     offset += st_num_entries;
    488     shape_offset += st.shape()[primary_dim];
    489   }
    490 
    491   return SparseTensor(output_ix, output_vals, final_shape, final_order);
    492 }
    493 
    494 template <typename T>
    495 std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
    496                                               const int split_dim,
    497                                               const int num_split) {
    498   std::vector<Tensor> output_indices;
    499   std::vector<Tensor> output_values;
    500   std::vector<TensorShape> output_shapes;
    501   output_indices.reserve(num_split);
    502   output_values.reserve(num_split);
    503   output_shapes.reserve(num_split);
    504 
    505   std::vector<typename TTypes<int64>::Matrix> output_indices_t;
    506   std::vector<typename TTypes<T>::Vec> output_values_t;
    507   output_indices_t.reserve(num_split);
    508   output_values_t.reserve(num_split);
    509   auto input_values_t = input_tensor.values().vec<T>();
    510   auto input_indices_t = input_tensor.indices().matrix<int64>();
    511 
    512   std::vector<int> num_values(num_split, 0);
    513   const int num_dim = input_tensor.shape().size();
    514   const int split_dim_size = input_tensor.shape()[split_dim];
    515   const int split_size = split_dim_size / num_split;
    516 
    517   CHECK(num_split > 0 && num_split <= split_dim_size) << "num_split must be in "
    518                                                          "the interval (0, "
    519                                                       << split_dim_size << "]";
    520   CHECK(split_dim >= 0 && split_dim < num_dim) << "num_dim must be in "
    521                                                   "the interval [0, "
    522                                                << num_dim << ")";
    523 
    524   const int residual = split_dim_size % num_split;
    525   for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
    526     const int dim = input_tensor.indices().matrix<int64>()(i, split_dim);
    527     int slice_index = GetSliceIndex(dim, split_size, residual);
    528     num_values[slice_index]++;
    529   }
    530 
    531   for (int i = 0; i < num_split; ++i) {
    532     // TODO(ataei): Pass an allocator to avoid allocating large memory buffer.
    533     output_indices.emplace_back(DT_INT64,
    534                                 TensorShape({num_values[i], num_dim}));
    535     output_values.emplace_back(DataTypeToEnum<T>::v(),
    536                                TensorShape({num_values[i]}));
    537     output_shapes.emplace_back(input_tensor.shape());
    538     output_indices_t.emplace_back(output_indices[i].matrix<int64>());
    539     output_values_t.emplace_back(output_values[i].vec<T>());
    540     const int size = GetSliceShape(i, split_size, residual);
    541     output_shapes[i].set_dim(split_dim, size);
    542   }
    543 
    544   std::vector<int> values_inserted_in_slice(num_split, 0);
    545   for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
    546     const int dim = input_indices_t(i, split_dim);
    547     const int slice_index = GetSliceIndex(dim, split_size, residual);
    548     const int slice_dim = values_inserted_in_slice[slice_index]++;
    549     output_values_t[slice_index](slice_dim) = input_values_t(i);
    550     for (int j = 0; j < num_dim; ++j) {
    551       const int64 original_dim = input_indices_t(i, j);
    552       output_indices_t[slice_index](slice_dim, j) =
    553           (j == split_dim)
    554               ? GetDimensionInSlice(original_dim, split_size, residual)
    555               : original_dim;
    556     }
    557   }
    558 
    559   std::vector<SparseTensor> output_tensors;
    560   output_tensors.reserve(num_split);
    561   for (int i = 0; i < num_split; ++i) {
    562     output_tensors.emplace_back(output_indices[i], output_values[i],
    563                                 output_shapes[i]);
    564   }
    565   return output_tensors;
    566 }
    567 
    568 template <typename T>
    569 SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
    570                                  const gtl::ArraySlice<int64>& start,
    571                                  const gtl::ArraySlice<int64>& size) {
    572   TensorShape output_shape(input_tensor.shape());
    573 
    574   const int dims = input_tensor.dims();
    575   for (int dim = 0; dim < dims; dim++) {
    576     int64 dim_size = start[dim] + size[dim] < output_shape.dim_size(dim)
    577                          ? size[dim]
    578                          : output_shape.dim_size(dim) - start[dim];
    579     output_shape.set_dim(dim, dim_size);
    580   }
    581 
    582   auto input_indices_t = input_tensor.indices().matrix<int64>();
    583   auto input_values_t = input_tensor.values().vec<T>();
    584 
    585   // Find the number of indices that fall inside start and size.
    586   int count = 0;
    587   for (int i = 0; i < input_tensor.indices().dim_size(0); i++) {
    588     // The following will check to see if an input is within the
    589     // range specified by start and size.
    590     // The for loop below iterates through all dimensions. In case
    591     // the index falls outside of the start and size at any dimension,
    592     // it will be considered as a "no hit" (hit = false). In this
    593     // case, it will not be counted as the index that fall inside
    594     // the range specified by start and size.
    595     bool hit = true;
    596     for (int dim = 0; dim < dims; dim++) {
    597       if (!(start[dim] <= input_indices_t(i, dim) &&
    598             input_indices_t(i, dim) < start[dim] + size[dim])) {
    599         hit = false;
    600         break;
    601       }
    602     }
    603     if (!hit) {
    604       continue;
    605     }
    606     count++;
    607   }
    608 
    609   Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count}));
    610   Tensor output_indices(DT_INT64, TensorShape({count, dims}));
    611 
    612   auto output_values_t = output_values.vec<T>();
    613   auto output_indices_t = output_indices.matrix<int64>();
    614 
    615   // Obtain the output indices that fall inside start and size.
    616   int index = 0;
    617   for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
    618        i++) {
    619     // The logic here is similar as the above except that the above
    620     // only count the number of indices while here we actually generate
    621     // the output.
    622     bool hit = true;
    623     for (int dim = 0; dim < dims; dim++) {
    624       if (!(start[dim] <= input_indices_t(i, dim) &&
    625             input_indices_t(i, dim) < start[dim] + size[dim])) {
    626         hit = false;
    627         break;
    628       }
    629     }
    630     if (!hit) {
    631       continue;
    632     }
    633     output_values_t(index) = input_values_t(i);
    634     for (int dim = 0; dim < dims; dim++) {
    635       output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim];
    636     }
    637     index++;
    638   }
    639 
    640   return SparseTensor(output_indices, output_values, output_shape);
    641 }
    642 
    643 }  // namespace sparse
    644 }  // namespace tensorflow
    645 
    646 #endif  // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
    647