Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_REF_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 template <typename Dimensions, typename Scalar>
     18 class TensorLazyBaseEvaluator {
     19  public:
     20   TensorLazyBaseEvaluator() : m_refcount(0) { }
     21   virtual ~TensorLazyBaseEvaluator() { }
     22 
     23   EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const = 0;
     24   EIGEN_DEVICE_FUNC virtual const Scalar* data() const = 0;
     25 
     26   EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const = 0;
     27   EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) = 0;
     28 
     29   void incrRefCount() { ++m_refcount; }
     30   void decrRefCount() { --m_refcount; }
     31   int refCount() const { return m_refcount; }
     32 
     33  private:
     34   // No copy, no assigment;
     35   TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other);
     36   TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other);
     37 
     38   int m_refcount;
     39 };
     40 
     41 
     42 template <typename Dimensions, typename Expr, typename Device>
     43 class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
     44  public:
     45   //  typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions;
     46   typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
     47 
     48   TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
     49     m_dims = m_impl.dimensions();
     50     m_impl.evalSubExprsIfNeeded(NULL);
     51   }
     52   virtual ~TensorLazyEvaluatorReadOnly() {
     53     m_impl.cleanup();
     54   }
     55 
     56   EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const {
     57     return m_dims;
     58   }
     59   EIGEN_DEVICE_FUNC virtual const Scalar* data() const {
     60     return m_impl.data();
     61   }
     62 
     63   EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const {
     64     return m_impl.coeff(index);
     65   }
     66   EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex /*index*/) {
     67     eigen_assert(false && "can't reference the coefficient of a rvalue");
     68     return m_dummy;
     69   };
     70 
     71  protected:
     72   TensorEvaluator<Expr, Device> m_impl;
     73   Dimensions m_dims;
     74   Scalar m_dummy;
     75 };
     76 
     77 template <typename Dimensions, typename Expr, typename Device>
     78 class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
     79  public:
     80   typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
     81   typedef typename Base::Scalar Scalar;
     82 
     83   TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) {
     84   }
     85   virtual ~TensorLazyEvaluatorWritable() {
     86   }
     87 
     88   EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) {
     89     return this->m_impl.coeffRef(index);
     90   }
     91 };
     92 
     93 template <typename Dimensions, typename Expr, typename Device>
     94 class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value),
     95                             TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
     96                             TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type {
     97  public:
     98   typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value),
     99                                          TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
    100                                          TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base;
    101   typedef typename Base::Scalar Scalar;
    102 
    103   TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) {
    104   }
    105   virtual ~TensorLazyEvaluator() {
    106   }
    107 };
    108 
    109 }  // namespace internal
    110 
    111 
    112 /** \class TensorRef
    113   * \ingroup CXX11_Tensor_Module
    114   *
    115   * \brief A reference to a tensor expression
    116   * The expression will be evaluated lazily (as much as possible).
    117   *
    118   */
    119 template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> >
    120 {
    121   public:
    122     typedef TensorRef<PlainObjectType> Self;
    123     typedef typename PlainObjectType::Base Base;
    124     typedef typename Eigen::internal::nested<Self>::type Nested;
    125     typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
    126     typedef typename internal::traits<PlainObjectType>::Index Index;
    127     typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
    128     typedef typename NumTraits<Scalar>::Real RealScalar;
    129     typedef typename Base::CoeffReturnType CoeffReturnType;
    130     typedef Scalar* PointerType;
    131     typedef PointerType PointerArgType;
    132 
    133     static const Index NumIndices = PlainObjectType::NumIndices;
    134     typedef typename PlainObjectType::Dimensions Dimensions;
    135 
    136     enum {
    137       IsAligned = false,
    138       PacketAccess = false,
    139       Layout = PlainObjectType::Layout,
    140       CoordAccess = false,  // to be implemented
    141       RawAccess = false
    142     };
    143 
    144     EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
    145     }
    146 
    147     template <typename Expression>
    148     EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
    149       m_evaluator->incrRefCount();
    150     }
    151 
    152     template <typename Expression>
    153     EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) {
    154       unrefEvaluator();
    155       m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
    156       m_evaluator->incrRefCount();
    157       return *this;
    158     }
    159 
    160     ~TensorRef() {
    161       unrefEvaluator();
    162     }
    163 
    164     TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) {
    165       eigen_assert(m_evaluator->refCount() > 0);
    166       m_evaluator->incrRefCount();
    167     }
    168 
    169     TensorRef& operator = (const TensorRef& other) {
    170       if (this != &other) {
    171         unrefEvaluator();
    172         m_evaluator = other.m_evaluator;
    173         eigen_assert(m_evaluator->refCount() > 0);
    174         m_evaluator->incrRefCount();
    175       }
    176       return *this;
    177     }
    178 
    179     EIGEN_DEVICE_FUNC
    180     EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
    181     EIGEN_DEVICE_FUNC
    182     EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
    183     EIGEN_DEVICE_FUNC
    184     EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
    185     EIGEN_DEVICE_FUNC
    186     EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); }
    187     EIGEN_DEVICE_FUNC
    188     EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); }
    189 
    190     EIGEN_DEVICE_FUNC
    191     EIGEN_STRONG_INLINE const Scalar operator()(Index index) const
    192     {
    193       return m_evaluator->coeff(index);
    194     }
    195 
    196 #if EIGEN_HAS_VARIADIC_TEMPLATES
    197     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
    198     EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const
    199     {
    200       const std::size_t num_indices = (sizeof...(otherIndices) + 1);
    201       const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
    202       return coeff(indices);
    203     }
    204     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
    205     EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
    206     {
    207       const std::size_t num_indices = (sizeof...(otherIndices) + 1);
    208       const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
    209       return coeffRef(indices);
    210     }
    211 #else
    212 
    213     EIGEN_DEVICE_FUNC
    214     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const
    215     {
    216       array<Index, 2> indices;
    217       indices[0] = i0;
    218       indices[1] = i1;
    219       return coeff(indices);
    220     }
    221     EIGEN_DEVICE_FUNC
    222     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const
    223     {
    224       array<Index, 3> indices;
    225       indices[0] = i0;
    226       indices[1] = i1;
    227       indices[2] = i2;
    228       return coeff(indices);
    229     }
    230     EIGEN_DEVICE_FUNC
    231     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const
    232     {
    233       array<Index, 4> indices;
    234       indices[0] = i0;
    235       indices[1] = i1;
    236       indices[2] = i2;
    237       indices[3] = i3;
    238       return coeff(indices);
    239     }
    240     EIGEN_DEVICE_FUNC
    241     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
    242     {
    243       array<Index, 5> indices;
    244       indices[0] = i0;
    245       indices[1] = i1;
    246       indices[2] = i2;
    247       indices[3] = i3;
    248       indices[4] = i4;
    249       return coeff(indices);
    250     }
    251     EIGEN_DEVICE_FUNC
    252     EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
    253     {
    254       array<Index, 2> indices;
    255       indices[0] = i0;
    256       indices[1] = i1;
    257       return coeffRef(indices);
    258     }
    259     EIGEN_DEVICE_FUNC
    260     EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
    261     {
    262       array<Index, 3> indices;
    263       indices[0] = i0;
    264       indices[1] = i1;
    265       indices[2] = i2;
    266       return coeffRef(indices);
    267     }
    268     EIGEN_DEVICE_FUNC
    269     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
    270     {
    271       array<Index, 4> indices;
    272       indices[0] = i0;
    273       indices[1] = i1;
    274       indices[2] = i2;
    275       indices[3] = i3;
    276       return coeffRef(indices);
    277     }
    278     EIGEN_DEVICE_FUNC
    279     EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
    280     {
    281       array<Index, 5> indices;
    282       indices[0] = i0;
    283       indices[1] = i1;
    284       indices[2] = i2;
    285       indices[3] = i3;
    286       indices[4] = i4;
    287       return coeffRef(indices);
    288     }
    289 #endif
    290 
    291     template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
    292     EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const
    293     {
    294       const Dimensions& dims = this->dimensions();
    295       Index index = 0;
    296       if (PlainObjectType::Options & RowMajor) {
    297         index += indices[0];
    298         for (size_t i = 1; i < NumIndices; ++i) {
    299           index = index * dims[i] + indices[i];
    300         }
    301       } else {
    302         index += indices[NumIndices-1];
    303         for (int i = NumIndices-2; i >= 0; --i) {
    304           index = index * dims[i] + indices[i];
    305         }
    306       }
    307       return m_evaluator->coeff(index);
    308     }
    309     template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
    310     EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
    311     {
    312       const Dimensions& dims = this->dimensions();
    313       Index index = 0;
    314       if (PlainObjectType::Options & RowMajor) {
    315         index += indices[0];
    316         for (size_t i = 1; i < NumIndices; ++i) {
    317           index = index * dims[i] + indices[i];
    318         }
    319       } else {
    320         index += indices[NumIndices-1];
    321         for (int i = NumIndices-2; i >= 0; --i) {
    322           index = index * dims[i] + indices[i];
    323         }
    324       }
    325       return m_evaluator->coeffRef(index);
    326     }
    327 
    328     EIGEN_DEVICE_FUNC
    329     EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
    330     {
    331       return m_evaluator->coeff(index);
    332     }
    333 
    334     EIGEN_DEVICE_FUNC
    335     EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
    336     {
    337       return m_evaluator->coeffRef(index);
    338     }
    339 
    340   private:
    341     EIGEN_STRONG_INLINE void unrefEvaluator() {
    342       if (m_evaluator) {
    343         m_evaluator->decrRefCount();
    344         if (m_evaluator->refCount() == 0) {
    345           delete m_evaluator;
    346         }
    347       }
    348     }
    349 
    350   internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
    351 };
    352 
    353 
    354 // evaluator for rvalues
    355 template<typename Derived, typename Device>
    356 struct TensorEvaluator<const TensorRef<Derived>, Device>
    357 {
    358   typedef typename Derived::Index Index;
    359   typedef typename Derived::Scalar Scalar;
    360   typedef typename Derived::Scalar CoeffReturnType;
    361   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    362   typedef typename Derived::Dimensions Dimensions;
    363 
    364   enum {
    365     IsAligned = false,
    366     PacketAccess = false,
    367     Layout = TensorRef<Derived>::Layout,
    368     CoordAccess = false,  // to be implemented
    369     RawAccess = false
    370   };
    371 
    372   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
    373       : m_ref(m)
    374   { }
    375 
    376   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); }
    377 
    378   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
    379     return true;
    380   }
    381 
    382   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
    383 
    384   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
    385     return m_ref.coeff(index);
    386   }
    387 
    388   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
    389     return m_ref.coeffRef(index);
    390   }
    391 
    392   EIGEN_DEVICE_FUNC Scalar* data() const { return m_ref.data(); }
    393 
    394  protected:
    395   TensorRef<Derived> m_ref;
    396 };
    397 
    398 
    399 // evaluator for lvalues
    400 template<typename Derived, typename Device>
    401 struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device>
    402 {
    403   typedef typename Derived::Index Index;
    404   typedef typename Derived::Scalar Scalar;
    405   typedef typename Derived::Scalar CoeffReturnType;
    406   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    407   typedef typename Derived::Dimensions Dimensions;
    408 
    409   typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
    410 
    411   enum {
    412     IsAligned = false,
    413     PacketAccess = false,
    414     RawAccess = false
    415   };
    416 
    417   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d)
    418   { }
    419 
    420   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
    421     return this->m_ref.coeffRef(index);
    422   }
    423 };
    424 
    425 
    426 
    427 } // end namespace Eigen
    428 
    429 #endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H
    430