1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud (at) inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 11 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 18 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major 19 template<typename Lhs, typename Rhs, typename ResultType> 20 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance) 21 { 22 // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res); 23 24 typedef typename remove_all<Lhs>::type::Scalar Scalar; 25 typedef typename remove_all<Lhs>::type::Index Index; 26 27 // make sure to call innerSize/outerSize since we fake the storage order. 28 Index rows = lhs.innerSize(); 29 Index cols = rhs.outerSize(); 30 //Index size = lhs.outerSize(); 31 eigen_assert(lhs.outerSize() == rhs.innerSize()); 32 33 // allocate a temporary buffer 34 AmbiVector<Scalar,Index> tempVector(rows); 35 36 // estimate the number of non zero entries 37 // given a rhs column containing Y non zeros, we assume that the respective Y columns 38 // of the lhs differs in average of one non zeros, thus the number of non zeros for 39 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero 40 // per column of the lhs. 41 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) 42 Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros(); 43 44 // mimics a resizeByInnerOuter: 45 if(ResultType::IsRowMajor) 46 res.resize(cols, rows); 47 else 48 res.resize(rows, cols); 49 50 res.reserve(estimated_nnz_prod); 51 double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols()); 52 for (Index j=0; j<cols; ++j) 53 { 54 // FIXME: 55 //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows()); 56 // let's do a more accurate determination of the nnz ratio for the current column j of res 57 tempVector.init(ratioColRes); 58 tempVector.setZero(); 59 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) 60 { 61 // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) 62 tempVector.restart(); 63 Scalar x = rhsIt.value(); 64 for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt) 65 { 66 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; 67 } 68 } 69 res.startVec(j); 70 for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it) 71 res.insertBackByOuterInner(j,it.index()) = it.value(); 72 } 73 res.finalize(); 74 } 75 76 template<typename Lhs, typename Rhs, typename ResultType, 77 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, 78 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, 79 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> 80 struct sparse_sparse_product_with_pruning_selector; 81 82 template<typename Lhs, typename Rhs, typename ResultType> 83 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> 84 { 85 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar; 86 typedef typename ResultType::RealScalar RealScalar; 87 88 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 89 { 90 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 91 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); 92 res.swap(_res); 93 } 94 }; 95 96 template<typename Lhs, typename Rhs, typename ResultType> 97 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> 98 { 99 typedef typename ResultType::RealScalar RealScalar; 100 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 101 { 102 // we need a col-major matrix to hold the result 103 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> SparseTemporaryType; 104 SparseTemporaryType _res(res.rows(), res.cols()); 105 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); 106 res = _res; 107 } 108 }; 109 110 template<typename Lhs, typename Rhs, typename ResultType> 111 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> 112 { 113 typedef typename ResultType::RealScalar RealScalar; 114 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 115 { 116 // let's transpose the product to get a column x column product 117 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 118 internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); 119 res.swap(_res); 120 } 121 }; 122 123 template<typename Lhs, typename Rhs, typename ResultType> 124 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> 125 { 126 typedef typename ResultType::RealScalar RealScalar; 127 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 128 { 129 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs; 130 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs; 131 ColMajorMatrixLhs colLhs(lhs); 132 ColMajorMatrixRhs colRhs(rhs); 133 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance); 134 135 // let's transpose the product to get a column x column product 136 // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; 137 // SparseTemporaryType _res(res.cols(), res.rows()); 138 // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); 139 // res = _res.transpose(); 140 } 141 }; 142 143 // NOTE the 2 others cases (col row *) must never occur since they are caught 144 // by ProductReturnType which transforms it to (col col *) by evaluating rhs. 145 146 } // end namespace internal 147 148 } // end namespace Eigen 149 150 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 151