1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 16 #include "tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h" 17 18 namespace tensorflow { 19 namespace boosted_trees { 20 namespace utils { 21 22 using ExampleRowRange = SparseColumnIterable::ExampleRowRange; 23 using Iterator = SparseColumnIterable::Iterator; 24 25 namespace { 26 27 // Iterator over indices matrix rows. 28 class IndicesRowIterator 29 : public std::iterator<std::random_access_iterator_tag, const int64> { 30 public: 31 IndicesRowIterator() : iter_(nullptr), row_idx_(-1) {} 32 IndicesRowIterator(SparseColumnIterable* iter, int row_idx) 33 : iter_(iter), row_idx_(row_idx) {} 34 IndicesRowIterator(const IndicesRowIterator& other) 35 : iter_(other.iter_), row_idx_(other.row_idx_) {} 36 37 IndicesRowIterator& operator=(const IndicesRowIterator& other) { 38 iter_ = other.iter_; 39 row_idx_ = other.row_idx_; 40 return (*this); 41 } 42 43 IndicesRowIterator& operator++() { 44 ++row_idx_; 45 return (*this); 46 } 47 48 IndicesRowIterator operator++(int) { 49 IndicesRowIterator tmp(*this); 50 ++row_idx_; 51 return tmp; 52 } 53 54 reference operator*() const { return iter_->ix()(row_idx_, 0); } 55 56 pointer operator->() { return &iter_->ix()(row_idx_, 0); } 57 58 IndicesRowIterator& operator--() { 59 --row_idx_; 60 return (*this); 61 } 62 63 IndicesRowIterator operator--(int) { 64 IndicesRowIterator tmp(*this); 65 --row_idx_; 66 return tmp; 67 } 68 69 IndicesRowIterator& operator+=(const difference_type& step) { 70 row_idx_ += step; 71 return (*this); 72 } 73 IndicesRowIterator& operator-=(const difference_type& step) { 74 row_idx_ -= step; 75 return (*this); 76 } 77 78 IndicesRowIterator operator+(const difference_type& step) const { 79 IndicesRowIterator tmp(*this); 80 tmp += step; 81 return tmp; 82 } 83 84 IndicesRowIterator operator-(const difference_type& step) const { 85 IndicesRowIterator tmp(*this); 86 tmp -= step; 87 return tmp; 88 } 89 90 difference_type operator-(const IndicesRowIterator& other) const { 91 return row_idx_ - other.row_idx_; 92 } 93 94 bool operator!=(const IndicesRowIterator& other) const { 95 QCHECK_EQ(iter_, other.iter_); 96 return (row_idx_ != other.row_idx_); 97 } 98 99 bool operator<(const IndicesRowIterator& other) const { 100 return (row_idx_ < other.row_idx_); 101 } 102 103 bool operator==(const IndicesRowIterator& other) const { 104 QCHECK_EQ(iter_, other.iter_); 105 return (row_idx_ == other.row_idx_); 106 } 107 108 Eigen::Index row_idx() const { return row_idx_; } 109 110 private: 111 SparseColumnIterable* iter_; 112 Eigen::Index row_idx_; 113 }; 114 } // namespace 115 116 Iterator::Iterator(SparseColumnIterable* iter, int64 example_idx) 117 : iter_(iter), example_idx_(example_idx), end_(iter->ix_.dimension(0)) { 118 cur_ = next_ = std::lower_bound(IndicesRowIterator(iter, 0), 119 IndicesRowIterator(iter, end_), example_idx_) 120 .row_idx(); 121 UpdateNext(); 122 } 123 124 } // namespace utils 125 } // namespace boosted_trees 126 } // namespace tensorflow 127