Home | History | Annotate | Download | only in utils
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 
     16 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
     17 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
     18 
     19 #include <vector>
     20 
     21 #include "tensorflow/contrib/boosted_trees/lib/utils/example.h"
     22 #include "tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/util/sparse/sparse_tensor.h"
     25 
     26 namespace tensorflow {
     27 namespace boosted_trees {
     28 namespace utils {
     29 
     30 // Enables row-wise iteration through examples from feature columns.
     31 class ExamplesIterable {
     32  public:
     33   // Constructs an iterable given the desired examples slice and corresponding
     34   // feature columns.
     35   ExamplesIterable(
     36       const std::vector<Tensor>& dense_float_feature_columns,
     37       const std::vector<sparse::SparseTensor>& sparse_float_feature_columns,
     38       const std::vector<sparse::SparseTensor>& sparse_int_feature_columns,
     39       int64 example_start, int64 example_end);
     40 
     41   // Helper class to iterate through examples.
     42   class Iterator {
     43    public:
     44     Iterator(ExamplesIterable* iter, int64 example_idx);
     45 
     46     Iterator& operator++() {
     47       // Advance to next example.
     48       ++example_idx_;
     49 
     50       // Update sparse column iterables.
     51       for (auto& it : sparse_float_column_iterators_) {
     52         ++it;
     53       }
     54       for (auto& it : sparse_int_column_iterators_) {
     55         ++it;
     56       }
     57       return (*this);
     58     }
     59 
     60     Iterator operator++(int) {
     61       Iterator tmp(*this);
     62       ++(*this);
     63       return tmp;
     64     }
     65 
     66     bool operator!=(const Iterator& other) const {
     67       QCHECK_EQ(iter_, other.iter_);
     68       return (example_idx_ != other.example_idx_);
     69     }
     70 
     71     bool operator==(const Iterator& other) const {
     72       QCHECK_EQ(iter_, other.iter_);
     73       return (example_idx_ == other.example_idx_);
     74     }
     75 
     76     const Example& operator*() {
     77       // Set example index based on iterator.
     78       example_.example_idx = example_idx_;
     79 
     80       // Get dense float values per column.
     81       auto& dense_float_features = example_.dense_float_features;
     82       for (size_t dense_float_idx = 0;
     83            dense_float_idx < dense_float_features.size(); ++dense_float_idx) {
     84         dense_float_features[dense_float_idx] =
     85             iter_->dense_float_column_values_[dense_float_idx](example_idx_, 0);
     86       }
     87 
     88       // Get sparse float values per column.
     89       auto& sparse_float_features = example_.sparse_float_features;
     90       // Iterate through each sparse float feature column.
     91       for (size_t sparse_float_idx = 0;
     92            sparse_float_idx < iter_->sparse_float_column_iterables_.size();
     93            ++sparse_float_idx) {
     94         // Clear info from a previous instance.
     95         sparse_float_features[sparse_float_idx].Clear();
     96 
     97         // Get range for values tensor.
     98         const auto& row_range =
     99             (*sparse_float_column_iterators_[sparse_float_idx]);
    100         DCHECK_EQ(example_idx_, row_range.example_idx);
    101 
    102         // If the example has this feature column.
    103         if (row_range.start < row_range.end) {
    104           const int32 dimension =
    105               iter_->sparse_float_dimensions_[sparse_float_idx];
    106           sparse_float_features[sparse_float_idx].SetDimension(dimension);
    107           if (dimension <= 1) {
    108             // single dimensional sparse feature column.
    109             DCHECK_EQ(1, row_range.end - row_range.start);
    110             sparse_float_features[sparse_float_idx].Add(
    111                 0, iter_->sparse_float_column_values_[sparse_float_idx](
    112                        row_range.start));
    113           } else {
    114             // Retrieve original indices tensor.
    115             const TTypes<int64>::ConstMatrix& indices =
    116                 iter_->sparse_float_column_iterables_[sparse_float_idx]
    117                     .sparse_indices();
    118 
    119             sparse_float_features[sparse_float_idx].Reserve(row_range.end -
    120                                                             row_range.start);
    121 
    122             // For each value.
    123             for (int64 row_idx = row_range.start; row_idx < row_range.end;
    124                  ++row_idx) {
    125               // Get the feature id for the feature column and the value.
    126               const int32 feature_id = indices(row_idx, 1);
    127               DCHECK_EQ(example_idx_, indices(row_idx, 0));
    128 
    129               // Save the value to our sparse matrix.
    130               sparse_float_features[sparse_float_idx].Add(
    131                   feature_id,
    132                   iter_->sparse_float_column_values_[sparse_float_idx](
    133                       row_idx));
    134             }
    135           }
    136         }
    137       }
    138 
    139       // Get sparse int values per column.
    140       auto& sparse_int_features = example_.sparse_int_features;
    141       for (size_t sparse_int_idx = 0;
    142            sparse_int_idx < sparse_int_features.size(); ++sparse_int_idx) {
    143         const auto& row_range = (*sparse_int_column_iterators_[sparse_int_idx]);
    144         DCHECK_EQ(example_idx_, row_range.example_idx);
    145         sparse_int_features[sparse_int_idx].clear();
    146         if (row_range.start < row_range.end) {
    147           sparse_int_features[sparse_int_idx].reserve(row_range.end -
    148                                                       row_range.start);
    149           for (int64 row_idx = row_range.start; row_idx < row_range.end;
    150                ++row_idx) {
    151             sparse_int_features[sparse_int_idx].insert(
    152                 iter_->sparse_int_column_values_[sparse_int_idx](row_idx));
    153           }
    154         }
    155       }
    156 
    157       return example_;
    158     }
    159 
    160    private:
    161     // Examples iterable (not owned).
    162     const ExamplesIterable* iter_;
    163 
    164     // Example index.
    165     int64 example_idx_;
    166 
    167     // Sparse float column iterators.
    168     std::vector<SparseColumnIterable::Iterator> sparse_float_column_iterators_;
    169 
    170     // Sparse int column iterators.
    171     std::vector<SparseColumnIterable::Iterator> sparse_int_column_iterators_;
    172 
    173     // Example placeholder.
    174     Example example_;
    175   };
    176 
    177   Iterator begin() { return Iterator(this, example_start_); }
    178   Iterator end() { return Iterator(this, example_end_); }
    179 
    180  private:
    181   // Example slice spec.
    182   const int64 example_start_;
    183   const int64 example_end_;
    184 
    185   // Dense float column values.
    186   std::vector<TTypes<float>::ConstMatrix> dense_float_column_values_;
    187 
    188   // Sparse float column iterables.
    189   std::vector<SparseColumnIterable> sparse_float_column_iterables_;
    190 
    191   // Sparse float column values.
    192   std::vector<TTypes<float>::ConstVec> sparse_float_column_values_;
    193 
    194   // Dimensions for sparse float feature columns.
    195   std::vector<int32> sparse_float_dimensions_;
    196 
    197   // Sparse int column iterables.
    198   std::vector<SparseColumnIterable> sparse_int_column_iterables_;
    199 
    200   // Sparse int column values.
    201   std::vector<TTypes<int64>::ConstVec> sparse_int_column_values_;
    202 };
    203 
    204 }  // namespace utils
    205 }  // namespace boosted_trees
    206 }  // namespace tensorflow
    207 
    208 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
    209