1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud (at) inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H 11 #define EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template<typename Lhs, typename Rhs, typename ResultType> 18 static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res) 19 { 20 typedef typename remove_all<Lhs>::type::Scalar Scalar; 21 typedef typename remove_all<Lhs>::type::Index Index; 22 23 // make sure to call innerSize/outerSize since we fake the storage order. 24 Index rows = lhs.innerSize(); 25 Index cols = rhs.outerSize(); 26 eigen_assert(lhs.outerSize() == rhs.innerSize()); 27 28 std::vector<bool> mask(rows,false); 29 Matrix<Scalar,Dynamic,1> values(rows); 30 Matrix<Index,Dynamic,1> indices(rows); 31 32 // estimate the number of non zero entries 33 // given a rhs column containing Y non zeros, we assume that the respective Y columns 34 // of the lhs differs in average of one non zeros, thus the number of non zeros for 35 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero 36 // per column of the lhs. 37 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) 38 Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros(); 39 40 res.setZero(); 41 res.reserve(Index(estimated_nnz_prod)); 42 // we compute each column of the result, one after the other 43 for (Index j=0; j<cols; ++j) 44 { 45 46 res.startVec(j); 47 Index nnz = 0; 48 for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) 49 { 50 Scalar y = rhsIt.value(); 51 Index k = rhsIt.index(); 52 for (typename Lhs::InnerIterator lhsIt(lhs, k); lhsIt; ++lhsIt) 53 { 54 Index i = lhsIt.index(); 55 Scalar x = lhsIt.value(); 56 if(!mask[i]) 57 { 58 mask[i] = true; 59 values[i] = x * y; 60 indices[nnz] = i; 61 ++nnz; 62 } 63 else 64 values[i] += x * y; 65 } 66 } 67 68 // unordered insertion 69 for(int k=0; k<nnz; ++k) 70 { 71 int i = indices[k]; 72 res.insertBackByOuterInnerUnordered(j,i) = values[i]; 73 mask[i] = false; 74 } 75 76 #if 0 77 // alternative ordered insertion code: 78 79 int t200 = rows/(log2(200)*1.39); 80 int t = (rows*100)/139; 81 82 // FIXME reserve nnz non zeros 83 // FIXME implement fast sort algorithms for very small nnz 84 // if the result is sparse enough => use a quick sort 85 // otherwise => loop through the entire vector 86 // In order to avoid to perform an expensive log2 when the 87 // result is clearly very sparse we use a linear bound up to 200. 88 //if((nnz<200 && nnz<t200) || nnz * log2(nnz) < t) 89 //res.startVec(j); 90 if(true) 91 { 92 if(nnz>1) std::sort(indices.data(),indices.data()+nnz); 93 for(int k=0; k<nnz; ++k) 94 { 95 int i = indices[k]; 96 res.insertBackByOuterInner(j,i) = values[i]; 97 mask[i] = false; 98 } 99 } 100 else 101 { 102 // dense path 103 for(int i=0; i<rows; ++i) 104 { 105 if(mask[i]) 106 { 107 mask[i] = false; 108 res.insertBackByOuterInner(j,i) = values[i]; 109 } 110 } 111 } 112 #endif 113 114 } 115 res.finalize(); 116 } 117 118 119 } // end namespace internal 120 121 namespace internal { 122 123 template<typename Lhs, typename Rhs, typename ResultType, 124 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, 125 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, 126 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> 127 struct conservative_sparse_sparse_product_selector; 128 129 template<typename Lhs, typename Rhs, typename ResultType> 130 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> 131 { 132 typedef typename remove_all<Lhs>::type LhsCleaned; 133 typedef typename LhsCleaned::Scalar Scalar; 134 135 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 136 { 137 typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; 138 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; 139 ColMajorMatrix resCol(lhs.rows(),rhs.cols()); 140 internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol); 141 // sort the non zeros: 142 RowMajorMatrix resRow(resCol); 143 res = resRow; 144 } 145 }; 146 147 template<typename Lhs, typename Rhs, typename ResultType> 148 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> 149 { 150 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 151 { 152 typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; 153 RowMajorMatrix rhsRow = rhs; 154 RowMajorMatrix resRow(lhs.rows(), rhs.cols()); 155 internal::conservative_sparse_sparse_product_impl<RowMajorMatrix,Lhs,RowMajorMatrix>(rhsRow, lhs, resRow); 156 res = resRow; 157 } 158 }; 159 160 template<typename Lhs, typename Rhs, typename ResultType> 161 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> 162 { 163 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 164 { 165 typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; 166 RowMajorMatrix lhsRow = lhs; 167 RowMajorMatrix resRow(lhs.rows(), rhs.cols()); 168 internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorMatrix,RowMajorMatrix>(rhs, lhsRow, resRow); 169 res = resRow; 170 } 171 }; 172 173 template<typename Lhs, typename Rhs, typename ResultType> 174 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> 175 { 176 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 177 { 178 typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; 179 RowMajorMatrix resRow(lhs.rows(), rhs.cols()); 180 internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow); 181 res = resRow; 182 } 183 }; 184 185 186 template<typename Lhs, typename Rhs, typename ResultType> 187 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> 188 { 189 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar; 190 191 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 192 { 193 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; 194 ColMajorMatrix resCol(lhs.rows(), rhs.cols()); 195 internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol); 196 res = resCol; 197 } 198 }; 199 200 template<typename Lhs, typename Rhs, typename ResultType> 201 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> 202 { 203 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 204 { 205 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; 206 ColMajorMatrix lhsCol = lhs; 207 ColMajorMatrix resCol(lhs.rows(), rhs.cols()); 208 internal::conservative_sparse_sparse_product_impl<ColMajorMatrix,Rhs,ColMajorMatrix>(lhsCol, rhs, resCol); 209 res = resCol; 210 } 211 }; 212 213 template<typename Lhs, typename Rhs, typename ResultType> 214 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> 215 { 216 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 217 { 218 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; 219 ColMajorMatrix rhsCol = rhs; 220 ColMajorMatrix resCol(lhs.rows(), rhs.cols()); 221 internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorMatrix,ColMajorMatrix>(lhs, rhsCol, resCol); 222 res = resCol; 223 } 224 }; 225 226 template<typename Lhs, typename Rhs, typename ResultType> 227 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> 228 { 229 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res) 230 { 231 typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix; 232 typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; 233 RowMajorMatrix resRow(lhs.rows(),rhs.cols()); 234 internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow); 235 // sort the non zeros: 236 ColMajorMatrix resCol(resRow); 237 res = resCol; 238 } 239 }; 240 241 } // end namespace internal 242 243 } // end namespace Eigen 244 245 #endif // EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H 246