Home | History | Annotate | Download | only in products
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
     11 #define EIGEN_TRIANGULARMATRIXVECTOR_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
     18 struct triangular_matrix_vector_product;
     19 
     20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
     21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
     22 {
     23   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
     24   enum {
     25     IsLower = ((Mode&Lower)==Lower),
     26     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
     27     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
     28   };
     29   static EIGEN_DONT_INLINE  void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
     30                                      const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
     31   {
     32     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
     33     Index size = (std::min)(_rows,_cols);
     34     Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
     35     Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
     36 
     37     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
     38     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
     39     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
     40 
     41     typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
     42     const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
     43     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
     44 
     45     typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
     46     ResMap res(_res,rows);
     47 
     48     for (Index pi=0; pi<size; pi+=PanelWidth)
     49     {
     50       Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
     51       for (Index k=0; k<actualPanelWidth; ++k)
     52       {
     53         Index i = pi + k;
     54         Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
     55         Index r = IsLower ? actualPanelWidth-k : k+1;
     56         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
     57           res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
     58         if (HasUnitDiag)
     59           res.coeffRef(i) += alpha * cjRhs.coeff(i);
     60       }
     61       Index r = IsLower ? rows - pi - actualPanelWidth : pi;
     62       if (r>0)
     63       {
     64         Index s = IsLower ? pi+actualPanelWidth : 0;
     65         general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
     66             r, actualPanelWidth,
     67             &lhs.coeffRef(s,pi), lhsStride,
     68             &rhs.coeffRef(pi), rhsIncr,
     69             &res.coeffRef(s), resIncr, alpha);
     70       }
     71     }
     72     if((!IsLower) && cols>size)
     73     {
     74       general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
     75           rows, cols-size,
     76           &lhs.coeffRef(0,size), lhsStride,
     77           &rhs.coeffRef(size), rhsIncr,
     78           _res, resIncr, alpha);
     79     }
     80   }
     81 };
     82 
     83 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
     84 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
     85 {
     86   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
     87   enum {
     88     IsLower = ((Mode&Lower)==Lower),
     89     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
     90     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
     91   };
     92   static void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
     93                   const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
     94   {
     95     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
     96     Index diagSize = (std::min)(_rows,_cols);
     97     Index rows = IsLower ? _rows : diagSize;
     98     Index cols = IsLower ? diagSize : _cols;
     99 
    100     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
    101     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
    102     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
    103 
    104     typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
    105     const RhsMap rhs(_rhs,cols);
    106     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
    107 
    108     typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
    109     ResMap res(_res,rows,InnerStride<>(resIncr));
    110 
    111     for (Index pi=0; pi<diagSize; pi+=PanelWidth)
    112     {
    113       Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
    114       for (Index k=0; k<actualPanelWidth; ++k)
    115       {
    116         Index i = pi + k;
    117         Index s = IsLower ? pi  : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
    118         Index r = IsLower ? k+1 : actualPanelWidth-k;
    119         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
    120           res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
    121         if (HasUnitDiag)
    122           res.coeffRef(i) += alpha * cjRhs.coeff(i);
    123       }
    124       Index r = IsLower ? pi : cols - pi - actualPanelWidth;
    125       if (r>0)
    126       {
    127         Index s = IsLower ? 0 : pi + actualPanelWidth;
    128         general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
    129             actualPanelWidth, r,
    130             &lhs.coeffRef(pi,s), lhsStride,
    131             &rhs.coeffRef(s), rhsIncr,
    132             &res.coeffRef(pi), resIncr, alpha);
    133       }
    134     }
    135     if(IsLower && rows>diagSize)
    136     {
    137       general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
    138             rows-diagSize, cols,
    139             &lhs.coeffRef(diagSize,0), lhsStride,
    140             &rhs.coeffRef(0), rhsIncr,
    141             &res.coeffRef(diagSize), resIncr, alpha);
    142     }
    143   }
    144 };
    145 
    146 /***************************************************************************
    147 * Wrapper to product_triangular_vector
    148 ***************************************************************************/
    149 
    150 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
    151 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
    152  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
    153 {};
    154 
    155 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
    156 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
    157  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
    158 {};
    159 
    160 
    161 template<int StorageOrder>
    162 struct trmv_selector;
    163 
    164 } // end namespace internal
    165 
    166 template<int Mode, typename Lhs, typename Rhs>
    167 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
    168   : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
    169 {
    170   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
    171 
    172   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
    173 
    174   template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
    175   {
    176     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
    177 
    178     internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
    179   }
    180 };
    181 
    182 template<int Mode, typename Lhs, typename Rhs>
    183 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
    184   : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
    185 {
    186   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
    187 
    188   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
    189 
    190   template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
    191   {
    192     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
    193 
    194     typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
    195     Transpose<Dest> dstT(dst);
    196     internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
    197       TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
    198   }
    199 };
    200 
    201 namespace internal {
    202 
    203 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
    204 
    205 template<> struct trmv_selector<ColMajor>
    206 {
    207   template<int Mode, typename Lhs, typename Rhs, typename Dest>
    208   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
    209   {
    210     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
    211     typedef typename ProductType::Index Index;
    212     typedef typename ProductType::LhsScalar   LhsScalar;
    213     typedef typename ProductType::RhsScalar   RhsScalar;
    214     typedef typename ProductType::Scalar      ResScalar;
    215     typedef typename ProductType::RealScalar  RealScalar;
    216     typedef typename ProductType::ActualLhsType ActualLhsType;
    217     typedef typename ProductType::ActualRhsType ActualRhsType;
    218     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
    219     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
    220     typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
    221 
    222     typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
    223     typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
    224 
    225     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
    226                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
    227 
    228     enum {
    229       // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
    230       // on, the other hand it is good for the cache to pack the vector anyways...
    231       EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
    232       ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
    233       MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
    234     };
    235 
    236     gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
    237 
    238     bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
    239     bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
    240 
    241     RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
    242 
    243     ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
    244                                                   evalToDest ? dest.data() : static_dest.data());
    245 
    246     if(!evalToDest)
    247     {
    248       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    249       int size = dest.size();
    250       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    251       #endif
    252       if(!alphaIsCompatible)
    253       {
    254         MappedDest(actualDestPtr, dest.size()).setZero();
    255         compatibleAlpha = RhsScalar(1);
    256       }
    257       else
    258         MappedDest(actualDestPtr, dest.size()) = dest;
    259     }
    260 
    261     internal::triangular_matrix_vector_product
    262       <Index,Mode,
    263        LhsScalar, LhsBlasTraits::NeedToConjugate,
    264        RhsScalar, RhsBlasTraits::NeedToConjugate,
    265        ColMajor>
    266       ::run(actualLhs.rows(),actualLhs.cols(),
    267             actualLhs.data(),actualLhs.outerStride(),
    268             actualRhs.data(),actualRhs.innerStride(),
    269             actualDestPtr,1,compatibleAlpha);
    270 
    271     if (!evalToDest)
    272     {
    273       if(!alphaIsCompatible)
    274         dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
    275       else
    276         dest = MappedDest(actualDestPtr, dest.size());
    277     }
    278   }
    279 };
    280 
    281 template<> struct trmv_selector<RowMajor>
    282 {
    283   template<int Mode, typename Lhs, typename Rhs, typename Dest>
    284   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
    285   {
    286     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
    287     typedef typename ProductType::LhsScalar LhsScalar;
    288     typedef typename ProductType::RhsScalar RhsScalar;
    289     typedef typename ProductType::Scalar    ResScalar;
    290     typedef typename ProductType::Index Index;
    291     typedef typename ProductType::ActualLhsType ActualLhsType;
    292     typedef typename ProductType::ActualRhsType ActualRhsType;
    293     typedef typename ProductType::_ActualRhsType _ActualRhsType;
    294     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
    295     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
    296 
    297     typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
    298     typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
    299 
    300     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
    301                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
    302 
    303     enum {
    304       DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
    305     };
    306 
    307     gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
    308 
    309     ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
    310         DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
    311 
    312     if(!DirectlyUseRhs)
    313     {
    314       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    315       int size = actualRhs.size();
    316       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
    317       #endif
    318       Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
    319     }
    320 
    321     internal::triangular_matrix_vector_product
    322       <Index,Mode,
    323        LhsScalar, LhsBlasTraits::NeedToConjugate,
    324        RhsScalar, RhsBlasTraits::NeedToConjugate,
    325        RowMajor>
    326       ::run(actualLhs.rows(),actualLhs.cols(),
    327             actualLhs.data(),actualLhs.outerStride(),
    328             actualRhsPtr,1,
    329             dest.data(),dest.innerStride(),
    330             actualAlpha);
    331   }
    332 };
    333 
    334 } // end namespace internal
    335 
    336 } // end namespace Eigen
    337 
    338 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H
    339