Home | History | Annotate | Download | only in sparse
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
     17 #define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
     18 
     19 #include <vector>
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/lib/core/status.h"
     23 #include "tensorflow/core/platform/logging.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 namespace tensorflow {
     27 namespace sparse {
     28 
     29 class GroupIterable;  // Predeclare GroupIterable for Group.
     30 
     31 // This class is returned when dereferencing a GroupIterable iterator.
     32 // It provides the methods group(), indices(), and values(), which
     33 // provide access into the underlying SparseTensor.
     34 class Group {
     35  public:
     36   Group(GroupIterable* iter, int64 loc, int64 next_loc)
     37       : iter_(iter), loc_(loc), next_loc_(next_loc) {}
     38 
     39   std::vector<int64> group() const;
     40   TTypes<int64>::UnalignedConstMatrix indices() const;
     41   template <typename T>
     42   typename TTypes<T>::UnalignedVec values() const;
     43 
     44  private:
     45   GroupIterable* iter_;
     46   int64 loc_;
     47   int64 next_loc_;
     48 };
     49 
     50 /////////////////
     51 // GroupIterable
     52 /////////////////
     53 //
     54 // Returned when calling sparse_tensor.group({dim0, dim1, ...}).
     55 //
     56 // Please note: the sparse_tensor should already be ordered according
     57 // to {dim0, dim1, ...}.  Otherwise this iteration will return invalid groups.
     58 //
     59 // Allows grouping and iteration of the SparseTensor according to the
     60 // subset of dimensions provided to the group call.
     61 //
     62 // The actual grouping dimensions are stored in the
     63 // internal vector group_dims_.  Iterators inside the iterable provide
     64 // the three methods:
     65 //
     66 // *  group(): returns a vector with the current group dimension values.
     67 // *  indices(): a map of index, providing the indices in
     68 //    this group.
     69 // *  values(): a map of values, providing the values in
     70 //    this group.
     71 //
     72 // To iterate across GroupIterable, see examples in README.md.
     73 //
     74 
     75 // Forward declaration of SparseTensor
     76 class GroupIterable {
     77  public:
     78   typedef gtl::ArraySlice<int64> VarDimArray;
     79 
     80   GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
     81       : ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {}
     82 
     83   class IteratorStep;
     84 
     85   IteratorStep begin() { return IteratorStep(this, 0); }
     86   IteratorStep at(int64 loc) {
     87     CHECK(loc >= 0 && loc <= ix_.dim_size(0))
     88         << "loc provided must lie between 0 and " << ix_.dim_size(0);
     89     return IteratorStep(this, loc);
     90   }
     91   IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }
     92 
     93   template <typename TIX>
     94   inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
     95     bool matches = true;
     96     for (int d : group_dims_) {
     97       if (ix(loc_a, d) != ix(loc_b, d)) {
     98         matches = false;
     99       }
    100     }
    101     return matches;
    102   }
    103 
    104   class IteratorStep {
    105    public:
    106     IteratorStep(GroupIterable* iter, int64 loc)
    107         : iter_(iter), loc_(loc), next_loc_(loc_) {
    108       UpdateEndOfGroup();
    109     }
    110 
    111     void UpdateEndOfGroup();
    112     bool operator!=(const IteratorStep& rhs) const;
    113     bool operator==(const IteratorStep& rhs) const;
    114     IteratorStep& operator++();    // prefix ++
    115     IteratorStep operator++(int);  // postfix ++
    116     Group operator*() const { return Group(iter_, loc_, next_loc_); }
    117     int64 loc() const { return loc_; }
    118 
    119    private:
    120     GroupIterable* iter_;
    121     int64 loc_;
    122     int64 next_loc_;
    123   };
    124 
    125  private:
    126   friend class Group;
    127   Tensor ix_;
    128   Tensor vals_;
    129   const int dims_;
    130   const VarDimArray group_dims_;
    131 };
    132 
    133 // Implementation of Group::values<T>()
    134 template <typename T>
    135 typename TTypes<T>::UnalignedVec Group::values() const {
    136   return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
    137                                           next_loc_ - loc_);
    138 }
    139 
    140 }  // namespace sparse
    141 }  // namespace tensorflow
    142 
    143 #endif  // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
    144