Home | History | Annotate | Download | only in Skyline
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008-2009 Guillaume Saupin <guillaume.saupin (at) cea.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SKYLINEPRODUCT_H
     11 #define EIGEN_SKYLINEPRODUCT_H
     12 
     13 namespace Eigen {
     14 
     15 template<typename Lhs, typename Rhs, int ProductMode>
     16 struct SkylineProductReturnType {
     17     typedef const typename internal::nested<Lhs, Rhs::RowsAtCompileTime>::type LhsNested;
     18     typedef const typename internal::nested<Rhs, Lhs::RowsAtCompileTime>::type RhsNested;
     19 
     20     typedef SkylineProduct<LhsNested, RhsNested, ProductMode> Type;
     21 };
     22 
     23 template<typename LhsNested, typename RhsNested, int ProductMode>
     24 struct internal::traits<SkylineProduct<LhsNested, RhsNested, ProductMode> > {
     25     // clean the nested types:
     26     typedef typename internal::remove_all<LhsNested>::type _LhsNested;
     27     typedef typename internal::remove_all<RhsNested>::type _RhsNested;
     28     typedef typename _LhsNested::Scalar Scalar;
     29 
     30     enum {
     31         LhsCoeffReadCost = _LhsNested::CoeffReadCost,
     32         RhsCoeffReadCost = _RhsNested::CoeffReadCost,
     33         LhsFlags = _LhsNested::Flags,
     34         RhsFlags = _RhsNested::Flags,
     35 
     36         RowsAtCompileTime = _LhsNested::RowsAtCompileTime,
     37         ColsAtCompileTime = _RhsNested::ColsAtCompileTime,
     38         InnerSize = EIGEN_SIZE_MIN_PREFER_FIXED(_LhsNested::ColsAtCompileTime, _RhsNested::RowsAtCompileTime),
     39 
     40         MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime,
     41         MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime,
     42 
     43         EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit),
     44         ResultIsSkyline = ProductMode == SkylineTimeSkylineProduct,
     45 
     46         RemovedBits = ~((EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSkyline ? 0 : SkylineBit)),
     47 
     48         Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
     49         | EvalBeforeAssigningBit
     50         | EvalBeforeNestingBit,
     51 
     52         CoeffReadCost = Dynamic
     53     };
     54 
     55     typedef typename internal::conditional<ResultIsSkyline,
     56             SkylineMatrixBase<SkylineProduct<LhsNested, RhsNested, ProductMode> >,
     57             MatrixBase<SkylineProduct<LhsNested, RhsNested, ProductMode> > >::type Base;
     58 };
     59 
     60 namespace internal {
     61 template<typename LhsNested, typename RhsNested, int ProductMode>
     62 class SkylineProduct : no_assignment_operator,
     63 public traits<SkylineProduct<LhsNested, RhsNested, ProductMode> >::Base {
     64 public:
     65 
     66     EIGEN_GENERIC_PUBLIC_INTERFACE(SkylineProduct)
     67 
     68 private:
     69 
     70     typedef typename traits<SkylineProduct>::_LhsNested _LhsNested;
     71     typedef typename traits<SkylineProduct>::_RhsNested _RhsNested;
     72 
     73 public:
     74 
     75     template<typename Lhs, typename Rhs>
     76     EIGEN_STRONG_INLINE SkylineProduct(const Lhs& lhs, const Rhs& rhs)
     77     : m_lhs(lhs), m_rhs(rhs) {
     78         eigen_assert(lhs.cols() == rhs.rows());
     79 
     80         enum {
     81             ProductIsValid = _LhsNested::ColsAtCompileTime == Dynamic
     82             || _RhsNested::RowsAtCompileTime == Dynamic
     83             || int(_LhsNested::ColsAtCompileTime) == int(_RhsNested::RowsAtCompileTime),
     84             AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime,
     85             SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested, _RhsNested)
     86         };
     87         // note to the lost user:
     88         //    * for a dot product use: v1.dot(v2)
     89         //    * for a coeff-wise product use: v1.cwise()*v2
     90         EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
     91                 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
     92                 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
     93                 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
     94                 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
     95     }
     96 
     97     EIGEN_STRONG_INLINE Index rows() const {
     98         return m_lhs.rows();
     99     }
    100 
    101     EIGEN_STRONG_INLINE Index cols() const {
    102         return m_rhs.cols();
    103     }
    104 
    105     EIGEN_STRONG_INLINE const _LhsNested& lhs() const {
    106         return m_lhs;
    107     }
    108 
    109     EIGEN_STRONG_INLINE const _RhsNested& rhs() const {
    110         return m_rhs;
    111     }
    112 
    113 protected:
    114     LhsNested m_lhs;
    115     RhsNested m_rhs;
    116 };
    117 
    118 // dense = skyline * dense
    119 // Note that here we force no inlining and separate the setZero() because GCC messes up otherwise
    120 
    121 template<typename Lhs, typename Rhs, typename Dest>
    122 EIGEN_DONT_INLINE void skyline_row_major_time_dense_product(const Lhs& lhs, const Rhs& rhs, Dest& dst) {
    123     typedef typename remove_all<Lhs>::type _Lhs;
    124     typedef typename remove_all<Rhs>::type _Rhs;
    125     typedef typename traits<Lhs>::Scalar Scalar;
    126 
    127     enum {
    128         LhsIsRowMajor = (_Lhs::Flags & RowMajorBit) == RowMajorBit,
    129         LhsIsSelfAdjoint = (_Lhs::Flags & SelfAdjointBit) == SelfAdjointBit,
    130         ProcessFirstHalf = LhsIsSelfAdjoint
    131         && (((_Lhs::Flags & (UpperTriangularBit | LowerTriangularBit)) == 0)
    132         || ((_Lhs::Flags & UpperTriangularBit) && !LhsIsRowMajor)
    133         || ((_Lhs::Flags & LowerTriangularBit) && LhsIsRowMajor)),
    134         ProcessSecondHalf = LhsIsSelfAdjoint && (!ProcessFirstHalf)
    135     };
    136 
    137     //Use matrix diagonal part <- Improvement : use inner iterator on dense matrix.
    138     for (Index col = 0; col < rhs.cols(); col++) {
    139         for (Index row = 0; row < lhs.rows(); row++) {
    140             dst(row, col) = lhs.coeffDiag(row) * rhs(row, col);
    141         }
    142     }
    143     //Use matrix lower triangular part
    144     for (Index row = 0; row < lhs.rows(); row++) {
    145         typename _Lhs::InnerLowerIterator lIt(lhs, row);
    146         const Index stop = lIt.col() + lIt.size();
    147         for (Index col = 0; col < rhs.cols(); col++) {
    148 
    149             Index k = lIt.col();
    150             Scalar tmp = 0;
    151             while (k < stop) {
    152                 tmp +=
    153                         lIt.value() *
    154                         rhs(k++, col);
    155                 ++lIt;
    156             }
    157             dst(row, col) += tmp;
    158             lIt += -lIt.size();
    159         }
    160 
    161     }
    162 
    163     //Use matrix upper triangular part
    164     for (Index lhscol = 0; lhscol < lhs.cols(); lhscol++) {
    165         typename _Lhs::InnerUpperIterator uIt(lhs, lhscol);
    166         const Index stop = uIt.size() + uIt.row();
    167         for (Index rhscol = 0; rhscol < rhs.cols(); rhscol++) {
    168 
    169 
    170             const Scalar rhsCoeff = rhs.coeff(lhscol, rhscol);
    171             Index k = uIt.row();
    172             while (k < stop) {
    173                 dst(k++, rhscol) +=
    174                         uIt.value() *
    175                         rhsCoeff;
    176                 ++uIt;
    177             }
    178             uIt += -uIt.size();
    179         }
    180     }
    181 
    182 }
    183 
    184 template<typename Lhs, typename Rhs, typename Dest>
    185 EIGEN_DONT_INLINE void skyline_col_major_time_dense_product(const Lhs& lhs, const Rhs& rhs, Dest& dst) {
    186     typedef typename remove_all<Lhs>::type _Lhs;
    187     typedef typename remove_all<Rhs>::type _Rhs;
    188     typedef typename traits<Lhs>::Scalar Scalar;
    189 
    190     enum {
    191         LhsIsRowMajor = (_Lhs::Flags & RowMajorBit) == RowMajorBit,
    192         LhsIsSelfAdjoint = (_Lhs::Flags & SelfAdjointBit) == SelfAdjointBit,
    193         ProcessFirstHalf = LhsIsSelfAdjoint
    194         && (((_Lhs::Flags & (UpperTriangularBit | LowerTriangularBit)) == 0)
    195         || ((_Lhs::Flags & UpperTriangularBit) && !LhsIsRowMajor)
    196         || ((_Lhs::Flags & LowerTriangularBit) && LhsIsRowMajor)),
    197         ProcessSecondHalf = LhsIsSelfAdjoint && (!ProcessFirstHalf)
    198     };
    199 
    200     //Use matrix diagonal part <- Improvement : use inner iterator on dense matrix.
    201     for (Index col = 0; col < rhs.cols(); col++) {
    202         for (Index row = 0; row < lhs.rows(); row++) {
    203             dst(row, col) = lhs.coeffDiag(row) * rhs(row, col);
    204         }
    205     }
    206 
    207     //Use matrix upper triangular part
    208     for (Index row = 0; row < lhs.rows(); row++) {
    209         typename _Lhs::InnerUpperIterator uIt(lhs, row);
    210         const Index stop = uIt.col() + uIt.size();
    211         for (Index col = 0; col < rhs.cols(); col++) {
    212 
    213             Index k = uIt.col();
    214             Scalar tmp = 0;
    215             while (k < stop) {
    216                 tmp +=
    217                         uIt.value() *
    218                         rhs(k++, col);
    219                 ++uIt;
    220             }
    221 
    222 
    223             dst(row, col) += tmp;
    224             uIt += -uIt.size();
    225         }
    226     }
    227 
    228     //Use matrix lower triangular part
    229     for (Index lhscol = 0; lhscol < lhs.cols(); lhscol++) {
    230         typename _Lhs::InnerLowerIterator lIt(lhs, lhscol);
    231         const Index stop = lIt.size() + lIt.row();
    232         for (Index rhscol = 0; rhscol < rhs.cols(); rhscol++) {
    233 
    234             const Scalar rhsCoeff = rhs.coeff(lhscol, rhscol);
    235             Index k = lIt.row();
    236             while (k < stop) {
    237                 dst(k++, rhscol) +=
    238                         lIt.value() *
    239                         rhsCoeff;
    240                 ++lIt;
    241             }
    242             lIt += -lIt.size();
    243         }
    244     }
    245 
    246 }
    247 
    248 template<typename Lhs, typename Rhs, typename ResultType,
    249         int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit>
    250         struct skyline_product_selector;
    251 
    252 template<typename Lhs, typename Rhs, typename ResultType>
    253 struct skyline_product_selector<Lhs, Rhs, ResultType, RowMajor> {
    254     typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
    255 
    256     static void run(const Lhs& lhs, const Rhs& rhs, ResultType & res) {
    257         skyline_row_major_time_dense_product<Lhs, Rhs, ResultType > (lhs, rhs, res);
    258     }
    259 };
    260 
    261 template<typename Lhs, typename Rhs, typename ResultType>
    262 struct skyline_product_selector<Lhs, Rhs, ResultType, ColMajor> {
    263     typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
    264 
    265     static void run(const Lhs& lhs, const Rhs& rhs, ResultType & res) {
    266         skyline_col_major_time_dense_product<Lhs, Rhs, ResultType > (lhs, rhs, res);
    267     }
    268 };
    269 
    270 } // end namespace internal
    271 
    272 // template<typename Derived>
    273 // template<typename Lhs, typename Rhs >
    274 // Derived & MatrixBase<Derived>::lazyAssign(const SkylineProduct<Lhs, Rhs, SkylineTimeDenseProduct>& product) {
    275 //     typedef typename internal::remove_all<Lhs>::type _Lhs;
    276 //     internal::skyline_product_selector<typename internal::remove_all<Lhs>::type,
    277 //             typename internal::remove_all<Rhs>::type,
    278 //             Derived>::run(product.lhs(), product.rhs(), derived());
    279 //
    280 //     return derived();
    281 // }
    282 
    283 // skyline * dense
    284 
    285 template<typename Derived>
    286 template<typename OtherDerived >
    287 EIGEN_STRONG_INLINE const typename SkylineProductReturnType<Derived, OtherDerived>::Type
    288 SkylineMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const {
    289 
    290     return typename SkylineProductReturnType<Derived, OtherDerived>::Type(derived(), other.derived());
    291 }
    292 
    293 } // end namespace Eigen
    294 
    295 #endif // EIGEN_SKYLINEPRODUCT_H
    296