Home | History | Annotate | Download | only in Core
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 // Copyright (C) 2007-2009 Benoit Jacob <jacob.benoit.1 (at) gmail.com>
      6 //
      7 // This Source Code Form is subject to the terms of the Mozilla
      8 // Public License v. 2.0. If a copy of the MPL was not distributed
      9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     10 
     11 #ifndef EIGEN_DIAGONALPRODUCT_H
     12 #define EIGEN_DIAGONALPRODUCT_H
     13 
     14 namespace Eigen {
     15 
     16 namespace internal {
     17 template<typename MatrixType, typename DiagonalType, int ProductOrder>
     18 struct traits<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
     19  : traits<MatrixType>
     20 {
     21   typedef typename scalar_product_traits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
     22   enum {
     23     RowsAtCompileTime = MatrixType::RowsAtCompileTime,
     24     ColsAtCompileTime = MatrixType::ColsAtCompileTime,
     25     MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
     26     MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
     27 
     28     _StorageOrder = MatrixType::Flags & RowMajorBit ? RowMajor : ColMajor,
     29     _PacketOnDiag = !((int(_StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
     30                     ||(int(_StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)),
     31     _SameTypes = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
     32     // FIXME currently we need same types, but in the future the next rule should be the one
     33     //_Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && ((!_PacketOnDiag) || (_SameTypes && bool(int(DiagonalType::Flags)&PacketAccessBit))),
     34     _Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && _SameTypes && ((!_PacketOnDiag) || (bool(int(DiagonalType::Flags)&PacketAccessBit))),
     35 
     36     Flags = (HereditaryBits & (unsigned int)(MatrixType::Flags)) | (_Vectorizable ? PacketAccessBit : 0),
     37     CoeffReadCost = NumTraits<Scalar>::MulCost + MatrixType::CoeffReadCost + DiagonalType::DiagonalVectorType::CoeffReadCost
     38   };
     39 };
     40 }
     41 
     42 template<typename MatrixType, typename DiagonalType, int ProductOrder>
     43 class DiagonalProduct : internal::no_assignment_operator,
     44                         public MatrixBase<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
     45 {
     46   public:
     47 
     48     typedef MatrixBase<DiagonalProduct> Base;
     49     EIGEN_DENSE_PUBLIC_INTERFACE(DiagonalProduct)
     50 
     51     inline DiagonalProduct(const MatrixType& matrix, const DiagonalType& diagonal)
     52       : m_matrix(matrix), m_diagonal(diagonal)
     53     {
     54       eigen_assert(diagonal.diagonal().size() == (ProductOrder == OnTheLeft ? matrix.rows() : matrix.cols()));
     55     }
     56 
     57     inline Index rows() const { return m_matrix.rows(); }
     58     inline Index cols() const { return m_matrix.cols(); }
     59 
     60     const Scalar coeff(Index row, Index col) const
     61     {
     62       return m_diagonal.diagonal().coeff(ProductOrder == OnTheLeft ? row : col) * m_matrix.coeff(row, col);
     63     }
     64 
     65     template<int LoadMode>
     66     EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
     67     {
     68       enum {
     69         StorageOrder = Flags & RowMajorBit ? RowMajor : ColMajor
     70       };
     71       const Index indexInDiagonalVector = ProductOrder == OnTheLeft ? row : col;
     72 
     73       return packet_impl<LoadMode>(row,col,indexInDiagonalVector,typename internal::conditional<
     74         ((int(StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
     75        ||(int(StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)), internal::true_type, internal::false_type>::type());
     76     }
     77 
     78   protected:
     79     template<int LoadMode>
     80     EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
     81     {
     82       return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
     83                      internal::pset1<PacketScalar>(m_diagonal.diagonal().coeff(id)));
     84     }
     85 
     86     template<int LoadMode>
     87     EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
     88     {
     89       enum {
     90         InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
     91         DiagonalVectorPacketLoadMode = (LoadMode == Aligned && ((InnerSize%16) == 0)) ? Aligned : Unaligned
     92       };
     93       return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
     94                      m_diagonal.diagonal().template packet<DiagonalVectorPacketLoadMode>(id));
     95     }
     96 
     97     typename MatrixType::Nested m_matrix;
     98     typename DiagonalType::Nested m_diagonal;
     99 };
    100 
    101 /** \returns the diagonal matrix product of \c *this by the diagonal matrix \a diagonal.
    102   */
    103 template<typename Derived>
    104 template<typename DiagonalDerived>
    105 inline const DiagonalProduct<Derived, DiagonalDerived, OnTheRight>
    106 MatrixBase<Derived>::operator*(const DiagonalBase<DiagonalDerived> &diagonal) const
    107 {
    108   return DiagonalProduct<Derived, DiagonalDerived, OnTheRight>(derived(), diagonal.derived());
    109 }
    110 
    111 /** \returns the diagonal matrix product of \c *this by the matrix \a matrix.
    112   */
    113 template<typename DiagonalDerived>
    114 template<typename MatrixDerived>
    115 inline const DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>
    116 DiagonalBase<DiagonalDerived>::operator*(const MatrixBase<MatrixDerived> &matrix) const
    117 {
    118   return DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>(matrix.derived(), derived());
    119 }
    120 
    121 } // end namespace Eigen
    122 
    123 #endif // EIGEN_DIAGONALPRODUCT_H
    124