1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ 17 #define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ 18 19 #include <vector> 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/lib/core/status.h" 23 #include "tensorflow/core/platform/logging.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 namespace sparse { 28 29 class GroupIterable; // Predeclare GroupIterable for Group. 30 31 // This class is returned when dereferencing a GroupIterable iterator. 32 // It provides the methods group(), indices(), and values(), which 33 // provide access into the underlying SparseTensor. 34 class Group { 35 public: 36 Group(GroupIterable* iter, int64 loc, int64 next_loc) 37 : iter_(iter), loc_(loc), next_loc_(next_loc) {} 38 39 std::vector<int64> group() const; 40 TTypes<int64>::UnalignedConstMatrix indices() const; 41 template <typename T> 42 typename TTypes<T>::UnalignedVec values() const; 43 44 private: 45 GroupIterable* iter_; 46 int64 loc_; 47 int64 next_loc_; 48 }; 49 50 ///////////////// 51 // GroupIterable 52 ///////////////// 53 // 54 // Returned when calling sparse_tensor.group({dim0, dim1, ...}). 55 // 56 // Please note: the sparse_tensor should already be ordered according 57 // to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups. 58 // 59 // Allows grouping and iteration of the SparseTensor according to the 60 // subset of dimensions provided to the group call. 61 // 62 // The actual grouping dimensions are stored in the 63 // internal vector group_dims_. Iterators inside the iterable provide 64 // the three methods: 65 // 66 // * group(): returns a vector with the current group dimension values. 67 // * indices(): a map of index, providing the indices in 68 // this group. 69 // * values(): a map of values, providing the values in 70 // this group. 71 // 72 // To iterate across GroupIterable, see examples in README.md. 73 // 74 75 // Forward declaration of SparseTensor 76 class GroupIterable { 77 public: 78 typedef gtl::ArraySlice<int64> VarDimArray; 79 80 GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims) 81 : ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {} 82 83 class IteratorStep; 84 85 IteratorStep begin() { return IteratorStep(this, 0); } 86 IteratorStep at(int64 loc) { 87 CHECK(loc >= 0 && loc <= ix_.dim_size(0)) 88 << "loc provided must lie between 0 and " << ix_.dim_size(0); 89 return IteratorStep(this, loc); 90 } 91 IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); } 92 93 template <typename TIX> 94 inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const { 95 bool matches = true; 96 for (int d : group_dims_) { 97 if (ix(loc_a, d) != ix(loc_b, d)) { 98 matches = false; 99 } 100 } 101 return matches; 102 } 103 104 class IteratorStep { 105 public: 106 IteratorStep(GroupIterable* iter, int64 loc) 107 : iter_(iter), loc_(loc), next_loc_(loc_) { 108 UpdateEndOfGroup(); 109 } 110 111 void UpdateEndOfGroup(); 112 bool operator!=(const IteratorStep& rhs) const; 113 bool operator==(const IteratorStep& rhs) const; 114 IteratorStep& operator++(); // prefix ++ 115 IteratorStep operator++(int); // postfix ++ 116 Group operator*() const { return Group(iter_, loc_, next_loc_); } 117 int64 loc() const { return loc_; } 118 119 private: 120 GroupIterable* iter_; 121 int64 loc_; 122 int64 next_loc_; 123 }; 124 125 private: 126 friend class Group; 127 Tensor ix_; 128 Tensor vals_; 129 const int dims_; 130 const VarDimArray group_dims_; 131 }; 132 133 // Implementation of Group::values<T>() 134 template <typename T> 135 typename TTypes<T>::UnalignedVec Group::values() const { 136 return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)), 137 next_loc_ - loc_); 138 } 139 140 } // namespace sparse 141 } // namespace tensorflow 142 143 #endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ 144