Home | History | Annotate | Download | only in SparseCore
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H
     11 #define EIGEN_SPARSE_CWISE_BINARY_OP_H
     12 
     13 namespace Eigen {
     14 
     15 // Here we have to handle 3 cases:
     16 //  1 - sparse op dense
     17 //  2 - dense op sparse
     18 //  3 - sparse op sparse
     19 // We also need to implement a 4th iterator for:
     20 //  4 - dense op dense
     21 // Finally, we also need to distinguish between the product and other operations :
     22 //                configuration      returned mode
     23 //  1 - sparse op dense    product      sparse
     24 //                         generic      dense
     25 //  2 - dense op sparse    product      sparse
     26 //                         generic      dense
     27 //  3 - sparse op sparse   product      sparse
     28 //                         generic      sparse
     29 //  4 - dense op dense     product      dense
     30 //                         generic      dense
     31 
     32 namespace internal {
     33 
     34 template<> struct promote_storage_type<Dense,Sparse>
     35 { typedef Sparse ret; };
     36 
     37 template<> struct promote_storage_type<Sparse,Dense>
     38 { typedef Sparse ret; };
     39 
     40 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived,
     41   typename _LhsStorageMode = typename traits<Lhs>::StorageKind,
     42   typename _RhsStorageMode = typename traits<Rhs>::StorageKind>
     43 class sparse_cwise_binary_op_inner_iterator_selector;
     44 
     45 } // end namespace internal
     46 
     47 template<typename BinaryOp, typename Lhs, typename Rhs>
     48 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse>
     49   : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
     50 {
     51   public:
     52     class InnerIterator;
     53     class ReverseInnerIterator;
     54     typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived;
     55     EIGEN_SPARSE_PUBLIC_INTERFACE(Derived)
     56     CwiseBinaryOpImpl()
     57     {
     58       typedef typename internal::traits<Lhs>::StorageKind LhsStorageKind;
     59       typedef typename internal::traits<Rhs>::StorageKind RhsStorageKind;
     60       EIGEN_STATIC_ASSERT((
     61                 (!internal::is_same<LhsStorageKind,RhsStorageKind>::value)
     62             ||  ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))),
     63             THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH);
     64     }
     65 };
     66 
     67 template<typename BinaryOp, typename Lhs, typename Rhs>
     68 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator
     69   : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator>
     70 {
     71   public:
     72     typedef typename Lhs::Index Index;
     73     typedef internal::sparse_cwise_binary_op_inner_iterator_selector<
     74       BinaryOp,Lhs,Rhs, InnerIterator> Base;
     75 
     76     EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename CwiseBinaryOpImpl::Index outer)
     77       : Base(binOp.derived(),outer)
     78     {}
     79 };
     80 
     81 /***************************************************************************
     82 * Implementation of inner-iterators
     83 ***************************************************************************/
     84 
     85 // template<typename T> struct internal::func_is_conjunction { enum { ret = false }; };
     86 // template<typename T> struct internal::func_is_conjunction<internal::scalar_product_op<T> > { enum { ret = true }; };
     87 
     88 // TODO generalize the internal::scalar_product_op specialization to all conjunctions if any !
     89 
     90 namespace internal {
     91 
     92 // sparse - sparse  (generic)
     93 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived>
     94 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse>
     95 {
     96     typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr;
     97     typedef typename traits<CwiseBinaryXpr>::Scalar Scalar;
     98     typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
     99     typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
    100     typedef typename _LhsNested::InnerIterator LhsIterator;
    101     typedef typename _RhsNested::InnerIterator RhsIterator;
    102     typedef typename Lhs::Index Index;
    103 
    104   public:
    105 
    106     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
    107       : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
    108     {
    109       this->operator++();
    110     }
    111 
    112     EIGEN_STRONG_INLINE Derived& operator++()
    113     {
    114       if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index()))
    115       {
    116         m_id = m_lhsIter.index();
    117         m_value = m_functor(m_lhsIter.value(), m_rhsIter.value());
    118         ++m_lhsIter;
    119         ++m_rhsIter;
    120       }
    121       else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index())))
    122       {
    123         m_id = m_lhsIter.index();
    124         m_value = m_functor(m_lhsIter.value(), Scalar(0));
    125         ++m_lhsIter;
    126       }
    127       else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index())))
    128       {
    129         m_id = m_rhsIter.index();
    130         m_value = m_functor(Scalar(0), m_rhsIter.value());
    131         ++m_rhsIter;
    132       }
    133       else
    134       {
    135         m_value = 0; // this is to avoid a compilation warning
    136         m_id = -1;
    137       }
    138       return *static_cast<Derived*>(this);
    139     }
    140 
    141     EIGEN_STRONG_INLINE Scalar value() const { return m_value; }
    142 
    143     EIGEN_STRONG_INLINE Index index() const { return m_id; }
    144     EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); }
    145     EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); }
    146 
    147     EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; }
    148 
    149   protected:
    150     LhsIterator m_lhsIter;
    151     RhsIterator m_rhsIter;
    152     const BinaryOp& m_functor;
    153     Scalar m_value;
    154     Index m_id;
    155 };
    156 
    157 // sparse - sparse  (product)
    158 template<typename T, typename Lhs, typename Rhs, typename Derived>
    159 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse>
    160 {
    161     typedef scalar_product_op<T> BinaryFunc;
    162     typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
    163     typedef typename CwiseBinaryXpr::Scalar Scalar;
    164     typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
    165     typedef typename _LhsNested::InnerIterator LhsIterator;
    166     typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
    167     typedef typename _RhsNested::InnerIterator RhsIterator;
    168     typedef typename Lhs::Index Index;
    169   public:
    170 
    171     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
    172       : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor())
    173     {
    174       while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
    175       {
    176         if (m_lhsIter.index() < m_rhsIter.index())
    177           ++m_lhsIter;
    178         else
    179           ++m_rhsIter;
    180       }
    181     }
    182 
    183     EIGEN_STRONG_INLINE Derived& operator++()
    184     {
    185       ++m_lhsIter;
    186       ++m_rhsIter;
    187       while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index()))
    188       {
    189         if (m_lhsIter.index() < m_rhsIter.index())
    190           ++m_lhsIter;
    191         else
    192           ++m_rhsIter;
    193       }
    194       return *static_cast<Derived*>(this);
    195     }
    196 
    197     EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); }
    198 
    199     EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
    200     EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
    201     EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
    202 
    203     EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); }
    204 
    205   protected:
    206     LhsIterator m_lhsIter;
    207     RhsIterator m_rhsIter;
    208     const BinaryFunc& m_functor;
    209 };
    210 
    211 // sparse - dense  (product)
    212 template<typename T, typename Lhs, typename Rhs, typename Derived>
    213 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense>
    214 {
    215     typedef scalar_product_op<T> BinaryFunc;
    216     typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
    217     typedef typename CwiseBinaryXpr::Scalar Scalar;
    218     typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested;
    219     typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested;
    220     typedef typename _LhsNested::InnerIterator LhsIterator;
    221     typedef typename Lhs::Index Index;
    222     enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit };
    223   public:
    224 
    225     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
    226       : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer)
    227     {}
    228 
    229     EIGEN_STRONG_INLINE Derived& operator++()
    230     {
    231       ++m_lhsIter;
    232       return *static_cast<Derived*>(this);
    233     }
    234 
    235     EIGEN_STRONG_INLINE Scalar value() const
    236     { return m_functor(m_lhsIter.value(),
    237                        m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); }
    238 
    239     EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); }
    240     EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); }
    241     EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); }
    242 
    243     EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; }
    244 
    245   protected:
    246     RhsNested m_rhs;
    247     LhsIterator m_lhsIter;
    248     const BinaryFunc m_functor;
    249     const Index m_outer;
    250 };
    251 
    252 // sparse - dense  (product)
    253 template<typename T, typename Lhs, typename Rhs, typename Derived>
    254 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse>
    255 {
    256     typedef scalar_product_op<T> BinaryFunc;
    257     typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr;
    258     typedef typename CwiseBinaryXpr::Scalar Scalar;
    259     typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested;
    260     typedef typename _RhsNested::InnerIterator RhsIterator;
    261     typedef typename Lhs::Index Index;
    262 
    263     enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit };
    264   public:
    265 
    266     EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer)
    267       : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer)
    268     {}
    269 
    270     EIGEN_STRONG_INLINE Derived& operator++()
    271     {
    272       ++m_rhsIter;
    273       return *static_cast<Derived*>(this);
    274     }
    275 
    276     EIGEN_STRONG_INLINE Scalar value() const
    277     { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); }
    278 
    279     EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); }
    280     EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); }
    281     EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); }
    282 
    283     EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; }
    284 
    285   protected:
    286     const CwiseBinaryXpr& m_xpr;
    287     RhsIterator m_rhsIter;
    288     const BinaryFunc& m_functor;
    289     const Index m_outer;
    290 };
    291 
    292 } // end namespace internal
    293 
    294 /***************************************************************************
    295 * Implementation of SparseMatrixBase and SparseCwise functions/operators
    296 ***************************************************************************/
    297 
    298 template<typename Derived>
    299 template<typename OtherDerived>
    300 EIGEN_STRONG_INLINE Derived &
    301 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other)
    302 {
    303   return *this = derived() - other.derived();
    304 }
    305 
    306 template<typename Derived>
    307 template<typename OtherDerived>
    308 EIGEN_STRONG_INLINE Derived &
    309 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other)
    310 {
    311   return *this = derived() + other.derived();
    312 }
    313 
    314 template<typename Derived>
    315 template<typename OtherDerived>
    316 EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
    317 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const
    318 {
    319   return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived());
    320 }
    321 
    322 } // end namespace Eigen
    323 
    324 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H
    325