1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud (at) inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_GENERAL_BLOCK_PANEL_H 11 #define EIGEN_GENERAL_BLOCK_PANEL_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false> 18 class gebp_traits; 19 20 21 /** \internal \returns b if a<=0, and returns a otherwise. */ 22 inline std::ptrdiff_t manage_caching_sizes_helper(std::ptrdiff_t a, std::ptrdiff_t b) 23 { 24 return a<=0 ? b : a; 25 } 26 27 /** \internal */ 28 inline void manage_caching_sizes(Action action, std::ptrdiff_t* l1=0, std::ptrdiff_t* l2=0) 29 { 30 static std::ptrdiff_t m_l1CacheSize = 0; 31 static std::ptrdiff_t m_l2CacheSize = 0; 32 if(m_l2CacheSize==0) 33 { 34 m_l1CacheSize = manage_caching_sizes_helper(queryL1CacheSize(),8 * 1024); 35 m_l2CacheSize = manage_caching_sizes_helper(queryTopLevelCacheSize(),1*1024*1024); 36 } 37 38 if(action==SetAction) 39 { 40 // set the cpu cache size and cache all block sizes from a global cache size in byte 41 eigen_internal_assert(l1!=0 && l2!=0); 42 m_l1CacheSize = *l1; 43 m_l2CacheSize = *l2; 44 } 45 else if(action==GetAction) 46 { 47 eigen_internal_assert(l1!=0 && l2!=0); 48 *l1 = m_l1CacheSize; 49 *l2 = m_l2CacheSize; 50 } 51 else 52 { 53 eigen_internal_assert(false); 54 } 55 } 56 57 /** \brief Computes the blocking parameters for a m x k times k x n matrix product 58 * 59 * \param[in,out] k Input: the third dimension of the product. Output: the blocking size along the same dimension. 60 * \param[in,out] m Input: the number of rows of the left hand side. Output: the blocking size along the same dimension. 61 * \param[in,out] n Input: the number of columns of the right hand side. Output: the blocking size along the same dimension. 62 * 63 * Given a m x k times k x n matrix product of scalar types \c LhsScalar and \c RhsScalar, 64 * this function computes the blocking size parameters along the respective dimensions 65 * for matrix products and related algorithms. The blocking sizes depends on various 66 * parameters: 67 * - the L1 and L2 cache sizes, 68 * - the register level blocking sizes defined by gebp_traits, 69 * - the number of scalars that fit into a packet (when vectorization is enabled). 70 * 71 * \sa setCpuCacheSizes */ 72 template<typename LhsScalar, typename RhsScalar, int KcFactor, typename SizeType> 73 void computeProductBlockingSizes(SizeType& k, SizeType& m, SizeType& n) 74 { 75 EIGEN_UNUSED_VARIABLE(n); 76 // Explanations: 77 // Let's recall the product algorithms form kc x nc horizontal panels B' on the rhs and 78 // mc x kc blocks A' on the lhs. A' has to fit into L2 cache. Moreover, B' is processed 79 // per kc x nr vertical small panels where nr is the blocking size along the n dimension 80 // at the register level. For vectorization purpose, these small vertical panels are unpacked, 81 // e.g., each coefficient is replicated to fit a packet. This small vertical panel has to 82 // stay in L1 cache. 83 std::ptrdiff_t l1, l2; 84 85 typedef gebp_traits<LhsScalar,RhsScalar> Traits; 86 enum { 87 kdiv = KcFactor * 2 * Traits::nr 88 * Traits::RhsProgress * sizeof(RhsScalar), 89 mr = gebp_traits<LhsScalar,RhsScalar>::mr, 90 mr_mask = (0xffffffff/mr)*mr 91 }; 92 93 manage_caching_sizes(GetAction, &l1, &l2); 94 k = std::min<SizeType>(k, l1/kdiv); 95 SizeType _m = k>0 ? l2/(4 * sizeof(LhsScalar) * k) : 0; 96 if(_m<m) m = _m & mr_mask; 97 } 98 99 template<typename LhsScalar, typename RhsScalar, typename SizeType> 100 inline void computeProductBlockingSizes(SizeType& k, SizeType& m, SizeType& n) 101 { 102 computeProductBlockingSizes<LhsScalar,RhsScalar,1>(k, m, n); 103 } 104 105 #ifdef EIGEN_HAS_FUSE_CJMADD 106 #define MADD(CJ,A,B,C,T) C = CJ.pmadd(A,B,C); 107 #else 108 109 // FIXME (a bit overkill maybe ?) 110 111 template<typename CJ, typename A, typename B, typename C, typename T> struct gebp_madd_selector { 112 EIGEN_ALWAYS_INLINE static void run(const CJ& cj, A& a, B& b, C& c, T& /*t*/) 113 { 114 c = cj.pmadd(a,b,c); 115 } 116 }; 117 118 template<typename CJ, typename T> struct gebp_madd_selector<CJ,T,T,T,T> { 119 EIGEN_ALWAYS_INLINE static void run(const CJ& cj, T& a, T& b, T& c, T& t) 120 { 121 t = b; t = cj.pmul(a,t); c = padd(c,t); 122 } 123 }; 124 125 template<typename CJ, typename A, typename B, typename C, typename T> 126 EIGEN_STRONG_INLINE void gebp_madd(const CJ& cj, A& a, B& b, C& c, T& t) 127 { 128 gebp_madd_selector<CJ,A,B,C,T>::run(cj,a,b,c,t); 129 } 130 131 #define MADD(CJ,A,B,C,T) gebp_madd(CJ,A,B,C,T); 132 // #define MADD(CJ,A,B,C,T) T = B; T = CJ.pmul(A,T); C = padd(C,T); 133 #endif 134 135 /* Vectorization logic 136 * real*real: unpack rhs to constant packets, ... 137 * 138 * cd*cd : unpack rhs to (b_r,b_r), (b_i,b_i), mul to get (a_r b_r,a_i b_r) (a_r b_i,a_i b_i), 139 * storing each res packet into two packets (2x2), 140 * at the end combine them: swap the second and addsub them 141 * cf*cf : same but with 2x4 blocks 142 * cplx*real : unpack rhs to constant packets, ... 143 * real*cplx : load lhs as (a0,a0,a1,a1), and mul as usual 144 */ 145 template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs, bool _ConjRhs> 146 class gebp_traits 147 { 148 public: 149 typedef _LhsScalar LhsScalar; 150 typedef _RhsScalar RhsScalar; 151 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; 152 153 enum { 154 ConjLhs = _ConjLhs, 155 ConjRhs = _ConjRhs, 156 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable, 157 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1, 158 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1, 159 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, 160 161 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, 162 163 // register block size along the N direction (must be either 2 or 4) 164 nr = NumberOfRegisters/4, 165 166 // register block size along the M direction (currently, this one cannot be modified) 167 mr = 2 * LhsPacketSize, 168 169 WorkSpaceFactor = nr * RhsPacketSize, 170 171 LhsProgress = LhsPacketSize, 172 RhsProgress = RhsPacketSize 173 }; 174 175 typedef typename packet_traits<LhsScalar>::type _LhsPacket; 176 typedef typename packet_traits<RhsScalar>::type _RhsPacket; 177 typedef typename packet_traits<ResScalar>::type _ResPacket; 178 179 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket; 180 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; 181 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; 182 183 typedef ResPacket AccPacket; 184 185 EIGEN_STRONG_INLINE void initAcc(AccPacket& p) 186 { 187 p = pset1<ResPacket>(ResScalar(0)); 188 } 189 190 EIGEN_STRONG_INLINE void unpackRhs(DenseIndex n, const RhsScalar* rhs, RhsScalar* b) 191 { 192 for(DenseIndex k=0; k<n; k++) 193 pstore1<RhsPacket>(&b[k*RhsPacketSize], rhs[k]); 194 } 195 196 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const 197 { 198 dest = pload<RhsPacket>(b); 199 } 200 201 EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const 202 { 203 dest = pload<LhsPacket>(a); 204 } 205 206 EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, AccPacket& tmp) const 207 { 208 tmp = b; tmp = pmul(a,tmp); c = padd(c,tmp); 209 } 210 211 EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const 212 { 213 r = pmadd(c,alpha,r); 214 } 215 216 protected: 217 // conj_helper<LhsScalar,RhsScalar,ConjLhs,ConjRhs> cj; 218 // conj_helper<LhsPacket,RhsPacket,ConjLhs,ConjRhs> pcj; 219 }; 220 221 template<typename RealScalar, bool _ConjLhs> 222 class gebp_traits<std::complex<RealScalar>, RealScalar, _ConjLhs, false> 223 { 224 public: 225 typedef std::complex<RealScalar> LhsScalar; 226 typedef RealScalar RhsScalar; 227 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; 228 229 enum { 230 ConjLhs = _ConjLhs, 231 ConjRhs = false, 232 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable, 233 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1, 234 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1, 235 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, 236 237 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, 238 nr = NumberOfRegisters/4, 239 mr = 2 * LhsPacketSize, 240 WorkSpaceFactor = nr*RhsPacketSize, 241 242 LhsProgress = LhsPacketSize, 243 RhsProgress = RhsPacketSize 244 }; 245 246 typedef typename packet_traits<LhsScalar>::type _LhsPacket; 247 typedef typename packet_traits<RhsScalar>::type _RhsPacket; 248 typedef typename packet_traits<ResScalar>::type _ResPacket; 249 250 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket; 251 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; 252 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; 253 254 typedef ResPacket AccPacket; 255 256 EIGEN_STRONG_INLINE void initAcc(AccPacket& p) 257 { 258 p = pset1<ResPacket>(ResScalar(0)); 259 } 260 261 EIGEN_STRONG_INLINE void unpackRhs(DenseIndex n, const RhsScalar* rhs, RhsScalar* b) 262 { 263 for(DenseIndex k=0; k<n; k++) 264 pstore1<RhsPacket>(&b[k*RhsPacketSize], rhs[k]); 265 } 266 267 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const 268 { 269 dest = pload<RhsPacket>(b); 270 } 271 272 EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const 273 { 274 dest = pload<LhsPacket>(a); 275 } 276 277 EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const 278 { 279 madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type()); 280 } 281 282 EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const 283 { 284 tmp = b; tmp = pmul(a.v,tmp); c.v = padd(c.v,tmp); 285 } 286 287 EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/, const false_type&) const 288 { 289 c += a * b; 290 } 291 292 EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const 293 { 294 r = cj.pmadd(c,alpha,r); 295 } 296 297 protected: 298 conj_helper<ResPacket,ResPacket,ConjLhs,false> cj; 299 }; 300 301 template<typename RealScalar, bool _ConjLhs, bool _ConjRhs> 302 class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, _ConjLhs, _ConjRhs > 303 { 304 public: 305 typedef std::complex<RealScalar> Scalar; 306 typedef std::complex<RealScalar> LhsScalar; 307 typedef std::complex<RealScalar> RhsScalar; 308 typedef std::complex<RealScalar> ResScalar; 309 310 enum { 311 ConjLhs = _ConjLhs, 312 ConjRhs = _ConjRhs, 313 Vectorizable = packet_traits<RealScalar>::Vectorizable 314 && packet_traits<Scalar>::Vectorizable, 315 RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1, 316 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, 317 318 nr = 2, 319 mr = 2 * ResPacketSize, 320 WorkSpaceFactor = Vectorizable ? 2*nr*RealPacketSize : nr, 321 322 LhsProgress = ResPacketSize, 323 RhsProgress = Vectorizable ? 2*ResPacketSize : 1 324 }; 325 326 typedef typename packet_traits<RealScalar>::type RealPacket; 327 typedef typename packet_traits<Scalar>::type ScalarPacket; 328 struct DoublePacket 329 { 330 RealPacket first; 331 RealPacket second; 332 }; 333 334 typedef typename conditional<Vectorizable,RealPacket, Scalar>::type LhsPacket; 335 typedef typename conditional<Vectorizable,DoublePacket,Scalar>::type RhsPacket; 336 typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type ResPacket; 337 typedef typename conditional<Vectorizable,DoublePacket,Scalar>::type AccPacket; 338 339 EIGEN_STRONG_INLINE void initAcc(Scalar& p) { p = Scalar(0); } 340 341 EIGEN_STRONG_INLINE void initAcc(DoublePacket& p) 342 { 343 p.first = pset1<RealPacket>(RealScalar(0)); 344 p.second = pset1<RealPacket>(RealScalar(0)); 345 } 346 347 /* Unpack the rhs coeff such that each complex coefficient is spread into 348 * two packects containing respectively the real and imaginary coefficient 349 * duplicated as many time as needed: (x+iy) => [x, ..., x] [y, ..., y] 350 */ 351 EIGEN_STRONG_INLINE void unpackRhs(DenseIndex n, const Scalar* rhs, Scalar* b) 352 { 353 for(DenseIndex k=0; k<n; k++) 354 { 355 if(Vectorizable) 356 { 357 pstore1<RealPacket>((RealScalar*)&b[k*ResPacketSize*2+0], real(rhs[k])); 358 pstore1<RealPacket>((RealScalar*)&b[k*ResPacketSize*2+ResPacketSize], imag(rhs[k])); 359 } 360 else 361 b[k] = rhs[k]; 362 } 363 } 364 365 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, ResPacket& dest) const { dest = *b; } 366 367 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacket& dest) const 368 { 369 dest.first = pload<RealPacket>((const RealScalar*)b); 370 dest.second = pload<RealPacket>((const RealScalar*)(b+ResPacketSize)); 371 } 372 373 // nothing special here 374 EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const 375 { 376 dest = pload<LhsPacket>((const typename unpacket_traits<LhsPacket>::type*)(a)); 377 } 378 379 EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacket& c, RhsPacket& /*tmp*/) const 380 { 381 c.first = padd(pmul(a,b.first), c.first); 382 c.second = padd(pmul(a,b.second),c.second); 383 } 384 385 EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, ResPacket& c, RhsPacket& /*tmp*/) const 386 { 387 c = cj.pmadd(a,b,c); 388 } 389 390 EIGEN_STRONG_INLINE void acc(const Scalar& c, const Scalar& alpha, Scalar& r) const { r += alpha * c; } 391 392 EIGEN_STRONG_INLINE void acc(const DoublePacket& c, const ResPacket& alpha, ResPacket& r) const 393 { 394 // assemble c 395 ResPacket tmp; 396 if((!ConjLhs)&&(!ConjRhs)) 397 { 398 tmp = pcplxflip(pconj(ResPacket(c.second))); 399 tmp = padd(ResPacket(c.first),tmp); 400 } 401 else if((!ConjLhs)&&(ConjRhs)) 402 { 403 tmp = pconj(pcplxflip(ResPacket(c.second))); 404 tmp = padd(ResPacket(c.first),tmp); 405 } 406 else if((ConjLhs)&&(!ConjRhs)) 407 { 408 tmp = pcplxflip(ResPacket(c.second)); 409 tmp = padd(pconj(ResPacket(c.first)),tmp); 410 } 411 else if((ConjLhs)&&(ConjRhs)) 412 { 413 tmp = pcplxflip(ResPacket(c.second)); 414 tmp = psub(pconj(ResPacket(c.first)),tmp); 415 } 416 417 r = pmadd(tmp,alpha,r); 418 } 419 420 protected: 421 conj_helper<LhsScalar,RhsScalar,ConjLhs,ConjRhs> cj; 422 }; 423 424 template<typename RealScalar, bool _ConjRhs> 425 class gebp_traits<RealScalar, std::complex<RealScalar>, false, _ConjRhs > 426 { 427 public: 428 typedef std::complex<RealScalar> Scalar; 429 typedef RealScalar LhsScalar; 430 typedef Scalar RhsScalar; 431 typedef Scalar ResScalar; 432 433 enum { 434 ConjLhs = false, 435 ConjRhs = _ConjRhs, 436 Vectorizable = packet_traits<RealScalar>::Vectorizable 437 && packet_traits<Scalar>::Vectorizable, 438 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1, 439 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1, 440 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1, 441 442 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS, 443 nr = 4, 444 mr = 2*ResPacketSize, 445 WorkSpaceFactor = nr*RhsPacketSize, 446 447 LhsProgress = ResPacketSize, 448 RhsProgress = ResPacketSize 449 }; 450 451 typedef typename packet_traits<LhsScalar>::type _LhsPacket; 452 typedef typename packet_traits<RhsScalar>::type _RhsPacket; 453 typedef typename packet_traits<ResScalar>::type _ResPacket; 454 455 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket; 456 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket; 457 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket; 458 459 typedef ResPacket AccPacket; 460 461 EIGEN_STRONG_INLINE void initAcc(AccPacket& p) 462 { 463 p = pset1<ResPacket>(ResScalar(0)); 464 } 465 466 EIGEN_STRONG_INLINE void unpackRhs(DenseIndex n, const RhsScalar* rhs, RhsScalar* b) 467 { 468 for(DenseIndex k=0; k<n; k++) 469 pstore1<RhsPacket>(&b[k*RhsPacketSize], rhs[k]); 470 } 471 472 EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const 473 { 474 dest = pload<RhsPacket>(b); 475 } 476 477 EIGEN_STRONG_INLINE void loadLhs(const LhsScalar* a, LhsPacket& dest) const 478 { 479 dest = ploaddup<LhsPacket>(a); 480 } 481 482 EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const 483 { 484 madd_impl(a, b, c, tmp, typename conditional<Vectorizable,true_type,false_type>::type()); 485 } 486 487 EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const 488 { 489 tmp = b; tmp.v = pmul(a,tmp.v); c = padd(c,tmp); 490 } 491 492 EIGEN_STRONG_INLINE void madd_impl(const LhsScalar& a, const RhsScalar& b, ResScalar& c, RhsScalar& /*tmp*/, const false_type&) const 493 { 494 c += a * b; 495 } 496 497 EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const 498 { 499 r = cj.pmadd(alpha,c,r); 500 } 501 502 protected: 503 conj_helper<ResPacket,ResPacket,false,ConjRhs> cj; 504 }; 505 506 /* optimized GEneral packed Block * packed Panel product kernel 507 * 508 * Mixing type logic: C += A * B 509 * | A | B | comments 510 * |real |cplx | no vectorization yet, would require to pack A with duplication 511 * |cplx |real | easy vectorization 512 */ 513 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 514 struct gebp_kernel 515 { 516 typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> Traits; 517 typedef typename Traits::ResScalar ResScalar; 518 typedef typename Traits::LhsPacket LhsPacket; 519 typedef typename Traits::RhsPacket RhsPacket; 520 typedef typename Traits::ResPacket ResPacket; 521 typedef typename Traits::AccPacket AccPacket; 522 523 enum { 524 Vectorizable = Traits::Vectorizable, 525 LhsProgress = Traits::LhsProgress, 526 RhsProgress = Traits::RhsProgress, 527 ResPacketSize = Traits::ResPacketSize 528 }; 529 530 EIGEN_DONT_INLINE 531 void operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index rows, Index depth, Index cols, ResScalar alpha, 532 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0, RhsScalar* unpackedB=0); 533 }; 534 535 template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 536 EIGEN_DONT_INLINE 537 void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs> 538 ::operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index rows, Index depth, Index cols, ResScalar alpha, 539 Index strideA, Index strideB, Index offsetA, Index offsetB, RhsScalar* unpackedB) 540 { 541 Traits traits; 542 543 if(strideA==-1) strideA = depth; 544 if(strideB==-1) strideB = depth; 545 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj; 546 // conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj; 547 Index packet_cols = (cols/nr) * nr; 548 const Index peeled_mc = (rows/mr)*mr; 549 // FIXME: 550 const Index peeled_mc2 = peeled_mc + (rows-peeled_mc >= LhsProgress ? LhsProgress : 0); 551 const Index peeled_kc = (depth/4)*4; 552 553 if(unpackedB==0) 554 unpackedB = const_cast<RhsScalar*>(blockB - strideB * nr * RhsProgress); 555 556 // loops on each micro vertical panel of rhs (depth x nr) 557 for(Index j2=0; j2<packet_cols; j2+=nr) 558 { 559 traits.unpackRhs(depth*nr,&blockB[j2*strideB+offsetB*nr],unpackedB); 560 561 // loops on each largest micro horizontal panel of lhs (mr x depth) 562 // => we select a mr x nr micro block of res which is entirely 563 // stored into mr/packet_size x nr registers. 564 for(Index i=0; i<peeled_mc; i+=mr) 565 { 566 const LhsScalar* blA = &blockA[i*strideA+offsetA*mr]; 567 prefetch(&blA[0]); 568 569 // gets res block as register 570 AccPacket C0, C1, C2, C3, C4, C5, C6, C7; 571 traits.initAcc(C0); 572 traits.initAcc(C1); 573 if(nr==4) traits.initAcc(C2); 574 if(nr==4) traits.initAcc(C3); 575 traits.initAcc(C4); 576 traits.initAcc(C5); 577 if(nr==4) traits.initAcc(C6); 578 if(nr==4) traits.initAcc(C7); 579 580 ResScalar* r0 = &res[(j2+0)*resStride + i]; 581 ResScalar* r1 = r0 + resStride; 582 ResScalar* r2 = r1 + resStride; 583 ResScalar* r3 = r2 + resStride; 584 585 prefetch(r0+16); 586 prefetch(r1+16); 587 prefetch(r2+16); 588 prefetch(r3+16); 589 590 // performs "inner" product 591 // TODO let's check wether the folowing peeled loop could not be 592 // optimized via optimal prefetching from one loop to the other 593 const RhsScalar* blB = unpackedB; 594 for(Index k=0; k<peeled_kc; k+=4) 595 { 596 if(nr==2) 597 { 598 LhsPacket A0, A1; 599 RhsPacket B_0; 600 RhsPacket T0; 601 602 EIGEN_ASM_COMMENT("mybegin2"); 603 traits.loadLhs(&blA[0*LhsProgress], A0); 604 traits.loadLhs(&blA[1*LhsProgress], A1); 605 traits.loadRhs(&blB[0*RhsProgress], B_0); 606 traits.madd(A0,B_0,C0,T0); 607 traits.madd(A1,B_0,C4,B_0); 608 traits.loadRhs(&blB[1*RhsProgress], B_0); 609 traits.madd(A0,B_0,C1,T0); 610 traits.madd(A1,B_0,C5,B_0); 611 612 traits.loadLhs(&blA[2*LhsProgress], A0); 613 traits.loadLhs(&blA[3*LhsProgress], A1); 614 traits.loadRhs(&blB[2*RhsProgress], B_0); 615 traits.madd(A0,B_0,C0,T0); 616 traits.madd(A1,B_0,C4,B_0); 617 traits.loadRhs(&blB[3*RhsProgress], B_0); 618 traits.madd(A0,B_0,C1,T0); 619 traits.madd(A1,B_0,C5,B_0); 620 621 traits.loadLhs(&blA[4*LhsProgress], A0); 622 traits.loadLhs(&blA[5*LhsProgress], A1); 623 traits.loadRhs(&blB[4*RhsProgress], B_0); 624 traits.madd(A0,B_0,C0,T0); 625 traits.madd(A1,B_0,C4,B_0); 626 traits.loadRhs(&blB[5*RhsProgress], B_0); 627 traits.madd(A0,B_0,C1,T0); 628 traits.madd(A1,B_0,C5,B_0); 629 630 traits.loadLhs(&blA[6*LhsProgress], A0); 631 traits.loadLhs(&blA[7*LhsProgress], A1); 632 traits.loadRhs(&blB[6*RhsProgress], B_0); 633 traits.madd(A0,B_0,C0,T0); 634 traits.madd(A1,B_0,C4,B_0); 635 traits.loadRhs(&blB[7*RhsProgress], B_0); 636 traits.madd(A0,B_0,C1,T0); 637 traits.madd(A1,B_0,C5,B_0); 638 EIGEN_ASM_COMMENT("myend"); 639 } 640 else 641 { 642 EIGEN_ASM_COMMENT("mybegin4"); 643 LhsPacket A0, A1; 644 RhsPacket B_0, B1, B2, B3; 645 RhsPacket T0; 646 647 traits.loadLhs(&blA[0*LhsProgress], A0); 648 traits.loadLhs(&blA[1*LhsProgress], A1); 649 traits.loadRhs(&blB[0*RhsProgress], B_0); 650 traits.loadRhs(&blB[1*RhsProgress], B1); 651 652 traits.madd(A0,B_0,C0,T0); 653 traits.loadRhs(&blB[2*RhsProgress], B2); 654 traits.madd(A1,B_0,C4,B_0); 655 traits.loadRhs(&blB[3*RhsProgress], B3); 656 traits.loadRhs(&blB[4*RhsProgress], B_0); 657 traits.madd(A0,B1,C1,T0); 658 traits.madd(A1,B1,C5,B1); 659 traits.loadRhs(&blB[5*RhsProgress], B1); 660 traits.madd(A0,B2,C2,T0); 661 traits.madd(A1,B2,C6,B2); 662 traits.loadRhs(&blB[6*RhsProgress], B2); 663 traits.madd(A0,B3,C3,T0); 664 traits.loadLhs(&blA[2*LhsProgress], A0); 665 traits.madd(A1,B3,C7,B3); 666 traits.loadLhs(&blA[3*LhsProgress], A1); 667 traits.loadRhs(&blB[7*RhsProgress], B3); 668 traits.madd(A0,B_0,C0,T0); 669 traits.madd(A1,B_0,C4,B_0); 670 traits.loadRhs(&blB[8*RhsProgress], B_0); 671 traits.madd(A0,B1,C1,T0); 672 traits.madd(A1,B1,C5,B1); 673 traits.loadRhs(&blB[9*RhsProgress], B1); 674 traits.madd(A0,B2,C2,T0); 675 traits.madd(A1,B2,C6,B2); 676 traits.loadRhs(&blB[10*RhsProgress], B2); 677 traits.madd(A0,B3,C3,T0); 678 traits.loadLhs(&blA[4*LhsProgress], A0); 679 traits.madd(A1,B3,C7,B3); 680 traits.loadLhs(&blA[5*LhsProgress], A1); 681 traits.loadRhs(&blB[11*RhsProgress], B3); 682 683 traits.madd(A0,B_0,C0,T0); 684 traits.madd(A1,B_0,C4,B_0); 685 traits.loadRhs(&blB[12*RhsProgress], B_0); 686 traits.madd(A0,B1,C1,T0); 687 traits.madd(A1,B1,C5,B1); 688 traits.loadRhs(&blB[13*RhsProgress], B1); 689 traits.madd(A0,B2,C2,T0); 690 traits.madd(A1,B2,C6,B2); 691 traits.loadRhs(&blB[14*RhsProgress], B2); 692 traits.madd(A0,B3,C3,T0); 693 traits.loadLhs(&blA[6*LhsProgress], A0); 694 traits.madd(A1,B3,C7,B3); 695 traits.loadLhs(&blA[7*LhsProgress], A1); 696 traits.loadRhs(&blB[15*RhsProgress], B3); 697 traits.madd(A0,B_0,C0,T0); 698 traits.madd(A1,B_0,C4,B_0); 699 traits.madd(A0,B1,C1,T0); 700 traits.madd(A1,B1,C5,B1); 701 traits.madd(A0,B2,C2,T0); 702 traits.madd(A1,B2,C6,B2); 703 traits.madd(A0,B3,C3,T0); 704 traits.madd(A1,B3,C7,B3); 705 } 706 707 blB += 4*nr*RhsProgress; 708 blA += 4*mr; 709 } 710 // process remaining peeled loop 711 for(Index k=peeled_kc; k<depth; k++) 712 { 713 if(nr==2) 714 { 715 LhsPacket A0, A1; 716 RhsPacket B_0; 717 RhsPacket T0; 718 719 traits.loadLhs(&blA[0*LhsProgress], A0); 720 traits.loadLhs(&blA[1*LhsProgress], A1); 721 traits.loadRhs(&blB[0*RhsProgress], B_0); 722 traits.madd(A0,B_0,C0,T0); 723 traits.madd(A1,B_0,C4,B_0); 724 traits.loadRhs(&blB[1*RhsProgress], B_0); 725 traits.madd(A0,B_0,C1,T0); 726 traits.madd(A1,B_0,C5,B_0); 727 } 728 else 729 { 730 LhsPacket A0, A1; 731 RhsPacket B_0, B1, B2, B3; 732 RhsPacket T0; 733 734 traits.loadLhs(&blA[0*LhsProgress], A0); 735 traits.loadLhs(&blA[1*LhsProgress], A1); 736 traits.loadRhs(&blB[0*RhsProgress], B_0); 737 traits.loadRhs(&blB[1*RhsProgress], B1); 738 739 traits.madd(A0,B_0,C0,T0); 740 traits.loadRhs(&blB[2*RhsProgress], B2); 741 traits.madd(A1,B_0,C4,B_0); 742 traits.loadRhs(&blB[3*RhsProgress], B3); 743 traits.madd(A0,B1,C1,T0); 744 traits.madd(A1,B1,C5,B1); 745 traits.madd(A0,B2,C2,T0); 746 traits.madd(A1,B2,C6,B2); 747 traits.madd(A0,B3,C3,T0); 748 traits.madd(A1,B3,C7,B3); 749 } 750 751 blB += nr*RhsProgress; 752 blA += mr; 753 } 754 755 if(nr==4) 756 { 757 ResPacket R0, R1, R2, R3, R4, R5, R6; 758 ResPacket alphav = pset1<ResPacket>(alpha); 759 760 R0 = ploadu<ResPacket>(r0); 761 R1 = ploadu<ResPacket>(r1); 762 R2 = ploadu<ResPacket>(r2); 763 R3 = ploadu<ResPacket>(r3); 764 R4 = ploadu<ResPacket>(r0 + ResPacketSize); 765 R5 = ploadu<ResPacket>(r1 + ResPacketSize); 766 R6 = ploadu<ResPacket>(r2 + ResPacketSize); 767 traits.acc(C0, alphav, R0); 768 pstoreu(r0, R0); 769 R0 = ploadu<ResPacket>(r3 + ResPacketSize); 770 771 traits.acc(C1, alphav, R1); 772 traits.acc(C2, alphav, R2); 773 traits.acc(C3, alphav, R3); 774 traits.acc(C4, alphav, R4); 775 traits.acc(C5, alphav, R5); 776 traits.acc(C6, alphav, R6); 777 traits.acc(C7, alphav, R0); 778 779 pstoreu(r1, R1); 780 pstoreu(r2, R2); 781 pstoreu(r3, R3); 782 pstoreu(r0 + ResPacketSize, R4); 783 pstoreu(r1 + ResPacketSize, R5); 784 pstoreu(r2 + ResPacketSize, R6); 785 pstoreu(r3 + ResPacketSize, R0); 786 } 787 else 788 { 789 ResPacket R0, R1, R4; 790 ResPacket alphav = pset1<ResPacket>(alpha); 791 792 R0 = ploadu<ResPacket>(r0); 793 R1 = ploadu<ResPacket>(r1); 794 R4 = ploadu<ResPacket>(r0 + ResPacketSize); 795 traits.acc(C0, alphav, R0); 796 pstoreu(r0, R0); 797 R0 = ploadu<ResPacket>(r1 + ResPacketSize); 798 traits.acc(C1, alphav, R1); 799 traits.acc(C4, alphav, R4); 800 traits.acc(C5, alphav, R0); 801 pstoreu(r1, R1); 802 pstoreu(r0 + ResPacketSize, R4); 803 pstoreu(r1 + ResPacketSize, R0); 804 } 805 806 } 807 808 if(rows-peeled_mc>=LhsProgress) 809 { 810 Index i = peeled_mc; 811 const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsProgress]; 812 prefetch(&blA[0]); 813 814 // gets res block as register 815 AccPacket C0, C1, C2, C3; 816 traits.initAcc(C0); 817 traits.initAcc(C1); 818 if(nr==4) traits.initAcc(C2); 819 if(nr==4) traits.initAcc(C3); 820 821 // performs "inner" product 822 const RhsScalar* blB = unpackedB; 823 for(Index k=0; k<peeled_kc; k+=4) 824 { 825 if(nr==2) 826 { 827 LhsPacket A0; 828 RhsPacket B_0, B1; 829 830 traits.loadLhs(&blA[0*LhsProgress], A0); 831 traits.loadRhs(&blB[0*RhsProgress], B_0); 832 traits.loadRhs(&blB[1*RhsProgress], B1); 833 traits.madd(A0,B_0,C0,B_0); 834 traits.loadRhs(&blB[2*RhsProgress], B_0); 835 traits.madd(A0,B1,C1,B1); 836 traits.loadLhs(&blA[1*LhsProgress], A0); 837 traits.loadRhs(&blB[3*RhsProgress], B1); 838 traits.madd(A0,B_0,C0,B_0); 839 traits.loadRhs(&blB[4*RhsProgress], B_0); 840 traits.madd(A0,B1,C1,B1); 841 traits.loadLhs(&blA[2*LhsProgress], A0); 842 traits.loadRhs(&blB[5*RhsProgress], B1); 843 traits.madd(A0,B_0,C0,B_0); 844 traits.loadRhs(&blB[6*RhsProgress], B_0); 845 traits.madd(A0,B1,C1,B1); 846 traits.loadLhs(&blA[3*LhsProgress], A0); 847 traits.loadRhs(&blB[7*RhsProgress], B1); 848 traits.madd(A0,B_0,C0,B_0); 849 traits.madd(A0,B1,C1,B1); 850 } 851 else 852 { 853 LhsPacket A0; 854 RhsPacket B_0, B1, B2, B3; 855 856 traits.loadLhs(&blA[0*LhsProgress], A0); 857 traits.loadRhs(&blB[0*RhsProgress], B_0); 858 traits.loadRhs(&blB[1*RhsProgress], B1); 859 860 traits.madd(A0,B_0,C0,B_0); 861 traits.loadRhs(&blB[2*RhsProgress], B2); 862 traits.loadRhs(&blB[3*RhsProgress], B3); 863 traits.loadRhs(&blB[4*RhsProgress], B_0); 864 traits.madd(A0,B1,C1,B1); 865 traits.loadRhs(&blB[5*RhsProgress], B1); 866 traits.madd(A0,B2,C2,B2); 867 traits.loadRhs(&blB[6*RhsProgress], B2); 868 traits.madd(A0,B3,C3,B3); 869 traits.loadLhs(&blA[1*LhsProgress], A0); 870 traits.loadRhs(&blB[7*RhsProgress], B3); 871 traits.madd(A0,B_0,C0,B_0); 872 traits.loadRhs(&blB[8*RhsProgress], B_0); 873 traits.madd(A0,B1,C1,B1); 874 traits.loadRhs(&blB[9*RhsProgress], B1); 875 traits.madd(A0,B2,C2,B2); 876 traits.loadRhs(&blB[10*RhsProgress], B2); 877 traits.madd(A0,B3,C3,B3); 878 traits.loadLhs(&blA[2*LhsProgress], A0); 879 traits.loadRhs(&blB[11*RhsProgress], B3); 880 881 traits.madd(A0,B_0,C0,B_0); 882 traits.loadRhs(&blB[12*RhsProgress], B_0); 883 traits.madd(A0,B1,C1,B1); 884 traits.loadRhs(&blB[13*RhsProgress], B1); 885 traits.madd(A0,B2,C2,B2); 886 traits.loadRhs(&blB[14*RhsProgress], B2); 887 traits.madd(A0,B3,C3,B3); 888 889 traits.loadLhs(&blA[3*LhsProgress], A0); 890 traits.loadRhs(&blB[15*RhsProgress], B3); 891 traits.madd(A0,B_0,C0,B_0); 892 traits.madd(A0,B1,C1,B1); 893 traits.madd(A0,B2,C2,B2); 894 traits.madd(A0,B3,C3,B3); 895 } 896 897 blB += nr*4*RhsProgress; 898 blA += 4*LhsProgress; 899 } 900 // process remaining peeled loop 901 for(Index k=peeled_kc; k<depth; k++) 902 { 903 if(nr==2) 904 { 905 LhsPacket A0; 906 RhsPacket B_0, B1; 907 908 traits.loadLhs(&blA[0*LhsProgress], A0); 909 traits.loadRhs(&blB[0*RhsProgress], B_0); 910 traits.loadRhs(&blB[1*RhsProgress], B1); 911 traits.madd(A0,B_0,C0,B_0); 912 traits.madd(A0,B1,C1,B1); 913 } 914 else 915 { 916 LhsPacket A0; 917 RhsPacket B_0, B1, B2, B3; 918 919 traits.loadLhs(&blA[0*LhsProgress], A0); 920 traits.loadRhs(&blB[0*RhsProgress], B_0); 921 traits.loadRhs(&blB[1*RhsProgress], B1); 922 traits.loadRhs(&blB[2*RhsProgress], B2); 923 traits.loadRhs(&blB[3*RhsProgress], B3); 924 925 traits.madd(A0,B_0,C0,B_0); 926 traits.madd(A0,B1,C1,B1); 927 traits.madd(A0,B2,C2,B2); 928 traits.madd(A0,B3,C3,B3); 929 } 930 931 blB += nr*RhsProgress; 932 blA += LhsProgress; 933 } 934 935 ResPacket R0, R1, R2, R3; 936 ResPacket alphav = pset1<ResPacket>(alpha); 937 938 ResScalar* r0 = &res[(j2+0)*resStride + i]; 939 ResScalar* r1 = r0 + resStride; 940 ResScalar* r2 = r1 + resStride; 941 ResScalar* r3 = r2 + resStride; 942 943 R0 = ploadu<ResPacket>(r0); 944 R1 = ploadu<ResPacket>(r1); 945 if(nr==4) R2 = ploadu<ResPacket>(r2); 946 if(nr==4) R3 = ploadu<ResPacket>(r3); 947 948 traits.acc(C0, alphav, R0); 949 traits.acc(C1, alphav, R1); 950 if(nr==4) traits.acc(C2, alphav, R2); 951 if(nr==4) traits.acc(C3, alphav, R3); 952 953 pstoreu(r0, R0); 954 pstoreu(r1, R1); 955 if(nr==4) pstoreu(r2, R2); 956 if(nr==4) pstoreu(r3, R3); 957 } 958 for(Index i=peeled_mc2; i<rows; i++) 959 { 960 const LhsScalar* blA = &blockA[i*strideA+offsetA]; 961 prefetch(&blA[0]); 962 963 // gets a 1 x nr res block as registers 964 ResScalar C0(0), C1(0), C2(0), C3(0); 965 // TODO directly use blockB ??? 966 const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr]; 967 for(Index k=0; k<depth; k++) 968 { 969 if(nr==2) 970 { 971 LhsScalar A0; 972 RhsScalar B_0, B1; 973 974 A0 = blA[k]; 975 B_0 = blB[0]; 976 B1 = blB[1]; 977 MADD(cj,A0,B_0,C0,B_0); 978 MADD(cj,A0,B1,C1,B1); 979 } 980 else 981 { 982 LhsScalar A0; 983 RhsScalar B_0, B1, B2, B3; 984 985 A0 = blA[k]; 986 B_0 = blB[0]; 987 B1 = blB[1]; 988 B2 = blB[2]; 989 B3 = blB[3]; 990 991 MADD(cj,A0,B_0,C0,B_0); 992 MADD(cj,A0,B1,C1,B1); 993 MADD(cj,A0,B2,C2,B2); 994 MADD(cj,A0,B3,C3,B3); 995 } 996 997 blB += nr; 998 } 999 res[(j2+0)*resStride + i] += alpha*C0; 1000 res[(j2+1)*resStride + i] += alpha*C1; 1001 if(nr==4) res[(j2+2)*resStride + i] += alpha*C2; 1002 if(nr==4) res[(j2+3)*resStride + i] += alpha*C3; 1003 } 1004 } 1005 // process remaining rhs/res columns one at a time 1006 // => do the same but with nr==1 1007 for(Index j2=packet_cols; j2<cols; j2++) 1008 { 1009 // unpack B 1010 traits.unpackRhs(depth, &blockB[j2*strideB+offsetB], unpackedB); 1011 1012 for(Index i=0; i<peeled_mc; i+=mr) 1013 { 1014 const LhsScalar* blA = &blockA[i*strideA+offsetA*mr]; 1015 prefetch(&blA[0]); 1016 1017 // TODO move the res loads to the stores 1018 1019 // get res block as registers 1020 AccPacket C0, C4; 1021 traits.initAcc(C0); 1022 traits.initAcc(C4); 1023 1024 const RhsScalar* blB = unpackedB; 1025 for(Index k=0; k<depth; k++) 1026 { 1027 LhsPacket A0, A1; 1028 RhsPacket B_0; 1029 RhsPacket T0; 1030 1031 traits.loadLhs(&blA[0*LhsProgress], A0); 1032 traits.loadLhs(&blA[1*LhsProgress], A1); 1033 traits.loadRhs(&blB[0*RhsProgress], B_0); 1034 traits.madd(A0,B_0,C0,T0); 1035 traits.madd(A1,B_0,C4,B_0); 1036 1037 blB += RhsProgress; 1038 blA += 2*LhsProgress; 1039 } 1040 ResPacket R0, R4; 1041 ResPacket alphav = pset1<ResPacket>(alpha); 1042 1043 ResScalar* r0 = &res[(j2+0)*resStride + i]; 1044 1045 R0 = ploadu<ResPacket>(r0); 1046 R4 = ploadu<ResPacket>(r0+ResPacketSize); 1047 1048 traits.acc(C0, alphav, R0); 1049 traits.acc(C4, alphav, R4); 1050 1051 pstoreu(r0, R0); 1052 pstoreu(r0+ResPacketSize, R4); 1053 } 1054 if(rows-peeled_mc>=LhsProgress) 1055 { 1056 Index i = peeled_mc; 1057 const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsProgress]; 1058 prefetch(&blA[0]); 1059 1060 AccPacket C0; 1061 traits.initAcc(C0); 1062 1063 const RhsScalar* blB = unpackedB; 1064 for(Index k=0; k<depth; k++) 1065 { 1066 LhsPacket A0; 1067 RhsPacket B_0; 1068 traits.loadLhs(blA, A0); 1069 traits.loadRhs(blB, B_0); 1070 traits.madd(A0, B_0, C0, B_0); 1071 blB += RhsProgress; 1072 blA += LhsProgress; 1073 } 1074 1075 ResPacket alphav = pset1<ResPacket>(alpha); 1076 ResPacket R0 = ploadu<ResPacket>(&res[(j2+0)*resStride + i]); 1077 traits.acc(C0, alphav, R0); 1078 pstoreu(&res[(j2+0)*resStride + i], R0); 1079 } 1080 for(Index i=peeled_mc2; i<rows; i++) 1081 { 1082 const LhsScalar* blA = &blockA[i*strideA+offsetA]; 1083 prefetch(&blA[0]); 1084 1085 // gets a 1 x 1 res block as registers 1086 ResScalar C0(0); 1087 // FIXME directly use blockB ?? 1088 const RhsScalar* blB = &blockB[j2*strideB+offsetB]; 1089 for(Index k=0; k<depth; k++) 1090 { 1091 LhsScalar A0 = blA[k]; 1092 RhsScalar B_0 = blB[k]; 1093 MADD(cj, A0, B_0, C0, B_0); 1094 } 1095 res[(j2+0)*resStride + i] += alpha*C0; 1096 } 1097 } 1098 } 1099 1100 1101 #undef CJMADD 1102 1103 // pack a block of the lhs 1104 // The traversal is as follow (mr==4): 1105 // 0 4 8 12 ... 1106 // 1 5 9 13 ... 1107 // 2 6 10 14 ... 1108 // 3 7 11 15 ... 1109 // 1110 // 16 20 24 28 ... 1111 // 17 21 25 29 ... 1112 // 18 22 26 30 ... 1113 // 19 23 27 31 ... 1114 // 1115 // 32 33 34 35 ... 1116 // 36 36 38 39 ... 1117 template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate, bool PanelMode> 1118 struct gemm_pack_lhs 1119 { 1120 EIGEN_DONT_INLINE void operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows, Index stride=0, Index offset=0); 1121 }; 1122 1123 template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate, bool PanelMode> 1124 EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, StorageOrder, Conjugate, PanelMode> 1125 ::operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows, Index stride, Index offset) 1126 { 1127 typedef typename packet_traits<Scalar>::type Packet; 1128 enum { PacketSize = packet_traits<Scalar>::size }; 1129 1130 EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK LHS"); 1131 EIGEN_UNUSED_VARIABLE(stride) 1132 EIGEN_UNUSED_VARIABLE(offset) 1133 eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); 1134 eigen_assert( (StorageOrder==RowMajor) || ((Pack1%PacketSize)==0 && Pack1<=4*PacketSize) ); 1135 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; 1136 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs,lhsStride); 1137 Index count = 0; 1138 Index peeled_mc = (rows/Pack1)*Pack1; 1139 for(Index i=0; i<peeled_mc; i+=Pack1) 1140 { 1141 if(PanelMode) count += Pack1 * offset; 1142 1143 if(StorageOrder==ColMajor) 1144 { 1145 for(Index k=0; k<depth; k++) 1146 { 1147 Packet A, B, C, D; 1148 if(Pack1>=1*PacketSize) A = ploadu<Packet>(&lhs(i+0*PacketSize, k)); 1149 if(Pack1>=2*PacketSize) B = ploadu<Packet>(&lhs(i+1*PacketSize, k)); 1150 if(Pack1>=3*PacketSize) C = ploadu<Packet>(&lhs(i+2*PacketSize, k)); 1151 if(Pack1>=4*PacketSize) D = ploadu<Packet>(&lhs(i+3*PacketSize, k)); 1152 if(Pack1>=1*PacketSize) { pstore(blockA+count, cj.pconj(A)); count+=PacketSize; } 1153 if(Pack1>=2*PacketSize) { pstore(blockA+count, cj.pconj(B)); count+=PacketSize; } 1154 if(Pack1>=3*PacketSize) { pstore(blockA+count, cj.pconj(C)); count+=PacketSize; } 1155 if(Pack1>=4*PacketSize) { pstore(blockA+count, cj.pconj(D)); count+=PacketSize; } 1156 } 1157 } 1158 else 1159 { 1160 for(Index k=0; k<depth; k++) 1161 { 1162 // TODO add a vectorized transpose here 1163 Index w=0; 1164 for(; w<Pack1-3; w+=4) 1165 { 1166 Scalar a(cj(lhs(i+w+0, k))), 1167 b(cj(lhs(i+w+1, k))), 1168 c(cj(lhs(i+w+2, k))), 1169 d(cj(lhs(i+w+3, k))); 1170 blockA[count++] = a; 1171 blockA[count++] = b; 1172 blockA[count++] = c; 1173 blockA[count++] = d; 1174 } 1175 if(Pack1%4) 1176 for(;w<Pack1;++w) 1177 blockA[count++] = cj(lhs(i+w, k)); 1178 } 1179 } 1180 if(PanelMode) count += Pack1 * (stride-offset-depth); 1181 } 1182 if(rows-peeled_mc>=Pack2) 1183 { 1184 if(PanelMode) count += Pack2*offset; 1185 for(Index k=0; k<depth; k++) 1186 for(Index w=0; w<Pack2; w++) 1187 blockA[count++] = cj(lhs(peeled_mc+w, k)); 1188 if(PanelMode) count += Pack2 * (stride-offset-depth); 1189 peeled_mc += Pack2; 1190 } 1191 for(Index i=peeled_mc; i<rows; i++) 1192 { 1193 if(PanelMode) count += offset; 1194 for(Index k=0; k<depth; k++) 1195 blockA[count++] = cj(lhs(i, k)); 1196 if(PanelMode) count += (stride-offset-depth); 1197 } 1198 } 1199 1200 // copy a complete panel of the rhs 1201 // this version is optimized for column major matrices 1202 // The traversal order is as follow: (nr==4): 1203 // 0 1 2 3 12 13 14 15 24 27 1204 // 4 5 6 7 16 17 18 19 25 28 1205 // 8 9 10 11 20 21 22 23 26 29 1206 // . . . . . . . . . . 1207 template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode> 1208 struct gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, PanelMode> 1209 { 1210 typedef typename packet_traits<Scalar>::type Packet; 1211 enum { PacketSize = packet_traits<Scalar>::size }; 1212 EIGEN_DONT_INLINE void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols, Index stride=0, Index offset=0); 1213 }; 1214 1215 template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode> 1216 EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, PanelMode> 1217 ::operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols, Index stride, Index offset) 1218 { 1219 EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR"); 1220 EIGEN_UNUSED_VARIABLE(stride) 1221 EIGEN_UNUSED_VARIABLE(offset) 1222 eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); 1223 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; 1224 Index packet_cols = (cols/nr) * nr; 1225 Index count = 0; 1226 for(Index j2=0; j2<packet_cols; j2+=nr) 1227 { 1228 // skip what we have before 1229 if(PanelMode) count += nr * offset; 1230 const Scalar* b0 = &rhs[(j2+0)*rhsStride]; 1231 const Scalar* b1 = &rhs[(j2+1)*rhsStride]; 1232 const Scalar* b2 = &rhs[(j2+2)*rhsStride]; 1233 const Scalar* b3 = &rhs[(j2+3)*rhsStride]; 1234 for(Index k=0; k<depth; k++) 1235 { 1236 blockB[count+0] = cj(b0[k]); 1237 blockB[count+1] = cj(b1[k]); 1238 if(nr==4) blockB[count+2] = cj(b2[k]); 1239 if(nr==4) blockB[count+3] = cj(b3[k]); 1240 count += nr; 1241 } 1242 // skip what we have after 1243 if(PanelMode) count += nr * (stride-offset-depth); 1244 } 1245 1246 // copy the remaining columns one at a time (nr==1) 1247 for(Index j2=packet_cols; j2<cols; ++j2) 1248 { 1249 if(PanelMode) count += offset; 1250 const Scalar* b0 = &rhs[(j2+0)*rhsStride]; 1251 for(Index k=0; k<depth; k++) 1252 { 1253 blockB[count] = cj(b0[k]); 1254 count += 1; 1255 } 1256 if(PanelMode) count += (stride-offset-depth); 1257 } 1258 } 1259 1260 // this version is optimized for row major matrices 1261 template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode> 1262 struct gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, PanelMode> 1263 { 1264 enum { PacketSize = packet_traits<Scalar>::size }; 1265 EIGEN_DONT_INLINE void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols, Index stride=0, Index offset=0); 1266 }; 1267 1268 template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode> 1269 EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, PanelMode> 1270 ::operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols, Index stride, Index offset) 1271 { 1272 EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR"); 1273 EIGEN_UNUSED_VARIABLE(stride) 1274 EIGEN_UNUSED_VARIABLE(offset) 1275 eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); 1276 conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; 1277 Index packet_cols = (cols/nr) * nr; 1278 Index count = 0; 1279 for(Index j2=0; j2<packet_cols; j2+=nr) 1280 { 1281 // skip what we have before 1282 if(PanelMode) count += nr * offset; 1283 for(Index k=0; k<depth; k++) 1284 { 1285 const Scalar* b0 = &rhs[k*rhsStride + j2]; 1286 blockB[count+0] = cj(b0[0]); 1287 blockB[count+1] = cj(b0[1]); 1288 if(nr==4) blockB[count+2] = cj(b0[2]); 1289 if(nr==4) blockB[count+3] = cj(b0[3]); 1290 count += nr; 1291 } 1292 // skip what we have after 1293 if(PanelMode) count += nr * (stride-offset-depth); 1294 } 1295 // copy the remaining columns one at a time (nr==1) 1296 for(Index j2=packet_cols; j2<cols; ++j2) 1297 { 1298 if(PanelMode) count += offset; 1299 const Scalar* b0 = &rhs[j2]; 1300 for(Index k=0; k<depth; k++) 1301 { 1302 blockB[count] = cj(b0[k*rhsStride]); 1303 count += 1; 1304 } 1305 if(PanelMode) count += stride-offset-depth; 1306 } 1307 } 1308 1309 } // end namespace internal 1310 1311 /** \returns the currently set level 1 cpu cache size (in bytes) used to estimate the ideal blocking size parameters. 1312 * \sa setCpuCacheSize */ 1313 inline std::ptrdiff_t l1CacheSize() 1314 { 1315 std::ptrdiff_t l1, l2; 1316 internal::manage_caching_sizes(GetAction, &l1, &l2); 1317 return l1; 1318 } 1319 1320 /** \returns the currently set level 2 cpu cache size (in bytes) used to estimate the ideal blocking size parameters. 1321 * \sa setCpuCacheSize */ 1322 inline std::ptrdiff_t l2CacheSize() 1323 { 1324 std::ptrdiff_t l1, l2; 1325 internal::manage_caching_sizes(GetAction, &l1, &l2); 1326 return l2; 1327 } 1328 1329 /** Set the cpu L1 and L2 cache sizes (in bytes). 1330 * These values are use to adjust the size of the blocks 1331 * for the algorithms working per blocks. 1332 * 1333 * \sa computeProductBlockingSizes */ 1334 inline void setCpuCacheSizes(std::ptrdiff_t l1, std::ptrdiff_t l2) 1335 { 1336 internal::manage_caching_sizes(SetAction, &l1, &l2); 1337 } 1338 1339 } // end namespace Eigen 1340 1341 #endif // EIGEN_GENERAL_BLOCK_PANEL_H 1342