Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_LAYOUT_SWAP_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_LAYOUT_SWAP_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorLayoutSwap
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Swap the layout from col-major to row-major, or row-major
     19   * to col-major, and invert the order of the dimensions.
     20   *
     21   * Beware: the dimensions are reversed by this operation. If you want to
     22   * preserve the ordering of the dimensions, you need to combine this
     23   * operation with a shuffle.
     24   *
     25   * \example:
     26   * Tensor<float, 2, ColMajor> input(2, 4);
     27   * Tensor<float, 2, RowMajor> output = input.swap_layout();
     28   * eigen_assert(output.dimension(0) == 4);
     29   * eigen_assert(output.dimension(1) == 2);
     30   *
     31   * array<int, 2> shuffle(1, 0);
     32   * output = input.swap_layout().shuffle(shuffle);
     33   * eigen_assert(output.dimension(0) == 2);
     34   * eigen_assert(output.dimension(1) == 4);
     35   *
     36   */
     37 namespace internal {
     38 template<typename XprType>
     39 struct traits<TensorLayoutSwapOp<XprType> > : public traits<XprType>
     40 {
     41   typedef typename XprType::Scalar Scalar;
     42   typedef traits<XprType> XprTraits;
     43   typedef typename XprTraits::StorageKind StorageKind;
     44   typedef typename XprTraits::Index Index;
     45   typedef typename XprType::Nested Nested;
     46   typedef typename remove_reference<Nested>::type _Nested;
     47   static const int NumDimensions = traits<XprType>::NumDimensions;
     48   static const int Layout = (traits<XprType>::Layout == ColMajor) ? RowMajor : ColMajor;
     49 };
     50 
     51 template<typename XprType>
     52 struct eval<TensorLayoutSwapOp<XprType>, Eigen::Dense>
     53 {
     54   typedef const TensorLayoutSwapOp<XprType>& type;
     55 };
     56 
     57 template<typename XprType>
     58 struct nested<TensorLayoutSwapOp<XprType>, 1, typename eval<TensorLayoutSwapOp<XprType> >::type>
     59 {
     60   typedef TensorLayoutSwapOp<XprType> type;
     61 };
     62 
     63 }  // end namespace internal
     64 
     65 
     66 
     67 template<typename XprType>
     68 class TensorLayoutSwapOp : public TensorBase<TensorLayoutSwapOp<XprType>, WriteAccessors>
     69 {
     70   public:
     71   typedef typename Eigen::internal::traits<TensorLayoutSwapOp>::Scalar Scalar;
     72   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     73   typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
     74   typedef typename Eigen::internal::nested<TensorLayoutSwapOp>::type Nested;
     75   typedef typename Eigen::internal::traits<TensorLayoutSwapOp>::StorageKind StorageKind;
     76   typedef typename Eigen::internal::traits<TensorLayoutSwapOp>::Index Index;
     77 
     78   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorLayoutSwapOp(const XprType& expr)
     79       : m_xpr(expr) {}
     80 
     81     EIGEN_DEVICE_FUNC
     82     const typename internal::remove_all<typename XprType::Nested>::type&
     83     expression() const { return m_xpr; }
     84 
     85     EIGEN_DEVICE_FUNC
     86     EIGEN_STRONG_INLINE TensorLayoutSwapOp& operator = (const TensorLayoutSwapOp& other)
     87     {
     88       typedef TensorAssignOp<TensorLayoutSwapOp, const TensorLayoutSwapOp> Assign;
     89       Assign assign(*this, other);
     90       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
     91       return *this;
     92     }
     93 
     94     template<typename OtherDerived>
     95     EIGEN_DEVICE_FUNC
     96     EIGEN_STRONG_INLINE TensorLayoutSwapOp& operator = (const OtherDerived& other)
     97     {
     98       typedef TensorAssignOp<TensorLayoutSwapOp, const OtherDerived> Assign;
     99       Assign assign(*this, other);
    100       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
    101       return *this;
    102     }
    103 
    104   protected:
    105     typename XprType::Nested m_xpr;
    106 };
    107 
    108 
    109 // Eval as rvalue
    110 template<typename ArgType, typename Device>
    111 struct TensorEvaluator<const TensorLayoutSwapOp<ArgType>, Device>
    112 {
    113   typedef TensorLayoutSwapOp<ArgType> XprType;
    114   typedef typename XprType::Index Index;
    115   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
    116   typedef DSizes<Index, NumDims> Dimensions;
    117 
    118   enum {
    119     IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
    120     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
    121     Layout = (static_cast<int>(TensorEvaluator<ArgType, Device>::Layout) == static_cast<int>(ColMajor)) ? RowMajor : ColMajor,
    122     CoordAccess = false,  // to be implemented
    123     RawAccess = TensorEvaluator<ArgType, Device>::RawAccess
    124   };
    125 
    126   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
    127       : m_impl(op.expression(), device)
    128   {
    129     for(int i = 0; i < NumDims; ++i) {
    130       m_dimensions[i] = m_impl.dimensions()[NumDims-1-i];
    131     }
    132   }
    133 
    134   typedef typename XprType::Scalar Scalar;
    135   typedef typename XprType::CoeffReturnType CoeffReturnType;
    136   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    137 
    138   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
    139 
    140   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
    141     return m_impl.evalSubExprsIfNeeded(data);
    142   }
    143   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
    144     m_impl.cleanup();
    145   }
    146 
    147   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
    148   {
    149     return m_impl.coeff(index);
    150   }
    151 
    152   template<int LoadMode>
    153   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
    154   {
    155     return m_impl.template packet<LoadMode>(index);
    156   }
    157 
    158   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
    159     return m_impl.costPerCoeff(vectorized);
    160   }
    161 
    162   EIGEN_DEVICE_FUNC Scalar* data() const { return m_impl.data(); }
    163 
    164   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
    165 
    166  protected:
    167   TensorEvaluator<ArgType, Device> m_impl;
    168   Dimensions m_dimensions;
    169 };
    170 
    171 
    172 // Eval as lvalue
    173 template<typename ArgType, typename Device>
    174   struct TensorEvaluator<TensorLayoutSwapOp<ArgType>, Device>
    175   : public TensorEvaluator<const TensorLayoutSwapOp<ArgType>, Device>
    176 {
    177   typedef TensorEvaluator<const TensorLayoutSwapOp<ArgType>, Device> Base;
    178   typedef TensorLayoutSwapOp<ArgType> XprType;
    179 
    180   enum {
    181     IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
    182     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
    183     Layout = (static_cast<int>(TensorEvaluator<ArgType, Device>::Layout) == static_cast<int>(ColMajor)) ? RowMajor : ColMajor,
    184     CoordAccess = false  // to be implemented
    185   };
    186 
    187   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
    188     : Base(op, device)
    189   { }
    190 
    191   typedef typename XprType::Index Index;
    192   typedef typename XprType::Scalar Scalar;
    193   typedef typename XprType::CoeffReturnType CoeffReturnType;
    194   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    195 
    196   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
    197   {
    198     return this->m_impl.coeffRef(index);
    199   }
    200   template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    201   void writePacket(Index index, const PacketReturnType& x)
    202   {
    203     this->m_impl.template writePacket<StoreMode>(index, x);
    204   }
    205 };
    206 
    207 } // end namespace Eigen
    208 
    209 #endif // EIGEN_CXX11_TENSOR_TENSOR_LAYOUT_SWAP_H
    210