Home | History | Annotate | Download | only in utils
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 
     16 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
     17 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
     18 
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/framework/tensor_types.h"
     21 #include "tensorflow/core/platform/logging.h"
     22 #include "tensorflow/core/platform/types.h"
     23 
     24 namespace tensorflow {
     25 namespace boosted_trees {
     26 namespace utils {
     27 
     28 // Enables row-wise iteration through examples on sparse feature columns.
     29 class SparseColumnIterable {
     30  public:
     31   // Indicates a contiguous range for an example: [start, end).
     32   struct ExampleRowRange {
     33     int64 example_idx;
     34     int64 start;
     35     int64 end;
     36   };
     37 
     38   // Helper class to iterate through examples and return the corresponding
     39   // indices row range. Note that the row range can be empty in case a given
     40   // example has no corresponding indices.
     41   // An Iterator can be initialized from any example start offset, the
     42   // corresponding range indicators will be initialized in log time.
     43   class Iterator {
     44    public:
     45     Iterator(SparseColumnIterable* iter, int64 example_idx);
     46 
     47     Iterator& operator++() {
     48       ++example_idx_;
     49       if (cur_ < end_ && iter_->ix()(cur_, 0) < example_idx_) {
     50         cur_ = next_;
     51         UpdateNext();
     52       }
     53       return (*this);
     54     }
     55 
     56     Iterator operator++(int) {
     57       Iterator tmp(*this);
     58       ++(*this);
     59       return tmp;
     60     }
     61 
     62     bool operator!=(const Iterator& other) const {
     63       QCHECK_EQ(iter_, other.iter_);
     64       return (example_idx_ != other.example_idx_);
     65     }
     66 
     67     bool operator==(const Iterator& other) const {
     68       QCHECK_EQ(iter_, other.iter_);
     69       return (example_idx_ == other.example_idx_);
     70     }
     71 
     72     const ExampleRowRange& operator*() {
     73       range_.example_idx = example_idx_;
     74       if (cur_ < end_ && iter_->ix()(cur_, 0) == example_idx_) {
     75         range_.start = cur_;
     76         range_.end = next_;
     77       } else {
     78         range_.start = 0;
     79         range_.end = 0;
     80       }
     81       return range_;
     82     }
     83 
     84    private:
     85     void UpdateNext() {
     86       next_ = std::min(next_ + 1, end_);
     87       while (next_ < end_ && iter_->ix()(cur_, 0) == iter_->ix()(next_, 0)) {
     88         ++next_;
     89       }
     90     }
     91 
     92     const SparseColumnIterable* iter_;
     93     int64 example_idx_;
     94     int64 cur_;
     95     int64 next_;
     96     const int64 end_;
     97     ExampleRowRange range_;
     98   };
     99 
    100   // Constructs an iterable given the desired examples slice and corresponding
    101   // feature columns.
    102   SparseColumnIterable(TTypes<int64>::ConstMatrix ix, int64 example_start,
    103                        int64 example_end)
    104       : ix_(ix), example_start_(example_start), example_end_(example_end) {
    105     QCHECK(example_start >= 0 && example_end >= 0);
    106   }
    107 
    108   Iterator begin() { return Iterator(this, example_start_); }
    109   Iterator end() { return Iterator(this, example_end_); }
    110 
    111   const TTypes<int64>::ConstMatrix& ix() const { return ix_; }
    112   int64 example_start() const { return example_start_; }
    113   int64 example_end() const { return example_end_; }
    114 
    115   const TTypes<int64>::ConstMatrix& sparse_indices() const { return ix_; }
    116 
    117  private:
    118   // Sparse indices matrix.
    119   TTypes<int64>::ConstMatrix ix_;
    120 
    121   // Example slice spec.
    122   const int64 example_start_;
    123   const int64 example_end_;
    124 };
    125 
    126 }  // namespace utils
    127 }  // namespace boosted_trees
    128 }  // namespace tensorflow
    129 
    130 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
    131