1 /* 2 Copyright (c) 2011, Intel Corporation. All rights reserved. 3 4 Redistribution and use in source and binary forms, with or without modification, 5 are permitted provided that the following conditions are met: 6 7 * Redistributions of source code must retain the above copyright notice, this 8 list of conditions and the following disclaimer. 9 * Redistributions in binary form must reproduce the above copyright notice, 10 this list of conditions and the following disclaimer in the documentation 11 and/or other materials provided with the distribution. 12 * Neither the name of Intel Corporation nor the names of its contributors may 13 be used to endorse or promote products derived from this software without 14 specific prior written permission. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 20 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 23 ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 27 ******************************************************************************** 28 * Content : Eigen bindings to Intel(R) MKL 29 * Triangular matrix * matrix product functionality based on ?TRMM. 30 ******************************************************************************** 31 */ 32 33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H 34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H 35 36 namespace Eigen { 37 38 namespace internal { 39 40 41 template <typename Scalar, typename Index, 42 int Mode, bool LhsIsTriangular, 43 int LhsStorageOrder, bool ConjugateLhs, 44 int RhsStorageOrder, bool ConjugateRhs, 45 int ResStorageOrder> 46 struct product_triangular_matrix_matrix_trmm : 47 product_triangular_matrix_matrix<Scalar,Index,Mode, 48 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, 49 RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {}; 50 51 52 // try to go to BLAS specialization 53 #define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \ 54 template <typename Index, int Mode, \ 55 int LhsStorageOrder, bool ConjugateLhs, \ 56 int RhsStorageOrder, bool ConjugateRhs> \ 57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \ 58 LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \ 59 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\ 60 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha) { \ 61 product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \ 62 LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \ 63 RhsStorageOrder, ConjugateRhs, ColMajor>::run( \ 64 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha); \ 65 } \ 66 }; 67 68 EIGEN_MKL_TRMM_SPECIALIZE(double, true) 69 EIGEN_MKL_TRMM_SPECIALIZE(double, false) 70 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true) 71 EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false) 72 EIGEN_MKL_TRMM_SPECIALIZE(float, true) 73 EIGEN_MKL_TRMM_SPECIALIZE(float, false) 74 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true) 75 EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false) 76 77 // implements col-major += alpha * op(triangular) * op(general) 78 #define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ 79 template <typename Index, int Mode, \ 80 int LhsStorageOrder, bool ConjugateLhs, \ 81 int RhsStorageOrder, bool ConjugateRhs> \ 82 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \ 83 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 84 { \ 85 enum { \ 86 IsLower = (Mode&Lower) == Lower, \ 87 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 88 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 89 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 90 LowUp = IsLower ? Lower : Upper, \ 91 conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \ 92 }; \ 93 \ 94 static EIGEN_DONT_INLINE void run( \ 95 Index _rows, Index _cols, Index _depth, \ 96 const EIGTYPE* _lhs, Index lhsStride, \ 97 const EIGTYPE* _rhs, Index rhsStride, \ 98 EIGTYPE* res, Index resStride, \ 99 EIGTYPE alpha) \ 100 { \ 101 Index diagSize = (std::min)(_rows,_depth); \ 102 Index rows = IsLower ? _rows : diagSize; \ 103 Index depth = IsLower ? diagSize : _depth; \ 104 Index cols = _cols; \ 105 \ 106 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 107 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 108 \ 109 /* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \ 110 if (rows != depth) { \ 111 \ 112 int nthr = mkl_domain_get_max_threads(MKL_BLAS); \ 113 \ 114 if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \ 115 /* Most likely no benefit to call TRMM or GEMM from MKL*/ \ 116 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \ 117 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 118 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha); \ 119 /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \ 120 } else { \ 121 /* Make sense to call GEMM */ \ 122 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 123 MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \ 124 MKL_INT aStride = aa_tmp.outerStride(); \ 125 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> blocking(_rows,_cols,_depth); \ 126 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 127 rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, blocking, 0); \ 128 \ 129 /*std::cout << "TRMM_L: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \ 130 } \ 131 return; \ 132 } \ 133 char side = 'L', transa, uplo, diag = 'N'; \ 134 EIGTYPE *b; \ 135 const EIGTYPE *a; \ 136 MKL_INT m, n, lda, ldb; \ 137 MKLTYPE alpha_; \ 138 \ 139 /* Set alpha_*/ \ 140 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \ 141 \ 142 /* Set m, n */ \ 143 m = (MKL_INT)diagSize; \ 144 n = (MKL_INT)cols; \ 145 \ 146 /* Set trans */ \ 147 transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \ 148 \ 149 /* Set b, ldb */ \ 150 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \ 151 MatrixX##EIGPREFIX b_tmp; \ 152 \ 153 if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \ 154 b = b_tmp.data(); \ 155 ldb = b_tmp.outerStride(); \ 156 \ 157 /* Set uplo */ \ 158 uplo = IsLower ? 'L' : 'U'; \ 159 if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 160 /* Set a, lda */ \ 161 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 162 MatrixLhs a_tmp; \ 163 \ 164 if ((conjA!=0) || (SetDiag==0)) { \ 165 if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \ 166 if (IsZeroDiag) \ 167 a_tmp.diagonal().setZero(); \ 168 else if (IsUnitDiag) \ 169 a_tmp.diagonal().setOnes();\ 170 a = a_tmp.data(); \ 171 lda = a_tmp.outerStride(); \ 172 } else { \ 173 a = _lhs; \ 174 lda = lhsStride; \ 175 } \ 176 /*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \ 177 /* call ?trmm*/ \ 178 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ 179 \ 180 /* Add op(a_triangular)*b into res*/ \ 181 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 182 res_tmp=res_tmp+b_tmp; \ 183 } \ 184 }; 185 186 EIGEN_MKL_TRMM_L(double, double, d, d) 187 EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z) 188 EIGEN_MKL_TRMM_L(float, float, f, s) 189 EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c) 190 191 // implements col-major += alpha * op(general) * op(triangular) 192 #define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \ 193 template <typename Index, int Mode, \ 194 int LhsStorageOrder, bool ConjugateLhs, \ 195 int RhsStorageOrder, bool ConjugateRhs> \ 196 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \ 197 LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \ 198 { \ 199 enum { \ 200 IsLower = (Mode&Lower) == Lower, \ 201 SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \ 202 IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \ 203 IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ 204 LowUp = IsLower ? Lower : Upper, \ 205 conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \ 206 }; \ 207 \ 208 static EIGEN_DONT_INLINE void run( \ 209 Index _rows, Index _cols, Index _depth, \ 210 const EIGTYPE* _lhs, Index lhsStride, \ 211 const EIGTYPE* _rhs, Index rhsStride, \ 212 EIGTYPE* res, Index resStride, \ 213 EIGTYPE alpha) \ 214 { \ 215 Index diagSize = (std::min)(_cols,_depth); \ 216 Index rows = _rows; \ 217 Index depth = IsLower ? _depth : diagSize; \ 218 Index cols = IsLower ? diagSize : _cols; \ 219 \ 220 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \ 221 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \ 222 \ 223 /* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \ 224 if (cols != depth) { \ 225 \ 226 int nthr = mkl_domain_get_max_threads(MKL_BLAS); \ 227 \ 228 if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \ 229 /* Most likely no benefit to call TRMM or GEMM from MKL*/ \ 230 product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \ 231 LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \ 232 _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha); \ 233 /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \ 234 } else { \ 235 /* Make sense to call GEMM */ \ 236 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 237 MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \ 238 MKL_INT aStride = aa_tmp.outerStride(); \ 239 gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> blocking(_rows,_cols,_depth); \ 240 general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \ 241 rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, blocking, 0); \ 242 \ 243 /*std::cout << "TRMM_R: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \ 244 } \ 245 return; \ 246 } \ 247 char side = 'R', transa, uplo, diag = 'N'; \ 248 EIGTYPE *b; \ 249 const EIGTYPE *a; \ 250 MKL_INT m, n, lda, ldb; \ 251 MKLTYPE alpha_; \ 252 \ 253 /* Set alpha_*/ \ 254 assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \ 255 \ 256 /* Set m, n */ \ 257 m = (MKL_INT)rows; \ 258 n = (MKL_INT)diagSize; \ 259 \ 260 /* Set trans */ \ 261 transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \ 262 \ 263 /* Set b, ldb */ \ 264 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \ 265 MatrixX##EIGPREFIX b_tmp; \ 266 \ 267 if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \ 268 b = b_tmp.data(); \ 269 ldb = b_tmp.outerStride(); \ 270 \ 271 /* Set uplo */ \ 272 uplo = IsLower ? 'L' : 'U'; \ 273 if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \ 274 /* Set a, lda */ \ 275 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \ 276 MatrixRhs a_tmp; \ 277 \ 278 if ((conjA!=0) || (SetDiag==0)) { \ 279 if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \ 280 if (IsZeroDiag) \ 281 a_tmp.diagonal().setZero(); \ 282 else if (IsUnitDiag) \ 283 a_tmp.diagonal().setOnes();\ 284 a = a_tmp.data(); \ 285 lda = a_tmp.outerStride(); \ 286 } else { \ 287 a = _rhs; \ 288 lda = rhsStride; \ 289 } \ 290 /*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \ 291 /* call ?trmm*/ \ 292 MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \ 293 \ 294 /* Add op(a_triangular)*b into res*/ \ 295 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \ 296 res_tmp=res_tmp+b_tmp; \ 297 } \ 298 }; 299 300 EIGEN_MKL_TRMM_R(double, double, d, d) 301 EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z) 302 EIGEN_MKL_TRMM_R(float, float, f, s) 303 EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c) 304 305 } // end namespace internal 306 307 } // end namespace Eigen 308 309 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H 310