Home | History | Annotate | Download | only in products
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
     11 #define EIGEN_TRIANGULARMATRIXVECTOR_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
     18 struct triangular_matrix_vector_product;
     19 
     20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
     21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
     22 {
     23   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
     24   enum {
     25     IsLower = ((Mode&Lower)==Lower),
     26     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
     27     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
     28   };
     29   static EIGEN_DONT_INLINE  void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
     30                                      const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
     31 };
     32 
     33 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
     34 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
     35   ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
     36         const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
     37   {
     38     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
     39     Index size = (std::min)(_rows,_cols);
     40     Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
     41     Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
     42 
     43     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
     44     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
     45     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
     46 
     47     typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
     48     const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
     49     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
     50 
     51     typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
     52     ResMap res(_res,rows);
     53 
     54     for (Index pi=0; pi<size; pi+=PanelWidth)
     55     {
     56       Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
     57       for (Index k=0; k<actualPanelWidth; ++k)
     58       {
     59         Index i = pi + k;
     60         Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
     61         Index r = IsLower ? actualPanelWidth-k : k+1;
     62         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
     63           res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
     64         if (HasUnitDiag)
     65           res.coeffRef(i) += alpha * cjRhs.coeff(i);
     66       }
     67       Index r = IsLower ? rows - pi - actualPanelWidth : pi;
     68       if (r>0)
     69       {
     70         Index s = IsLower ? pi+actualPanelWidth : 0;
     71         general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
     72             r, actualPanelWidth,
     73             &lhs.coeffRef(s,pi), lhsStride,
     74             &rhs.coeffRef(pi), rhsIncr,
     75             &res.coeffRef(s), resIncr, alpha);
     76       }
     77     }
     78     if((!IsLower) && cols>size)
     79     {
     80       general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
     81           rows, cols-size,
     82           &lhs.coeffRef(0,size), lhsStride,
     83           &rhs.coeffRef(size), rhsIncr,
     84           _res, resIncr, alpha);
     85     }
     86   }
     87 
     88 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
     89 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
     90 {
     91   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
     92   enum {
     93     IsLower = ((Mode&Lower)==Lower),
     94     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
     95     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
     96   };
     97   static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
     98                                     const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
     99 };
    100 
    101 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
    102 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
    103   ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
    104         const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
    105   {
    106     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
    107     Index diagSize = (std::min)(_rows,_cols);
    108     Index rows = IsLower ? _rows : diagSize;
    109     Index cols = IsLower ? diagSize : _cols;
    110 
    111     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
    112     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
    113     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
    114 
    115     typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
    116     const RhsMap rhs(_rhs,cols);
    117     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
    118 
    119     typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
    120     ResMap res(_res,rows,InnerStride<>(resIncr));
    121 
    122     for (Index pi=0; pi<diagSize; pi+=PanelWidth)
    123     {
    124       Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
    125       for (Index k=0; k<actualPanelWidth; ++k)
    126       {
    127         Index i = pi + k;
    128         Index s = IsLower ? pi  : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
    129         Index r = IsLower ? k+1 : actualPanelWidth-k;
    130         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
    131           res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
    132         if (HasUnitDiag)
    133           res.coeffRef(i) += alpha * cjRhs.coeff(i);
    134       }
    135       Index r = IsLower ? pi : cols - pi - actualPanelWidth;
    136       if (r>0)
    137       {
    138         Index s = IsLower ? 0 : pi + actualPanelWidth;
    139         general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
    140             actualPanelWidth, r,
    141             &lhs.coeffRef(pi,s), lhsStride,
    142             &rhs.coeffRef(s), rhsIncr,
    143             &res.coeffRef(pi), resIncr, alpha);
    144       }
    145     }
    146     if(IsLower && rows>diagSize)
    147     {
    148       general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
    149             rows-diagSize, cols,
    150             &lhs.coeffRef(diagSize,0), lhsStride,
    151             &rhs.coeffRef(0), rhsIncr,
    152             &res.coeffRef(diagSize), resIncr, alpha);
    153     }
    154   }
    155 
    156 /***************************************************************************
    157 * Wrapper to product_triangular_vector
    158 ***************************************************************************/
    159 
    160 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
    161 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
    162  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
    163 {};
    164 
    165 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
    166 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
    167  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
    168 {};
    169 
    170 
    171 template<int StorageOrder>
    172 struct trmv_selector;
    173 
    174 } // end namespace internal
    175 
    176 template<int Mode, typename Lhs, typename Rhs>
    177 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
    178   : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
    179 {
    180   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
    181 
    182   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
    183 
    184   template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
    185   {
    186     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
    187 
    188     internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
    189   }
    190 };
    191 
    192 template<int Mode, typename Lhs, typename Rhs>
    193 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
    194   : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
    195 {
    196   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
    197 
    198   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
    199 
    200   template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
    201   {
    202     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
    203 
    204     typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
    205     Transpose<Dest> dstT(dst);
    206     internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
    207       TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
    208   }
    209 };
    210 
    211 namespace internal {
    212 
    213 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
    214 
    215 template<> struct trmv_selector<ColMajor>
    216 {
    217   template<int Mode, typename Lhs, typename Rhs, typename Dest>
    218   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha)
    219   {
    220     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
    221     typedef typename ProductType::Index Index;
    222     typedef typename ProductType::LhsScalar   LhsScalar;
    223     typedef typename ProductType::RhsScalar   RhsScalar;
    224     typedef typename ProductType::Scalar      ResScalar;
    225     typedef typename ProductType::RealScalar  RealScalar;
    226     typedef typename ProductType::ActualLhsType ActualLhsType;
    227     typedef typename ProductType::ActualRhsType ActualRhsType;
    228     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
    229     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
    230     typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
    231 
    232     typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
    233     typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
    234 
    235     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
    236                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
    237 
    238     enum {
    239       // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
    240       // on, the other hand it is good for the cache to pack the vector anyways...
    241       EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
    242       ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
    243       MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
    244     };
    245 
    246     gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
    247 
    248     bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
    249     bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
    250 
    251     RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
    252 
    253     ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
    254                                                   evalToDest ? dest.data() : static_dest.data());
    255 
    256     if(!evalToDest)
    257     {
    258       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    259       Index size = dest.size();
    260       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    261       #endif
    262       if(!alphaIsCompatible)
    263       {
    264         MappedDest(actualDestPtr, dest.size()).setZero();
    265         compatibleAlpha = RhsScalar(1);
    266       }
    267       else
    268         MappedDest(actualDestPtr, dest.size()) = dest;
    269     }
    270 
    271     internal::triangular_matrix_vector_product
    272       <Index,Mode,
    273        LhsScalar, LhsBlasTraits::NeedToConjugate,
    274        RhsScalar, RhsBlasTraits::NeedToConjugate,
    275        ColMajor>
    276       ::run(actualLhs.rows(),actualLhs.cols(),
    277             actualLhs.data(),actualLhs.outerStride(),
    278             actualRhs.data(),actualRhs.innerStride(),
    279             actualDestPtr,1,compatibleAlpha);
    280 
    281     if (!evalToDest)
    282     {
    283       if(!alphaIsCompatible)
    284         dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
    285       else
    286         dest = MappedDest(actualDestPtr, dest.size());
    287     }
    288   }
    289 };
    290 
    291 template<> struct trmv_selector<RowMajor>
    292 {
    293   template<int Mode, typename Lhs, typename Rhs, typename Dest>
    294   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha)
    295   {
    296     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
    297     typedef typename ProductType::LhsScalar LhsScalar;
    298     typedef typename ProductType::RhsScalar RhsScalar;
    299     typedef typename ProductType::Scalar    ResScalar;
    300     typedef typename ProductType::Index Index;
    301     typedef typename ProductType::ActualLhsType ActualLhsType;
    302     typedef typename ProductType::ActualRhsType ActualRhsType;
    303     typedef typename ProductType::_ActualRhsType _ActualRhsType;
    304     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
    305     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
    306 
    307     typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
    308     typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
    309 
    310     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
    311                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
    312 
    313     enum {
    314       DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
    315     };
    316 
    317     gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
    318 
    319     ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
    320         DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
    321 
    322     if(!DirectlyUseRhs)
    323     {
    324       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    325       int size = actualRhs.size();
    326       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    327       #endif
    328       Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
    329     }
    330 
    331     internal::triangular_matrix_vector_product
    332       <Index,Mode,
    333        LhsScalar, LhsBlasTraits::NeedToConjugate,
    334        RhsScalar, RhsBlasTraits::NeedToConjugate,
    335        RowMajor>
    336       ::run(actualLhs.rows(),actualLhs.cols(),
    337             actualLhs.data(),actualLhs.outerStride(),
    338             actualRhsPtr,1,
    339             dest.data(),dest.innerStride(),
    340             actualAlpha);
    341   }
    342 };
    343 
    344 } // end namespace internal
    345 
    346 } // end namespace Eigen
    347 
    348 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H
    349