Home | History | Annotate | Download | only in SparseCore
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
     11 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 
     18 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
     19 template<typename Lhs, typename Rhs, typename ResultType>
     20 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
     21 {
     22   // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
     23 
     24   typedef typename remove_all<Lhs>::type::Scalar Scalar;
     25   typedef typename remove_all<Lhs>::type::Index Index;
     26 
     27   // make sure to call innerSize/outerSize since we fake the storage order.
     28   Index rows = lhs.innerSize();
     29   Index cols = rhs.outerSize();
     30   //Index size = lhs.outerSize();
     31   eigen_assert(lhs.outerSize() == rhs.innerSize());
     32 
     33   // allocate a temporary buffer
     34   AmbiVector<Scalar,Index> tempVector(rows);
     35 
     36   // estimate the number of non zero entries
     37   // given a rhs column containing Y non zeros, we assume that the respective Y columns
     38   // of the lhs differs in average of one non zeros, thus the number of non zeros for
     39   // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
     40   // per column of the lhs.
     41   // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
     42   Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
     43 
     44   // mimics a resizeByInnerOuter:
     45   if(ResultType::IsRowMajor)
     46     res.resize(cols, rows);
     47   else
     48     res.resize(rows, cols);
     49 
     50   res.reserve(estimated_nnz_prod);
     51   double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
     52   for (Index j=0; j<cols; ++j)
     53   {
     54     // FIXME:
     55     //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
     56     // let's do a more accurate determination of the nnz ratio for the current column j of res
     57     tempVector.init(ratioColRes);
     58     tempVector.setZero();
     59     for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
     60     {
     61       // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
     62       tempVector.restart();
     63       Scalar x = rhsIt.value();
     64       for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
     65       {
     66         tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
     67       }
     68     }
     69     res.startVec(j);
     70     for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
     71       res.insertBackByOuterInner(j,it.index()) = it.value();
     72   }
     73   res.finalize();
     74 }
     75 
     76 template<typename Lhs, typename Rhs, typename ResultType,
     77   int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
     78   int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
     79   int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
     80 struct sparse_sparse_product_with_pruning_selector;
     81 
     82 template<typename Lhs, typename Rhs, typename ResultType>
     83 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
     84 {
     85   typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
     86   typedef typename ResultType::RealScalar RealScalar;
     87 
     88   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
     89   {
     90     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
     91     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
     92     res.swap(_res);
     93   }
     94 };
     95 
     96 template<typename Lhs, typename Rhs, typename ResultType>
     97 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
     98 {
     99   typedef typename ResultType::RealScalar RealScalar;
    100   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    101   {
    102     // we need a col-major matrix to hold the result
    103     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> SparseTemporaryType;
    104     SparseTemporaryType _res(res.rows(), res.cols());
    105     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
    106     res = _res;
    107   }
    108 };
    109 
    110 template<typename Lhs, typename Rhs, typename ResultType>
    111 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
    112 {
    113   typedef typename ResultType::RealScalar RealScalar;
    114   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    115   {
    116     // let's transpose the product to get a column x column product
    117     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
    118     internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
    119     res.swap(_res);
    120   }
    121 };
    122 
    123 template<typename Lhs, typename Rhs, typename ResultType>
    124 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
    125 {
    126   typedef typename ResultType::RealScalar RealScalar;
    127   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
    128   {
    129     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs;
    130     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs;
    131     ColMajorMatrixLhs colLhs(lhs);
    132     ColMajorMatrixRhs colRhs(rhs);
    133     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
    134 
    135     // let's transpose the product to get a column x column product
    136 //     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
    137 //     SparseTemporaryType _res(res.cols(), res.rows());
    138 //     sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
    139 //     res = _res.transpose();
    140   }
    141 };
    142 
    143 // NOTE the 2 others cases (col row *) must never occur since they are caught
    144 // by ProductReturnType which transforms it to (col col *) by evaluating rhs.
    145 
    146 } // end namespace internal
    147 
    148 } // end namespace Eigen
    149 
    150 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
    151