1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 16 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ 17 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ 18 19 #include "tensorflow/core/framework/tensor.h" 20 #include "tensorflow/core/framework/tensor_types.h" 21 #include "tensorflow/core/platform/logging.h" 22 #include "tensorflow/core/platform/types.h" 23 24 namespace tensorflow { 25 namespace boosted_trees { 26 namespace utils { 27 28 // Enables row-wise iteration through examples on sparse feature columns. 29 class SparseColumnIterable { 30 public: 31 // Indicates a contiguous range for an example: [start, end). 32 struct ExampleRowRange { 33 int64 example_idx; 34 int64 start; 35 int64 end; 36 }; 37 38 // Helper class to iterate through examples and return the corresponding 39 // indices row range. Note that the row range can be empty in case a given 40 // example has no corresponding indices. 41 // An Iterator can be initialized from any example start offset, the 42 // corresponding range indicators will be initialized in log time. 43 class Iterator { 44 public: 45 Iterator(SparseColumnIterable* iter, int64 example_idx); 46 47 Iterator& operator++() { 48 ++example_idx_; 49 if (cur_ < end_ && iter_->ix()(cur_, 0) < example_idx_) { 50 cur_ = next_; 51 UpdateNext(); 52 } 53 return (*this); 54 } 55 56 Iterator operator++(int) { 57 Iterator tmp(*this); 58 ++(*this); 59 return tmp; 60 } 61 62 bool operator!=(const Iterator& other) const { 63 QCHECK_EQ(iter_, other.iter_); 64 return (example_idx_ != other.example_idx_); 65 } 66 67 bool operator==(const Iterator& other) const { 68 QCHECK_EQ(iter_, other.iter_); 69 return (example_idx_ == other.example_idx_); 70 } 71 72 const ExampleRowRange& operator*() { 73 range_.example_idx = example_idx_; 74 if (cur_ < end_ && iter_->ix()(cur_, 0) == example_idx_) { 75 range_.start = cur_; 76 range_.end = next_; 77 } else { 78 range_.start = 0; 79 range_.end = 0; 80 } 81 return range_; 82 } 83 84 private: 85 void UpdateNext() { 86 next_ = std::min(next_ + 1, end_); 87 while (next_ < end_ && iter_->ix()(cur_, 0) == iter_->ix()(next_, 0)) { 88 ++next_; 89 } 90 } 91 92 const SparseColumnIterable* iter_; 93 int64 example_idx_; 94 int64 cur_; 95 int64 next_; 96 const int64 end_; 97 ExampleRowRange range_; 98 }; 99 100 // Constructs an iterable given the desired examples slice and corresponding 101 // feature columns. 102 SparseColumnIterable(TTypes<int64>::ConstMatrix ix, int64 example_start, 103 int64 example_end) 104 : ix_(ix), example_start_(example_start), example_end_(example_end) { 105 QCHECK(example_start >= 0 && example_end >= 0); 106 } 107 108 Iterator begin() { return Iterator(this, example_start_); } 109 Iterator end() { return Iterator(this, example_end_); } 110 111 const TTypes<int64>::ConstMatrix& ix() const { return ix_; } 112 int64 example_start() const { return example_start_; } 113 int64 example_end() const { return example_end_; } 114 115 const TTypes<int64>::ConstMatrix& sparse_indices() const { return ix_; } 116 117 private: 118 // Sparse indices matrix. 119 TTypes<int64>::ConstMatrix ix_; 120 121 // Example slice spec. 122 const int64 example_start_; 123 const int64 example_end_; 124 }; 125 126 } // namespace utils 127 } // namespace boosted_trees 128 } // namespace tensorflow 129 130 #endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ 131