Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorExpr
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Tensor expression classes.
     19   *
     20   * The TensorCwiseNullaryOp class applies a nullary operators to an expression.
     21   * This is typically used to generate constants.
     22   *
     23   * The TensorCwiseUnaryOp class represents an expression where a unary operator
     24   * (e.g. cwiseSqrt) is applied to an expression.
     25   *
     26   * The TensorCwiseBinaryOp class represents an expression where a binary
     27   * operator (e.g. addition) is applied to a lhs and a rhs expression.
     28   *
     29   */
     30 namespace internal {
     31 template<typename NullaryOp, typename XprType>
     32 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
     33     : traits<XprType>
     34 {
     35   typedef traits<XprType> XprTraits;
     36   typedef typename XprType::Scalar Scalar;
     37   typedef typename XprType::Nested XprTypeNested;
     38   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
     39   static const int NumDimensions = XprTraits::NumDimensions;
     40   static const int Layout = XprTraits::Layout;
     41 
     42   enum {
     43     Flags = 0
     44   };
     45 };
     46 
     47 }  // end namespace internal
     48 
     49 
     50 
     51 template<typename NullaryOp, typename XprType>
     52 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
     53 {
     54   public:
     55     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
     56     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     57     typedef typename XprType::CoeffReturnType CoeffReturnType;
     58     typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
     59     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
     60     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
     61 
     62     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
     63         : m_xpr(xpr), m_functor(func) {}
     64 
     65     EIGEN_DEVICE_FUNC
     66     const typename internal::remove_all<typename XprType::Nested>::type&
     67     nestedExpression() const { return m_xpr; }
     68 
     69     EIGEN_DEVICE_FUNC
     70     const NullaryOp& functor() const { return m_functor; }
     71 
     72   protected:
     73     typename XprType::Nested m_xpr;
     74     const NullaryOp m_functor;
     75 };
     76 
     77 
     78 
     79 namespace internal {
     80 template<typename UnaryOp, typename XprType>
     81 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
     82     : traits<XprType>
     83 {
     84   // TODO(phli): Add InputScalar, InputPacket.  Check references to
     85   // current Scalar/Packet to see if the intent is Input or Output.
     86   typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
     87   typedef traits<XprType> XprTraits;
     88   typedef typename XprType::Nested XprTypeNested;
     89   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
     90   static const int NumDimensions = XprTraits::NumDimensions;
     91   static const int Layout = XprTraits::Layout;
     92 };
     93 
     94 template<typename UnaryOp, typename XprType>
     95 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
     96 {
     97   typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
     98 };
     99 
    100 template<typename UnaryOp, typename XprType>
    101 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
    102 {
    103   typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
    104 };
    105 
    106 }  // end namespace internal
    107 
    108 
    109 
    110 template<typename UnaryOp, typename XprType>
    111 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
    112 {
    113   public:
    114     // TODO(phli): Add InputScalar, InputPacket.  Check references to
    115     // current Scalar/Packet to see if the intent is Input or Output.
    116     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
    117     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    118     typedef Scalar CoeffReturnType;
    119     typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
    120     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
    121     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
    122 
    123     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
    124       : m_xpr(xpr), m_functor(func) {}
    125 
    126     EIGEN_DEVICE_FUNC
    127     const UnaryOp& functor() const { return m_functor; }
    128 
    129     /** \returns the nested expression */
    130     EIGEN_DEVICE_FUNC
    131     const typename internal::remove_all<typename XprType::Nested>::type&
    132     nestedExpression() const { return m_xpr; }
    133 
    134   protected:
    135     typename XprType::Nested m_xpr;
    136     const UnaryOp m_functor;
    137 };
    138 
    139 
    140 namespace internal {
    141 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
    142 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
    143 {
    144   // Type promotion to handle the case where the types of the lhs and the rhs
    145   // are different.
    146   // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
    147   // current Scalar/Packet to see if the intent is Inputs or Output.
    148   typedef typename result_of<
    149       BinaryOp(typename LhsXprType::Scalar,
    150                typename RhsXprType::Scalar)>::type Scalar;
    151   typedef traits<LhsXprType> XprTraits;
    152   typedef typename promote_storage_type<
    153       typename traits<LhsXprType>::StorageKind,
    154       typename traits<RhsXprType>::StorageKind>::ret StorageKind;
    155   typedef typename promote_index_type<
    156       typename traits<LhsXprType>::Index,
    157       typename traits<RhsXprType>::Index>::type Index;
    158   typedef typename LhsXprType::Nested LhsNested;
    159   typedef typename RhsXprType::Nested RhsNested;
    160   typedef typename remove_reference<LhsNested>::type _LhsNested;
    161   typedef typename remove_reference<RhsNested>::type _RhsNested;
    162   static const int NumDimensions = XprTraits::NumDimensions;
    163   static const int Layout = XprTraits::Layout;
    164 
    165   enum {
    166     Flags = 0
    167   };
    168 };
    169 
    170 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
    171 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
    172 {
    173   typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
    174 };
    175 
    176 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
    177 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
    178 {
    179   typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
    180 };
    181 
    182 }  // end namespace internal
    183 
    184 
    185 
    186 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
    187 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
    188 {
    189   public:
    190     // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
    191     // current Scalar/Packet to see if the intent is Inputs or Output.
    192     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
    193     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    194     typedef Scalar CoeffReturnType;
    195     typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
    196     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
    197     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
    198 
    199     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
    200         : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
    201 
    202     EIGEN_DEVICE_FUNC
    203     const BinaryOp& functor() const { return m_functor; }
    204 
    205     /** \returns the nested expressions */
    206     EIGEN_DEVICE_FUNC
    207     const typename internal::remove_all<typename LhsXprType::Nested>::type&
    208     lhsExpression() const { return m_lhs_xpr; }
    209 
    210     EIGEN_DEVICE_FUNC
    211     const typename internal::remove_all<typename RhsXprType::Nested>::type&
    212     rhsExpression() const { return m_rhs_xpr; }
    213 
    214   protected:
    215     typename LhsXprType::Nested m_lhs_xpr;
    216     typename RhsXprType::Nested m_rhs_xpr;
    217     const BinaryOp m_functor;
    218 };
    219 
    220 
    221 namespace internal {
    222 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
    223 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
    224 {
    225   // Type promotion to handle the case where the types of the args are different.
    226   typedef typename result_of<
    227       TernaryOp(typename Arg1XprType::Scalar,
    228                 typename Arg2XprType::Scalar,
    229                 typename Arg3XprType::Scalar)>::type Scalar;
    230   typedef traits<Arg1XprType> XprTraits;
    231   typedef typename traits<Arg1XprType>::StorageKind StorageKind;
    232   typedef typename traits<Arg1XprType>::Index Index;
    233   typedef typename Arg1XprType::Nested Arg1Nested;
    234   typedef typename Arg2XprType::Nested Arg2Nested;
    235   typedef typename Arg3XprType::Nested Arg3Nested;
    236   typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
    237   typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
    238   typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
    239   static const int NumDimensions = XprTraits::NumDimensions;
    240   static const int Layout = XprTraits::Layout;
    241 
    242   enum {
    243     Flags = 0
    244   };
    245 };
    246 
    247 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
    248 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
    249 {
    250   typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
    251 };
    252 
    253 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
    254 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
    255 {
    256   typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
    257 };
    258 
    259 }  // end namespace internal
    260 
    261 
    262 
    263 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
    264 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
    265 {
    266   public:
    267     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
    268     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    269     typedef Scalar CoeffReturnType;
    270     typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
    271     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
    272     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
    273 
    274     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
    275         : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
    276 
    277     EIGEN_DEVICE_FUNC
    278     const TernaryOp& functor() const { return m_functor; }
    279 
    280     /** \returns the nested expressions */
    281     EIGEN_DEVICE_FUNC
    282     const typename internal::remove_all<typename Arg1XprType::Nested>::type&
    283     arg1Expression() const { return m_arg1_xpr; }
    284 
    285     EIGEN_DEVICE_FUNC
    286     const typename internal::remove_all<typename Arg2XprType::Nested>::type&
    287     arg2Expression() const { return m_arg2_xpr; }
    288 
    289     EIGEN_DEVICE_FUNC
    290     const typename internal::remove_all<typename Arg3XprType::Nested>::type&
    291     arg3Expression() const { return m_arg3_xpr; }
    292 
    293   protected:
    294     typename Arg1XprType::Nested m_arg1_xpr;
    295     typename Arg2XprType::Nested m_arg2_xpr;
    296     typename Arg3XprType::Nested m_arg3_xpr;
    297     const TernaryOp m_functor;
    298 };
    299 
    300 
    301 namespace internal {
    302 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
    303 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
    304     : traits<ThenXprType>
    305 {
    306   typedef typename traits<ThenXprType>::Scalar Scalar;
    307   typedef traits<ThenXprType> XprTraits;
    308   typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
    309                                         typename traits<ElseXprType>::StorageKind>::ret StorageKind;
    310   typedef typename promote_index_type<typename traits<ElseXprType>::Index,
    311                                       typename traits<ThenXprType>::Index>::type Index;
    312   typedef typename IfXprType::Nested IfNested;
    313   typedef typename ThenXprType::Nested ThenNested;
    314   typedef typename ElseXprType::Nested ElseNested;
    315   static const int NumDimensions = XprTraits::NumDimensions;
    316   static const int Layout = XprTraits::Layout;
    317 };
    318 
    319 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
    320 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
    321 {
    322   typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
    323 };
    324 
    325 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
    326 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
    327 {
    328   typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
    329 };
    330 
    331 }  // end namespace internal
    332 
    333 
    334 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
    335 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
    336 {
    337   public:
    338     typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
    339     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    340     typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
    341                                                     typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
    342     typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
    343     typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
    344     typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
    345 
    346     EIGEN_DEVICE_FUNC
    347     TensorSelectOp(const IfXprType& a_condition,
    348                    const ThenXprType& a_then,
    349                    const ElseXprType& a_else)
    350       : m_condition(a_condition), m_then(a_then), m_else(a_else)
    351     { }
    352 
    353     EIGEN_DEVICE_FUNC
    354     const IfXprType& ifExpression() const { return m_condition; }
    355 
    356     EIGEN_DEVICE_FUNC
    357     const ThenXprType& thenExpression() const { return m_then; }
    358 
    359     EIGEN_DEVICE_FUNC
    360     const ElseXprType& elseExpression() const { return m_else; }
    361 
    362   protected:
    363     typename IfXprType::Nested m_condition;
    364     typename ThenXprType::Nested m_then;
    365     typename ElseXprType::Nested m_else;
    366 };
    367 
    368 
    369 } // end namespace Eigen
    370 
    371 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
    372