Home | History | Annotate | Download | only in SparseCore
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H
     11 #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H
     12 
     13 namespace Eigen {
     14 
     15 // The product of a diagonal matrix with a sparse matrix can be easily
     16 // implemented using expression template.
     17 // We have two consider very different cases:
     18 // 1 - diag * row-major sparse
     19 //     => each inner vector <=> scalar * sparse vector product
     20 //     => so we can reuse CwiseUnaryOp::InnerIterator
     21 // 2 - diag * col-major sparse
     22 //     => each inner vector <=> densevector * sparse vector cwise product
     23 //     => again, we can reuse specialization of CwiseBinaryOp::InnerIterator
     24 //        for that particular case
     25 // The two other cases are symmetric.
     26 
     27 namespace internal {
     28 
     29 template<typename Lhs, typename Rhs>
     30 struct traits<SparseDiagonalProduct<Lhs, Rhs> >
     31 {
     32   typedef typename remove_all<Lhs>::type _Lhs;
     33   typedef typename remove_all<Rhs>::type _Rhs;
     34   typedef typename _Lhs::Scalar Scalar;
     35   typedef typename promote_index_type<typename traits<Lhs>::Index,
     36                                          typename traits<Rhs>::Index>::type Index;
     37   typedef Sparse StorageKind;
     38   typedef MatrixXpr XprKind;
     39   enum {
     40     RowsAtCompileTime = _Lhs::RowsAtCompileTime,
     41     ColsAtCompileTime = _Rhs::ColsAtCompileTime,
     42 
     43     MaxRowsAtCompileTime = _Lhs::MaxRowsAtCompileTime,
     44     MaxColsAtCompileTime = _Rhs::MaxColsAtCompileTime,
     45 
     46     SparseFlags = is_diagonal<_Lhs>::ret ? int(_Rhs::Flags) : int(_Lhs::Flags),
     47     Flags = (SparseFlags&RowMajorBit),
     48     CoeffReadCost = Dynamic
     49   };
     50 };
     51 
     52 enum {SDP_IsDiagonal, SDP_IsSparseRowMajor, SDP_IsSparseColMajor};
     53 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType, int RhsMode, int LhsMode>
     54 class sparse_diagonal_product_inner_iterator_selector;
     55 
     56 } // end namespace internal
     57 
     58 template<typename Lhs, typename Rhs>
     59 class SparseDiagonalProduct
     60   : public SparseMatrixBase<SparseDiagonalProduct<Lhs,Rhs> >,
     61     internal::no_assignment_operator
     62 {
     63     typedef typename Lhs::Nested LhsNested;
     64     typedef typename Rhs::Nested RhsNested;
     65 
     66     typedef typename internal::remove_all<LhsNested>::type _LhsNested;
     67     typedef typename internal::remove_all<RhsNested>::type _RhsNested;
     68 
     69     enum {
     70       LhsMode = internal::is_diagonal<_LhsNested>::ret ? internal::SDP_IsDiagonal
     71               : (_LhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor,
     72       RhsMode = internal::is_diagonal<_RhsNested>::ret ? internal::SDP_IsDiagonal
     73               : (_RhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor
     74     };
     75 
     76   public:
     77 
     78     EIGEN_SPARSE_PUBLIC_INTERFACE(SparseDiagonalProduct)
     79 
     80     typedef internal::sparse_diagonal_product_inner_iterator_selector
     81                       <_LhsNested,_RhsNested,SparseDiagonalProduct,LhsMode,RhsMode> InnerIterator;
     82 
     83     // We do not want ReverseInnerIterator for diagonal-sparse products,
     84     // but this dummy declaration is needed to make diag * sparse * diag compile.
     85     class ReverseInnerIterator;
     86 
     87     EIGEN_STRONG_INLINE SparseDiagonalProduct(const Lhs& lhs, const Rhs& rhs)
     88       : m_lhs(lhs), m_rhs(rhs)
     89     {
     90       eigen_assert(lhs.cols() == rhs.rows() && "invalid sparse matrix * diagonal matrix product");
     91     }
     92 
     93     EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); }
     94     EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); }
     95 
     96     EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
     97     EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
     98 
     99   protected:
    100     LhsNested m_lhs;
    101     RhsNested m_rhs;
    102 };
    103 
    104 namespace internal {
    105 
    106 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
    107 class sparse_diagonal_product_inner_iterator_selector
    108 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
    109   : public CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator
    110 {
    111     typedef typename CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator Base;
    112     typedef typename Lhs::Index Index;
    113   public:
    114     inline sparse_diagonal_product_inner_iterator_selector(
    115               const SparseDiagonalProductType& expr, Index outer)
    116       : Base(expr.rhs()*(expr.lhs().diagonal().coeff(outer)), outer)
    117     {}
    118 };
    119 
    120 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
    121 class sparse_diagonal_product_inner_iterator_selector
    122 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseColMajor>
    123   : public CwiseBinaryOp<
    124       scalar_product_op<typename Lhs::Scalar>,
    125       const typename Rhs::ConstInnerVectorReturnType,
    126       const typename Lhs::DiagonalVectorType>::InnerIterator
    127 {
    128     typedef typename CwiseBinaryOp<
    129       scalar_product_op<typename Lhs::Scalar>,
    130       const typename Rhs::ConstInnerVectorReturnType,
    131       const typename Lhs::DiagonalVectorType>::InnerIterator Base;
    132     typedef typename Lhs::Index Index;
    133     Index m_outer;
    134   public:
    135     inline sparse_diagonal_product_inner_iterator_selector(
    136               const SparseDiagonalProductType& expr, Index outer)
    137       : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0), m_outer(outer)
    138     {}
    139 
    140     inline Index outer() const { return m_outer; }
    141     inline Index col() const { return m_outer; }
    142 };
    143 
    144 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
    145 class sparse_diagonal_product_inner_iterator_selector
    146 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseColMajor,SDP_IsDiagonal>
    147   : public CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator
    148 {
    149     typedef typename CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator Base;
    150     typedef typename Lhs::Index Index;
    151   public:
    152     inline sparse_diagonal_product_inner_iterator_selector(
    153               const SparseDiagonalProductType& expr, Index outer)
    154       : Base(expr.lhs()*expr.rhs().diagonal().coeff(outer), outer)
    155     {}
    156 };
    157 
    158 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
    159 class sparse_diagonal_product_inner_iterator_selector
    160 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseRowMajor,SDP_IsDiagonal>
    161   : public CwiseBinaryOp<
    162       scalar_product_op<typename Rhs::Scalar>,
    163       const typename Lhs::ConstInnerVectorReturnType,
    164       const Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator
    165 {
    166     typedef typename CwiseBinaryOp<
    167       scalar_product_op<typename Rhs::Scalar>,
    168       const typename Lhs::ConstInnerVectorReturnType,
    169       const Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator Base;
    170     typedef typename Lhs::Index Index;
    171     Index m_outer;
    172   public:
    173     inline sparse_diagonal_product_inner_iterator_selector(
    174               const SparseDiagonalProductType& expr, Index outer)
    175       : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0), m_outer(outer)
    176     {}
    177 
    178     inline Index outer() const { return m_outer; }
    179     inline Index row() const { return m_outer; }
    180 };
    181 
    182 } // end namespace internal
    183 
    184 // SparseMatrixBase functions
    185 
    186 template<typename Derived>
    187 template<typename OtherDerived>
    188 const SparseDiagonalProduct<Derived,OtherDerived>
    189 SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) const
    190 {
    191   return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
    192 }
    193 
    194 } // end namespace Eigen
    195 
    196 #endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H
    197