1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud (at) inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 #include <iostream> 10 #include "common.h" 11 12 int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha, 13 const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) 14 { 15 // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n"; 16 typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*); 17 static const functype func[12] = { 18 // array index: NOTR | (NOTR << 2) 19 (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,ColMajor,false,ColMajor>::run), 20 // array index: TR | (NOTR << 2) 21 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,false,ColMajor>::run), 22 // array index: ADJ | (NOTR << 2) 23 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor>::run), 24 0, 25 // array index: NOTR | (TR << 2) 26 (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,false,ColMajor>::run), 27 // array index: TR | (TR << 2) 28 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,false,ColMajor>::run), 29 // array index: ADJ | (TR << 2) 30 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,false,ColMajor>::run), 31 0, 32 // array index: NOTR | (ADJ << 2) 33 (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor>::run), 34 // array index: TR | (ADJ << 2) 35 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,Conj, ColMajor>::run), 36 // array index: ADJ | (ADJ << 2) 37 (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,Conj, ColMajor>::run), 38 0 39 }; 40 41 const Scalar* a = reinterpret_cast<const Scalar*>(pa); 42 const Scalar* b = reinterpret_cast<const Scalar*>(pb); 43 Scalar* c = reinterpret_cast<Scalar*>(pc); 44 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha); 45 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta); 46 47 int info = 0; 48 if(OP(*opa)==INVALID) info = 1; 49 else if(OP(*opb)==INVALID) info = 2; 50 else if(*m<0) info = 3; 51 else if(*n<0) info = 4; 52 else if(*k<0) info = 5; 53 else if(*lda<std::max(1,(OP(*opa)==NOTR)?*m:*k)) info = 8; 54 else if(*ldb<std::max(1,(OP(*opb)==NOTR)?*k:*n)) info = 10; 55 else if(*ldc<std::max(1,*m)) info = 13; 56 if(info) 57 return xerbla_(SCALAR_SUFFIX_UP"GEMM ",&info,6); 58 59 if (*m == 0 || *n == 0) 60 return 0; 61 62 if(beta!=Scalar(1)) 63 { 64 if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero(); 65 else matrix(c, *m, *n, *ldc) *= beta; 66 } 67 68 if(*k == 0) 69 return 0; 70 71 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k,1,true); 72 73 int code = OP(*opa) | (OP(*opb) << 2); 74 func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0); 75 return 0; 76 } 77 78 int EIGEN_BLAS_FUNC(trsm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, 79 const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) 80 { 81 // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n"; 82 typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&); 83 static const functype func[32] = { 84 // array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4) 85 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,ColMajor,ColMajor>::run), 86 // array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4) 87 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,RowMajor,ColMajor>::run), 88 // array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4) 89 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, Conj, RowMajor,ColMajor>::run),\ 90 0, 91 // array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4) 92 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,ColMajor,ColMajor>::run), 93 // array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4) 94 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,RowMajor,ColMajor>::run), 95 // array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4) 96 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, Conj, RowMajor,ColMajor>::run), 97 0, 98 // array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4) 99 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,ColMajor,ColMajor>::run), 100 // array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4) 101 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,RowMajor,ColMajor>::run), 102 // array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4) 103 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, Conj, RowMajor,ColMajor>::run), 104 0, 105 // array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4) 106 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,ColMajor,ColMajor>::run), 107 // array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4) 108 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,RowMajor,ColMajor>::run), 109 // array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4) 110 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, Conj, RowMajor,ColMajor>::run), 111 0, 112 // array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4) 113 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,ColMajor,ColMajor>::run), 114 // array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4) 115 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,RowMajor,ColMajor>::run), 116 // array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4) 117 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,Conj, RowMajor,ColMajor>::run), 118 0, 119 // array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4) 120 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,ColMajor,ColMajor>::run), 121 // array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4) 122 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,RowMajor,ColMajor>::run), 123 // array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4) 124 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,Conj, RowMajor,ColMajor>::run), 125 0, 126 // array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4) 127 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,ColMajor,ColMajor>::run), 128 // array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4) 129 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,RowMajor,ColMajor>::run), 130 // array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4) 131 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,Conj, RowMajor,ColMajor>::run), 132 0, 133 // array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4) 134 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,ColMajor,ColMajor>::run), 135 // array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4) 136 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,RowMajor,ColMajor>::run), 137 // array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4) 138 (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,Conj, RowMajor,ColMajor>::run), 139 0 140 }; 141 142 const Scalar* a = reinterpret_cast<const Scalar*>(pa); 143 Scalar* b = reinterpret_cast<Scalar*>(pb); 144 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha); 145 146 int info = 0; 147 if(SIDE(*side)==INVALID) info = 1; 148 else if(UPLO(*uplo)==INVALID) info = 2; 149 else if(OP(*opa)==INVALID) info = 3; 150 else if(DIAG(*diag)==INVALID) info = 4; 151 else if(*m<0) info = 5; 152 else if(*n<0) info = 6; 153 else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9; 154 else if(*ldb<std::max(1,*m)) info = 11; 155 if(info) 156 return xerbla_(SCALAR_SUFFIX_UP"TRSM ",&info,6); 157 158 if(*m==0 || *n==0) 159 return 0; 160 161 int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4); 162 163 if(SIDE(*side)==LEFT) 164 { 165 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false); 166 func[code](*m, *n, a, *lda, b, *ldb, blocking); 167 } 168 else 169 { 170 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false); 171 func[code](*n, *m, a, *lda, b, *ldb, blocking); 172 } 173 174 if(alpha!=Scalar(1)) 175 matrix(b,*m,*n,*ldb) *= alpha; 176 177 return 0; 178 } 179 180 181 // b = alpha*op(a)*b for side = 'L'or'l' 182 // b = alpha*b*op(a) for side = 'R'or'r' 183 int EIGEN_BLAS_FUNC(trmm)(const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n, 184 const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) 185 { 186 // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n"; 187 typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&); 188 static const functype func[32] = { 189 // array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4) 190 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, ColMajor,false,ColMajor,false,ColMajor>::run), 191 // array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4) 192 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,false,ColMajor,false,ColMajor>::run), 193 // array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4) 194 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run), 195 0, 196 // array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4) 197 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,ColMajor,false,ColMajor>::run), 198 // array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4) 199 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,false,ColMajor>::run), 200 // array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4) 201 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run), 202 0, 203 // array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4) 204 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, ColMajor,false,ColMajor,false,ColMajor>::run), 205 // array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4) 206 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,false,ColMajor,false,ColMajor>::run), 207 // array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4) 208 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run), 209 0, 210 // array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4) 211 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,ColMajor,false,ColMajor>::run), 212 // array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4) 213 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,false,ColMajor>::run), 214 // array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4) 215 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run), 216 0, 217 // array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4) 218 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run), 219 // array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4) 220 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run), 221 // array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4) 222 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run), 223 0, 224 // array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4) 225 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run), 226 // array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4) 227 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run), 228 // array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4) 229 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run), 230 0, 231 // array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4) 232 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run), 233 // array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4) 234 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run), 235 // array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4) 236 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run), 237 0, 238 // array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4) 239 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run), 240 // array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4) 241 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run), 242 // array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4) 243 (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run), 244 0 245 }; 246 247 const Scalar* a = reinterpret_cast<const Scalar*>(pa); 248 Scalar* b = reinterpret_cast<Scalar*>(pb); 249 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha); 250 251 int info = 0; 252 if(SIDE(*side)==INVALID) info = 1; 253 else if(UPLO(*uplo)==INVALID) info = 2; 254 else if(OP(*opa)==INVALID) info = 3; 255 else if(DIAG(*diag)==INVALID) info = 4; 256 else if(*m<0) info = 5; 257 else if(*n<0) info = 6; 258 else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9; 259 else if(*ldb<std::max(1,*m)) info = 11; 260 if(info) 261 return xerbla_(SCALAR_SUFFIX_UP"TRMM ",&info,6); 262 263 int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4); 264 265 if(*m==0 || *n==0) 266 return 1; 267 268 // FIXME find a way to avoid this copy 269 Matrix<Scalar,Dynamic,Dynamic,ColMajor> tmp = matrix(b,*m,*n,*ldb); 270 matrix(b,*m,*n,*ldb).setZero(); 271 272 if(SIDE(*side)==LEFT) 273 { 274 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false); 275 func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha, blocking); 276 } 277 else 278 { 279 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false); 280 func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha, blocking); 281 } 282 return 1; 283 } 284 285 // c = alpha*a*b + beta*c for side = 'L'or'l' 286 // c = alpha*b*a + beta*c for side = 'R'or'r 287 int EIGEN_BLAS_FUNC(symm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, 288 const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) 289 { 290 // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n"; 291 const Scalar* a = reinterpret_cast<const Scalar*>(pa); 292 const Scalar* b = reinterpret_cast<const Scalar*>(pb); 293 Scalar* c = reinterpret_cast<Scalar*>(pc); 294 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha); 295 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta); 296 297 int info = 0; 298 if(SIDE(*side)==INVALID) info = 1; 299 else if(UPLO(*uplo)==INVALID) info = 2; 300 else if(*m<0) info = 3; 301 else if(*n<0) info = 4; 302 else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7; 303 else if(*ldb<std::max(1,*m)) info = 9; 304 else if(*ldc<std::max(1,*m)) info = 12; 305 if(info) 306 return xerbla_(SCALAR_SUFFIX_UP"SYMM ",&info,6); 307 308 if(beta!=Scalar(1)) 309 { 310 if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero(); 311 else matrix(c, *m, *n, *ldc) *= beta; 312 } 313 314 if(*m==0 || *n==0) 315 { 316 return 1; 317 } 318 319 int size = (SIDE(*side)==LEFT) ? (*m) : (*n); 320 #if ISCOMPLEX 321 // FIXME add support for symmetric complex matrix 322 Matrix<Scalar,Dynamic,Dynamic,ColMajor> matA(size,size); 323 if(UPLO(*uplo)==UP) 324 { 325 matA.triangularView<Upper>() = matrix(a,size,size,*lda); 326 matA.triangularView<Lower>() = matrix(a,size,size,*lda).transpose(); 327 } 328 else if(UPLO(*uplo)==LO) 329 { 330 matA.triangularView<Lower>() = matrix(a,size,size,*lda); 331 matA.triangularView<Upper>() = matrix(a,size,size,*lda).transpose(); 332 } 333 if(SIDE(*side)==LEFT) 334 matrix(c, *m, *n, *ldc) += alpha * matA * matrix(b, *m, *n, *ldb); 335 else if(SIDE(*side)==RIGHT) 336 matrix(c, *m, *n, *ldc) += alpha * matrix(b, *m, *n, *ldb) * matA; 337 #else 338 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,size,1,false); 339 340 if(SIDE(*side)==LEFT) 341 if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, RowMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking); 342 else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking); 343 else return 0; 344 else if(SIDE(*side)==RIGHT) 345 if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, RowMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking); 346 else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, ColMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking); 347 else return 0; 348 else 349 return 0; 350 #endif 351 352 return 0; 353 } 354 355 // c = alpha*a*a' + beta*c for op = 'N'or'n' 356 // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c' 357 int EIGEN_BLAS_FUNC(syrk)(const char *uplo, const char *op, const int *n, const int *k, 358 const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) 359 { 360 // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n"; 361 #if !ISCOMPLEX 362 typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&); 363 static const functype func[8] = { 364 // array index: NOTR | (UP << 2) 365 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Upper>::run), 366 // array index: TR | (UP << 2) 367 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Upper>::run), 368 // array index: ADJ | (UP << 2) 369 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Upper>::run), 370 0, 371 // array index: NOTR | (LO << 2) 372 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Lower>::run), 373 // array index: TR | (LO << 2) 374 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Lower>::run), 375 // array index: ADJ | (LO << 2) 376 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Lower>::run), 377 0 378 }; 379 #endif 380 381 const Scalar* a = reinterpret_cast<const Scalar*>(pa); 382 Scalar* c = reinterpret_cast<Scalar*>(pc); 383 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha); 384 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta); 385 386 int info = 0; 387 if(UPLO(*uplo)==INVALID) info = 1; 388 else if(OP(*op)==INVALID || (ISCOMPLEX && OP(*op)==ADJ) ) info = 2; 389 else if(*n<0) info = 3; 390 else if(*k<0) info = 4; 391 else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7; 392 else if(*ldc<std::max(1,*n)) info = 10; 393 if(info) 394 return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6); 395 396 if(beta!=Scalar(1)) 397 { 398 if(UPLO(*uplo)==UP) 399 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero(); 400 else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta; 401 else 402 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero(); 403 else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta; 404 } 405 406 if(*n==0 || *k==0) 407 return 0; 408 409 #if ISCOMPLEX 410 // FIXME add support for symmetric complex matrix 411 if(UPLO(*uplo)==UP) 412 { 413 if(OP(*op)==NOTR) 414 matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose(); 415 else 416 matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda); 417 } 418 else 419 { 420 if(OP(*op)==NOTR) 421 matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose(); 422 else 423 matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda); 424 } 425 #else 426 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*n,*n,*k,1,false); 427 428 int code = OP(*op) | (UPLO(*uplo) << 2); 429 func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha, blocking); 430 #endif 431 432 return 0; 433 } 434 435 // c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n' 436 // c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't' 437 int EIGEN_BLAS_FUNC(syr2k)(const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, 438 const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) 439 { 440 const Scalar* a = reinterpret_cast<const Scalar*>(pa); 441 const Scalar* b = reinterpret_cast<const Scalar*>(pb); 442 Scalar* c = reinterpret_cast<Scalar*>(pc); 443 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha); 444 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta); 445 446 // std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n"; 447 448 int info = 0; 449 if(UPLO(*uplo)==INVALID) info = 1; 450 else if(OP(*op)==INVALID || (ISCOMPLEX && OP(*op)==ADJ) ) info = 2; 451 else if(*n<0) info = 3; 452 else if(*k<0) info = 4; 453 else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7; 454 else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9; 455 else if(*ldc<std::max(1,*n)) info = 12; 456 if(info) 457 return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6); 458 459 if(beta!=Scalar(1)) 460 { 461 if(UPLO(*uplo)==UP) 462 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero(); 463 else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta; 464 else 465 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero(); 466 else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta; 467 } 468 469 if(*k==0) 470 return 1; 471 472 if(OP(*op)==NOTR) 473 { 474 if(UPLO(*uplo)==UP) 475 { 476 matrix(c, *n, *n, *ldc).triangularView<Upper>() 477 += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose() 478 + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose(); 479 } 480 else if(UPLO(*uplo)==LO) 481 matrix(c, *n, *n, *ldc).triangularView<Lower>() 482 += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose() 483 + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose(); 484 } 485 else if(OP(*op)==TR || OP(*op)==ADJ) 486 { 487 if(UPLO(*uplo)==UP) 488 matrix(c, *n, *n, *ldc).triangularView<Upper>() 489 += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb) 490 + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda); 491 else if(UPLO(*uplo)==LO) 492 matrix(c, *n, *n, *ldc).triangularView<Lower>() 493 += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb) 494 + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda); 495 } 496 497 return 0; 498 } 499 500 501 #if ISCOMPLEX 502 503 // c = alpha*a*b + beta*c for side = 'L'or'l' 504 // c = alpha*b*a + beta*c for side = 'R'or'r 505 int EIGEN_BLAS_FUNC(hemm)(const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, 506 const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) 507 { 508 const Scalar* a = reinterpret_cast<const Scalar*>(pa); 509 const Scalar* b = reinterpret_cast<const Scalar*>(pb); 510 Scalar* c = reinterpret_cast<Scalar*>(pc); 511 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha); 512 Scalar beta = *reinterpret_cast<const Scalar*>(pbeta); 513 514 // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; 515 516 int info = 0; 517 if(SIDE(*side)==INVALID) info = 1; 518 else if(UPLO(*uplo)==INVALID) info = 2; 519 else if(*m<0) info = 3; 520 else if(*n<0) info = 4; 521 else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7; 522 else if(*ldb<std::max(1,*m)) info = 9; 523 else if(*ldc<std::max(1,*m)) info = 12; 524 if(info) 525 return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6); 526 527 if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero(); 528 else if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta; 529 530 if(*m==0 || *n==0) 531 { 532 return 1; 533 } 534 535 int size = (SIDE(*side)==LEFT) ? (*m) : (*n); 536 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,size,1,false); 537 538 if(SIDE(*side)==LEFT) 539 { 540 if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar,DenseIndex,RowMajor,true,Conj, ColMajor,false,false, ColMajor> 541 ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking); 542 else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,true,false, ColMajor,false,false, ColMajor> 543 ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking); 544 else return 0; 545 } 546 else if(SIDE(*side)==RIGHT) 547 { 548 if(UPLO(*uplo)==UP) matrix(c,*m,*n,*ldc) += alpha * matrix(b,*m,*n,*ldb) * matrix(a,*n,*n,*lda).selfadjointView<Upper>();/*internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, RowMajor,true,Conj, ColMajor> 549 ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);*/ 550 else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, ColMajor,true,false, ColMajor> 551 ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking); 552 else return 0; 553 } 554 else 555 { 556 return 0; 557 } 558 559 return 0; 560 } 561 562 // c = alpha*a*conj(a') + beta*c for op = 'N'or'n' 563 // c = alpha*conj(a')*a + beta*c for op = 'C'or'c' 564 int EIGEN_BLAS_FUNC(herk)(const char *uplo, const char *op, const int *n, const int *k, 565 const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) 566 { 567 // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n"; 568 569 typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&); 570 static const functype func[8] = { 571 // array index: NOTR | (UP << 2) 572 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Upper>::run), 573 0, 574 // array index: ADJ | (UP << 2) 575 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Upper>::run), 576 0, 577 // array index: NOTR | (LO << 2) 578 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Lower>::run), 579 0, 580 // array index: ADJ | (LO << 2) 581 (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Lower>::run), 582 0 583 }; 584 585 const Scalar* a = reinterpret_cast<const Scalar*>(pa); 586 Scalar* c = reinterpret_cast<Scalar*>(pc); 587 RealScalar alpha = *palpha; 588 RealScalar beta = *pbeta; 589 590 // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n"; 591 592 int info = 0; 593 if(UPLO(*uplo)==INVALID) info = 1; 594 else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2; 595 else if(*n<0) info = 3; 596 else if(*k<0) info = 4; 597 else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7; 598 else if(*ldc<std::max(1,*n)) info = 10; 599 if(info) 600 return xerbla_(SCALAR_SUFFIX_UP"HERK ",&info,6); 601 602 int code = OP(*op) | (UPLO(*uplo) << 2); 603 604 if(beta!=RealScalar(1)) 605 { 606 if(UPLO(*uplo)==UP) 607 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero(); 608 else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta; 609 else 610 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero(); 611 else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta; 612 613 if(beta!=Scalar(0)) 614 { 615 matrix(c, *n, *n, *ldc).diagonal().real() *= beta; 616 matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); 617 } 618 } 619 620 if(*k>0 && alpha!=RealScalar(0)) 621 { 622 internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*n,*n,*k,1,false); 623 func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha, blocking); 624 matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); 625 } 626 return 0; 627 } 628 629 // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n' 630 // c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c' 631 int EIGEN_BLAS_FUNC(her2k)(const char *uplo, const char *op, const int *n, const int *k, 632 const RealScalar *palpha, const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) 633 { 634 const Scalar* a = reinterpret_cast<const Scalar*>(pa); 635 const Scalar* b = reinterpret_cast<const Scalar*>(pb); 636 Scalar* c = reinterpret_cast<Scalar*>(pc); 637 Scalar alpha = *reinterpret_cast<const Scalar*>(palpha); 638 RealScalar beta = *pbeta; 639 640 // std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << *ldb << " " << beta << " " << *ldc << "\n"; 641 642 int info = 0; 643 if(UPLO(*uplo)==INVALID) info = 1; 644 else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2; 645 else if(*n<0) info = 3; 646 else if(*k<0) info = 4; 647 else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7; 648 else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9; 649 else if(*ldc<std::max(1,*n)) info = 12; 650 if(info) 651 return xerbla_(SCALAR_SUFFIX_UP"HER2K",&info,6); 652 653 if(beta!=RealScalar(1)) 654 { 655 if(UPLO(*uplo)==UP) 656 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero(); 657 else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta; 658 else 659 if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero(); 660 else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta; 661 662 if(beta!=Scalar(0)) 663 { 664 matrix(c, *n, *n, *ldc).diagonal().real() *= beta; 665 matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); 666 } 667 } 668 else if(*k>0 && alpha!=Scalar(0)) 669 matrix(c, *n, *n, *ldc).diagonal().imag().setZero(); 670 671 if(*k==0) 672 return 1; 673 674 if(OP(*op)==NOTR) 675 { 676 if(UPLO(*uplo)==UP) 677 { 678 matrix(c, *n, *n, *ldc).triangularView<Upper>() 679 += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint() 680 + numext::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint(); 681 } 682 else if(UPLO(*uplo)==LO) 683 matrix(c, *n, *n, *ldc).triangularView<Lower>() 684 += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint() 685 + numext::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint(); 686 } 687 else if(OP(*op)==ADJ) 688 { 689 if(UPLO(*uplo)==UP) 690 matrix(c, *n, *n, *ldc).triangularView<Upper>() 691 += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb) 692 + numext::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda); 693 else if(UPLO(*uplo)==LO) 694 matrix(c, *n, *n, *ldc).triangularView<Lower>() 695 += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb) 696 + numext::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda); 697 } 698 699 return 1; 700 } 701 702 #endif // ISCOMPLEX 703