Home | History | Annotate | Download | only in SparseCore
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SPARSEPRODUCT_H
     11 #define EIGEN_SPARSEPRODUCT_H
     12 
     13 namespace Eigen {
     14 
     15 /** \returns an expression of the product of two sparse matrices.
     16   * By default a conservative product preserving the symbolic non zeros is performed.
     17   * The automatic pruning of the small values can be achieved by calling the pruned() function
     18   * in which case a totally different product algorithm is employed:
     19   * \code
     20   * C = (A*B).pruned();             // supress numerical zeros (exact)
     21   * C = (A*B).pruned(ref);
     22   * C = (A*B).pruned(ref,epsilon);
     23   * \endcode
     24   * where \c ref is a meaningful non zero reference value.
     25   * */
     26 template<typename Derived>
     27 template<typename OtherDerived>
     28 inline const Product<Derived,OtherDerived,AliasFreeProduct>
     29 SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
     30 {
     31   return Product<Derived,OtherDerived,AliasFreeProduct>(derived(), other.derived());
     32 }
     33 
     34 namespace internal {
     35 
     36 // sparse * sparse
     37 template<typename Lhs, typename Rhs, int ProductType>
     38 struct generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
     39 {
     40   template<typename Dest>
     41   static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
     42   {
     43     evalTo(dst, lhs, rhs, typename evaluator_traits<Dest>::Shape());
     44   }
     45 
     46   // dense += sparse * sparse
     47   template<typename Dest,typename ActualLhs>
     48   static void addTo(Dest& dst, const ActualLhs& lhs, const Rhs& rhs, typename enable_if<is_same<typename evaluator_traits<Dest>::Shape,DenseShape>::value,int*>::type* = 0)
     49   {
     50     typedef typename nested_eval<ActualLhs,Dynamic>::type LhsNested;
     51     typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
     52     LhsNested lhsNested(lhs);
     53     RhsNested rhsNested(rhs);
     54     internal::sparse_sparse_to_dense_product_selector<typename remove_all<LhsNested>::type,
     55                                                       typename remove_all<RhsNested>::type, Dest>::run(lhsNested,rhsNested,dst);
     56   }
     57 
     58   // dense -= sparse * sparse
     59   template<typename Dest>
     60   static void subTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, typename enable_if<is_same<typename evaluator_traits<Dest>::Shape,DenseShape>::value,int*>::type* = 0)
     61   {
     62     addTo(dst, -lhs, rhs);
     63   }
     64 
     65 protected:
     66 
     67   // sparse = sparse * sparse
     68   template<typename Dest>
     69   static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, SparseShape)
     70   {
     71     typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
     72     typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
     73     LhsNested lhsNested(lhs);
     74     RhsNested rhsNested(rhs);
     75     internal::conservative_sparse_sparse_product_selector<typename remove_all<LhsNested>::type,
     76                                                           typename remove_all<RhsNested>::type, Dest>::run(lhsNested,rhsNested,dst);
     77   }
     78 
     79   // dense = sparse * sparse
     80   template<typename Dest>
     81   static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, DenseShape)
     82   {
     83     dst.setZero();
     84     addTo(dst, lhs, rhs);
     85   }
     86 };
     87 
     88 // sparse * sparse-triangular
     89 template<typename Lhs, typename Rhs, int ProductType>
     90 struct generic_product_impl<Lhs, Rhs, SparseShape, SparseTriangularShape, ProductType>
     91  : public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
     92 {};
     93 
     94 // sparse-triangular * sparse
     95 template<typename Lhs, typename Rhs, int ProductType>
     96 struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, SparseShape, ProductType>
     97  : public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
     98 {};
     99 
    100 // dense = sparse-product (can be sparse*sparse, sparse*perm, etc.)
    101 template< typename DstXprType, typename Lhs, typename Rhs>
    102 struct Assignment<DstXprType, Product<Lhs,Rhs,AliasFreeProduct>, internal::assign_op<typename DstXprType::Scalar,typename Product<Lhs,Rhs,AliasFreeProduct>::Scalar>, Sparse2Dense>
    103 {
    104   typedef Product<Lhs,Rhs,AliasFreeProduct> SrcXprType;
    105   static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar,typename SrcXprType::Scalar> &)
    106   {
    107     Index dstRows = src.rows();
    108     Index dstCols = src.cols();
    109     if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
    110       dst.resize(dstRows, dstCols);
    111 
    112     generic_product_impl<Lhs, Rhs>::evalTo(dst,src.lhs(),src.rhs());
    113   }
    114 };
    115 
    116 // dense += sparse-product (can be sparse*sparse, sparse*perm, etc.)
    117 template< typename DstXprType, typename Lhs, typename Rhs>
    118 struct Assignment<DstXprType, Product<Lhs,Rhs,AliasFreeProduct>, internal::add_assign_op<typename DstXprType::Scalar,typename Product<Lhs,Rhs,AliasFreeProduct>::Scalar>, Sparse2Dense>
    119 {
    120   typedef Product<Lhs,Rhs,AliasFreeProduct> SrcXprType;
    121   static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<typename DstXprType::Scalar,typename SrcXprType::Scalar> &)
    122   {
    123     generic_product_impl<Lhs, Rhs>::addTo(dst,src.lhs(),src.rhs());
    124   }
    125 };
    126 
    127 // dense -= sparse-product (can be sparse*sparse, sparse*perm, etc.)
    128 template< typename DstXprType, typename Lhs, typename Rhs>
    129 struct Assignment<DstXprType, Product<Lhs,Rhs,AliasFreeProduct>, internal::sub_assign_op<typename DstXprType::Scalar,typename Product<Lhs,Rhs,AliasFreeProduct>::Scalar>, Sparse2Dense>
    130 {
    131   typedef Product<Lhs,Rhs,AliasFreeProduct> SrcXprType;
    132   static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<typename DstXprType::Scalar,typename SrcXprType::Scalar> &)
    133   {
    134     generic_product_impl<Lhs, Rhs>::subTo(dst,src.lhs(),src.rhs());
    135   }
    136 };
    137 
    138 template<typename Lhs, typename Rhs, int Options>
    139 struct unary_evaluator<SparseView<Product<Lhs, Rhs, Options> >, IteratorBased>
    140  : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>
    141 {
    142   typedef SparseView<Product<Lhs, Rhs, Options> > XprType;
    143   typedef typename XprType::PlainObject PlainObject;
    144   typedef evaluator<PlainObject> Base;
    145 
    146   explicit unary_evaluator(const XprType& xpr)
    147     : m_result(xpr.rows(), xpr.cols())
    148   {
    149     using std::abs;
    150     ::new (static_cast<Base*>(this)) Base(m_result);
    151     typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
    152     typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
    153     LhsNested lhsNested(xpr.nestedExpression().lhs());
    154     RhsNested rhsNested(xpr.nestedExpression().rhs());
    155 
    156     internal::sparse_sparse_product_with_pruning_selector<typename remove_all<LhsNested>::type,
    157                                                           typename remove_all<RhsNested>::type, PlainObject>::run(lhsNested,rhsNested,m_result,
    158                                                                                                                   abs(xpr.reference())*xpr.epsilon());
    159   }
    160 
    161 protected:
    162   PlainObject m_result;
    163 };
    164 
    165 } // end namespace internal
    166 
    167 } // end namespace Eigen
    168 
    169 #endif // EIGEN_SPARSEPRODUCT_H
    170