1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H 11 #define EIGEN_TRIANGULARMATRIXVECTOR_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized> 18 struct triangular_matrix_vector_product; 19 20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version> 21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version> 22 { 23 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; 24 enum { 25 IsLower = ((Mode&Lower)==Lower), 26 HasUnitDiag = (Mode & UnitDiag)==UnitDiag, 27 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag 28 }; 29 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 30 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha); 31 }; 32 33 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version> 34 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version> 35 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 36 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha) 37 { 38 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; 39 Index size = (std::min)(_rows,_cols); 40 Index rows = IsLower ? _rows : (std::min)(_rows,_cols); 41 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols; 42 43 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap; 44 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); 45 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); 46 47 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap; 48 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr)); 49 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); 50 51 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap; 52 ResMap res(_res,rows); 53 54 for (Index pi=0; pi<size; pi+=PanelWidth) 55 { 56 Index actualPanelWidth = (std::min)(PanelWidth, size-pi); 57 for (Index k=0; k<actualPanelWidth; ++k) 58 { 59 Index i = pi + k; 60 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi; 61 Index r = IsLower ? actualPanelWidth-k : k+1; 62 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0) 63 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r); 64 if (HasUnitDiag) 65 res.coeffRef(i) += alpha * cjRhs.coeff(i); 66 } 67 Index r = IsLower ? rows - pi - actualPanelWidth : pi; 68 if (r>0) 69 { 70 Index s = IsLower ? pi+actualPanelWidth : 0; 71 general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run( 72 r, actualPanelWidth, 73 &lhs.coeffRef(s,pi), lhsStride, 74 &rhs.coeffRef(pi), rhsIncr, 75 &res.coeffRef(s), resIncr, alpha); 76 } 77 } 78 if((!IsLower) && cols>size) 79 { 80 general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run( 81 rows, cols-size, 82 &lhs.coeffRef(0,size), lhsStride, 83 &rhs.coeffRef(size), rhsIncr, 84 _res, resIncr, alpha); 85 } 86 } 87 88 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version> 89 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version> 90 { 91 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; 92 enum { 93 IsLower = ((Mode&Lower)==Lower), 94 HasUnitDiag = (Mode & UnitDiag)==UnitDiag, 95 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag 96 }; 97 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 98 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha); 99 }; 100 101 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version> 102 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version> 103 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 104 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha) 105 { 106 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; 107 Index diagSize = (std::min)(_rows,_cols); 108 Index rows = IsLower ? _rows : diagSize; 109 Index cols = IsLower ? diagSize : _cols; 110 111 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap; 112 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); 113 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); 114 115 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap; 116 const RhsMap rhs(_rhs,cols); 117 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); 118 119 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap; 120 ResMap res(_res,rows,InnerStride<>(resIncr)); 121 122 for (Index pi=0; pi<diagSize; pi+=PanelWidth) 123 { 124 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi); 125 for (Index k=0; k<actualPanelWidth; ++k) 126 { 127 Index i = pi + k; 128 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i); 129 Index r = IsLower ? k+1 : actualPanelWidth-k; 130 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0) 131 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum(); 132 if (HasUnitDiag) 133 res.coeffRef(i) += alpha * cjRhs.coeff(i); 134 } 135 Index r = IsLower ? pi : cols - pi - actualPanelWidth; 136 if (r>0) 137 { 138 Index s = IsLower ? 0 : pi + actualPanelWidth; 139 general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run( 140 actualPanelWidth, r, 141 &lhs.coeffRef(pi,s), lhsStride, 142 &rhs.coeffRef(s), rhsIncr, 143 &res.coeffRef(pi), resIncr, alpha); 144 } 145 } 146 if(IsLower && rows>diagSize) 147 { 148 general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run( 149 rows-diagSize, cols, 150 &lhs.coeffRef(diagSize,0), lhsStride, 151 &rhs.coeffRef(0), rhsIncr, 152 &res.coeffRef(diagSize), resIncr, alpha); 153 } 154 } 155 156 /*************************************************************************** 157 * Wrapper to product_triangular_vector 158 ***************************************************************************/ 159 160 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> 161 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> > 162 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> > 163 {}; 164 165 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> 166 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> > 167 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> > 168 {}; 169 170 171 template<int StorageOrder> 172 struct trmv_selector; 173 174 } // end namespace internal 175 176 template<int Mode, typename Lhs, typename Rhs> 177 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> 178 : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs > 179 { 180 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) 181 182 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} 183 184 template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const 185 { 186 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); 187 188 internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha); 189 } 190 }; 191 192 template<int Mode, typename Lhs, typename Rhs> 193 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false> 194 : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs > 195 { 196 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) 197 198 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} 199 200 template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const 201 { 202 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); 203 204 typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose; 205 Transpose<Dest> dstT(dst); 206 internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run( 207 TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha); 208 } 209 }; 210 211 namespace internal { 212 213 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same. 214 215 template<> struct trmv_selector<ColMajor> 216 { 217 template<int Mode, typename Lhs, typename Rhs, typename Dest> 218 static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha) 219 { 220 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; 221 typedef typename ProductType::Index Index; 222 typedef typename ProductType::LhsScalar LhsScalar; 223 typedef typename ProductType::RhsScalar RhsScalar; 224 typedef typename ProductType::Scalar ResScalar; 225 typedef typename ProductType::RealScalar RealScalar; 226 typedef typename ProductType::ActualLhsType ActualLhsType; 227 typedef typename ProductType::ActualRhsType ActualRhsType; 228 typedef typename ProductType::LhsBlasTraits LhsBlasTraits; 229 typedef typename ProductType::RhsBlasTraits RhsBlasTraits; 230 typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest; 231 232 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); 233 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); 234 235 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) 236 * RhsBlasTraits::extractScalarFactor(prod.rhs()); 237 238 enum { 239 // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1 240 // on, the other hand it is good for the cache to pack the vector anyways... 241 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1, 242 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex), 243 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal 244 }; 245 246 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest; 247 248 bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0)); 249 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; 250 251 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha); 252 253 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), 254 evalToDest ? dest.data() : static_dest.data()); 255 256 if(!evalToDest) 257 { 258 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 259 Index size = dest.size(); 260 EIGEN_DENSE_STORAGE_CTOR_PLUGIN 261 #endif 262 if(!alphaIsCompatible) 263 { 264 MappedDest(actualDestPtr, dest.size()).setZero(); 265 compatibleAlpha = RhsScalar(1); 266 } 267 else 268 MappedDest(actualDestPtr, dest.size()) = dest; 269 } 270 271 internal::triangular_matrix_vector_product 272 <Index,Mode, 273 LhsScalar, LhsBlasTraits::NeedToConjugate, 274 RhsScalar, RhsBlasTraits::NeedToConjugate, 275 ColMajor> 276 ::run(actualLhs.rows(),actualLhs.cols(), 277 actualLhs.data(),actualLhs.outerStride(), 278 actualRhs.data(),actualRhs.innerStride(), 279 actualDestPtr,1,compatibleAlpha); 280 281 if (!evalToDest) 282 { 283 if(!alphaIsCompatible) 284 dest += actualAlpha * MappedDest(actualDestPtr, dest.size()); 285 else 286 dest = MappedDest(actualDestPtr, dest.size()); 287 } 288 } 289 }; 290 291 template<> struct trmv_selector<RowMajor> 292 { 293 template<int Mode, typename Lhs, typename Rhs, typename Dest> 294 static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha) 295 { 296 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; 297 typedef typename ProductType::LhsScalar LhsScalar; 298 typedef typename ProductType::RhsScalar RhsScalar; 299 typedef typename ProductType::Scalar ResScalar; 300 typedef typename ProductType::Index Index; 301 typedef typename ProductType::ActualLhsType ActualLhsType; 302 typedef typename ProductType::ActualRhsType ActualRhsType; 303 typedef typename ProductType::_ActualRhsType _ActualRhsType; 304 typedef typename ProductType::LhsBlasTraits LhsBlasTraits; 305 typedef typename ProductType::RhsBlasTraits RhsBlasTraits; 306 307 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); 308 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); 309 310 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) 311 * RhsBlasTraits::extractScalarFactor(prod.rhs()); 312 313 enum { 314 DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1 315 }; 316 317 gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs; 318 319 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(), 320 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data()); 321 322 if(!DirectlyUseRhs) 323 { 324 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 325 int size = actualRhs.size(); 326 EIGEN_DENSE_STORAGE_CTOR_PLUGIN 327 #endif 328 Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; 329 } 330 331 internal::triangular_matrix_vector_product 332 <Index,Mode, 333 LhsScalar, LhsBlasTraits::NeedToConjugate, 334 RhsScalar, RhsBlasTraits::NeedToConjugate, 335 RowMajor> 336 ::run(actualLhs.rows(),actualLhs.cols(), 337 actualLhs.data(),actualLhs.outerStride(), 338 actualRhsPtr,1, 339 dest.data(),dest.innerStride(), 340 actualAlpha); 341 } 342 }; 343 344 } // end namespace internal 345 346 } // end namespace Eigen 347 348 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H 349