Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_GENERATOR_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_GENERATOR_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorGenerator
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Tensor generator class.
     19   *
     20   *
     21   */
     22 namespace internal {
     23 template<typename Generator, typename XprType>
     24 struct traits<TensorGeneratorOp<Generator, XprType> > : public traits<XprType>
     25 {
     26   typedef typename XprType::Scalar Scalar;
     27   typedef traits<XprType> XprTraits;
     28   typedef typename XprTraits::StorageKind StorageKind;
     29   typedef typename XprTraits::Index Index;
     30   typedef typename XprType::Nested Nested;
     31   typedef typename remove_reference<Nested>::type _Nested;
     32   static const int NumDimensions = XprTraits::NumDimensions;
     33   static const int Layout = XprTraits::Layout;
     34 };
     35 
     36 template<typename Generator, typename XprType>
     37 struct eval<TensorGeneratorOp<Generator, XprType>, Eigen::Dense>
     38 {
     39   typedef const TensorGeneratorOp<Generator, XprType>& type;
     40 };
     41 
     42 template<typename Generator, typename XprType>
     43 struct nested<TensorGeneratorOp<Generator, XprType>, 1, typename eval<TensorGeneratorOp<Generator, XprType> >::type>
     44 {
     45   typedef TensorGeneratorOp<Generator, XprType> type;
     46 };
     47 
     48 }  // end namespace internal
     49 
     50 
     51 
     52 template<typename Generator, typename XprType>
     53 class TensorGeneratorOp : public TensorBase<TensorGeneratorOp<Generator, XprType>, ReadOnlyAccessors>
     54 {
     55   public:
     56   typedef typename Eigen::internal::traits<TensorGeneratorOp>::Scalar Scalar;
     57   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     58   typedef typename XprType::CoeffReturnType CoeffReturnType;
     59   typedef typename Eigen::internal::nested<TensorGeneratorOp>::type Nested;
     60   typedef typename Eigen::internal::traits<TensorGeneratorOp>::StorageKind StorageKind;
     61   typedef typename Eigen::internal::traits<TensorGeneratorOp>::Index Index;
     62 
     63   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorGeneratorOp(const XprType& expr, const Generator& generator)
     64       : m_xpr(expr), m_generator(generator) {}
     65 
     66     EIGEN_DEVICE_FUNC
     67     const Generator& generator() const { return m_generator; }
     68 
     69     EIGEN_DEVICE_FUNC
     70     const typename internal::remove_all<typename XprType::Nested>::type&
     71     expression() const { return m_xpr; }
     72 
     73   protected:
     74     typename XprType::Nested m_xpr;
     75     const Generator m_generator;
     76 };
     77 
     78 
     79 // Eval as rvalue
     80 template<typename Generator, typename ArgType, typename Device>
     81 struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
     82 {
     83   typedef TensorGeneratorOp<Generator, ArgType> XprType;
     84   typedef typename XprType::Index Index;
     85   typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
     86   static const int NumDims = internal::array_size<Dimensions>::value;
     87   typedef typename XprType::Scalar Scalar;
     88   typedef typename XprType::CoeffReturnType CoeffReturnType;
     89   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
     90   enum {
     91     IsAligned = false,
     92     PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
     93     BlockAccess = false,
     94     Layout = TensorEvaluator<ArgType, Device>::Layout,
     95     CoordAccess = false,  // to be implemented
     96     RawAccess = false
     97   };
     98 
     99   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
    100       : m_generator(op.generator())
    101   {
    102     TensorEvaluator<ArgType, Device> impl(op.expression(), device);
    103     m_dimensions = impl.dimensions();
    104 
    105     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    106       m_strides[0] = 1;
    107       for (int i = 1; i < NumDims; ++i) {
    108         m_strides[i] = m_strides[i - 1] * m_dimensions[i - 1];
    109       }
    110     } else {
    111       m_strides[NumDims - 1] = 1;
    112       for (int i = NumDims - 2; i >= 0; --i) {
    113         m_strides[i] = m_strides[i + 1] * m_dimensions[i + 1];
    114       }
    115     }
    116   }
    117 
    118   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
    119 
    120   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
    121     return true;
    122   }
    123   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
    124   }
    125 
    126   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
    127   {
    128     array<Index, NumDims> coords;
    129     extract_coordinates(index, coords);
    130     return m_generator(coords);
    131   }
    132 
    133   template<int LoadMode>
    134   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
    135   {
    136     const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
    137     EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
    138     eigen_assert(index+packetSize-1 < dimensions().TotalSize());
    139 
    140     EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[packetSize];
    141     for (int i = 0; i < packetSize; ++i) {
    142       values[i] = coeff(index+i);
    143     }
    144     PacketReturnType rslt = internal::pload<PacketReturnType>(values);
    145     return rslt;
    146   }
    147 
    148   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
    149   costPerCoeff(bool) const {
    150     // TODO(rmlarsen): This is just a placeholder. Define interface to make
    151     // generators return their cost.
    152     return TensorOpCost(0, 0, TensorOpCost::AddCost<Scalar>() +
    153                                   TensorOpCost::MulCost<Scalar>());
    154   }
    155 
    156   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
    157 
    158  protected:
    159   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    160   void extract_coordinates(Index index, array<Index, NumDims>& coords) const {
    161     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    162       for (int i = NumDims - 1; i > 0; --i) {
    163         const Index idx = index / m_strides[i];
    164         index -= idx * m_strides[i];
    165         coords[i] = idx;
    166       }
    167       coords[0] = index;
    168     } else {
    169       for (int i = 0; i < NumDims - 1; ++i) {
    170         const Index idx = index / m_strides[i];
    171         index -= idx * m_strides[i];
    172         coords[i] = idx;
    173       }
    174       coords[NumDims-1] = index;
    175     }
    176   }
    177 
    178   Dimensions m_dimensions;
    179   array<Index, NumDims> m_strides;
    180   Generator m_generator;
    181 };
    182 
    183 } // end namespace Eigen
    184 
    185 #endif // EIGEN_CXX11_TENSOR_TENSOR_GENERATOR_H
    186