Home | History | Annotate | Download | only in SparseCore
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
     11 #define EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 template<typename Lhs, typename Rhs, typename ResultType>
     18 static void conservative_sparse_sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
     19 {
     20   typedef typename remove_all<Lhs>::type::Scalar Scalar;
     21   typedef typename remove_all<Lhs>::type::Index Index;
     22 
     23   // make sure to call innerSize/outerSize since we fake the storage order.
     24   Index rows = lhs.innerSize();
     25   Index cols = rhs.outerSize();
     26   eigen_assert(lhs.outerSize() == rhs.innerSize());
     27 
     28   std::vector<bool> mask(rows,false);
     29   Matrix<Scalar,Dynamic,1> values(rows);
     30   Matrix<Index,Dynamic,1>  indices(rows);
     31 
     32   // estimate the number of non zero entries
     33   // given a rhs column containing Y non zeros, we assume that the respective Y columns
     34   // of the lhs differs in average of one non zeros, thus the number of non zeros for
     35   // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
     36   // per column of the lhs.
     37   // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
     38   Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
     39 
     40   res.setZero();
     41   res.reserve(Index(estimated_nnz_prod));
     42   // we compute each column of the result, one after the other
     43   for (Index j=0; j<cols; ++j)
     44   {
     45 
     46     res.startVec(j);
     47     Index nnz = 0;
     48     for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
     49     {
     50       Scalar y = rhsIt.value();
     51       Index k = rhsIt.index();
     52       for (typename Lhs::InnerIterator lhsIt(lhs, k); lhsIt; ++lhsIt)
     53       {
     54         Index i = lhsIt.index();
     55         Scalar x = lhsIt.value();
     56         if(!mask[i])
     57         {
     58           mask[i] = true;
     59           values[i] = x * y;
     60           indices[nnz] = i;
     61           ++nnz;
     62         }
     63         else
     64           values[i] += x * y;
     65       }
     66     }
     67 
     68     // unordered insertion
     69     for(int k=0; k<nnz; ++k)
     70     {
     71       int i = indices[k];
     72       res.insertBackByOuterInnerUnordered(j,i) = values[i];
     73       mask[i] = false;
     74     }
     75 
     76 #if 0
     77     // alternative ordered insertion code:
     78 
     79     int t200 = rows/(log2(200)*1.39);
     80     int t = (rows*100)/139;
     81 
     82     // FIXME reserve nnz non zeros
     83     // FIXME implement fast sort algorithms for very small nnz
     84     // if the result is sparse enough => use a quick sort
     85     // otherwise => loop through the entire vector
     86     // In order to avoid to perform an expensive log2 when the
     87     // result is clearly very sparse we use a linear bound up to 200.
     88     //if((nnz<200 && nnz<t200) || nnz * log2(nnz) < t)
     89     //res.startVec(j);
     90     if(true)
     91     {
     92       if(nnz>1) std::sort(indices.data(),indices.data()+nnz);
     93       for(int k=0; k<nnz; ++k)
     94       {
     95         int i = indices[k];
     96         res.insertBackByOuterInner(j,i) = values[i];
     97         mask[i] = false;
     98       }
     99     }
    100     else
    101     {
    102       // dense path
    103       for(int i=0; i<rows; ++i)
    104       {
    105         if(mask[i])
    106         {
    107           mask[i] = false;
    108           res.insertBackByOuterInner(j,i) = values[i];
    109         }
    110       }
    111     }
    112 #endif
    113 
    114   }
    115   res.finalize();
    116 }
    117 
    118 
    119 } // end namespace internal
    120 
    121 namespace internal {
    122 
    123 template<typename Lhs, typename Rhs, typename ResultType,
    124   int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
    125   int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
    126   int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
    127 struct conservative_sparse_sparse_product_selector;
    128 
    129 template<typename Lhs, typename Rhs, typename ResultType>
    130 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
    131 {
    132   typedef typename remove_all<Lhs>::type LhsCleaned;
    133   typedef typename LhsCleaned::Scalar Scalar;
    134 
    135   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
    136   {
    137     typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
    138     typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
    139     ColMajorMatrix resCol(lhs.rows(),rhs.cols());
    140     internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
    141     // sort the non zeros:
    142     RowMajorMatrix resRow(resCol);
    143     res = resRow;
    144   }
    145 };
    146 
    147 template<typename Lhs, typename Rhs, typename ResultType>
    148 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
    149 {
    150   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
    151   {
    152      typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
    153      RowMajorMatrix rhsRow = rhs;
    154      RowMajorMatrix resRow(lhs.rows(), rhs.cols());
    155      internal::conservative_sparse_sparse_product_impl<RowMajorMatrix,Lhs,RowMajorMatrix>(rhsRow, lhs, resRow);
    156      res = resRow;
    157   }
    158 };
    159 
    160 template<typename Lhs, typename Rhs, typename ResultType>
    161 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
    162 {
    163   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
    164   {
    165     typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
    166     RowMajorMatrix lhsRow = lhs;
    167     RowMajorMatrix resRow(lhs.rows(), rhs.cols());
    168     internal::conservative_sparse_sparse_product_impl<Rhs,RowMajorMatrix,RowMajorMatrix>(rhs, lhsRow, resRow);
    169     res = resRow;
    170   }
    171 };
    172 
    173 template<typename Lhs, typename Rhs, typename ResultType>
    174 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
    175 {
    176   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
    177   {
    178     typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
    179     RowMajorMatrix resRow(lhs.rows(), rhs.cols());
    180     internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
    181     res = resRow;
    182   }
    183 };
    184 
    185 
    186 template<typename Lhs, typename Rhs, typename ResultType>
    187 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
    188 {
    189   typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
    190 
    191   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
    192   {
    193     typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
    194     ColMajorMatrix resCol(lhs.rows(), rhs.cols());
    195     internal::conservative_sparse_sparse_product_impl<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
    196     res = resCol;
    197   }
    198 };
    199 
    200 template<typename Lhs, typename Rhs, typename ResultType>
    201 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
    202 {
    203   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
    204   {
    205     typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
    206     ColMajorMatrix lhsCol = lhs;
    207     ColMajorMatrix resCol(lhs.rows(), rhs.cols());
    208     internal::conservative_sparse_sparse_product_impl<ColMajorMatrix,Rhs,ColMajorMatrix>(lhsCol, rhs, resCol);
    209     res = resCol;
    210   }
    211 };
    212 
    213 template<typename Lhs, typename Rhs, typename ResultType>
    214 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
    215 {
    216   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
    217   {
    218     typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
    219     ColMajorMatrix rhsCol = rhs;
    220     ColMajorMatrix resCol(lhs.rows(), rhs.cols());
    221     internal::conservative_sparse_sparse_product_impl<Lhs,ColMajorMatrix,ColMajorMatrix>(lhs, rhsCol, resCol);
    222     res = resCol;
    223   }
    224 };
    225 
    226 template<typename Lhs, typename Rhs, typename ResultType>
    227 struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
    228 {
    229   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
    230   {
    231     typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
    232     typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
    233     RowMajorMatrix resRow(lhs.rows(),rhs.cols());
    234     internal::conservative_sparse_sparse_product_impl<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
    235     // sort the non zeros:
    236     ColMajorMatrix resCol(resRow);
    237     res = resCol;
    238   }
    239 };
    240 
    241 } // end namespace internal
    242 
    243 } // end namespace Eigen
    244 
    245 #endif // EIGEN_CONSERVATIVESPARSESPARSEPRODUCT_H
    246