Home | History | Annotate | Download | only in SparseCore
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SPARSETRIANGULARSOLVER_H
     11 #define EIGEN_SPARSETRIANGULARSOLVER_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 template<typename Lhs, typename Rhs, int Mode,
     18   int UpLo = (Mode & Lower)
     19            ? Lower
     20            : (Mode & Upper)
     21            ? Upper
     22            : -1,
     23   int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
     24 struct sparse_solve_triangular_selector;
     25 
     26 // forward substitution, row-major
     27 template<typename Lhs, typename Rhs, int Mode>
     28 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
     29 {
     30   typedef typename Rhs::Scalar Scalar;
     31   static void run(const Lhs& lhs, Rhs& other)
     32   {
     33     for(int col=0 ; col<other.cols() ; ++col)
     34     {
     35       for(int i=0; i<lhs.rows(); ++i)
     36       {
     37         Scalar tmp = other.coeff(i,col);
     38         Scalar lastVal(0);
     39         int lastIndex = 0;
     40         for(typename Lhs::InnerIterator it(lhs, i); it; ++it)
     41         {
     42           lastVal = it.value();
     43           lastIndex = it.index();
     44           if(lastIndex==i)
     45             break;
     46           tmp -= lastVal * other.coeff(lastIndex,col);
     47         }
     48         if (Mode & UnitDiag)
     49           other.coeffRef(i,col) = tmp;
     50         else
     51         {
     52           eigen_assert(lastIndex==i);
     53           other.coeffRef(i,col) = tmp/lastVal;
     54         }
     55       }
     56     }
     57   }
     58 };
     59 
     60 // backward substitution, row-major
     61 template<typename Lhs, typename Rhs, int Mode>
     62 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
     63 {
     64   typedef typename Rhs::Scalar Scalar;
     65   static void run(const Lhs& lhs, Rhs& other)
     66   {
     67     for(int col=0 ; col<other.cols() ; ++col)
     68     {
     69       for(int i=lhs.rows()-1 ; i>=0 ; --i)
     70       {
     71         Scalar tmp = other.coeff(i,col);
     72         Scalar l_ii = 0;
     73         typename Lhs::InnerIterator it(lhs, i);
     74         while(it && it.index()<i)
     75           ++it;
     76         if(!(Mode & UnitDiag))
     77         {
     78           eigen_assert(it && it.index()==i);
     79           l_ii = it.value();
     80           ++it;
     81         }
     82         else if (it && it.index() == i)
     83           ++it;
     84         for(; it; ++it)
     85         {
     86           tmp -= it.value() * other.coeff(it.index(),col);
     87         }
     88 
     89         if (Mode & UnitDiag)
     90           other.coeffRef(i,col) = tmp;
     91         else
     92           other.coeffRef(i,col) = tmp/l_ii;
     93       }
     94     }
     95   }
     96 };
     97 
     98 // forward substitution, col-major
     99 template<typename Lhs, typename Rhs, int Mode>
    100 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
    101 {
    102   typedef typename Rhs::Scalar Scalar;
    103   static void run(const Lhs& lhs, Rhs& other)
    104   {
    105     for(int col=0 ; col<other.cols() ; ++col)
    106     {
    107       for(int i=0; i<lhs.cols(); ++i)
    108       {
    109         Scalar& tmp = other.coeffRef(i,col);
    110         if (tmp!=Scalar(0)) // optimization when other is actually sparse
    111         {
    112           typename Lhs::InnerIterator it(lhs, i);
    113           while(it && it.index()<i)
    114             ++it;
    115           if(!(Mode & UnitDiag))
    116           {
    117             eigen_assert(it && it.index()==i);
    118             tmp /= it.value();
    119           }
    120           if (it && it.index()==i)
    121             ++it;
    122           for(; it; ++it)
    123             other.coeffRef(it.index(), col) -= tmp * it.value();
    124         }
    125       }
    126     }
    127   }
    128 };
    129 
    130 // backward substitution, col-major
    131 template<typename Lhs, typename Rhs, int Mode>
    132 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
    133 {
    134   typedef typename Rhs::Scalar Scalar;
    135   static void run(const Lhs& lhs, Rhs& other)
    136   {
    137     for(int col=0 ; col<other.cols() ; ++col)
    138     {
    139       for(int i=lhs.cols()-1; i>=0; --i)
    140       {
    141         Scalar& tmp = other.coeffRef(i,col);
    142         if (tmp!=Scalar(0)) // optimization when other is actually sparse
    143         {
    144           if(!(Mode & UnitDiag))
    145           {
    146             // TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
    147             typename Lhs::ReverseInnerIterator it(lhs, i);
    148             while(it && it.index()!=i)
    149               --it;
    150             eigen_assert(it && it.index()==i);
    151             other.coeffRef(i,col) /= it.value();
    152           }
    153           typename Lhs::InnerIterator it(lhs, i);
    154           for(; it && it.index()<i; ++it)
    155             other.coeffRef(it.index(), col) -= tmp * it.value();
    156         }
    157       }
    158     }
    159   }
    160 };
    161 
    162 } // end namespace internal
    163 
    164 template<typename ExpressionType,int Mode>
    165 template<typename OtherDerived>
    166 void SparseTriangularView<ExpressionType,Mode>::solveInPlace(MatrixBase<OtherDerived>& other) const
    167 {
    168   eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
    169   eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
    170 
    171   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
    172 
    173   typedef typename internal::conditional<copy,
    174     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
    175   OtherCopy otherCopy(other.derived());
    176 
    177   internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(m_matrix, otherCopy);
    178 
    179   if (copy)
    180     other = otherCopy;
    181 }
    182 
    183 template<typename ExpressionType,int Mode>
    184 template<typename OtherDerived>
    185 typename internal::plain_matrix_type_column_major<OtherDerived>::type
    186 SparseTriangularView<ExpressionType,Mode>::solve(const MatrixBase<OtherDerived>& other) const
    187 {
    188   typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
    189   solveInPlace(res);
    190   return res;
    191 }
    192 
    193 // pure sparse path
    194 
    195 namespace internal {
    196 
    197 template<typename Lhs, typename Rhs, int Mode,
    198   int UpLo = (Mode & Lower)
    199            ? Lower
    200            : (Mode & Upper)
    201            ? Upper
    202            : -1,
    203   int StorageOrder = int(Lhs::Flags) & (RowMajorBit)>
    204 struct sparse_solve_triangular_sparse_selector;
    205 
    206 // forward substitution, col-major
    207 template<typename Lhs, typename Rhs, int Mode, int UpLo>
    208 struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
    209 {
    210   typedef typename Rhs::Scalar Scalar;
    211   typedef typename promote_index_type<typename traits<Lhs>::Index,
    212                                          typename traits<Rhs>::Index>::type Index;
    213   static void run(const Lhs& lhs, Rhs& other)
    214   {
    215     const bool IsLower = (UpLo==Lower);
    216     AmbiVector<Scalar,Index> tempVector(other.rows()*2);
    217     tempVector.setBounds(0,other.rows());
    218 
    219     Rhs res(other.rows(), other.cols());
    220     res.reserve(other.nonZeros());
    221 
    222     for(int col=0 ; col<other.cols() ; ++col)
    223     {
    224       // FIXME estimate number of non zeros
    225       tempVector.init(.99/*float(other.col(col).nonZeros())/float(other.rows())*/);
    226       tempVector.setZero();
    227       tempVector.restart();
    228       for (typename Rhs::InnerIterator rhsIt(other, col); rhsIt; ++rhsIt)
    229       {
    230         tempVector.coeffRef(rhsIt.index()) = rhsIt.value();
    231       }
    232 
    233       for(int i=IsLower?0:lhs.cols()-1;
    234           IsLower?i<lhs.cols():i>=0;
    235           i+=IsLower?1:-1)
    236       {
    237         tempVector.restart();
    238         Scalar& ci = tempVector.coeffRef(i);
    239         if (ci!=Scalar(0))
    240         {
    241           // find
    242           typename Lhs::InnerIterator it(lhs, i);
    243           if(!(Mode & UnitDiag))
    244           {
    245             if (IsLower)
    246             {
    247               eigen_assert(it.index()==i);
    248               ci /= it.value();
    249             }
    250             else
    251               ci /= lhs.coeff(i,i);
    252           }
    253           tempVector.restart();
    254           if (IsLower)
    255           {
    256             if (it.index()==i)
    257               ++it;
    258             for(; it; ++it)
    259               tempVector.coeffRef(it.index()) -= ci * it.value();
    260           }
    261           else
    262           {
    263             for(; it && it.index()<i; ++it)
    264               tempVector.coeffRef(it.index()) -= ci * it.value();
    265           }
    266         }
    267       }
    268 
    269 
    270       int count = 0;
    271       // FIXME compute a reference value to filter zeros
    272       for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector/*,1e-12*/); it; ++it)
    273       {
    274         ++ count;
    275 //         std::cerr << "fill " << it.index() << ", " << col << "\n";
    276 //         std::cout << it.value() << "  ";
    277         // FIXME use insertBack
    278         res.insert(it.index(), col) = it.value();
    279       }
    280 //       std::cout << "tempVector.nonZeros() == " << int(count) << " / " << (other.rows()) << "\n";
    281     }
    282     res.finalize();
    283     other = res.markAsRValue();
    284   }
    285 };
    286 
    287 } // end namespace internal
    288 
    289 template<typename ExpressionType,int Mode>
    290 template<typename OtherDerived>
    291 void SparseTriangularView<ExpressionType,Mode>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
    292 {
    293   eigen_assert(m_matrix.cols() == m_matrix.rows() && m_matrix.cols() == other.rows());
    294   eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
    295 
    296 //   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
    297 
    298 //   typedef typename internal::conditional<copy,
    299 //     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
    300 //   OtherCopy otherCopy(other.derived());
    301 
    302   internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(m_matrix, other.derived());
    303 
    304 //   if (copy)
    305 //     other = otherCopy;
    306 }
    307 
    308 #ifdef EIGEN2_SUPPORT
    309 
    310 // deprecated stuff:
    311 
    312 /** \deprecated */
    313 template<typename Derived>
    314 template<typename OtherDerived>
    315 void SparseMatrixBase<Derived>::solveTriangularInPlace(MatrixBase<OtherDerived>& other) const
    316 {
    317   this->template triangular<Flags&(Upper|Lower)>().solveInPlace(other);
    318 }
    319 
    320 /** \deprecated */
    321 template<typename Derived>
    322 template<typename OtherDerived>
    323 typename internal::plain_matrix_type_column_major<OtherDerived>::type
    324 SparseMatrixBase<Derived>::solveTriangular(const MatrixBase<OtherDerived>& other) const
    325 {
    326   typename internal::plain_matrix_type_column_major<OtherDerived>::type res(other);
    327   derived().solveTriangularInPlace(res);
    328   return res;
    329 }
    330 #endif // EIGEN2_SUPPORT
    331 
    332 } // end namespace Eigen
    333 
    334 #endif // EIGEN_SPARSETRIANGULARSOLVER_H
    335