1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H 11 #define EIGEN_TRIANGULARMATRIXVECTOR_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized> 18 struct triangular_matrix_vector_product; 19 20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version> 21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version> 22 { 23 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; 24 enum { 25 IsLower = ((Mode&Lower)==Lower), 26 HasUnitDiag = (Mode & UnitDiag)==UnitDiag, 27 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag 28 }; 29 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 30 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha) 31 { 32 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; 33 Index size = (std::min)(_rows,_cols); 34 Index rows = IsLower ? _rows : (std::min)(_rows,_cols); 35 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols; 36 37 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap; 38 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); 39 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); 40 41 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap; 42 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr)); 43 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); 44 45 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap; 46 ResMap res(_res,rows); 47 48 for (Index pi=0; pi<size; pi+=PanelWidth) 49 { 50 Index actualPanelWidth = (std::min)(PanelWidth, size-pi); 51 for (Index k=0; k<actualPanelWidth; ++k) 52 { 53 Index i = pi + k; 54 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi; 55 Index r = IsLower ? actualPanelWidth-k : k+1; 56 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0) 57 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r); 58 if (HasUnitDiag) 59 res.coeffRef(i) += alpha * cjRhs.coeff(i); 60 } 61 Index r = IsLower ? rows - pi - actualPanelWidth : pi; 62 if (r>0) 63 { 64 Index s = IsLower ? pi+actualPanelWidth : 0; 65 general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run( 66 r, actualPanelWidth, 67 &lhs.coeffRef(s,pi), lhsStride, 68 &rhs.coeffRef(pi), rhsIncr, 69 &res.coeffRef(s), resIncr, alpha); 70 } 71 } 72 if((!IsLower) && cols>size) 73 { 74 general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run( 75 rows, cols-size, 76 &lhs.coeffRef(0,size), lhsStride, 77 &rhs.coeffRef(size), rhsIncr, 78 _res, resIncr, alpha); 79 } 80 } 81 }; 82 83 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version> 84 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version> 85 { 86 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; 87 enum { 88 IsLower = ((Mode&Lower)==Lower), 89 HasUnitDiag = (Mode & UnitDiag)==UnitDiag, 90 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag 91 }; 92 static void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 93 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha) 94 { 95 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; 96 Index diagSize = (std::min)(_rows,_cols); 97 Index rows = IsLower ? _rows : diagSize; 98 Index cols = IsLower ? diagSize : _cols; 99 100 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap; 101 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); 102 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); 103 104 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap; 105 const RhsMap rhs(_rhs,cols); 106 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); 107 108 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap; 109 ResMap res(_res,rows,InnerStride<>(resIncr)); 110 111 for (Index pi=0; pi<diagSize; pi+=PanelWidth) 112 { 113 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi); 114 for (Index k=0; k<actualPanelWidth; ++k) 115 { 116 Index i = pi + k; 117 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i); 118 Index r = IsLower ? k+1 : actualPanelWidth-k; 119 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0) 120 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum(); 121 if (HasUnitDiag) 122 res.coeffRef(i) += alpha * cjRhs.coeff(i); 123 } 124 Index r = IsLower ? pi : cols - pi - actualPanelWidth; 125 if (r>0) 126 { 127 Index s = IsLower ? 0 : pi + actualPanelWidth; 128 general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run( 129 actualPanelWidth, r, 130 &lhs.coeffRef(pi,s), lhsStride, 131 &rhs.coeffRef(s), rhsIncr, 132 &res.coeffRef(pi), resIncr, alpha); 133 } 134 } 135 if(IsLower && rows>diagSize) 136 { 137 general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run( 138 rows-diagSize, cols, 139 &lhs.coeffRef(diagSize,0), lhsStride, 140 &rhs.coeffRef(0), rhsIncr, 141 &res.coeffRef(diagSize), resIncr, alpha); 142 } 143 } 144 }; 145 146 /*************************************************************************** 147 * Wrapper to product_triangular_vector 148 ***************************************************************************/ 149 150 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> 151 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> > 152 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> > 153 {}; 154 155 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> 156 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> > 157 : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> > 158 {}; 159 160 161 template<int StorageOrder> 162 struct trmv_selector; 163 164 } // end namespace internal 165 166 template<int Mode, typename Lhs, typename Rhs> 167 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> 168 : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs > 169 { 170 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) 171 172 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} 173 174 template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const 175 { 176 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); 177 178 internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha); 179 } 180 }; 181 182 template<int Mode, typename Lhs, typename Rhs> 183 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false> 184 : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs > 185 { 186 EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) 187 188 TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} 189 190 template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const 191 { 192 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); 193 194 typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose; 195 Transpose<Dest> dstT(dst); 196 internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run( 197 TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha); 198 } 199 }; 200 201 namespace internal { 202 203 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same. 204 205 template<> struct trmv_selector<ColMajor> 206 { 207 template<int Mode, typename Lhs, typename Rhs, typename Dest> 208 static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha) 209 { 210 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; 211 typedef typename ProductType::Index Index; 212 typedef typename ProductType::LhsScalar LhsScalar; 213 typedef typename ProductType::RhsScalar RhsScalar; 214 typedef typename ProductType::Scalar ResScalar; 215 typedef typename ProductType::RealScalar RealScalar; 216 typedef typename ProductType::ActualLhsType ActualLhsType; 217 typedef typename ProductType::ActualRhsType ActualRhsType; 218 typedef typename ProductType::LhsBlasTraits LhsBlasTraits; 219 typedef typename ProductType::RhsBlasTraits RhsBlasTraits; 220 typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest; 221 222 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); 223 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); 224 225 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) 226 * RhsBlasTraits::extractScalarFactor(prod.rhs()); 227 228 enum { 229 // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1 230 // on, the other hand it is good for the cache to pack the vector anyways... 231 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1, 232 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex), 233 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal 234 }; 235 236 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest; 237 238 bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0)); 239 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; 240 241 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha); 242 243 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), 244 evalToDest ? dest.data() : static_dest.data()); 245 246 if(!evalToDest) 247 { 248 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 249 int size = dest.size(); 250 EIGEN_DENSE_STORAGE_CTOR_PLUGIN 251 #endif 252 if(!alphaIsCompatible) 253 { 254 MappedDest(actualDestPtr, dest.size()).setZero(); 255 compatibleAlpha = RhsScalar(1); 256 } 257 else 258 MappedDest(actualDestPtr, dest.size()) = dest; 259 } 260 261 internal::triangular_matrix_vector_product 262 <Index,Mode, 263 LhsScalar, LhsBlasTraits::NeedToConjugate, 264 RhsScalar, RhsBlasTraits::NeedToConjugate, 265 ColMajor> 266 ::run(actualLhs.rows(),actualLhs.cols(), 267 actualLhs.data(),actualLhs.outerStride(), 268 actualRhs.data(),actualRhs.innerStride(), 269 actualDestPtr,1,compatibleAlpha); 270 271 if (!evalToDest) 272 { 273 if(!alphaIsCompatible) 274 dest += actualAlpha * MappedDest(actualDestPtr, dest.size()); 275 else 276 dest = MappedDest(actualDestPtr, dest.size()); 277 } 278 } 279 }; 280 281 template<> struct trmv_selector<RowMajor> 282 { 283 template<int Mode, typename Lhs, typename Rhs, typename Dest> 284 static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha) 285 { 286 typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; 287 typedef typename ProductType::LhsScalar LhsScalar; 288 typedef typename ProductType::RhsScalar RhsScalar; 289 typedef typename ProductType::Scalar ResScalar; 290 typedef typename ProductType::Index Index; 291 typedef typename ProductType::ActualLhsType ActualLhsType; 292 typedef typename ProductType::ActualRhsType ActualRhsType; 293 typedef typename ProductType::_ActualRhsType _ActualRhsType; 294 typedef typename ProductType::LhsBlasTraits LhsBlasTraits; 295 typedef typename ProductType::RhsBlasTraits RhsBlasTraits; 296 297 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); 298 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); 299 300 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) 301 * RhsBlasTraits::extractScalarFactor(prod.rhs()); 302 303 enum { 304 DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1 305 }; 306 307 gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs; 308 309 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(), 310 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data()); 311 312 if(!DirectlyUseRhs) 313 { 314 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 315 int size = actualRhs.size(); 316 EIGEN_DENSE_STORAGE_CTOR_PLUGIN 317 #endif 318 Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; 319 } 320 321 internal::triangular_matrix_vector_product 322 <Index,Mode, 323 LhsScalar, LhsBlasTraits::NeedToConjugate, 324 RhsScalar, RhsBlasTraits::NeedToConjugate, 325 RowMajor> 326 ::run(actualLhs.rows(),actualLhs.cols(), 327 actualLhs.data(),actualLhs.outerStride(), 328 actualRhsPtr,1, 329 dest.data(),dest.innerStride(), 330 actualAlpha); 331 } 332 }; 333 334 } // end namespace internal 335 336 } // end namespace Eigen 337 338 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H 339