Home | History | Annotate | Download | only in products
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_TRIANGULAR_SOLVER_VECTOR_H
     11 #define EIGEN_TRIANGULAR_SOLVER_VECTOR_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
     18 struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Conjugate, StorageOrder>
     19 {
     20   static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
     21   {
     22     triangular_solve_vector<LhsScalar,RhsScalar,Index,OnTheLeft,
     23         ((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
     24         Conjugate,StorageOrder==RowMajor?ColMajor:RowMajor
     25       >::run(size, _lhs, lhsStride, rhs);
     26   }
     27 };
     28 
     29 // forward and backward substitution, row-major, rhs is a vector
     30 template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
     31 struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, RowMajor>
     32 {
     33   enum {
     34     IsLower = ((Mode&Lower)==Lower)
     35   };
     36   static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
     37   {
     38     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
     39     const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
     40 
     41     typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
     42     typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
     43 
     44     typename internal::conditional<
     45                           Conjugate,
     46                           const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
     47                           const LhsMap&>
     48                         ::type cjLhs(lhs);
     49     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
     50     for(Index pi=IsLower ? 0 : size;
     51         IsLower ? pi<size : pi>0;
     52         IsLower ? pi+=PanelWidth : pi-=PanelWidth)
     53     {
     54       Index actualPanelWidth = (std::min)(IsLower ? size - pi : pi, PanelWidth);
     55 
     56       Index r = IsLower ? pi : size - pi; // remaining size
     57       if (r > 0)
     58       {
     59         // let's directly call the low level product function because:
     60         // 1 - it is faster to compile
     61         // 2 - it is slighlty faster at runtime
     62         Index startRow = IsLower ? pi : pi-actualPanelWidth;
     63         Index startCol = IsLower ? 0 : pi;
     64 
     65         general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,Conjugate,RhsScalar,RhsMapper,false>::run(
     66           actualPanelWidth, r,
     67           LhsMapper(&lhs.coeffRef(startRow,startCol), lhsStride),
     68           RhsMapper(rhs + startCol, 1),
     69           rhs + startRow, 1,
     70           RhsScalar(-1));
     71       }
     72 
     73       for(Index k=0; k<actualPanelWidth; ++k)
     74       {
     75         Index i = IsLower ? pi+k : pi-k-1;
     76         Index s = IsLower ? pi   : i+1;
     77         if (k>0)
     78           rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map<const Matrix<RhsScalar,Dynamic,1> >(rhs+s,k))).sum();
     79 
     80         if(!(Mode & UnitDiag))
     81           rhs[i] /= cjLhs(i,i);
     82       }
     83     }
     84   }
     85 };
     86 
     87 // forward and backward substitution, column-major, rhs is a vector
     88 template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
     89 struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, ColMajor>
     90 {
     91   enum {
     92     IsLower = ((Mode&Lower)==Lower)
     93   };
     94   static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
     95   {
     96     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
     97     const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
     98     typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
     99     typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
    100     typename internal::conditional<Conjugate,
    101                                    const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
    102                                    const LhsMap&
    103                                   >::type cjLhs(lhs);
    104     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
    105 
    106     for(Index pi=IsLower ? 0 : size;
    107         IsLower ? pi<size : pi>0;
    108         IsLower ? pi+=PanelWidth : pi-=PanelWidth)
    109     {
    110       Index actualPanelWidth = (std::min)(IsLower ? size - pi : pi, PanelWidth);
    111       Index startBlock = IsLower ? pi : pi-actualPanelWidth;
    112       Index endBlock = IsLower ? pi + actualPanelWidth : 0;
    113 
    114       for(Index k=0; k<actualPanelWidth; ++k)
    115       {
    116         Index i = IsLower ? pi+k : pi-k-1;
    117         if(!(Mode & UnitDiag))
    118           rhs[i] /= cjLhs.coeff(i,i);
    119 
    120         Index r = actualPanelWidth - k - 1; // remaining size
    121         Index s = IsLower ? i+1 : i-r;
    122         if (r>0)
    123           Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,r) -= rhs[i] * cjLhs.col(i).segment(s,r);
    124       }
    125       Index r = IsLower ? size - endBlock : startBlock; // remaining size
    126       if (r > 0)
    127       {
    128         // let's directly call the low level product function because:
    129         // 1 - it is faster to compile
    130         // 2 - it is slighlty faster at runtime
    131         general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,Conjugate,RhsScalar,RhsMapper,false>::run(
    132             r, actualPanelWidth,
    133             LhsMapper(&lhs.coeffRef(endBlock,startBlock), lhsStride),
    134             RhsMapper(rhs+startBlock, 1),
    135             rhs+endBlock, 1, RhsScalar(-1));
    136       }
    137     }
    138   }
    139 };
    140 
    141 } // end namespace internal
    142 
    143 } // end namespace Eigen
    144 
    145 #endif // EIGEN_TRIANGULAR_SOLVER_VECTOR_H
    146