1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud (at) inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 11 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 18 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major 19 template<typename Lhs, typename Rhs, typename ResultType> 20 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance) 21 { 22 // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res); 23 24 typedef typename remove_all<Lhs>::type::Scalar Scalar; 25 typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex; 26 27 // make sure to call innerSize/outerSize since we fake the storage order. 28 Index rows = lhs.innerSize(); 29 Index cols = rhs.outerSize(); 30 //Index size = lhs.outerSize(); 31 eigen_assert(lhs.outerSize() == rhs.innerSize()); 32 33 // allocate a temporary buffer 34 AmbiVector<Scalar,StorageIndex> tempVector(rows); 35 36 // mimics a resizeByInnerOuter: 37 if(ResultType::IsRowMajor) 38 res.resize(cols, rows); 39 else 40 res.resize(rows, cols); 41 42 evaluator<Lhs> lhsEval(lhs); 43 evaluator<Rhs> rhsEval(rhs); 44 45 // estimate the number of non zero entries 46 // given a rhs column containing Y non zeros, we assume that the respective Y columns 47 // of the lhs differs in average of one non zeros, thus the number of non zeros for 48 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero 49 // per column of the lhs. 50 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) 51 Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate(); 52 53 res.reserve(estimated_nnz_prod); 54 double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols())); 55 for (Index j=0; j<cols; ++j) 56 { 57 // FIXME: 58 //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows()); 59 // let's do a more accurate determination of the nnz ratio for the current column j of res 60 tempVector.init(ratioColRes); 61 tempVector.setZero(); 62 for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) 63 { 64 // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) 65 tempVector.restart(); 66 Scalar x = rhsIt.value(); 67 for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) 68 { 69 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; 70 } 71 } 72 res.startVec(j); 73 for (typename AmbiVector<Scalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it) 74 res.insertBackByOuterInner(j,it.index()) = it.value(); 75 } 76 res.finalize(); 77 } 78 79 template<typename Lhs, typename Rhs, typename ResultType, 80 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, 81 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, 82 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> 83 struct sparse_sparse_product_with_pruning_selector; 84 85 template<typename Lhs, typename Rhs, typename ResultType> 86 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> 87 { 88 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar; 89 typedef typename ResultType::RealScalar RealScalar; 90 91 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 92 { 93 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 94 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); 95 res.swap(_res); 96 } 97 }; 98 99 template<typename Lhs, typename Rhs, typename ResultType> 100 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> 101 { 102 typedef typename ResultType::RealScalar RealScalar; 103 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 104 { 105 // we need a col-major matrix to hold the result 106 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType; 107 SparseTemporaryType _res(res.rows(), res.cols()); 108 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); 109 res = _res; 110 } 111 }; 112 113 template<typename Lhs, typename Rhs, typename ResultType> 114 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> 115 { 116 typedef typename ResultType::RealScalar RealScalar; 117 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 118 { 119 // let's transpose the product to get a column x column product 120 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 121 internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); 122 res.swap(_res); 123 } 124 }; 125 126 template<typename Lhs, typename Rhs, typename ResultType> 127 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> 128 { 129 typedef typename ResultType::RealScalar RealScalar; 130 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 131 { 132 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; 133 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; 134 ColMajorMatrixLhs colLhs(lhs); 135 ColMajorMatrixRhs colRhs(rhs); 136 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance); 137 138 // let's transpose the product to get a column x column product 139 // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; 140 // SparseTemporaryType _res(res.cols(), res.rows()); 141 // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); 142 // res = _res.transpose(); 143 } 144 }; 145 146 template<typename Lhs, typename Rhs, typename ResultType> 147 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> 148 { 149 typedef typename ResultType::RealScalar RealScalar; 150 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 151 { 152 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs; 153 RowMajorMatrixLhs rowLhs(lhs); 154 sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance); 155 } 156 }; 157 158 template<typename Lhs, typename Rhs, typename ResultType> 159 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> 160 { 161 typedef typename ResultType::RealScalar RealScalar; 162 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 163 { 164 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs; 165 RowMajorMatrixRhs rowRhs(rhs); 166 sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance); 167 } 168 }; 169 170 template<typename Lhs, typename Rhs, typename ResultType> 171 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> 172 { 173 typedef typename ResultType::RealScalar RealScalar; 174 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 175 { 176 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; 177 ColMajorMatrixRhs colRhs(rhs); 178 internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance); 179 } 180 }; 181 182 template<typename Lhs, typename Rhs, typename ResultType> 183 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> 184 { 185 typedef typename ResultType::RealScalar RealScalar; 186 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 187 { 188 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; 189 ColMajorMatrixLhs colLhs(lhs); 190 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance); 191 } 192 }; 193 194 } // end namespace internal 195 196 } // end namespace Eigen 197 198 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 199