Home | History | Annotate | Download | only in Core
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 // Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1 (at) gmail.com>
      6 // Copyright (C) 2016 Eugene Brevdo <ebrevdo (at) gmail.com>
      7 //
      8 // This Source Code Form is subject to the terms of the Mozilla
      9 // Public License v. 2.0. If a copy of the MPL was not distributed
     10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     11 
     12 #ifndef EIGEN_CWISE_TERNARY_OP_H
     13 #define EIGEN_CWISE_TERNARY_OP_H
     14 
     15 namespace Eigen {
     16 
     17 namespace internal {
     18 template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
     19 struct traits<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > {
     20   // we must not inherit from traits<Arg1> since it has
     21   // the potential to cause problems with MSVC
     22   typedef typename remove_all<Arg1>::type Ancestor;
     23   typedef typename traits<Ancestor>::XprKind XprKind;
     24   enum {
     25     RowsAtCompileTime = traits<Ancestor>::RowsAtCompileTime,
     26     ColsAtCompileTime = traits<Ancestor>::ColsAtCompileTime,
     27     MaxRowsAtCompileTime = traits<Ancestor>::MaxRowsAtCompileTime,
     28     MaxColsAtCompileTime = traits<Ancestor>::MaxColsAtCompileTime
     29   };
     30 
     31   // even though we require Arg1, Arg2, and Arg3 to have the same scalar type
     32   // (see CwiseTernaryOp constructor),
     33   // we still want to handle the case when the result type is different.
     34   typedef typename result_of<TernaryOp(
     35       const typename Arg1::Scalar&, const typename Arg2::Scalar&,
     36       const typename Arg3::Scalar&)>::type Scalar;
     37 
     38   typedef typename internal::traits<Arg1>::StorageKind StorageKind;
     39   typedef typename internal::traits<Arg1>::StorageIndex StorageIndex;
     40 
     41   typedef typename Arg1::Nested Arg1Nested;
     42   typedef typename Arg2::Nested Arg2Nested;
     43   typedef typename Arg3::Nested Arg3Nested;
     44   typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
     45   typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
     46   typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
     47   enum { Flags = _Arg1Nested::Flags & RowMajorBit };
     48 };
     49 }  // end namespace internal
     50 
     51 template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
     52           typename StorageKind>
     53 class CwiseTernaryOpImpl;
     54 
     55 /** \class CwiseTernaryOp
     56   * \ingroup Core_Module
     57   *
     58   * \brief Generic expression where a coefficient-wise ternary operator is
     59  * applied to two expressions
     60   *
     61   * \tparam TernaryOp template functor implementing the operator
     62   * \tparam Arg1Type the type of the first argument
     63   * \tparam Arg2Type the type of the second argument
     64   * \tparam Arg3Type the type of the third argument
     65   *
     66   * This class represents an expression where a coefficient-wise ternary
     67  * operator is applied to three expressions.
     68   * It is the return type of ternary operators, by which we mean only those
     69  * ternary operators where
     70   * all three arguments are Eigen expressions.
     71   * For example, the return type of betainc(matrix1, matrix2, matrix3) is a
     72  * CwiseTernaryOp.
     73   *
     74   * Most of the time, this is the only way that it is used, so you typically
     75  * don't have to name
     76   * CwiseTernaryOp types explicitly.
     77   *
     78   * \sa MatrixBase::ternaryExpr(const MatrixBase<Argument2> &, const
     79  * MatrixBase<Argument3> &, const CustomTernaryOp &) const, class CwiseBinaryOp,
     80  * class CwiseUnaryOp, class CwiseNullaryOp
     81   */
     82 template <typename TernaryOp, typename Arg1Type, typename Arg2Type,
     83           typename Arg3Type>
     84 class CwiseTernaryOp : public CwiseTernaryOpImpl<
     85                            TernaryOp, Arg1Type, Arg2Type, Arg3Type,
     86                            typename internal::traits<Arg1Type>::StorageKind>,
     87                        internal::no_assignment_operator
     88 {
     89  public:
     90   typedef typename internal::remove_all<Arg1Type>::type Arg1;
     91   typedef typename internal::remove_all<Arg2Type>::type Arg2;
     92   typedef typename internal::remove_all<Arg3Type>::type Arg3;
     93 
     94   typedef typename CwiseTernaryOpImpl<
     95       TernaryOp, Arg1Type, Arg2Type, Arg3Type,
     96       typename internal::traits<Arg1Type>::StorageKind>::Base Base;
     97   EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseTernaryOp)
     98 
     99   typedef typename internal::ref_selector<Arg1Type>::type Arg1Nested;
    100   typedef typename internal::ref_selector<Arg2Type>::type Arg2Nested;
    101   typedef typename internal::ref_selector<Arg3Type>::type Arg3Nested;
    102   typedef typename internal::remove_reference<Arg1Nested>::type _Arg1Nested;
    103   typedef typename internal::remove_reference<Arg2Nested>::type _Arg2Nested;
    104   typedef typename internal::remove_reference<Arg3Nested>::type _Arg3Nested;
    105 
    106   EIGEN_DEVICE_FUNC
    107   EIGEN_STRONG_INLINE CwiseTernaryOp(const Arg1& a1, const Arg2& a2,
    108                                      const Arg3& a3,
    109                                      const TernaryOp& func = TernaryOp())
    110       : m_arg1(a1), m_arg2(a2), m_arg3(a3), m_functor(func) {
    111     // require the sizes to match
    112     EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg2)
    113     EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg3)
    114 
    115     // The index types should match
    116     EIGEN_STATIC_ASSERT((internal::is_same<
    117                          typename internal::traits<Arg1Type>::StorageKind,
    118                          typename internal::traits<Arg2Type>::StorageKind>::value),
    119                         STORAGE_KIND_MUST_MATCH)
    120     EIGEN_STATIC_ASSERT((internal::is_same<
    121                          typename internal::traits<Arg1Type>::StorageKind,
    122                          typename internal::traits<Arg3Type>::StorageKind>::value),
    123                         STORAGE_KIND_MUST_MATCH)
    124 
    125     eigen_assert(a1.rows() == a2.rows() && a1.cols() == a2.cols() &&
    126                  a1.rows() == a3.rows() && a1.cols() == a3.cols());
    127   }
    128 
    129   EIGEN_DEVICE_FUNC
    130   EIGEN_STRONG_INLINE Index rows() const {
    131     // return the fixed size type if available to enable compile time
    132     // optimizations
    133     if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
    134                 RowsAtCompileTime == Dynamic &&
    135         internal::traits<typename internal::remove_all<Arg2Nested>::type>::
    136                 RowsAtCompileTime == Dynamic)
    137       return m_arg3.rows();
    138     else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
    139                      RowsAtCompileTime == Dynamic &&
    140              internal::traits<typename internal::remove_all<Arg3Nested>::type>::
    141                      RowsAtCompileTime == Dynamic)
    142       return m_arg2.rows();
    143     else
    144       return m_arg1.rows();
    145   }
    146   EIGEN_DEVICE_FUNC
    147   EIGEN_STRONG_INLINE Index cols() const {
    148     // return the fixed size type if available to enable compile time
    149     // optimizations
    150     if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
    151                 ColsAtCompileTime == Dynamic &&
    152         internal::traits<typename internal::remove_all<Arg2Nested>::type>::
    153                 ColsAtCompileTime == Dynamic)
    154       return m_arg3.cols();
    155     else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
    156                      ColsAtCompileTime == Dynamic &&
    157              internal::traits<typename internal::remove_all<Arg3Nested>::type>::
    158                      ColsAtCompileTime == Dynamic)
    159       return m_arg2.cols();
    160     else
    161       return m_arg1.cols();
    162   }
    163 
    164   /** \returns the first argument nested expression */
    165   EIGEN_DEVICE_FUNC
    166   const _Arg1Nested& arg1() const { return m_arg1; }
    167   /** \returns the first argument nested expression */
    168   EIGEN_DEVICE_FUNC
    169   const _Arg2Nested& arg2() const { return m_arg2; }
    170   /** \returns the third argument nested expression */
    171   EIGEN_DEVICE_FUNC
    172   const _Arg3Nested& arg3() const { return m_arg3; }
    173   /** \returns the functor representing the ternary operation */
    174   EIGEN_DEVICE_FUNC
    175   const TernaryOp& functor() const { return m_functor; }
    176 
    177  protected:
    178   Arg1Nested m_arg1;
    179   Arg2Nested m_arg2;
    180   Arg3Nested m_arg3;
    181   const TernaryOp m_functor;
    182 };
    183 
    184 // Generic API dispatcher
    185 template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
    186           typename StorageKind>
    187 class CwiseTernaryOpImpl
    188     : public internal::generic_xpr_base<
    189           CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type {
    190  public:
    191   typedef typename internal::generic_xpr_base<
    192       CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type Base;
    193 };
    194 
    195 }  // end namespace Eigen
    196 
    197 #endif  // EIGEN_CWISE_TERNARY_OP_H
    198