1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H 11 #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H 12 13 namespace Eigen { 14 15 // The product of a diagonal matrix with a sparse matrix can be easily 16 // implemented using expression template. 17 // We have two consider very different cases: 18 // 1 - diag * row-major sparse 19 // => each inner vector <=> scalar * sparse vector product 20 // => so we can reuse CwiseUnaryOp::InnerIterator 21 // 2 - diag * col-major sparse 22 // => each inner vector <=> densevector * sparse vector cwise product 23 // => again, we can reuse specialization of CwiseBinaryOp::InnerIterator 24 // for that particular case 25 // The two other cases are symmetric. 26 27 namespace internal { 28 29 template<typename Lhs, typename Rhs> 30 struct traits<SparseDiagonalProduct<Lhs, Rhs> > 31 { 32 typedef typename remove_all<Lhs>::type _Lhs; 33 typedef typename remove_all<Rhs>::type _Rhs; 34 typedef typename _Lhs::Scalar Scalar; 35 typedef typename promote_index_type<typename traits<Lhs>::Index, 36 typename traits<Rhs>::Index>::type Index; 37 typedef Sparse StorageKind; 38 typedef MatrixXpr XprKind; 39 enum { 40 RowsAtCompileTime = _Lhs::RowsAtCompileTime, 41 ColsAtCompileTime = _Rhs::ColsAtCompileTime, 42 43 MaxRowsAtCompileTime = _Lhs::MaxRowsAtCompileTime, 44 MaxColsAtCompileTime = _Rhs::MaxColsAtCompileTime, 45 46 SparseFlags = is_diagonal<_Lhs>::ret ? int(_Rhs::Flags) : int(_Lhs::Flags), 47 Flags = (SparseFlags&RowMajorBit), 48 CoeffReadCost = Dynamic 49 }; 50 }; 51 52 enum {SDP_IsDiagonal, SDP_IsSparseRowMajor, SDP_IsSparseColMajor}; 53 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType, int RhsMode, int LhsMode> 54 class sparse_diagonal_product_inner_iterator_selector; 55 56 } // end namespace internal 57 58 template<typename Lhs, typename Rhs> 59 class SparseDiagonalProduct 60 : public SparseMatrixBase<SparseDiagonalProduct<Lhs,Rhs> >, 61 internal::no_assignment_operator 62 { 63 typedef typename Lhs::Nested LhsNested; 64 typedef typename Rhs::Nested RhsNested; 65 66 typedef typename internal::remove_all<LhsNested>::type _LhsNested; 67 typedef typename internal::remove_all<RhsNested>::type _RhsNested; 68 69 enum { 70 LhsMode = internal::is_diagonal<_LhsNested>::ret ? internal::SDP_IsDiagonal 71 : (_LhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor, 72 RhsMode = internal::is_diagonal<_RhsNested>::ret ? internal::SDP_IsDiagonal 73 : (_RhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor 74 }; 75 76 public: 77 78 EIGEN_SPARSE_PUBLIC_INTERFACE(SparseDiagonalProduct) 79 80 typedef internal::sparse_diagonal_product_inner_iterator_selector 81 <_LhsNested,_RhsNested,SparseDiagonalProduct,LhsMode,RhsMode> InnerIterator; 82 83 EIGEN_STRONG_INLINE SparseDiagonalProduct(const Lhs& lhs, const Rhs& rhs) 84 : m_lhs(lhs), m_rhs(rhs) 85 { 86 eigen_assert(lhs.cols() == rhs.rows() && "invalid sparse matrix * diagonal matrix product"); 87 } 88 89 EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); } 90 EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); } 91 92 EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } 93 EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } 94 95 protected: 96 LhsNested m_lhs; 97 RhsNested m_rhs; 98 }; 99 100 namespace internal { 101 102 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> 103 class sparse_diagonal_product_inner_iterator_selector 104 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor> 105 : public CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator 106 { 107 typedef typename CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator Base; 108 typedef typename Lhs::Index Index; 109 public: 110 inline sparse_diagonal_product_inner_iterator_selector( 111 const SparseDiagonalProductType& expr, Index outer) 112 : Base(expr.rhs()*(expr.lhs().diagonal().coeff(outer)), outer) 113 {} 114 }; 115 116 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> 117 class sparse_diagonal_product_inner_iterator_selector 118 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseColMajor> 119 : public CwiseBinaryOp< 120 scalar_product_op<typename Lhs::Scalar>, 121 SparseInnerVectorSet<Rhs,1>, 122 typename Lhs::DiagonalVectorType>::InnerIterator 123 { 124 typedef typename CwiseBinaryOp< 125 scalar_product_op<typename Lhs::Scalar>, 126 SparseInnerVectorSet<Rhs,1>, 127 typename Lhs::DiagonalVectorType>::InnerIterator Base; 128 typedef typename Lhs::Index Index; 129 public: 130 inline sparse_diagonal_product_inner_iterator_selector( 131 const SparseDiagonalProductType& expr, Index outer) 132 : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0) 133 {} 134 }; 135 136 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> 137 class sparse_diagonal_product_inner_iterator_selector 138 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseColMajor,SDP_IsDiagonal> 139 : public CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator 140 { 141 typedef typename CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator Base; 142 typedef typename Lhs::Index Index; 143 public: 144 inline sparse_diagonal_product_inner_iterator_selector( 145 const SparseDiagonalProductType& expr, Index outer) 146 : Base(expr.lhs()*expr.rhs().diagonal().coeff(outer), outer) 147 {} 148 }; 149 150 template<typename Lhs, typename Rhs, typename SparseDiagonalProductType> 151 class sparse_diagonal_product_inner_iterator_selector 152 <Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseRowMajor,SDP_IsDiagonal> 153 : public CwiseBinaryOp< 154 scalar_product_op<typename Rhs::Scalar>, 155 SparseInnerVectorSet<Lhs,1>, 156 Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator 157 { 158 typedef typename CwiseBinaryOp< 159 scalar_product_op<typename Rhs::Scalar>, 160 SparseInnerVectorSet<Lhs,1>, 161 Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator Base; 162 typedef typename Lhs::Index Index; 163 public: 164 inline sparse_diagonal_product_inner_iterator_selector( 165 const SparseDiagonalProductType& expr, Index outer) 166 : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0) 167 {} 168 }; 169 170 } // end namespace internal 171 172 // SparseMatrixBase functions 173 174 template<typename Derived> 175 template<typename OtherDerived> 176 const SparseDiagonalProduct<Derived,OtherDerived> 177 SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) const 178 { 179 return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived()); 180 } 181 182 } // end namespace Eigen 183 184 #endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H 185