1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud (at) inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H 11 #define EIGEN_SPARSE_CWISE_BINARY_OP_H 12 13 namespace Eigen { 14 15 // Here we have to handle 3 cases: 16 // 1 - sparse op dense 17 // 2 - dense op sparse 18 // 3 - sparse op sparse 19 // We also need to implement a 4th iterator for: 20 // 4 - dense op dense 21 // Finally, we also need to distinguish between the product and other operations : 22 // configuration returned mode 23 // 1 - sparse op dense product sparse 24 // generic dense 25 // 2 - dense op sparse product sparse 26 // generic dense 27 // 3 - sparse op sparse product sparse 28 // generic sparse 29 // 4 - dense op dense product dense 30 // generic dense 31 32 namespace internal { 33 34 template<> struct promote_storage_type<Dense,Sparse> 35 { typedef Sparse ret; }; 36 37 template<> struct promote_storage_type<Sparse,Dense> 38 { typedef Sparse ret; }; 39 40 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived, 41 typename _LhsStorageMode = typename traits<Lhs>::StorageKind, 42 typename _RhsStorageMode = typename traits<Rhs>::StorageKind> 43 class sparse_cwise_binary_op_inner_iterator_selector; 44 45 } // end namespace internal 46 47 template<typename BinaryOp, typename Lhs, typename Rhs> 48 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse> 49 : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > 50 { 51 public: 52 class InnerIterator; 53 class ReverseInnerIterator; 54 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived; 55 EIGEN_SPARSE_PUBLIC_INTERFACE(Derived) 56 CwiseBinaryOpImpl() 57 { 58 typedef typename internal::traits<Lhs>::StorageKind LhsStorageKind; 59 typedef typename internal::traits<Rhs>::StorageKind RhsStorageKind; 60 EIGEN_STATIC_ASSERT(( 61 (!internal::is_same<LhsStorageKind,RhsStorageKind>::value) 62 || ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))), 63 THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH); 64 } 65 }; 66 67 template<typename BinaryOp, typename Lhs, typename Rhs> 68 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator 69 : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator> 70 { 71 public: 72 typedef typename Lhs::Index Index; 73 typedef internal::sparse_cwise_binary_op_inner_iterator_selector< 74 BinaryOp,Lhs,Rhs, InnerIterator> Base; 75 76 EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename CwiseBinaryOpImpl::Index outer) 77 : Base(binOp.derived(),outer) 78 {} 79 }; 80 81 /*************************************************************************** 82 * Implementation of inner-iterators 83 ***************************************************************************/ 84 85 // template<typename T> struct internal::func_is_conjunction { enum { ret = false }; }; 86 // template<typename T> struct internal::func_is_conjunction<internal::scalar_product_op<T> > { enum { ret = true }; }; 87 88 // TODO generalize the internal::scalar_product_op specialization to all conjunctions if any ! 89 90 namespace internal { 91 92 // sparse - sparse (generic) 93 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived> 94 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse> 95 { 96 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr; 97 typedef typename traits<CwiseBinaryXpr>::Scalar Scalar; 98 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 99 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 100 typedef typename _LhsNested::InnerIterator LhsIterator; 101 typedef typename _RhsNested::InnerIterator RhsIterator; 102 typedef typename Lhs::Index Index; 103 104 public: 105 106 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 107 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()) 108 { 109 this->operator++(); 110 } 111 112 EIGEN_STRONG_INLINE Derived& operator++() 113 { 114 if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index())) 115 { 116 m_id = m_lhsIter.index(); 117 m_value = m_functor(m_lhsIter.value(), m_rhsIter.value()); 118 ++m_lhsIter; 119 ++m_rhsIter; 120 } 121 else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index()))) 122 { 123 m_id = m_lhsIter.index(); 124 m_value = m_functor(m_lhsIter.value(), Scalar(0)); 125 ++m_lhsIter; 126 } 127 else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index()))) 128 { 129 m_id = m_rhsIter.index(); 130 m_value = m_functor(Scalar(0), m_rhsIter.value()); 131 ++m_rhsIter; 132 } 133 else 134 { 135 m_value = 0; // this is to avoid a compilation warning 136 m_id = -1; 137 } 138 return *static_cast<Derived*>(this); 139 } 140 141 EIGEN_STRONG_INLINE Scalar value() const { return m_value; } 142 143 EIGEN_STRONG_INLINE Index index() const { return m_id; } 144 EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); } 145 EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); } 146 147 EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; } 148 149 protected: 150 LhsIterator m_lhsIter; 151 RhsIterator m_rhsIter; 152 const BinaryOp& m_functor; 153 Scalar m_value; 154 Index m_id; 155 }; 156 157 // sparse - sparse (product) 158 template<typename T, typename Lhs, typename Rhs, typename Derived> 159 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse> 160 { 161 typedef scalar_product_op<T> BinaryFunc; 162 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 163 typedef typename CwiseBinaryXpr::Scalar Scalar; 164 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 165 typedef typename _LhsNested::InnerIterator LhsIterator; 166 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 167 typedef typename _RhsNested::InnerIterator RhsIterator; 168 typedef typename Lhs::Index Index; 169 public: 170 171 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 172 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()) 173 { 174 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 175 { 176 if (m_lhsIter.index() < m_rhsIter.index()) 177 ++m_lhsIter; 178 else 179 ++m_rhsIter; 180 } 181 } 182 183 EIGEN_STRONG_INLINE Derived& operator++() 184 { 185 ++m_lhsIter; 186 ++m_rhsIter; 187 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 188 { 189 if (m_lhsIter.index() < m_rhsIter.index()) 190 ++m_lhsIter; 191 else 192 ++m_rhsIter; 193 } 194 return *static_cast<Derived*>(this); 195 } 196 197 EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); } 198 199 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); } 200 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 201 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 202 203 EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); } 204 205 protected: 206 LhsIterator m_lhsIter; 207 RhsIterator m_rhsIter; 208 const BinaryFunc& m_functor; 209 }; 210 211 // sparse - dense (product) 212 template<typename T, typename Lhs, typename Rhs, typename Derived> 213 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense> 214 { 215 typedef scalar_product_op<T> BinaryFunc; 216 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 217 typedef typename CwiseBinaryXpr::Scalar Scalar; 218 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 219 typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested; 220 typedef typename _LhsNested::InnerIterator LhsIterator; 221 typedef typename Lhs::Index Index; 222 enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit }; 223 public: 224 225 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 226 : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer) 227 {} 228 229 EIGEN_STRONG_INLINE Derived& operator++() 230 { 231 ++m_lhsIter; 232 return *static_cast<Derived*>(this); 233 } 234 235 EIGEN_STRONG_INLINE Scalar value() const 236 { return m_functor(m_lhsIter.value(), 237 m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); } 238 239 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); } 240 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 241 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 242 243 EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; } 244 245 protected: 246 RhsNested m_rhs; 247 LhsIterator m_lhsIter; 248 const BinaryFunc m_functor; 249 const Index m_outer; 250 }; 251 252 // sparse - dense (product) 253 template<typename T, typename Lhs, typename Rhs, typename Derived> 254 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse> 255 { 256 typedef scalar_product_op<T> BinaryFunc; 257 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 258 typedef typename CwiseBinaryXpr::Scalar Scalar; 259 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 260 typedef typename _RhsNested::InnerIterator RhsIterator; 261 typedef typename Lhs::Index Index; 262 263 enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit }; 264 public: 265 266 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 267 : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer) 268 {} 269 270 EIGEN_STRONG_INLINE Derived& operator++() 271 { 272 ++m_rhsIter; 273 return *static_cast<Derived*>(this); 274 } 275 276 EIGEN_STRONG_INLINE Scalar value() const 277 { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); } 278 279 EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); } 280 EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); } 281 EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); } 282 283 EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; } 284 285 protected: 286 const CwiseBinaryXpr& m_xpr; 287 RhsIterator m_rhsIter; 288 const BinaryFunc& m_functor; 289 const Index m_outer; 290 }; 291 292 } // end namespace internal 293 294 /*************************************************************************** 295 * Implementation of SparseMatrixBase and SparseCwise functions/operators 296 ***************************************************************************/ 297 298 template<typename Derived> 299 template<typename OtherDerived> 300 EIGEN_STRONG_INLINE Derived & 301 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other) 302 { 303 return *this = derived() - other.derived(); 304 } 305 306 template<typename Derived> 307 template<typename OtherDerived> 308 EIGEN_STRONG_INLINE Derived & 309 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other) 310 { 311 return *this = derived() + other.derived(); 312 } 313 314 template<typename Derived> 315 template<typename OtherDerived> 316 EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE 317 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const 318 { 319 return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived()); 320 } 321 322 } // end namespace Eigen 323 324 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H 325