1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 16 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ 17 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ 18 19 #include <vector> 20 21 #include "tensorflow/contrib/boosted_trees/lib/utils/example.h" 22 #include "tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/util/sparse/sparse_tensor.h" 25 26 namespace tensorflow { 27 namespace boosted_trees { 28 namespace utils { 29 30 // Enables row-wise iteration through examples from feature columns. 31 class ExamplesIterable { 32 public: 33 // Constructs an iterable given the desired examples slice and corresponding 34 // feature columns. 35 ExamplesIterable( 36 const std::vector<Tensor>& dense_float_feature_columns, 37 const std::vector<sparse::SparseTensor>& sparse_float_feature_columns, 38 const std::vector<sparse::SparseTensor>& sparse_int_feature_columns, 39 int64 example_start, int64 example_end); 40 41 // Helper class to iterate through examples. 42 class Iterator { 43 public: 44 Iterator(ExamplesIterable* iter, int64 example_idx); 45 46 Iterator& operator++() { 47 // Advance to next example. 48 ++example_idx_; 49 50 // Update sparse column iterables. 51 for (auto& it : sparse_float_column_iterators_) { 52 ++it; 53 } 54 for (auto& it : sparse_int_column_iterators_) { 55 ++it; 56 } 57 return (*this); 58 } 59 60 Iterator operator++(int) { 61 Iterator tmp(*this); 62 ++(*this); 63 return tmp; 64 } 65 66 bool operator!=(const Iterator& other) const { 67 QCHECK_EQ(iter_, other.iter_); 68 return (example_idx_ != other.example_idx_); 69 } 70 71 bool operator==(const Iterator& other) const { 72 QCHECK_EQ(iter_, other.iter_); 73 return (example_idx_ == other.example_idx_); 74 } 75 76 const Example& operator*() { 77 // Set example index based on iterator. 78 example_.example_idx = example_idx_; 79 80 // Get dense float values per column. 81 auto& dense_float_features = example_.dense_float_features; 82 for (size_t dense_float_idx = 0; 83 dense_float_idx < dense_float_features.size(); ++dense_float_idx) { 84 dense_float_features[dense_float_idx] = 85 iter_->dense_float_column_values_[dense_float_idx](example_idx_, 0); 86 } 87 88 // Get sparse float values per column. 89 auto& sparse_float_features = example_.sparse_float_features; 90 // Iterate through each sparse float feature column. 91 for (size_t sparse_float_idx = 0; 92 sparse_float_idx < iter_->sparse_float_column_iterables_.size(); 93 ++sparse_float_idx) { 94 // Clear info from a previous instance. 95 sparse_float_features[sparse_float_idx].Clear(); 96 97 // Get range for values tensor. 98 const auto& row_range = 99 (*sparse_float_column_iterators_[sparse_float_idx]); 100 DCHECK_EQ(example_idx_, row_range.example_idx); 101 102 // If the example has this feature column. 103 if (row_range.start < row_range.end) { 104 const int32 dimension = 105 iter_->sparse_float_dimensions_[sparse_float_idx]; 106 sparse_float_features[sparse_float_idx].SetDimension(dimension); 107 if (dimension <= 1) { 108 // single dimensional sparse feature column. 109 DCHECK_EQ(1, row_range.end - row_range.start); 110 sparse_float_features[sparse_float_idx].Add( 111 0, iter_->sparse_float_column_values_[sparse_float_idx]( 112 row_range.start)); 113 } else { 114 // Retrieve original indices tensor. 115 const TTypes<int64>::ConstMatrix& indices = 116 iter_->sparse_float_column_iterables_[sparse_float_idx] 117 .sparse_indices(); 118 119 sparse_float_features[sparse_float_idx].Reserve(row_range.end - 120 row_range.start); 121 122 // For each value. 123 for (int64 row_idx = row_range.start; row_idx < row_range.end; 124 ++row_idx) { 125 // Get the feature id for the feature column and the value. 126 const int32 feature_id = indices(row_idx, 1); 127 DCHECK_EQ(example_idx_, indices(row_idx, 0)); 128 129 // Save the value to our sparse matrix. 130 sparse_float_features[sparse_float_idx].Add( 131 feature_id, 132 iter_->sparse_float_column_values_[sparse_float_idx]( 133 row_idx)); 134 } 135 } 136 } 137 } 138 139 // Get sparse int values per column. 140 auto& sparse_int_features = example_.sparse_int_features; 141 for (size_t sparse_int_idx = 0; 142 sparse_int_idx < sparse_int_features.size(); ++sparse_int_idx) { 143 const auto& row_range = (*sparse_int_column_iterators_[sparse_int_idx]); 144 DCHECK_EQ(example_idx_, row_range.example_idx); 145 sparse_int_features[sparse_int_idx].clear(); 146 if (row_range.start < row_range.end) { 147 sparse_int_features[sparse_int_idx].reserve(row_range.end - 148 row_range.start); 149 for (int64 row_idx = row_range.start; row_idx < row_range.end; 150 ++row_idx) { 151 sparse_int_features[sparse_int_idx].insert( 152 iter_->sparse_int_column_values_[sparse_int_idx](row_idx)); 153 } 154 } 155 } 156 157 return example_; 158 } 159 160 private: 161 // Examples iterable (not owned). 162 const ExamplesIterable* iter_; 163 164 // Example index. 165 int64 example_idx_; 166 167 // Sparse float column iterators. 168 std::vector<SparseColumnIterable::Iterator> sparse_float_column_iterators_; 169 170 // Sparse int column iterators. 171 std::vector<SparseColumnIterable::Iterator> sparse_int_column_iterators_; 172 173 // Example placeholder. 174 Example example_; 175 }; 176 177 Iterator begin() { return Iterator(this, example_start_); } 178 Iterator end() { return Iterator(this, example_end_); } 179 180 private: 181 // Example slice spec. 182 const int64 example_start_; 183 const int64 example_end_; 184 185 // Dense float column values. 186 std::vector<TTypes<float>::ConstMatrix> dense_float_column_values_; 187 188 // Sparse float column iterables. 189 std::vector<SparseColumnIterable> sparse_float_column_iterables_; 190 191 // Sparse float column values. 192 std::vector<TTypes<float>::ConstVec> sparse_float_column_values_; 193 194 // Dimensions for sparse float feature columns. 195 std::vector<int32> sparse_float_dimensions_; 196 197 // Sparse int column iterables. 198 std::vector<SparseColumnIterable> sparse_int_column_iterables_; 199 200 // Sparse int column values. 201 std::vector<TTypes<int64>::ConstVec> sparse_int_column_values_; 202 }; 203 204 } // namespace utils 205 } // namespace boosted_trees 206 } // namespace tensorflow 207 208 #endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ 209