Home | History | Annotate | Download | only in utils
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 
     16 #include "tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h"
     17 
     18 namespace tensorflow {
     19 namespace boosted_trees {
     20 namespace utils {
     21 
     22 using ExampleRowRange = SparseColumnIterable::ExampleRowRange;
     23 using Iterator = SparseColumnIterable::Iterator;
     24 
     25 namespace {
     26 
     27 // Iterator over indices matrix rows.
     28 class IndicesRowIterator
     29     : public std::iterator<std::random_access_iterator_tag, const int64> {
     30  public:
     31   IndicesRowIterator() : iter_(nullptr), row_idx_(-1) {}
     32   IndicesRowIterator(SparseColumnIterable* iter, int row_idx)
     33       : iter_(iter), row_idx_(row_idx) {}
     34   IndicesRowIterator(const IndicesRowIterator& other)
     35       : iter_(other.iter_), row_idx_(other.row_idx_) {}
     36 
     37   IndicesRowIterator& operator=(const IndicesRowIterator& other) {
     38     iter_ = other.iter_;
     39     row_idx_ = other.row_idx_;
     40     return (*this);
     41   }
     42 
     43   IndicesRowIterator& operator++() {
     44     ++row_idx_;
     45     return (*this);
     46   }
     47 
     48   IndicesRowIterator operator++(int) {
     49     IndicesRowIterator tmp(*this);
     50     ++row_idx_;
     51     return tmp;
     52   }
     53 
     54   reference operator*() const { return iter_->ix()(row_idx_, 0); }
     55 
     56   pointer operator->() { return &iter_->ix()(row_idx_, 0); }
     57 
     58   IndicesRowIterator& operator--() {
     59     --row_idx_;
     60     return (*this);
     61   }
     62 
     63   IndicesRowIterator operator--(int) {
     64     IndicesRowIterator tmp(*this);
     65     --row_idx_;
     66     return tmp;
     67   }
     68 
     69   IndicesRowIterator& operator+=(const difference_type& step) {
     70     row_idx_ += step;
     71     return (*this);
     72   }
     73   IndicesRowIterator& operator-=(const difference_type& step) {
     74     row_idx_ -= step;
     75     return (*this);
     76   }
     77 
     78   IndicesRowIterator operator+(const difference_type& step) const {
     79     IndicesRowIterator tmp(*this);
     80     tmp += step;
     81     return tmp;
     82   }
     83 
     84   IndicesRowIterator operator-(const difference_type& step) const {
     85     IndicesRowIterator tmp(*this);
     86     tmp -= step;
     87     return tmp;
     88   }
     89 
     90   difference_type operator-(const IndicesRowIterator& other) const {
     91     return row_idx_ - other.row_idx_;
     92   }
     93 
     94   bool operator!=(const IndicesRowIterator& other) const {
     95     QCHECK_EQ(iter_, other.iter_);
     96     return (row_idx_ != other.row_idx_);
     97   }
     98 
     99   bool operator<(const IndicesRowIterator& other) const {
    100     return (row_idx_ < other.row_idx_);
    101   }
    102 
    103   bool operator==(const IndicesRowIterator& other) const {
    104     QCHECK_EQ(iter_, other.iter_);
    105     return (row_idx_ == other.row_idx_);
    106   }
    107 
    108   Eigen::Index row_idx() const { return row_idx_; }
    109 
    110  private:
    111   SparseColumnIterable* iter_;
    112   Eigen::Index row_idx_;
    113 };
    114 }  // namespace
    115 
    116 Iterator::Iterator(SparseColumnIterable* iter, int64 example_idx)
    117     : iter_(iter), example_idx_(example_idx), end_(iter->ix_.dimension(0)) {
    118   cur_ = next_ = std::lower_bound(IndicesRowIterator(iter, 0),
    119                                   IndicesRowIterator(iter, end_), example_idx_)
    120                      .row_idx();
    121   UpdateNext();
    122 }
    123 
    124 }  // namespace utils
    125 }  // namespace boosted_trees
    126 }  // namespace tensorflow
    127