Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_MAP_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_MAP_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorMap
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief A tensor expression mapping an existing array of data.
     19   *
     20   */
     21 /// template <class> class MakePointer_ is added to convert the host pointer to the device pointer.
     22 /// It is added due to the fact that for our device compiler T* is not allowed.
     23 /// If we wanted to use the same Evaluator functions we have to convert that type to our pointer T.
     24 /// This is done through our MakePointer_ class. By default the Type in the MakePointer_<T> is T* .
     25 /// Therefore, by adding the default value, we managed to convert the type and it does not break any
     26 /// existing code as its default value is T*.
     27 template<typename PlainObjectType, int Options_, template <class> class MakePointer_> class TensorMap : public TensorBase<TensorMap<PlainObjectType, Options_, MakePointer_> >
     28 {
     29   public:
     30     typedef TensorMap<PlainObjectType, Options_, MakePointer_> Self;
     31     typedef typename PlainObjectType::Base Base;
     32     typedef typename Eigen::internal::nested<Self>::type Nested;
     33     typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
     34     typedef typename internal::traits<PlainObjectType>::Index Index;
     35     typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
     36     typedef typename NumTraits<Scalar>::Real RealScalar;
     37     typedef typename Base::CoeffReturnType CoeffReturnType;
     38 
     39   /*    typedef typename internal::conditional<
     40                          bool(internal::is_lvalue<PlainObjectType>::value),
     41                          Scalar *,
     42                          const Scalar *>::type
     43                      PointerType;*/
     44     typedef typename MakePointer_<Scalar>::Type PointerType;
     45     typedef PointerType PointerArgType;
     46 
     47     static const int Options = Options_;
     48 
     49     static const Index NumIndices = PlainObjectType::NumIndices;
     50     typedef typename PlainObjectType::Dimensions Dimensions;
     51 
     52     enum {
     53       IsAligned = ((int(Options_)&Aligned)==Aligned),
     54       Layout = PlainObjectType::Layout,
     55       CoordAccess = true,
     56       RawAccess = true
     57     };
     58 
     59     EIGEN_DEVICE_FUNC
     60     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr) : m_data(dataPtr), m_dimensions() {
     61       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
     62       EIGEN_STATIC_ASSERT((0 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
     63     }
     64 
     65 #if EIGEN_HAS_VARIADIC_TEMPLATES
     66     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
     67     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(firstDimension, otherDimensions...) {
     68       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
     69       EIGEN_STATIC_ASSERT((sizeof...(otherDimensions) + 1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
     70     }
     71 #else
     72     EIGEN_DEVICE_FUNC
     73     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(firstDimension) {
     74       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
     75       EIGEN_STATIC_ASSERT((1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
     76     }
     77     EIGEN_DEVICE_FUNC
     78     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2) : m_data(dataPtr), m_dimensions(dim1, dim2) {
     79       EIGEN_STATIC_ASSERT(2 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
     80     }
     81     EIGEN_DEVICE_FUNC
     82     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3) {
     83       EIGEN_STATIC_ASSERT(3 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
     84     }
     85     EIGEN_DEVICE_FUNC
     86     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4) {
     87       EIGEN_STATIC_ASSERT(4 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
     88     }
     89     EIGEN_DEVICE_FUNC
     90     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4, dim5) {
     91       EIGEN_STATIC_ASSERT(5 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
     92     }
     93 #endif
     94 
     95    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
     96       : m_data(dataPtr), m_dimensions(dimensions)
     97     { }
     98 
     99     template <typename Dimensions>
    100     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
    101       : m_data(dataPtr), m_dimensions(dimensions)
    102     { }
    103 
    104     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor)
    105       : m_data(tensor.data()), m_dimensions(tensor.dimensions())
    106     { }
    107 
    108     EIGEN_DEVICE_FUNC
    109     EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); }
    110     EIGEN_DEVICE_FUNC
    111     EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
    112     EIGEN_DEVICE_FUNC
    113     EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
    114     EIGEN_DEVICE_FUNC
    115     EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
    116     EIGEN_DEVICE_FUNC
    117     EIGEN_STRONG_INLINE PointerType data() { return m_data; }
    118     EIGEN_DEVICE_FUNC
    119     EIGEN_STRONG_INLINE const PointerType data() const { return m_data; }
    120 
    121     EIGEN_DEVICE_FUNC
    122     EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
    123     {
    124       //      eigen_assert(checkIndexRange(indices));
    125       if (PlainObjectType::Options&RowMajor) {
    126         const Index index = m_dimensions.IndexOfRowMajor(indices);
    127         return m_data[index];
    128       } else {
    129         const Index index = m_dimensions.IndexOfColMajor(indices);
    130         return m_data[index];
    131       }
    132     }
    133 
    134     EIGEN_DEVICE_FUNC
    135     EIGEN_STRONG_INLINE const Scalar& operator()() const
    136     {
    137       EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE)
    138       return m_data[0];
    139     }
    140 
    141     EIGEN_DEVICE_FUNC
    142     EIGEN_STRONG_INLINE const Scalar& operator()(Index index) const
    143     {
    144       eigen_internal_assert(index >= 0 && index < size());
    145       return m_data[index];
    146     }
    147 
    148 #if EIGEN_HAS_VARIADIC_TEMPLATES
    149     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
    150     EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices) const
    151     {
    152       EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
    153       if (PlainObjectType::Options&RowMajor) {
    154         const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
    155         return m_data[index];
    156       } else {
    157         const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
    158         return m_data[index];
    159       }
    160     }
    161 #else
    162     EIGEN_DEVICE_FUNC
    163     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1) const
    164     {
    165       if (PlainObjectType::Options&RowMajor) {
    166         const Index index = i1 + i0 * m_dimensions[1];
    167         return m_data[index];
    168       } else {
    169         const Index index = i0 + i1 * m_dimensions[0];
    170         return m_data[index];
    171       }
    172     }
    173     EIGEN_DEVICE_FUNC
    174     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2) const
    175     {
    176       if (PlainObjectType::Options&RowMajor) {
    177          const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
    178          return m_data[index];
    179       } else {
    180          const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
    181         return m_data[index];
    182       }
    183     }
    184     EIGEN_DEVICE_FUNC
    185     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3) const
    186     {
    187       if (PlainObjectType::Options&RowMajor) {
    188         const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
    189         return m_data[index];
    190       } else {
    191         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
    192         return m_data[index];
    193       }
    194     }
    195     EIGEN_DEVICE_FUNC
    196     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
    197     {
    198       if (PlainObjectType::Options&RowMajor) {
    199         const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
    200         return m_data[index];
    201       } else {
    202         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
    203         return m_data[index];
    204       }
    205     }
    206 #endif
    207 
    208     EIGEN_DEVICE_FUNC
    209     EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
    210     {
    211       //      eigen_assert(checkIndexRange(indices));
    212       if (PlainObjectType::Options&RowMajor) {
    213         const Index index = m_dimensions.IndexOfRowMajor(indices);
    214         return m_data[index];
    215       } else {
    216         const Index index = m_dimensions.IndexOfColMajor(indices);
    217         return m_data[index];
    218       }
    219     }
    220 
    221     EIGEN_DEVICE_FUNC
    222     EIGEN_STRONG_INLINE Scalar& operator()()
    223     {
    224       EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE)
    225       return m_data[0];
    226     }
    227 
    228     EIGEN_DEVICE_FUNC
    229     EIGEN_STRONG_INLINE Scalar& operator()(Index index)
    230     {
    231       eigen_internal_assert(index >= 0 && index < size());
    232       return m_data[index];
    233     }
    234 
    235 #if EIGEN_HAS_VARIADIC_TEMPLATES
    236     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
    237     EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices)
    238     {
    239       static_assert(sizeof...(otherIndices) + 2 == NumIndices || NumIndices == Dynamic, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
    240       const std::size_t NumDims = sizeof...(otherIndices) + 2;
    241       if (PlainObjectType::Options&RowMajor) {
    242         const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}});
    243         return m_data[index];
    244       } else {
    245         const Index index = m_dimensions.IndexOfColMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}});
    246         return m_data[index];
    247       }
    248     }
    249 #else
    250     EIGEN_DEVICE_FUNC
    251     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1)
    252     {
    253        if (PlainObjectType::Options&RowMajor) {
    254          const Index index = i1 + i0 * m_dimensions[1];
    255         return m_data[index];
    256       } else {
    257         const Index index = i0 + i1 * m_dimensions[0];
    258         return m_data[index];
    259       }
    260     }
    261     EIGEN_DEVICE_FUNC
    262     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2)
    263     {
    264        if (PlainObjectType::Options&RowMajor) {
    265          const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
    266         return m_data[index];
    267       } else {
    268          const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
    269         return m_data[index];
    270       }
    271     }
    272     EIGEN_DEVICE_FUNC
    273     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
    274     {
    275       if (PlainObjectType::Options&RowMajor) {
    276         const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
    277         return m_data[index];
    278       } else {
    279         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
    280         return m_data[index];
    281       }
    282     }
    283     EIGEN_DEVICE_FUNC
    284     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4)
    285     {
    286       if (PlainObjectType::Options&RowMajor) {
    287         const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
    288         return m_data[index];
    289       } else {
    290         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
    291         return m_data[index];
    292       }
    293     }
    294 #endif
    295 
    296     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Self& operator=(const Self& other)
    297     {
    298       typedef TensorAssignOp<Self, const Self> Assign;
    299       Assign assign(*this, other);
    300       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
    301       return *this;
    302     }
    303 
    304     template<typename OtherDerived>
    305     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    306     Self& operator=(const OtherDerived& other)
    307     {
    308       typedef TensorAssignOp<Self, const OtherDerived> Assign;
    309       Assign assign(*this, other);
    310       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
    311       return *this;
    312     }
    313 
    314   private:
    315     typename MakePointer_<Scalar>::Type m_data;
    316     Dimensions m_dimensions;
    317 };
    318 
    319 } // end namespace Eigen
    320 
    321 #endif // EIGEN_CXX11_TENSOR_TENSOR_MAP_H
    322