Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorCustomUnaryOp
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Tensor custom class.
     19   *
     20   *
     21   */
     22 namespace internal {
     23 template<typename CustomUnaryFunc, typename XprType>
     24 struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
     25 {
     26   typedef typename XprType::Scalar Scalar;
     27   typedef typename XprType::StorageKind StorageKind;
     28   typedef typename XprType::Index Index;
     29   typedef typename XprType::Nested Nested;
     30   typedef typename remove_reference<Nested>::type _Nested;
     31   static const int NumDimensions = traits<XprType>::NumDimensions;
     32   static const int Layout = traits<XprType>::Layout;
     33 };
     34 
     35 template<typename CustomUnaryFunc, typename XprType>
     36 struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
     37 {
     38   typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type;
     39 };
     40 
     41 template<typename CustomUnaryFunc, typename XprType>
     42 struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
     43 {
     44   typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
     45 };
     46 
     47 }  // end namespace internal
     48 
     49 
     50 
     51 template<typename CustomUnaryFunc, typename XprType>
     52 class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
     53 {
     54   public:
     55   typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
     56   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     57   typedef typename XprType::CoeffReturnType CoeffReturnType;
     58   typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
     59   typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
     60   typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
     61 
     62   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
     63       : m_expr(expr), m_func(func) {}
     64 
     65   EIGEN_DEVICE_FUNC
     66   const CustomUnaryFunc& func() const { return m_func; }
     67 
     68   EIGEN_DEVICE_FUNC
     69   const typename internal::remove_all<typename XprType::Nested>::type&
     70   expression() const { return m_expr; }
     71 
     72   protected:
     73     typename XprType::Nested m_expr;
     74     const CustomUnaryFunc m_func;
     75 };
     76 
     77 
     78 // Eval as rvalue
     79 template<typename CustomUnaryFunc, typename XprType, typename Device>
     80 struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
     81 {
     82   typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType;
     83   typedef typename internal::traits<ArgType>::Index Index;
     84   static const int NumDims = internal::traits<ArgType>::NumDimensions;
     85   typedef DSizes<Index, NumDims> Dimensions;
     86   typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
     87   typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
     88   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
     89   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
     90 
     91   enum {
     92     IsAligned = false,
     93     PacketAccess = (internal::packet_traits<Scalar>::size > 1),
     94     BlockAccess = false,
     95     Layout = TensorEvaluator<XprType, Device>::Layout,
     96     CoordAccess = false,  // to be implemented
     97     RawAccess = false
     98   };
     99 
    100   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
    101       : m_op(op), m_device(device), m_result(NULL)
    102   {
    103     m_dimensions = op.func().dimensions(op.expression());
    104   }
    105 
    106   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
    107 
    108   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
    109     if (data) {
    110       evalTo(data);
    111       return false;
    112     } else {
    113       m_result = static_cast<CoeffReturnType*>(
    114           m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
    115       evalTo(m_result);
    116       return true;
    117     }
    118   }
    119 
    120   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
    121     if (m_result != NULL) {
    122       m_device.deallocate(m_result);
    123       m_result = NULL;
    124     }
    125   }
    126 
    127   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
    128     return m_result[index];
    129   }
    130 
    131   template<int LoadMode>
    132   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
    133     return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
    134   }
    135 
    136   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
    137     // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
    138     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
    139   }
    140 
    141   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
    142 
    143  protected:
    144   EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
    145     TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(
    146         data, m_dimensions);
    147     m_op.func().eval(m_op.expression(), result, m_device);
    148   }
    149 
    150   Dimensions m_dimensions;
    151   const ArgType m_op;
    152   const Device& m_device;
    153   CoeffReturnType* m_result;
    154 };
    155 
    156 
    157 
    158 /** \class TensorCustomBinaryOp
    159   * \ingroup CXX11_Tensor_Module
    160   *
    161   * \brief Tensor custom class.
    162   *
    163   *
    164   */
    165 namespace internal {
    166 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
    167 struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
    168 {
    169   typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
    170                                                   typename RhsXprType::Scalar>::ret Scalar;
    171   typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
    172                                                   typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
    173   typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
    174                                         typename traits<RhsXprType>::StorageKind>::ret StorageKind;
    175   typedef typename promote_index_type<typename traits<LhsXprType>::Index,
    176                                       typename traits<RhsXprType>::Index>::type Index;
    177   typedef typename LhsXprType::Nested LhsNested;
    178   typedef typename RhsXprType::Nested RhsNested;
    179   typedef typename remove_reference<LhsNested>::type _LhsNested;
    180   typedef typename remove_reference<RhsNested>::type _RhsNested;
    181   static const int NumDimensions = traits<LhsXprType>::NumDimensions;
    182   static const int Layout = traits<LhsXprType>::Layout;
    183 };
    184 
    185 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
    186 struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
    187 {
    188   typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
    189 };
    190 
    191 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
    192 struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
    193 {
    194   typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
    195 };
    196 
    197 }  // end namespace internal
    198 
    199 
    200 
    201 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
    202 class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
    203 {
    204   public:
    205   typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
    206   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    207   typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
    208   typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
    209   typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
    210   typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
    211 
    212   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
    213 
    214       : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
    215 
    216   EIGEN_DEVICE_FUNC
    217   const CustomBinaryFunc& func() const { return m_func; }
    218 
    219   EIGEN_DEVICE_FUNC
    220   const typename internal::remove_all<typename LhsXprType::Nested>::type&
    221   lhsExpression() const { return m_lhs_xpr; }
    222 
    223   EIGEN_DEVICE_FUNC
    224   const typename internal::remove_all<typename RhsXprType::Nested>::type&
    225   rhsExpression() const { return m_rhs_xpr; }
    226 
    227   protected:
    228     typename LhsXprType::Nested m_lhs_xpr;
    229     typename RhsXprType::Nested m_rhs_xpr;
    230     const CustomBinaryFunc m_func;
    231 };
    232 
    233 
    234 // Eval as rvalue
    235 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
    236 struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
    237 {
    238   typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType;
    239   typedef typename internal::traits<XprType>::Index Index;
    240   static const int NumDims = internal::traits<XprType>::NumDimensions;
    241   typedef DSizes<Index, NumDims> Dimensions;
    242   typedef typename XprType::Scalar Scalar;
    243   typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
    244   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    245   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
    246 
    247   enum {
    248     IsAligned = false,
    249     PacketAccess = (internal::packet_traits<Scalar>::size > 1),
    250     BlockAccess = false,
    251     Layout = TensorEvaluator<LhsXprType, Device>::Layout,
    252     CoordAccess = false,  // to be implemented
    253     RawAccess = false
    254   };
    255 
    256   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
    257       : m_op(op), m_device(device), m_result(NULL)
    258   {
    259     m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
    260   }
    261 
    262   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
    263 
    264   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
    265     if (data) {
    266       evalTo(data);
    267       return false;
    268     } else {
    269       m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
    270       evalTo(m_result);
    271       return true;
    272     }
    273   }
    274 
    275   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
    276     if (m_result != NULL) {
    277       m_device.deallocate(m_result);
    278       m_result = NULL;
    279     }
    280   }
    281 
    282   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
    283     return m_result[index];
    284   }
    285 
    286   template<int LoadMode>
    287   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
    288     return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
    289   }
    290 
    291   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
    292     // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
    293     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
    294   }
    295 
    296   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
    297 
    298  protected:
    299   EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
    300     TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions);
    301     m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
    302   }
    303 
    304   Dimensions m_dimensions;
    305   const XprType m_op;
    306   const Device& m_device;
    307   CoeffReturnType* m_result;
    308 };
    309 
    310 
    311 } // end namespace Eigen
    312 
    313 #endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
    314