Home | History | Annotate | Download | only in products
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SELFADJOINT_PRODUCT_H
     11 #define EIGEN_SELFADJOINT_PRODUCT_H
     12 
     13 /**********************************************************************
     14 * This file implements a self adjoint product: C += A A^T updating only
     15 * half of the selfadjoint matrix C.
     16 * It corresponds to the level 3 SYRK and level 2 SYR Blas routines.
     17 **********************************************************************/
     18 
     19 namespace Eigen {
     20 
     21 
     22 template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
     23 struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs>
     24 {
     25   static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha)
     26   {
     27     internal::conj_if<ConjRhs> cj;
     28     typedef Map<const Matrix<Scalar,Dynamic,1> > OtherMap;
     29     typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjLhsType;
     30     for (Index i=0; i<size; ++i)
     31     {
     32       Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1)))
     33           += (alpha * cj(vecY[i])) * ConjLhsType(OtherMap(vecX+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1)));
     34     }
     35   }
     36 };
     37 
     38 template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
     39 struct selfadjoint_rank1_update<Scalar,Index,RowMajor,UpLo,ConjLhs,ConjRhs>
     40 {
     41   static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha)
     42   {
     43     selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vecY,vecX,alpha);
     44   }
     45 };
     46 
     47 template<typename MatrixType, typename OtherType, int UpLo, bool OtherIsVector = OtherType::IsVectorAtCompileTime>
     48 struct selfadjoint_product_selector;
     49 
     50 template<typename MatrixType, typename OtherType, int UpLo>
     51 struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true>
     52 {
     53   static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha)
     54   {
     55     typedef typename MatrixType::Scalar Scalar;
     56     typedef internal::blas_traits<OtherType> OtherBlasTraits;
     57     typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
     58     typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType;
     59     typename internal::add_const_on_value_type<ActualOtherType>::type actualOther = OtherBlasTraits::extract(other.derived());
     60 
     61     Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
     62 
     63     enum {
     64       StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor,
     65       UseOtherDirectly = _ActualOtherType::InnerStrideAtCompileTime==1
     66     };
     67     internal::gemv_static_vector_if<Scalar,OtherType::SizeAtCompileTime,OtherType::MaxSizeAtCompileTime,!UseOtherDirectly> static_other;
     68 
     69     ei_declare_aligned_stack_constructed_variable(Scalar, actualOtherPtr, other.size(),
     70       (UseOtherDirectly ? const_cast<Scalar*>(actualOther.data()) : static_other.data()));
     71 
     72     if(!UseOtherDirectly)
     73       Map<typename _ActualOtherType::PlainObject>(actualOtherPtr, actualOther.size()) = actualOther;
     74 
     75     selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo,
     76                               OtherBlasTraits::NeedToConjugate  && NumTraits<Scalar>::IsComplex,
     77                             (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex>
     78           ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualOtherPtr, actualAlpha);
     79   }
     80 };
     81 
     82 template<typename MatrixType, typename OtherType, int UpLo>
     83 struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false>
     84 {
     85   static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha)
     86   {
     87     typedef typename MatrixType::Scalar Scalar;
     88     typedef internal::blas_traits<OtherType> OtherBlasTraits;
     89     typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
     90     typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType;
     91     typename internal::add_const_on_value_type<ActualOtherType>::type actualOther = OtherBlasTraits::extract(other.derived());
     92 
     93     Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
     94 
     95     enum {
     96       IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0,
     97       OtherIsRowMajor = _ActualOtherType::Flags&RowMajorBit ? 1 : 0
     98     };
     99 
    100     Index size = mat.cols();
    101     Index depth = actualOther.cols();
    102 
    103     typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor,Scalar,Scalar,
    104               MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime, _ActualOtherType::MaxColsAtCompileTime> BlockingType;
    105 
    106     BlockingType blocking(size, size, depth, 1, false);
    107 
    108 
    109     internal::general_matrix_matrix_triangular_product<Index,
    110       Scalar, OtherIsRowMajor ? RowMajor : ColMajor,   OtherBlasTraits::NeedToConjugate  && NumTraits<Scalar>::IsComplex,
    111       Scalar, OtherIsRowMajor ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex,
    112       IsRowMajor ? RowMajor : ColMajor, UpLo>
    113       ::run(size, depth,
    114             &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(),
    115             mat.data(), mat.outerStride(), actualAlpha, blocking);
    116   }
    117 };
    118 
    119 // high level API
    120 
    121 template<typename MatrixType, unsigned int UpLo>
    122 template<typename DerivedU>
    123 SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
    124 ::rankUpdate(const MatrixBase<DerivedU>& u, const Scalar& alpha)
    125 {
    126   selfadjoint_product_selector<MatrixType,DerivedU,UpLo>::run(_expression().const_cast_derived(), u.derived(), alpha);
    127 
    128   return *this;
    129 }
    130 
    131 } // end namespace Eigen
    132 
    133 #endif // EIGEN_SELFADJOINT_PRODUCT_H
    134