Home | History | Annotate | Download | only in util
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_BLASUTIL_H
     11 #define EIGEN_BLASUTIL_H
     12 
     13 // This file contains many lightweight helper classes used to
     14 // implement and control fast level 2 and level 3 BLAS-like routines.
     15 
     16 namespace Eigen {
     17 
     18 namespace internal {
     19 
     20 // forward declarations
     21 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
     22 struct gebp_kernel;
     23 
     24 template<typename Scalar, typename Index, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
     25 struct gemm_pack_rhs;
     26 
     27 template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
     28 struct gemm_pack_lhs;
     29 
     30 template<
     31   typename Index,
     32   typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
     33   typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
     34   int ResStorageOrder>
     35 struct general_matrix_matrix_product;
     36 
     37 template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs, int Version=Specialized>
     38 struct general_matrix_vector_product;
     39 
     40 
     41 template<bool Conjugate> struct conj_if;
     42 
     43 template<> struct conj_if<true> {
     44   template<typename T>
     45   inline T operator()(const T& x) { return numext::conj(x); }
     46   template<typename T>
     47   inline T pconj(const T& x) { return internal::pconj(x); }
     48 };
     49 
     50 template<> struct conj_if<false> {
     51   template<typename T>
     52   inline const T& operator()(const T& x) { return x; }
     53   template<typename T>
     54   inline const T& pconj(const T& x) { return x; }
     55 };
     56 
     57 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
     58 {
     59   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
     60   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
     61 };
     62 
     63 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
     64 {
     65   typedef std::complex<RealScalar> Scalar;
     66   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
     67   { return c + pmul(x,y); }
     68 
     69   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
     70   { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
     71 };
     72 
     73 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
     74 {
     75   typedef std::complex<RealScalar> Scalar;
     76   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
     77   { return c + pmul(x,y); }
     78 
     79   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
     80   { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
     81 };
     82 
     83 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
     84 {
     85   typedef std::complex<RealScalar> Scalar;
     86   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
     87   { return c + pmul(x,y); }
     88 
     89   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
     90   { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
     91 };
     92 
     93 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
     94 {
     95   typedef std::complex<RealScalar> Scalar;
     96   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
     97   { return padd(c, pmul(x,y)); }
     98   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
     99   { return conj_if<Conj>()(x)*y; }
    100 };
    101 
    102 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
    103 {
    104   typedef std::complex<RealScalar> Scalar;
    105   EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
    106   { return padd(c, pmul(x,y)); }
    107   EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
    108   { return x*conj_if<Conj>()(y); }
    109 };
    110 
    111 template<typename From,typename To> struct get_factor {
    112   static EIGEN_STRONG_INLINE To run(const From& x) { return x; }
    113 };
    114 
    115 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
    116   static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); }
    117 };
    118 
    119 // Lightweight helper class to access matrix coefficients.
    120 // Yes, this is somehow redundant with Map<>, but this version is much much lighter,
    121 // and so I hope better compilation performance (time and code quality).
    122 template<typename Scalar, typename Index, int StorageOrder>
    123 class blas_data_mapper
    124 {
    125   public:
    126     blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
    127     EIGEN_STRONG_INLINE Scalar& operator()(Index i, Index j)
    128     { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
    129   protected:
    130     Scalar* EIGEN_RESTRICT m_data;
    131     Index m_stride;
    132 };
    133 
    134 // lightweight helper class to access matrix coefficients (const version)
    135 template<typename Scalar, typename Index, int StorageOrder>
    136 class const_blas_data_mapper
    137 {
    138   public:
    139     const_blas_data_mapper(const Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
    140     EIGEN_STRONG_INLINE const Scalar& operator()(Index i, Index j) const
    141     { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
    142   protected:
    143     const Scalar* EIGEN_RESTRICT m_data;
    144     Index m_stride;
    145 };
    146 
    147 
    148 /* Helper class to analyze the factors of a Product expression.
    149  * In particular it allows to pop out operator-, scalar multiples,
    150  * and conjugate */
    151 template<typename XprType> struct blas_traits
    152 {
    153   typedef typename traits<XprType>::Scalar Scalar;
    154   typedef const XprType& ExtractType;
    155   typedef XprType _ExtractType;
    156   enum {
    157     IsComplex = NumTraits<Scalar>::IsComplex,
    158     IsTransposed = false,
    159     NeedToConjugate = false,
    160     HasUsableDirectAccess = (    (int(XprType::Flags)&DirectAccessBit)
    161                               && (   bool(XprType::IsVectorAtCompileTime)
    162                                   || int(inner_stride_at_compile_time<XprType>::ret) == 1)
    163                              ) ?  1 : 0
    164   };
    165   typedef typename conditional<bool(HasUsableDirectAccess),
    166     ExtractType,
    167     typename _ExtractType::PlainObject
    168     >::type DirectLinearAccessType;
    169   static inline ExtractType extract(const XprType& x) { return x; }
    170   static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
    171 };
    172 
    173 // pop conjugate
    174 template<typename Scalar, typename NestedXpr>
    175 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
    176  : blas_traits<NestedXpr>
    177 {
    178   typedef blas_traits<NestedXpr> Base;
    179   typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
    180   typedef typename Base::ExtractType ExtractType;
    181 
    182   enum {
    183     IsComplex = NumTraits<Scalar>::IsComplex,
    184     NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
    185   };
    186   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
    187   static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
    188 };
    189 
    190 // pop scalar multiple
    191 template<typename Scalar, typename NestedXpr>
    192 struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> >
    193  : blas_traits<NestedXpr>
    194 {
    195   typedef blas_traits<NestedXpr> Base;
    196   typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType;
    197   typedef typename Base::ExtractType ExtractType;
    198   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
    199   static inline Scalar extractScalarFactor(const XprType& x)
    200   { return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); }
    201 };
    202 
    203 // pop opposite
    204 template<typename Scalar, typename NestedXpr>
    205 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
    206  : blas_traits<NestedXpr>
    207 {
    208   typedef blas_traits<NestedXpr> Base;
    209   typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
    210   typedef typename Base::ExtractType ExtractType;
    211   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
    212   static inline Scalar extractScalarFactor(const XprType& x)
    213   { return - Base::extractScalarFactor(x.nestedExpression()); }
    214 };
    215 
    216 // pop/push transpose
    217 template<typename NestedXpr>
    218 struct blas_traits<Transpose<NestedXpr> >
    219  : blas_traits<NestedXpr>
    220 {
    221   typedef typename NestedXpr::Scalar Scalar;
    222   typedef blas_traits<NestedXpr> Base;
    223   typedef Transpose<NestedXpr> XprType;
    224   typedef Transpose<const typename Base::_ExtractType>  ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
    225   typedef Transpose<const typename Base::_ExtractType> _ExtractType;
    226   typedef typename conditional<bool(Base::HasUsableDirectAccess),
    227     ExtractType,
    228     typename ExtractType::PlainObject
    229     >::type DirectLinearAccessType;
    230   enum {
    231     IsTransposed = Base::IsTransposed ? 0 : 1
    232   };
    233   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
    234   static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
    235 };
    236 
    237 template<typename T>
    238 struct blas_traits<const T>
    239      : blas_traits<T>
    240 {};
    241 
    242 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
    243 struct extract_data_selector {
    244   static const typename T::Scalar* run(const T& m)
    245   {
    246     return blas_traits<T>::extract(m).data();
    247   }
    248 };
    249 
    250 template<typename T>
    251 struct extract_data_selector<T,false> {
    252   static typename T::Scalar* run(const T&) { return 0; }
    253 };
    254 
    255 template<typename T> const typename T::Scalar* extract_data(const T& m)
    256 {
    257   return extract_data_selector<T>::run(m);
    258 }
    259 
    260 } // end namespace internal
    261 
    262 } // end namespace Eigen
    263 
    264 #endif // EIGEN_BLASUTIL_H
    265