Home | History | Annotate | Download | only in SparseCore
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H
     11 #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H
     12 
     13 namespace Eigen {
     14 
     15 // The product of a diagonal matrix with a sparse matrix can be easily
     16 // implemented using expression template.
     17 // We have two consider very different cases:
     18 // 1 - diag * row-major sparse
     19 //     => each inner vector <=> scalar * sparse vector product
     20 //     => so we can reuse CwiseUnaryOp::InnerIterator
     21 // 2 - diag * col-major sparse
     22 //     => each inner vector <=> densevector * sparse vector cwise product
     23 //     => again, we can reuse specialization of CwiseBinaryOp::InnerIterator
     24 //        for that particular case
     25 // The two other cases are symmetric.
     26 
     27 namespace internal {
     28 
     29 template<typename Lhs, typename Rhs>
     30 struct traits<SparseDiagonalProduct<Lhs, Rhs> >
     31 {
     32   typedef typename remove_all<Lhs>::type _Lhs;
     33   typedef typename remove_all<Rhs>::type _Rhs;
     34   typedef typename _Lhs::Scalar Scalar;
     35   typedef typename promote_index_type<typename traits<Lhs>::Index,
     36                                          typename traits<Rhs>::Index>::type Index;
     37   typedef Sparse StorageKind;
     38   typedef MatrixXpr XprKind;
     39   enum {
     40     RowsAtCompileTime = _Lhs::RowsAtCompileTime,
     41     ColsAtCompileTime = _Rhs::ColsAtCompileTime,
     42 
     43     MaxRowsAtCompileTime = _Lhs::MaxRowsAtCompileTime,
     44     MaxColsAtCompileTime = _Rhs::MaxColsAtCompileTime,
     45 
     46     SparseFlags = is_diagonal<_Lhs>::ret ? int(_Rhs::Flags) : int(_Lhs::Flags),
     47     Flags = (SparseFlags&RowMajorBit),
     48     CoeffReadCost = Dynamic
     49   };
     50 };
     51 
     52 enum {SDP_IsDiagonal, SDP_IsSparseRowMajor, SDP_IsSparseColMajor};
     53 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType, int RhsMode, int LhsMode>
     54 class sparse_diagonal_product_inner_iterator_selector;
     55 
     56 } // end namespace internal
     57 
     58 template<typename Lhs, typename Rhs>
     59 class SparseDiagonalProduct
     60   : public SparseMatrixBase<SparseDiagonalProduct<Lhs,Rhs> >,
     61     internal::no_assignment_operator
     62 {
     63     typedef typename Lhs::Nested LhsNested;
     64     typedef typename Rhs::Nested RhsNested;
     65 
     66     typedef typename internal::remove_all<LhsNested>::type _LhsNested;
     67     typedef typename internal::remove_all<RhsNested>::type _RhsNested;
     68 
     69     enum {
     70       LhsMode = internal::is_diagonal<_LhsNested>::ret ? internal::SDP_IsDiagonal
     71               : (_LhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor,
     72       RhsMode = internal::is_diagonal<_RhsNested>::ret ? internal::SDP_IsDiagonal
     73               : (_RhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor
     74     };
     75 
     76   public:
     77 
     78     EIGEN_SPARSE_PUBLIC_INTERFACE(SparseDiagonalProduct)
     79 
     80     typedef internal::sparse_diagonal_product_inner_iterator_selector
     81                 <_LhsNested,_RhsNested,SparseDiagonalProduct,LhsMode,RhsMode> InnerIterator;
     82 
     83     EIGEN_STRONG_INLINE SparseDiagonalProduct(const Lhs& lhs, const Rhs& rhs)
     84       : m_lhs(lhs), m_rhs(rhs)
     85     {
     86       eigen_assert(lhs.cols() == rhs.rows() && "invalid sparse matrix * diagonal matrix product");
     87     }
     88 
     89     EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); }
     90     EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); }
     91 
     92     EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
     93     EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
     94 
     95   protected:
     96     LhsNested m_lhs;
     97     RhsNested m_rhs;
     98 };
     99 
    100 namespace internal {
    101 
    102 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
    103 class sparse_diagonal_product_inner_iterator_selector
    104 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
    105   : public CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator
    106 {
    107     typedef typename CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator Base;
    108     typedef typename Lhs::Index Index;
    109   public:
    110     inline sparse_diagonal_product_inner_iterator_selector(
    111               const SparseDiagonalProductType& expr, Index outer)
    112       : Base(expr.rhs()*(expr.lhs().diagonal().coeff(outer)), outer)
    113     {}
    114 };
    115 
    116 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
    117 class sparse_diagonal_product_inner_iterator_selector
    118 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseColMajor>
    119   : public CwiseBinaryOp<
    120       scalar_product_op<typename Lhs::Scalar>,
    121       SparseInnerVectorSet<Rhs,1>,
    122       typename Lhs::DiagonalVectorType>::InnerIterator
    123 {
    124     typedef typename CwiseBinaryOp<
    125       scalar_product_op<typename Lhs::Scalar>,
    126       SparseInnerVectorSet<Rhs,1>,
    127       typename Lhs::DiagonalVectorType>::InnerIterator Base;
    128     typedef typename Lhs::Index Index;
    129   public:
    130     inline sparse_diagonal_product_inner_iterator_selector(
    131               const SparseDiagonalProductType& expr, Index outer)
    132       : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0)
    133     {}
    134 };
    135 
    136 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
    137 class sparse_diagonal_product_inner_iterator_selector
    138 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseColMajor,SDP_IsDiagonal>
    139   : public CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator
    140 {
    141     typedef typename CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator Base;
    142     typedef typename Lhs::Index Index;
    143   public:
    144     inline sparse_diagonal_product_inner_iterator_selector(
    145               const SparseDiagonalProductType& expr, Index outer)
    146       : Base(expr.lhs()*expr.rhs().diagonal().coeff(outer), outer)
    147     {}
    148 };
    149 
    150 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
    151 class sparse_diagonal_product_inner_iterator_selector
    152 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseRowMajor,SDP_IsDiagonal>
    153   : public CwiseBinaryOp<
    154       scalar_product_op<typename Rhs::Scalar>,
    155       SparseInnerVectorSet<Lhs,1>,
    156       Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator
    157 {
    158     typedef typename CwiseBinaryOp<
    159       scalar_product_op<typename Rhs::Scalar>,
    160       SparseInnerVectorSet<Lhs,1>,
    161       Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator Base;
    162     typedef typename Lhs::Index Index;
    163   public:
    164     inline sparse_diagonal_product_inner_iterator_selector(
    165               const SparseDiagonalProductType& expr, Index outer)
    166       : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0)
    167     {}
    168 };
    169 
    170 } // end namespace internal
    171 
    172 // SparseMatrixBase functions
    173 
    174 template<typename Derived>
    175 template<typename OtherDerived>
    176 const SparseDiagonalProduct<Derived,OtherDerived>
    177 SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) const
    178 {
    179   return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
    180 }
    181 
    182 } // end namespace Eigen
    183 
    184 #endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H
    185