1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ 17 #define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ 18 19 #include <limits> 20 #include <numeric> 21 #include <vector> 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_types.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/framework/types.pb.h" 28 #include "tensorflow/core/kernels/bounds_check.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/strings/str_util.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/core/util/sparse/dim_comparator.h" 34 #include "tensorflow/core/util/sparse/group_iterator.h" 35 36 namespace tensorflow { 37 namespace sparse { 38 39 class SparseTensor { 40 public: 41 typedef typename gtl::ArraySlice<int64> VarDimArray; 42 typedef typename gtl::InlinedVector<int64, 8> ShapeArray; 43 44 SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape) 45 : SparseTensor(ix, vals, TensorShapeToVector(shape), 46 UndefinedOrder(TensorShapeToVector(shape))) {} 47 48 SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape) 49 : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {} 50 51 SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape, 52 const VarDimArray order) 53 : SparseTensor(ix, vals, TensorShapeToVector(shape), order) {} 54 55 SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape, 56 const VarDimArray order) 57 : ix_(ix), 58 vals_(vals), 59 shape_(shape.begin(), shape.end()), 60 order_(order.begin(), order.end()), 61 dims_(GetDimsFromIx(ix)) { 62 CHECK_EQ(ix.dtype(), DT_INT64) 63 << "indices must be type int64 but got: " << ix.dtype(); 64 CHECK(TensorShapeUtils::IsVector(vals.shape())) 65 << "vals must be a vec, but got: " << vals.shape().DebugString(); 66 CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0)) 67 << "indices and values rows (indexing dimension) must match."; 68 CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank."; 69 CHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank."; 70 } 71 72 SparseTensor(const SparseTensor& other) 73 : SparseTensor(other.ix_, other.vals_, other.shape_, other.order_) {} 74 75 SparseTensor(SparseTensor&& other) 76 : SparseTensor(std::move(other.ix_), std::move(other.vals_), 77 std::move(other.shape_), std::move(other.order_)) {} 78 79 SparseTensor& operator=(const SparseTensor& other) { 80 ix_ = other.ix_; 81 vals_ = other.vals_; 82 shape_ = other.shape_; 83 order_ = other.order_; 84 return *this; 85 } 86 87 std::size_t num_entries() const { return ix_.dim_size(0); } 88 89 int dims() const { return shape_.size(); } 90 91 const Tensor& indices() const { return ix_; } 92 93 const Tensor& values() const { return vals_; } 94 95 DataType dtype() const { return vals_.dtype(); } 96 97 Status IndicesValid() const { 98 const auto ix_t = ix_.matrix<int64>(); 99 for (int64 ord : order_) { 100 if (ord < 0) { 101 return errors::FailedPrecondition( 102 "Order was not provided. Provide an order at " 103 "construction time or run ReorderInPlace"); 104 } 105 } 106 107 for (std::size_t n = 0; n < num_entries(); ++n) { 108 TF_RETURN_IF_ERROR(IndexValid(ix_t, n)); 109 } 110 111 return Status::OK(); 112 } 113 114 VarDimArray shape() const { return shape_; } 115 116 VarDimArray order() const { return order_; } 117 118 // Resorts the indices and values according to the dimensions in order. 119 template <typename T> 120 void Reorder(const VarDimArray& order); 121 122 // Returns a group iterable that can be used for clumping indices 123 // and values according to the group indices of interest. 124 // 125 // Precondition: order()[0..group_ix.size()] == group_ix. 126 // 127 // See the README.md in this directory for more usage information. 128 GroupIterable group(const VarDimArray& group_ix) const { 129 CHECK_LE(group_ix.size(), dims_); 130 for (std::size_t di = 0; di < group_ix.size(); ++di) { 131 CHECK_GE(group_ix[di], 0) << "Group dimension out of range"; 132 CHECK_LT(group_ix[di], dims_) << "Group dimension out of range"; 133 CHECK_EQ(group_ix[di], order_[di]) 134 << "Group dimension does not match sorted order"; 135 } 136 return GroupIterable(ix_, vals_, dims_, group_ix); 137 } 138 139 // Stores the sparse indices into the dense tensor out. 140 // Preconditions: 141 // out->shape().dims() == shape().dims() 142 // out->shape().dim_size(d) >= shape(d) for all d 143 // 144 // Returns true on success. False on failure (mismatched dimensions 145 // or out-of-bounds indices). 146 // 147 // If initialize==True, ToDense first overwrites all coefficients in out to 0. 148 // 149 template <typename T> 150 bool ToDense(Tensor* out, bool initialize = true); 151 152 // Concat() will concatenate all the tensors according to their first order 153 // dimension. All tensors must have identical shape except for 154 // the first order dimension. All tensors orders' first dimension 155 // must match. 156 // 157 // If all of the tensors have identical ordering, then the output 158 // will have this ordering. Otherwise the output is set as not 159 // having any order and a Reorder<T>() should be called on it before 160 // performing any subsequent operations. 161 template <typename T> 162 static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors); 163 164 // Split() will split the input SparseTensor into a list of num_split 165 // SparseTensor given a splitting dimension. If the input dimension range 166 // isn't an integer multiple of split_dim, we add one extra dimension for 167 // each slice. 168 template <typename T> 169 static std::vector<SparseTensor> Split(const SparseTensor& tensor, 170 const int split_dim, 171 const int num_split); 172 173 // Slice() will slice the input SparseTensor into a SparseTensor based on 174 // specified start and size. Both start and size are 1-D array with each 175 // element of the array representing one dimension. The start is the start 176 // index at each dimension and the size is the size at each dimension. 177 template <typename T> 178 static SparseTensor Slice(const SparseTensor& tensor, 179 const gtl::ArraySlice<int64>& start, 180 const gtl::ArraySlice<int64>& size); 181 182 // Picks out the dimensions according to `dim_indices`. 183 std::vector<int64> PickDims(gtl::ArraySlice<int64> dim_indices) const { 184 std::vector<int64> res(dim_indices.size()); 185 for (size_t i = 0; i < dim_indices.size(); ++i) { 186 res[i] = shape_[dim_indices[i]]; 187 } 188 return res; 189 } 190 191 private: 192 static int GetDimsFromIx(const Tensor& ix) { 193 CHECK(TensorShapeUtils::IsMatrix(ix.shape())) 194 << "indices must be a matrix, but got: " << ix.shape().DebugString(); 195 return ix.dim_size(1); 196 } 197 198 static inline ShapeArray UndefinedOrder(const VarDimArray shape) { 199 return ShapeArray(shape.size(), -1); 200 } 201 202 static inline ShapeArray TensorShapeToVector(const TensorShape& shape) { 203 ShapeArray vec(shape.dims()); 204 for (int i = 0; i < shape.dims(); ++i) vec[i] = shape.dim_size(i); 205 return vec; 206 } 207 208 // Helper for IndicesValid() 209 inline Status IndexValid(const TTypes<int64>::ConstMatrix& ix_t, 210 int n) const { 211 bool valid = true; 212 bool different = false; 213 bool increasing = true; 214 if (n == 0) { 215 for (int di = 0; di < dims_; ++di) { 216 if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false; 217 } 218 different = true; 219 } else { 220 for (int di = 0; di < dims_; ++di) { 221 if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false; 222 int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]); 223 if (diff > 0) different = true; 224 if (!different && diff < 0) increasing = false; 225 } 226 } 227 if (TF_PREDICT_FALSE(!valid || !increasing || !different)) { 228 string index = strings::StrCat("indices[", n, "] = ["); 229 for (int di = 0; di < dims_; ++di) { 230 strings::StrAppend(&index, ix_t(n, di), di < dims_ - 1 ? "," : "]"); 231 } 232 if (!valid) { 233 return errors::InvalidArgument(index, 234 " is out of bounds: need 0 <= index < [", 235 str_util::Join(shape_, ","), "]"); 236 } 237 if (!increasing) { 238 return errors::InvalidArgument(index, " is out of order"); 239 } 240 if (!different) { 241 return errors::InvalidArgument(index, " is repeated"); 242 } 243 } 244 return Status::OK(); 245 } 246 247 // Helper for ToDense<T>() 248 template <typename T> 249 bool ValidateAndInitializeToDense(Tensor* out, bool initialize); 250 251 // Helper for Split() that returns the slice index. 252 static inline int GetSliceIndex(const int dim, const int split_size, 253 const int residual) { 254 CHECK_GT(split_size, 0); 255 CHECK_GE(dim, 0); 256 if (residual == 0) return dim / split_size; 257 const int offset = residual * (split_size + 1); 258 if (dim < offset) { 259 return dim / (split_size + 1); 260 } else { 261 return residual + ((dim - offset) / split_size); 262 } 263 } 264 265 // Helper for Split() that returns the dimension in the slice. 266 static inline int GetDimensionInSlice(const int dim, const int split_size, 267 const int residual) { 268 CHECK_GT(split_size, 0); 269 CHECK_GE(dim, 0); 270 if (residual == 0) return dim % split_size; 271 const int offset = residual * (split_size + 1); 272 if (dim < offset) { 273 return dim % (split_size + 1); 274 } else { 275 return (dim - offset) % split_size; 276 } 277 } 278 279 // Helper for Split() that returns the shape given a slice index. 280 static inline int GetSliceShape(const int slice_index, const int split_size, 281 const int residual) { 282 CHECK_GT(split_size, 0); 283 CHECK_GE(slice_index, 0); 284 if (residual == 0) return split_size; 285 if (slice_index < residual) { 286 return split_size + 1; 287 } else { 288 return split_size; 289 } 290 } 291 292 Tensor ix_; 293 Tensor vals_; 294 ShapeArray shape_; 295 ShapeArray order_; 296 const int dims_; 297 }; 298 299 // This operation updates the indices and values Tensor rows, so it is 300 // an in-place algorithm. It requires O(N log N) time and O(N) 301 // temporary space. 302 template <typename T> 303 void SparseTensor::Reorder(const VarDimArray& order) { 304 CHECK_EQ(DataTypeToEnum<T>::v(), dtype()) 305 << "Reorder requested with the wrong datatype"; 306 CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank"; 307 auto ix_t = ix_.matrix<int64>(); 308 auto vals_t = vals_.vec<T>(); 309 310 std::vector<int64> reorder(num_entries()); 311 std::iota(reorder.begin(), reorder.end(), 0); 312 313 // Sort to get order of indices 314 switch (order.size()) { 315 #define CASE_SORT(ORDER_SIZE) \ 316 case ORDER_SIZE: { \ 317 FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape()); \ 318 std::sort(reorder.begin(), reorder.end(), sorter); \ 319 break; \ 320 } 321 CASE_SORT(0); 322 CASE_SORT(1); 323 CASE_SORT(2); 324 CASE_SORT(3); 325 CASE_SORT(4); 326 CASE_SORT(5); 327 #undef CASE_SORT 328 default: { 329 DimComparator sorter(ix_t, order, shape()); 330 std::sort(reorder.begin(), reorder.end(), sorter); 331 } 332 } 333 334 // We have a forward reordering, but what we'll need is a 335 // permutation (the inverse). This can be calculated with O(1) 336 // additional 337 // and O(n) time (INVPERM) but we just do the simple thing here. 338 std::vector<size_t> permutation(reorder.size()); 339 for (std::size_t n = 0; n < reorder.size(); ++n) { 340 permutation[reorder[n]] = n; 341 } 342 343 // Update indices & values by converting the permutations to 344 // a product of transpositions. Iterate over the cycles in the 345 // permutation, and convert each of those into a product of 346 // transpositions (swaps): 347 // https://en.wikipedia.org/wiki/Cyclic_permutation 348 // This is N swaps, 2*N comparisons. 349 for (std::size_t n = 0; n + 1 < permutation.size(); ++n) { 350 while (n != permutation[n]) { 351 std::size_t r = permutation[n]; 352 std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0))); 353 std::swap(vals_t(n), vals_t(r)); 354 std::swap(permutation[n], permutation[r]); 355 } 356 } 357 358 order_ = ShapeArray(order.begin(), order.end()); 359 } 360 361 template <typename T> 362 bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) { 363 CHECK_EQ(DataTypeToEnum<T>::v(), dtype()) 364 << "ToDense requested with the wrong datatype"; 365 366 CHECK_EQ(out->shape().dims(), dims_) 367 << "Incompatible dimensions between SparseTensor and output"; 368 369 CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v()) 370 << "Output must be type: " << DataTypeToEnum<T>::v() 371 << " but got: " << out->dtype(); 372 373 // Make sure the dense output is the same rank and has room 374 // to hold the SparseTensor. 375 const auto& out_shape = out->shape(); 376 if (shape_.size() != out_shape.dims()) return false; 377 for (int d = 0; d < shape_.size(); ++d) { 378 if (shape_[d] > out_shape.dim_size(d)) return false; 379 } 380 381 if (initialize) { 382 auto out_t = out->flat<T>(); 383 out_t.setConstant(T()); 384 } 385 386 return true; 387 } 388 389 template <typename T> 390 bool SparseTensor::ToDense(Tensor* out, bool initialize) { 391 if (!ValidateAndInitializeToDense<T>(out, initialize)) return false; 392 393 auto out_t = out->flat<T>(); 394 auto ix_t = ix_.matrix<int64>(); 395 auto vals_t = vals_.vec<T>(); 396 397 std::vector<int64> strides(dims_); 398 const auto& out_shape = out->shape(); 399 if (dims_ > 0) { 400 strides[dims_ - 1] = 1; 401 } 402 for (int d = dims_ - 2; d >= 0; --d) { 403 strides[d] = strides[d + 1] * out_shape.dim_size(d + 1); 404 } 405 406 for (int n = 0; n < vals_t.dimension(0); ++n) { 407 bool invalid_dims = false; 408 int64 ix = 0; 409 for (int d = 0; d < dims_; ++d) { 410 const int64 ix_n_d = internal::SubtleMustCopy(ix_t(n, d)); 411 if (!FastBoundsCheck(ix_n_d, out_shape.dim_size(d))) { 412 invalid_dims = true; 413 } 414 ix += strides[d] * ix_n_d; 415 } 416 if (invalid_dims) return false; 417 out_t(ix) = vals_t(n); 418 } 419 return true; 420 } 421 422 template <typename T> 423 SparseTensor SparseTensor::Concat( 424 const gtl::ArraySlice<SparseTensor>& tensors) { 425 CHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors"; 426 const int dims = tensors[0].dims_; 427 CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors"; 428 auto order_0 = tensors[0].order(); 429 const int primary_dim = order_0[0]; 430 ShapeArray final_order(order_0.begin(), order_0.end()); 431 ShapeArray final_shape(tensors[0].shape().begin(), tensors[0].shape().end()); 432 final_shape[primary_dim] = 0; // We'll build this up as we go along. 433 int num_entries = 0; 434 435 bool fully_ordered = true; 436 for (const SparseTensor& st : tensors) { 437 CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank."; 438 CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype()) 439 << "Concat requested with the wrong data type"; 440 CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered"; 441 CHECK_EQ(st.order()[0], primary_dim) 442 << "All SparseTensors' order[0] must match. This is the concat dim."; 443 if (st.order() != final_order) fully_ordered = false; 444 const VarDimArray& st_shape = st.shape(); 445 for (int d = 0; d < dims - 1; ++d) { 446 const int cdim = (d < primary_dim) ? d : d + 1; 447 CHECK_EQ(final_shape[cdim], st_shape[cdim]) 448 << "All SparseTensors' shapes must match except on the concat dim. " 449 << "Concat dim: " << primary_dim 450 << ", mismatched shape at dim: " << cdim 451 << ". Expecting shape like: [" << str_util::Join(final_shape, ",") 452 << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]"; 453 } 454 455 // Update dimension of final shape 456 final_shape[primary_dim] = 457 (final_shape[primary_dim] + st_shape[primary_dim]); 458 459 num_entries += st.num_entries(); // Update number of entries 460 } 461 462 // If nonconsistent ordering among inputs, set final order to -1s. 463 if (!fully_ordered) { 464 final_order = UndefinedOrder(final_shape); 465 } 466 467 Tensor output_ix(DT_INT64, TensorShape({num_entries, dims})); 468 Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries})); 469 470 TTypes<int64>::Matrix ix_t = output_ix.matrix<int64>(); 471 typename TTypes<T>::Vec vals_t = output_vals.vec<T>(); 472 473 Eigen::DenseIndex offset = 0; 474 int64 shape_offset = 0; 475 for (const SparseTensor& st : tensors) { 476 const int st_num_entries = st.num_entries(); 477 478 // Fill in indices & values. 479 std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset)); 480 481 const auto* st_ix = &st.ix_.matrix<int64>()(0, 0); 482 auto* ix_out = &ix_t(offset, 0); 483 for (std::size_t i = 0; i < st_num_entries * dims; ++i) { 484 *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0); 485 } 486 487 offset += st_num_entries; 488 shape_offset += st.shape()[primary_dim]; 489 } 490 491 return SparseTensor(output_ix, output_vals, final_shape, final_order); 492 } 493 494 template <typename T> 495 std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor, 496 const int split_dim, 497 const int num_split) { 498 std::vector<Tensor> output_indices; 499 std::vector<Tensor> output_values; 500 std::vector<TensorShape> output_shapes; 501 output_indices.reserve(num_split); 502 output_values.reserve(num_split); 503 output_shapes.reserve(num_split); 504 505 std::vector<typename TTypes<int64>::Matrix> output_indices_t; 506 std::vector<typename TTypes<T>::Vec> output_values_t; 507 output_indices_t.reserve(num_split); 508 output_values_t.reserve(num_split); 509 auto input_values_t = input_tensor.values().vec<T>(); 510 auto input_indices_t = input_tensor.indices().matrix<int64>(); 511 512 std::vector<int> num_values(num_split, 0); 513 const int num_dim = input_tensor.shape().size(); 514 const int split_dim_size = input_tensor.shape()[split_dim]; 515 const int split_size = split_dim_size / num_split; 516 517 CHECK(num_split > 0 && num_split <= split_dim_size) << "num_split must be in " 518 "the interval (0, " 519 << split_dim_size << "]"; 520 CHECK(split_dim >= 0 && split_dim < num_dim) << "num_dim must be in " 521 "the interval [0, " 522 << num_dim << ")"; 523 524 const int residual = split_dim_size % num_split; 525 for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) { 526 const int dim = input_tensor.indices().matrix<int64>()(i, split_dim); 527 int slice_index = GetSliceIndex(dim, split_size, residual); 528 num_values[slice_index]++; 529 } 530 531 for (int i = 0; i < num_split; ++i) { 532 // TODO(ataei): Pass an allocator to avoid allocating large memory buffer. 533 output_indices.emplace_back(DT_INT64, 534 TensorShape({num_values[i], num_dim})); 535 output_values.emplace_back(DataTypeToEnum<T>::v(), 536 TensorShape({num_values[i]})); 537 output_shapes.emplace_back(input_tensor.shape()); 538 output_indices_t.emplace_back(output_indices[i].matrix<int64>()); 539 output_values_t.emplace_back(output_values[i].vec<T>()); 540 const int size = GetSliceShape(i, split_size, residual); 541 output_shapes[i].set_dim(split_dim, size); 542 } 543 544 std::vector<int> values_inserted_in_slice(num_split, 0); 545 for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) { 546 const int dim = input_indices_t(i, split_dim); 547 const int slice_index = GetSliceIndex(dim, split_size, residual); 548 const int slice_dim = values_inserted_in_slice[slice_index]++; 549 output_values_t[slice_index](slice_dim) = input_values_t(i); 550 for (int j = 0; j < num_dim; ++j) { 551 const int64 original_dim = input_indices_t(i, j); 552 output_indices_t[slice_index](slice_dim, j) = 553 (j == split_dim) 554 ? GetDimensionInSlice(original_dim, split_size, residual) 555 : original_dim; 556 } 557 } 558 559 std::vector<SparseTensor> output_tensors; 560 output_tensors.reserve(num_split); 561 for (int i = 0; i < num_split; ++i) { 562 output_tensors.emplace_back(output_indices[i], output_values[i], 563 output_shapes[i]); 564 } 565 return output_tensors; 566 } 567 568 template <typename T> 569 SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor, 570 const gtl::ArraySlice<int64>& start, 571 const gtl::ArraySlice<int64>& size) { 572 TensorShape output_shape(input_tensor.shape()); 573 574 const int dims = input_tensor.dims(); 575 for (int dim = 0; dim < dims; dim++) { 576 int64 dim_size = start[dim] + size[dim] < output_shape.dim_size(dim) 577 ? size[dim] 578 : output_shape.dim_size(dim) - start[dim]; 579 output_shape.set_dim(dim, dim_size); 580 } 581 582 auto input_indices_t = input_tensor.indices().matrix<int64>(); 583 auto input_values_t = input_tensor.values().vec<T>(); 584 585 // Find the number of indices that fall inside start and size. 586 int count = 0; 587 for (int i = 0; i < input_tensor.indices().dim_size(0); i++) { 588 // The following will check to see if an input is within the 589 // range specified by start and size. 590 // The for loop below iterates through all dimensions. In case 591 // the index falls outside of the start and size at any dimension, 592 // it will be considered as a "no hit" (hit = false). In this 593 // case, it will not be counted as the index that fall inside 594 // the range specified by start and size. 595 bool hit = true; 596 for (int dim = 0; dim < dims; dim++) { 597 if (!(start[dim] <= input_indices_t(i, dim) && 598 input_indices_t(i, dim) < start[dim] + size[dim])) { 599 hit = false; 600 break; 601 } 602 } 603 if (!hit) { 604 continue; 605 } 606 count++; 607 } 608 609 Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count})); 610 Tensor output_indices(DT_INT64, TensorShape({count, dims})); 611 612 auto output_values_t = output_values.vec<T>(); 613 auto output_indices_t = output_indices.matrix<int64>(); 614 615 // Obtain the output indices that fall inside start and size. 616 int index = 0; 617 for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count; 618 i++) { 619 // The logic here is similar as the above except that the above 620 // only count the number of indices while here we actually generate 621 // the output. 622 bool hit = true; 623 for (int dim = 0; dim < dims; dim++) { 624 if (!(start[dim] <= input_indices_t(i, dim) && 625 input_indices_t(i, dim) < start[dim] + size[dim])) { 626 hit = false; 627 break; 628 } 629 } 630 if (!hit) { 631 continue; 632 } 633 output_values_t(index) = input_values_t(i); 634 for (int dim = 0; dim < dims; dim++) { 635 output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim]; 636 } 637 index++; 638 } 639 640 return SparseTensor(output_indices, output_values, output_shape); 641 } 642 643 } // namespace sparse 644 } // namespace tensorflow 645 646 #endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ 647