Home | History | Annotate | Download | only in products
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
     11 #define EIGEN_TRIANGULARMATRIXVECTOR_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
     18 struct triangular_matrix_vector_product;
     19 
     20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
     21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
     22 {
     23   typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
     24   enum {
     25     IsLower = ((Mode&Lower)==Lower),
     26     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
     27     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
     28   };
     29   static EIGEN_DONT_INLINE  void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
     30                                      const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha);
     31 };
     32 
     33 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
     34 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
     35   ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
     36         const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha)
     37   {
     38     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
     39     Index size = (std::min)(_rows,_cols);
     40     Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
     41     Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
     42 
     43     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
     44     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
     45     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
     46 
     47     typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
     48     const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
     49     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
     50 
     51     typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
     52     ResMap res(_res,rows);
     53 
     54     typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
     55     typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
     56 
     57     for (Index pi=0; pi<size; pi+=PanelWidth)
     58     {
     59       Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
     60       for (Index k=0; k<actualPanelWidth; ++k)
     61       {
     62         Index i = pi + k;
     63         Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
     64         Index r = IsLower ? actualPanelWidth-k : k+1;
     65         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
     66           res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
     67         if (HasUnitDiag)
     68           res.coeffRef(i) += alpha * cjRhs.coeff(i);
     69       }
     70       Index r = IsLower ? rows - pi - actualPanelWidth : pi;
     71       if (r>0)
     72       {
     73         Index s = IsLower ? pi+actualPanelWidth : 0;
     74         general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
     75             r, actualPanelWidth,
     76             LhsMapper(&lhs.coeffRef(s,pi), lhsStride),
     77             RhsMapper(&rhs.coeffRef(pi), rhsIncr),
     78             &res.coeffRef(s), resIncr, alpha);
     79       }
     80     }
     81     if((!IsLower) && cols>size)
     82     {
     83       general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
     84           rows, cols-size,
     85           LhsMapper(&lhs.coeffRef(0,size), lhsStride),
     86           RhsMapper(&rhs.coeffRef(size), rhsIncr),
     87           _res, resIncr, alpha);
     88     }
     89   }
     90 
     91 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
     92 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
     93 {
     94   typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
     95   enum {
     96     IsLower = ((Mode&Lower)==Lower),
     97     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
     98     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
     99   };
    100   static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
    101                                     const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
    102 };
    103 
    104 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
    105 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
    106   ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
    107         const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
    108   {
    109     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
    110     Index diagSize = (std::min)(_rows,_cols);
    111     Index rows = IsLower ? _rows : diagSize;
    112     Index cols = IsLower ? diagSize : _cols;
    113 
    114     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
    115     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
    116     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
    117 
    118     typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
    119     const RhsMap rhs(_rhs,cols);
    120     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
    121 
    122     typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
    123     ResMap res(_res,rows,InnerStride<>(resIncr));
    124 
    125     typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
    126     typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
    127 
    128     for (Index pi=0; pi<diagSize; pi+=PanelWidth)
    129     {
    130       Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
    131       for (Index k=0; k<actualPanelWidth; ++k)
    132       {
    133         Index i = pi + k;
    134         Index s = IsLower ? pi  : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
    135         Index r = IsLower ? k+1 : actualPanelWidth-k;
    136         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
    137           res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
    138         if (HasUnitDiag)
    139           res.coeffRef(i) += alpha * cjRhs.coeff(i);
    140       }
    141       Index r = IsLower ? pi : cols - pi - actualPanelWidth;
    142       if (r>0)
    143       {
    144         Index s = IsLower ? 0 : pi + actualPanelWidth;
    145         general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
    146             actualPanelWidth, r,
    147             LhsMapper(&lhs.coeffRef(pi,s), lhsStride),
    148             RhsMapper(&rhs.coeffRef(s), rhsIncr),
    149             &res.coeffRef(pi), resIncr, alpha);
    150       }
    151     }
    152     if(IsLower && rows>diagSize)
    153     {
    154       general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
    155             rows-diagSize, cols,
    156             LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride),
    157             RhsMapper(&rhs.coeffRef(0), rhsIncr),
    158             &res.coeffRef(diagSize), resIncr, alpha);
    159     }
    160   }
    161 
    162 /***************************************************************************
    163 * Wrapper to product_triangular_vector
    164 ***************************************************************************/
    165 
    166 template<int Mode,int StorageOrder>
    167 struct trmv_selector;
    168 
    169 } // end namespace internal
    170 
    171 namespace internal {
    172 
    173 template<int Mode, typename Lhs, typename Rhs>
    174 struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true>
    175 {
    176   template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
    177   {
    178     eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
    179 
    180     internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha);
    181   }
    182 };
    183 
    184 template<int Mode, typename Lhs, typename Rhs>
    185 struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false>
    186 {
    187   template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
    188   {
    189     eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
    190 
    191     Transpose<Dest> dstT(dst);
    192     internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
    193                             (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
    194             ::run(rhs.transpose(),lhs.transpose(), dstT, alpha);
    195   }
    196 };
    197 
    198 } // end namespace internal
    199 
    200 namespace internal {
    201 
    202 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
    203 
    204 template<int Mode> struct trmv_selector<Mode,ColMajor>
    205 {
    206   template<typename Lhs, typename Rhs, typename Dest>
    207   static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
    208   {
    209     typedef typename Lhs::Scalar      LhsScalar;
    210     typedef typename Rhs::Scalar      RhsScalar;
    211     typedef typename Dest::Scalar     ResScalar;
    212     typedef typename Dest::RealScalar RealScalar;
    213 
    214     typedef internal::blas_traits<Lhs> LhsBlasTraits;
    215     typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
    216     typedef internal::blas_traits<Rhs> RhsBlasTraits;
    217     typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
    218 
    219     typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest;
    220 
    221     typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
    222     typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
    223 
    224     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
    225                                   * RhsBlasTraits::extractScalarFactor(rhs);
    226 
    227     enum {
    228       // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
    229       // on, the other hand it is good for the cache to pack the vector anyways...
    230       EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
    231       ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
    232       MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
    233     };
    234 
    235     gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
    236 
    237     bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
    238     bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
    239 
    240     RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
    241 
    242     ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
    243                                                   evalToDest ? dest.data() : static_dest.data());
    244 
    245     if(!evalToDest)
    246     {
    247       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    248       Index size = dest.size();
    249       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    250       #endif
    251       if(!alphaIsCompatible)
    252       {
    253         MappedDest(actualDestPtr, dest.size()).setZero();
    254         compatibleAlpha = RhsScalar(1);
    255       }
    256       else
    257         MappedDest(actualDestPtr, dest.size()) = dest;
    258     }
    259 
    260     internal::triangular_matrix_vector_product
    261       <Index,Mode,
    262        LhsScalar, LhsBlasTraits::NeedToConjugate,
    263        RhsScalar, RhsBlasTraits::NeedToConjugate,
    264        ColMajor>
    265       ::run(actualLhs.rows(),actualLhs.cols(),
    266             actualLhs.data(),actualLhs.outerStride(),
    267             actualRhs.data(),actualRhs.innerStride(),
    268             actualDestPtr,1,compatibleAlpha);
    269 
    270     if (!evalToDest)
    271     {
    272       if(!alphaIsCompatible)
    273         dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
    274       else
    275         dest = MappedDest(actualDestPtr, dest.size());
    276     }
    277   }
    278 };
    279 
    280 template<int Mode> struct trmv_selector<Mode,RowMajor>
    281 {
    282   template<typename Lhs, typename Rhs, typename Dest>
    283   static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
    284   {
    285     typedef typename Lhs::Scalar      LhsScalar;
    286     typedef typename Rhs::Scalar      RhsScalar;
    287     typedef typename Dest::Scalar     ResScalar;
    288 
    289     typedef internal::blas_traits<Lhs> LhsBlasTraits;
    290     typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
    291     typedef internal::blas_traits<Rhs> RhsBlasTraits;
    292     typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
    293     typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
    294 
    295     typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
    296     typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
    297 
    298     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
    299                                   * RhsBlasTraits::extractScalarFactor(rhs);
    300 
    301     enum {
    302       DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
    303     };
    304 
    305     gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
    306 
    307     ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
    308         DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
    309 
    310     if(!DirectlyUseRhs)
    311     {
    312       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    313       Index size = actualRhs.size();
    314       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    315       #endif
    316       Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
    317     }
    318 
    319     internal::triangular_matrix_vector_product
    320       <Index,Mode,
    321        LhsScalar, LhsBlasTraits::NeedToConjugate,
    322        RhsScalar, RhsBlasTraits::NeedToConjugate,
    323        RowMajor>
    324       ::run(actualLhs.rows(),actualLhs.cols(),
    325             actualLhs.data(),actualLhs.outerStride(),
    326             actualRhsPtr,1,
    327             dest.data(),dest.innerStride(),
    328             actualAlpha);
    329   }
    330 };
    331 
    332 } // end namespace internal
    333 
    334 } // end namespace Eigen
    335 
    336 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H
    337