Home | History | Annotate | Download | only in SparseCore
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H
     11 #define EIGEN_SPARSE_CWISE_BINARY_OP_H
     12 
     13 namespace Eigen {
     14 
     15 // Here we have to handle 3 cases:
     16 //  1 - sparse op dense
     17 //  2 - dense op sparse
     18 //  3 - sparse op sparse
     19 // We also need to implement a 4th iterator for:
     20 //  4 - dense op dense
     21 // Finally, we also need to distinguish between the product and other operations :
     22 //                configuration      returned mode
     23 //  1 - sparse op dense    product      sparse
     24 //                         generic      dense
     25 //  2 - dense op sparse    product      sparse
     26 //                         generic      dense
     27 //  3 - sparse op sparse   product      sparse
     28 //                         generic      sparse
     29 //  4 - dense op dense     product      dense
     30 //                         generic      dense
     31 
     32 namespace internal {
     33 
     34 template<> struct promote_storage_type<Dense,Sparse>
     35 { typedef Sparse ret; };
     36 
     37 template<> struct promote_storage_type<Sparse,Dense>
     38 { typedef Sparse ret; };
     39 
     40 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived,
     41   typename _LhsStorageMode = typename traits<Lhs>::StorageKind,
     42   typename _RhsStorageMode = typename traits<Rhs>::StorageKind>
     43 class sparse_cwise_binary_op_inner_iterator_selector;
     44 
     45 } // end namespace internal
     46 
     47 template<typename BinaryOp, typename Lhs, typename Rhs>
     48 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse>
     49   : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
     50 {
     51   public:
     52     class InnerIterator;
     53     class ReverseInnerIterator;
     54     typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived;
     55     EIGEN_SPARSE_PUBLIC_INTERFACE(Derived)
     56     CwiseBinaryOpImpl()
     57     {
     58       typedef typename internal::traits<Lhs>::StorageKind LhsStorageKind;
     59       typedef typename internal::traits<Rhs>::StorageKind RhsStorageKind;
     60       EIGEN_STATIC_ASSERT((
     61                 (!internal::is_same<LhsStorageKind,RhsStorageKind>::value)
     62             ||  ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))),
     63             THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH);
     64     }
     65 };
     66 
     67 template<typename BinaryOp, typename Lhs, typename Rhs>
     68 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator
     69   : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator>
     70 {
     71   public:
     72     typedef typename Lhs::Index Index;
     73     typedef internal::sparse_cwise_binary_op_inner_iterator_selector<
     74       BinaryOp,Lhs,Rhs, InnerIterator> Base;
     75 
     76     // NOTE: we have to prefix Index by "typename Lhs::" to avoid an ICE with VC11
     77     EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename Lhs::Index outer)
     78       : Base(binOp.derived(),outer)
     79     {}
     80 };
     81 
     82 /***************************************************************************
     83 * Implementation of inner-iterators
     84 ***************************************************************************/
     85 
     86 // template<typename T> struct internal::func_is_conjunction { enum { ret = false }; };
     87 // template<typename T> struct internal::func_is_conjunction<internal::scalar_product_op<T> > { enum { ret = true }; };
     88 
     89 // TODO generalize the internal::scalar_product_op specialization to all conjunctions if any !
     90 
     91 namespace internal {
     92 
     93 // sparse - sparse  (generic)
     94 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived>
     95 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse>
     96 {
     97     typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr;
     98     typedef typename traits<CwiseBinaryXpr>::Scalar Scalar;
     99     typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
    100     typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
    101     typedef typename _LhsNested::InnerIterator LhsIterator;
    102     typedef typename _RhsNested::InnerIterator RhsIterator;
    103     typedef typename Lhs::Index Index;
    104 
    105   public:
    106 
    107     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
    108       : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
    109     {
    110       this->operator++();
    111     }
    112 
    113     EIGEN_STRONG_INLINE Derived& operator++()
    114     {
    115       if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index()))
    116       {
    117         m_id = m_lhsIter.index();
    118         m_value = m_functor(m_lhsIter.value(), m_rhsIter.value());
    119         ++m_lhsIter;
    120         ++m_rhsIter;
    121       }
    122       else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index())))
    123       {
    124         m_id = m_lhsIter.index();
    125         m_value = m_functor(m_lhsIter.value(), Scalar(0));
    126         ++m_lhsIter;
    127       }
    128       else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index())))
    129       {
    130         m_id = m_rhsIter.index();
    131         m_value = m_functor(Scalar(0), m_rhsIter.value());
    132         ++m_rhsIter;
    133       }
    134       else
    135       {
    136         m_value = 0; // this is to avoid a compilation warning
    137         m_id = -1;
    138       }
    139       return *static_cast<Derived*>(this);
    140     }
    141 
    142     EIGEN_STRONG_INLINE Scalar value() const { return m_value; }
    143 
    144     EIGEN_STRONG_INLINE Index index() const { return m_id; }
    145     EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); }
    146     EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); }
    147 
    148     EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; }
    149 
    150   protected:
    151     LhsIterator m_lhsIter;
    152     RhsIterator m_rhsIter;
    153     const BinaryOp& m_functor;
    154     Scalar m_value;
    155     Index m_id;
    156 };
    157 
    158 // sparse - sparse  (product)
    159 template<typename T, typename Lhs, typename Rhs, typename Derived>
    160 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse>
    161 {
    162     typedef scalar_product_op<T> BinaryFunc;
    163     typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
    164     typedef typename CwiseBinaryXpr::Scalar Scalar;
    165     typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
    166     typedef typename _LhsNested::InnerIterator LhsIterator;
    167     typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
    168     typedef typename _RhsNested::InnerIterator RhsIterator;
    169     typedef typename Lhs::Index Index;
    170   public:
    171 
    172     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
    173       : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
    174     {
    175       while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
    176       {
    177         if (m_lhsIter.index() < m_rhsIter.index())
    178           ++m_lhsIter;
    179         else
    180           ++m_rhsIter;
    181       }
    182     }
    183 
    184     EIGEN_STRONG_INLINE Derived& operator++()
    185     {
    186       ++m_lhsIter;
    187       ++m_rhsIter;
    188       while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
    189       {
    190         if (m_lhsIter.index() < m_rhsIter.index())
    191           ++m_lhsIter;
    192         else
    193           ++m_rhsIter;
    194       }
    195       return *static_cast<Derived*>(this);
    196     }
    197 
    198     EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); }
    199 
    200     EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
    201     EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
    202     EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
    203 
    204     EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); }
    205 
    206   protected:
    207     LhsIterator m_lhsIter;
    208     RhsIterator m_rhsIter;
    209     const BinaryFunc& m_functor;
    210 };
    211 
    212 // sparse - dense  (product)
    213 template<typename T, typename Lhs, typename Rhs, typename Derived>
    214 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense>
    215 {
    216     typedef scalar_product_op<T> BinaryFunc;
    217     typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
    218     typedef typename CwiseBinaryXpr::Scalar Scalar;
    219     typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
    220     typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested;
    221     typedef typename _LhsNested::InnerIterator LhsIterator;
    222     typedef typename Lhs::Index Index;
    223     enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit };
    224   public:
    225 
    226     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
    227       : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer)
    228     {}
    229 
    230     EIGEN_STRONG_INLINE Derived& operator++()
    231     {
    232       ++m_lhsIter;
    233       return *static_cast<Derived*>(this);
    234     }
    235 
    236     EIGEN_STRONG_INLINE Scalar value() const
    237     { return m_functor(m_lhsIter.value(),
    238                        m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); }
    239 
    240     EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
    241     EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
    242     EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
    243 
    244     EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; }
    245 
    246   protected:
    247     RhsNested m_rhs;
    248     LhsIterator m_lhsIter;
    249     const BinaryFunc m_functor;
    250     const Index m_outer;
    251 };
    252 
    253 // sparse - dense  (product)
    254 template<typename T, typename Lhs, typename Rhs, typename Derived>
    255 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse>
    256 {
    257     typedef scalar_product_op<T> BinaryFunc;
    258     typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
    259     typedef typename CwiseBinaryXpr::Scalar Scalar;
    260     typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
    261     typedef typename _RhsNested::InnerIterator RhsIterator;
    262     typedef typename Lhs::Index Index;
    263 
    264     enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit };
    265   public:
    266 
    267     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
    268       : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer)
    269     {}
    270 
    271     EIGEN_STRONG_INLINE Derived& operator++()
    272     {
    273       ++m_rhsIter;
    274       return *static_cast<Derived*>(this);
    275     }
    276 
    277     EIGEN_STRONG_INLINE Scalar value() const
    278     { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); }
    279 
    280     EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); }
    281     EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); }
    282     EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); }
    283 
    284     EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; }
    285 
    286   protected:
    287     const CwiseBinaryXpr& m_xpr;
    288     RhsIterator m_rhsIter;
    289     const BinaryFunc& m_functor;
    290     const Index m_outer;
    291 };
    292 
    293 } // end namespace internal
    294 
    295 /***************************************************************************
    296 * Implementation of SparseMatrixBase and SparseCwise functions/operators
    297 ***************************************************************************/
    298 
    299 template<typename Derived>
    300 template<typename OtherDerived>
    301 EIGEN_STRONG_INLINE Derived &
    302 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other)
    303 {
    304   return derived() = derived() - other.derived();
    305 }
    306 
    307 template<typename Derived>
    308 template<typename OtherDerived>
    309 EIGEN_STRONG_INLINE Derived &
    310 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other)
    311 {
    312   return derived() = derived() + other.derived();
    313 }
    314 
    315 template<typename Derived>
    316 template<typename OtherDerived>
    317 EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
    318 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const
    319 {
    320   return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived());
    321 }
    322 
    323 } // end namespace Eigen
    324 
    325 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H
    326