1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18 #include "rsCpuIntrinsic.h" 19 #include "rsCpuIntrinsicInlines.h" 20 #include "rsCpuBLASDispatch.h" 21 #include "eight_bit_int_gemm.h" 22 23 namespace android { 24 namespace renderscript { 25 26 27 class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic { 28 public: 29 void invokeForEach(uint32_t slot, 30 const Allocation ** ain, 31 uint32_t inLen, 32 Allocation * aout, 33 const void * usr, 34 uint32_t usrLen, 35 const RsScriptCall *sc) override; 36 void populateScript(Script *) override; 37 ~RsdCpuScriptIntrinsicBLAS() override; 38 RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s); 39 40 protected: 41 42 uint8_t a_offset = 0; 43 uint8_t b_offset = 0; 44 uint8_t c_offset = 0; 45 46 #ifdef RS_COMPATIBILITY_LIB 47 bool isBlasLibInitialized = false; 48 #endif 49 static void kernelBNNM(size_t m, size_t n, size_t k, 50 const uint8_t* a, uint8_t a_offset, size_t lda, 51 const uint8_t* b, uint8_t b_offset, size_t ldb, 52 uint8_t* c, int32_t c_offset, size_t ldc, 53 int32_t c_mult_int); 54 55 56 57 }; 58 59 void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) { 60 s->mHal.info.exportedVariableCount = 0; 61 } 62 63 static void initABC(const Allocation ** ain, 64 size_t size, 65 void** A, 66 void** B, 67 void** C, 68 int* lda, 69 int* ldb, 70 int* ldc) 71 { 72 if (ain[0]) { 73 *A = ain[0]->mHal.drvState.lod[0].mallocPtr; 74 *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size); 75 } 76 if (ain[1]) { 77 *B = ain[1]->mHal.drvState.lod[0].mallocPtr; 78 *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size); 79 } 80 if (ain[2]) { 81 *C = ain[2]->mHal.drvState.lod[0].mallocPtr; 82 *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size); 83 } 84 } 85 86 // Routine to setup LaunchStruct for GEMM callback. 87 static void setupGEMM(MTLaunchStructForEachBlas *mtls, const Allocation **ain, RsBlasCall* call, 88 RsdCpuReferenceImpl *ctx) { 89 uint32_t mm, nn, kk; 90 mm = call->M; 91 nn = call->N; 92 kk = call->K; 93 94 memset(mtls, 0, sizeof(MTLaunchStructForEachBlas)); 95 mtls->rs = ctx; 96 mtls->sc = call; 97 mtls->dimPtr = &mtls->fep.dim; 98 mtls->fep.dim.x = nn; 99 mtls->fep.dim.y = mm; 100 mtls->fep.dim.z = kk; 101 if (ain) { 102 memcpy(mtls->ains, ain, 3 * sizeof(ain[0])); 103 } 104 uint32_t elementBytes = 4; 105 if (ain[0]) { 106 elementBytes = ain[0]->getType()->getElement()->getSizeBytes(); 107 } 108 const uint32_t MIN_SIZE_TO_TILE = 64 * 1024 / elementBytes; 109 const uint32_t MAX_WORK_PER_THREAD = 512 / elementBytes; 110 const uint32_t THREAD_COUNT = ctx->getThreadCount(); 111 uint32_t tileSizeN = 0; 112 uint32_t tileSizeM = 0; 113 114 // Do not tile the matrix if: 115 // 1. It is too small comparing to the other matrix. 116 // 2. It is too small comparing to MIN_SIZE_TO_TILE . 117 if (nn * kk > MIN_SIZE_TO_TILE && nn * THREAD_COUNT > mm) { 118 tileSizeN = rsMin(nn / THREAD_COUNT, MAX_WORK_PER_THREAD); 119 } 120 if (mm * kk > MIN_SIZE_TO_TILE && mm * THREAD_COUNT > nn) { 121 tileSizeM = rsMin(mm / THREAD_COUNT, MAX_WORK_PER_THREAD); 122 } 123 mtls->numTileM = 1; 124 mtls->numTileN = 1; 125 mtls->tileSizeM = mm; 126 mtls->tileSizeN = nn; 127 128 // If tiling is needed, compute the number of slices for A & B. 129 mtls->isThreadable = (tileSizeM > 0 || tileSizeN > 0); 130 if (tileSizeM) { 131 mtls->numTileM += (mm - 1) / tileSizeM; 132 mtls->tileSizeM = tileSizeM; 133 } 134 if (tileSizeN) { 135 mtls->numTileN += (nn - 1) / tileSizeN; 136 mtls->tileSizeN = tileSizeN; 137 } 138 139 mtls->mSliceNum = 0; 140 } 141 142 // Generic GEMM callback routine. 143 template <typename T_data, typename T_param, typename Func> 144 static void walk_tiled_gemm(Func blasFunc, T_param alpha, T_param beta, int vecSize, 145 RsBlasCall* call, const MTLaunchStructForEachBlas *mtls) { 146 // setup BLAS enum args 147 enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA; 148 enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB; 149 150 void *A = nullptr; 151 void *B = nullptr; 152 void *C = nullptr; 153 154 int lda = 0, ldb = 0, ldc = 0; 155 156 const Allocation *ain[RS_KERNEL_INPUT_LIMIT]; 157 ain[0] = mtls->ains[0]; 158 ain[1] = mtls->ains[1]; 159 ain[2] = mtls->ains[2]; 160 161 initABC(ain, sizeof(T_data) * vecSize, &A, &B, &C, &lda, &ldb, &ldc); 162 163 // Determin the stride of the tiled matrices. 164 int mStride = (TransA == CblasNoTrans) ? lda : 1; 165 int nStride = (TransB == CblasNoTrans) ? 1 : ldb; 166 while (1) { 167 uint32_t slice = (uint32_t)__sync_fetch_and_add(&mtls->mSliceNum, 1); 168 169 uint32_t mStart = (slice % mtls->numTileM) * mtls->tileSizeM; 170 uint32_t mEnd = mStart + mtls->tileSizeM; 171 mEnd = rsMin(mEnd, (uint32_t)call->M); 172 if (mEnd <= mStart) { 173 return; 174 } 175 176 uint32_t nStart = (slice / mtls->numTileM) * mtls->tileSizeN; 177 uint32_t nEnd = nStart + mtls->tileSizeN; 178 nEnd = rsMin(nEnd, (uint32_t)call->N); 179 if (nEnd <= nStart) { 180 return; 181 } 182 183 blasFunc(CblasRowMajor, TransA, TransB, 184 mEnd - mStart, nEnd - nStart, call->K, alpha, 185 (T_data *)A + mStart * mStride * vecSize, lda, 186 (T_data *)B + nStart * nStride * vecSize, ldb, beta, 187 (T_data *)C + (mStart * ldc + nStart) * vecSize, ldc); 188 } 189 } 190 191 // SGEMM callback 192 static void walk_2d_sgemm(void *usr, uint32_t idx) { 193 const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr; 194 RsBlasCall* call = (RsBlasCall*) mtls->sc; 195 196 float alpha = call->alpha.f; 197 float beta = call->beta.f; 198 199 walk_tiled_gemm<float, float, FnPtr_cblas_sgemm>(cblas_sgemm, alpha, beta, 1, call, mtls); 200 } 201 202 // DGEMM callback 203 static void walk_2d_dgemm(void *usr, uint32_t idx) { 204 const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr; 205 RsBlasCall* call = (RsBlasCall*) mtls->sc; 206 207 double alpha = call->alpha.d; 208 double beta = call->beta.d; 209 210 walk_tiled_gemm<double, double, FnPtr_cblas_dgemm>(cblas_dgemm, alpha, beta, 1, call, mtls); 211 } 212 213 // CGEMM callback 214 static void walk_2d_cgemm(void *usr, uint32_t idx) { 215 const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr; 216 RsBlasCall* call = (RsBlasCall*) mtls->sc; 217 218 void * alpha = (void *)&call->alpha.c; 219 void * beta = (void *)&call->beta.c; 220 221 walk_tiled_gemm<float, void *, FnPtr_cblas_cgemm>(cblas_cgemm, alpha, beta, 2, call, mtls); 222 } 223 224 // ZGEMM callback 225 static void walk_2d_zgemm(void *usr, uint32_t idx) { 226 const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr; 227 RsBlasCall* call = (RsBlasCall*) mtls->sc; 228 229 void * alpha = (void *)&call->alpha.z; 230 void * beta = (void *)&call->beta.z; 231 232 walk_tiled_gemm<double, void *, FnPtr_cblas_zgemm>(cblas_zgemm, alpha, beta, 2, call, mtls); 233 } 234 235 236 void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, 237 const Allocation ** ain, 238 uint32_t inLen, 239 Allocation * aout, 240 const void * usr, 241 uint32_t usrLen, 242 const RsScriptCall *sc) { 243 RsBlasCall* call = (RsBlasCall*) usr; 244 // setup BLAS enum args 245 enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA; 246 enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB; 247 enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo; 248 enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag; 249 enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side; 250 251 void *A = nullptr; 252 void *B = nullptr; 253 void *C = nullptr; 254 void *X = nullptr; 255 void *Y = nullptr; 256 257 int lda = 0, ldb = 0, ldc = 0; 258 259 MTLaunchStructForEachBlas mtls; 260 261 #ifdef RS_COMPATIBILITY_LIB 262 // Allow BNNM even without libblas 263 if (call->func != RsBlas_bnnm && !isBlasLibInitialized) { 264 if (!loadBLASLib()) { 265 ALOGE("Failed to load the BLAS lib, IntrinsicBLAS NOT supported!\n"); 266 return; 267 } 268 isBlasLibInitialized = true; 269 } 270 #endif 271 272 switch (call->func) { 273 274 // Level 1 BLAS: returns into a 1D Allocation 275 276 277 // Level 2 BLAS 278 case (RsBlas_sgemv): 279 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 280 cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A, 281 lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 282 break; 283 case (RsBlas_sgbmv): 284 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 285 cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 286 call->alpha.f, (float*)A, lda, (float*)X, call->incX, 287 call->beta.f, (float*)Y, call->incY); 288 break; 289 case (RsBlas_strmv): 290 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 291 cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 292 lda, (float*)X, call->incX); 293 break; 294 case (RsBlas_stbmv): 295 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 296 cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 297 lda, (float*)X, call->incX); 298 break; 299 // stpmv takes a packed 1D Allocation only 300 case (RsBlas_stpmv): 301 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 302 cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 303 (float*)X, call->incX); 304 break; 305 case (RsBlas_strsv): 306 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 307 cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda, 308 (float*)X, call->incX); 309 break; 310 case (RsBlas_stbsv): 311 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 312 cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 313 lda, (float*)X, call->incX); 314 break; 315 case (RsBlas_stpsv): 316 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 317 cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 318 (float*)X, call->incX); 319 break; 320 case (RsBlas_dgemv): 321 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 322 cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A, 323 lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 324 break; 325 case (RsBlas_dgbmv): 326 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 327 cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 328 call->alpha.d, (double*)A, lda, (double*)X, call->incX, 329 call->beta.d, (double*)Y, call->incY); 330 break; 331 case (RsBlas_dtrmv): 332 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 333 cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 334 lda, (double*)X, call->incX); 335 break; 336 case (RsBlas_dtbmv): 337 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 338 cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 339 lda, (double*)X, call->incX); 340 break; 341 // stpmv takes a packed 1D Allocation only 342 case (RsBlas_dtpmv): 343 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 344 cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 345 (double*)X, call->incX); 346 break; 347 case (RsBlas_dtrsv): 348 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 349 cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda, 350 (double*)X, call->incX); 351 break; 352 case (RsBlas_dtbsv): 353 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 354 cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 355 lda, (double*)X, call->incX); 356 break; 357 case (RsBlas_dtpsv): 358 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 359 cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 360 (double*)X, call->incX); 361 break; 362 case (RsBlas_cgemv): 363 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 364 cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A, 365 lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY); 366 break; 367 case (RsBlas_cgbmv): 368 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 369 cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 370 (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX, 371 (void*)&call->beta.c, (void*)Y, call->incY); 372 break; 373 case (RsBlas_ctrmv): 374 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 375 cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 376 lda, (void*)X, call->incX); 377 break; 378 case (RsBlas_ctbmv): 379 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 380 cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 381 lda, (void*)X, call->incX); 382 break; 383 // stpmv takes a packed 1D Allocation only 384 case (RsBlas_ctpmv): 385 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 386 cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 387 (void*)X, call->incX); 388 break; 389 case (RsBlas_ctrsv): 390 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 391 cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 392 (void*)X, call->incX); 393 break; 394 case (RsBlas_ctbsv): 395 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 396 cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 397 lda, (void*)X, call->incX); 398 break; 399 case (RsBlas_ctpsv): 400 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 401 cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 402 (void*)X, call->incX); 403 break; 404 case (RsBlas_zgemv): 405 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 406 cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A, 407 lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY); 408 break; 409 case (RsBlas_zgbmv): 410 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 411 cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 412 (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX, 413 (void*)&call->beta.z, (void*)Y, call->incY); 414 break; 415 case (RsBlas_ztrmv): 416 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 417 cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 418 lda, (void*)X, call->incX); 419 break; 420 case (RsBlas_ztbmv): 421 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 422 cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 423 lda, (void*)X, call->incX); 424 break; 425 // stpmv takes a packed 1D Allocation only 426 case (RsBlas_ztpmv): 427 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 428 cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 429 (void*)X, call->incX); 430 break; 431 case (RsBlas_ztrsv): 432 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 433 cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 434 (void*)X, call->incX); 435 break; 436 case (RsBlas_ztbsv): 437 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 438 cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 439 lda, (void*)X, call->incX); 440 break; 441 case (RsBlas_ztpsv): 442 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 443 cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 444 (void*)X, call->incX); 445 break; 446 447 448 // S and D only 449 case (RsBlas_ssymv): 450 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 451 cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda, 452 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 453 break; 454 case (RsBlas_ssbmv): 455 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 456 cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f, 457 (float*)A, lda, (float*)X, call->incX, call->beta.f, 458 (float*)Y, call->incY); 459 break; 460 //sspmv requires a packed 1D Allocation 461 case (RsBlas_sspmv): 462 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 463 cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, 464 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 465 break; 466 // following calls have init reordered because A is output matrix 467 case (RsBlas_sger): 468 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 469 cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X, 470 call->incX, (float*)Y, call->incY, (float*)A, lda); 471 break; 472 case (RsBlas_ssyr): 473 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 474 cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 475 (float*)A, lda); 476 break; 477 // sspr is packed 1D Allocation A only 478 case (RsBlas_sspr): 479 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 480 cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 481 (float*)A); 482 break; 483 case (RsBlas_ssyr2): 484 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 485 cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 486 (float*)Y, call->incY, (float*)A, lda); 487 break; 488 // sspr2 is packed 1D Allocation A only 489 case (RsBlas_sspr2): 490 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 491 cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 492 (float*)Y, call->incY, (float*)A); 493 break; 494 case (RsBlas_dsymv): 495 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 496 cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda, 497 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 498 break; 499 case (RsBlas_dsbmv): 500 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 501 cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d, 502 (double*)A, lda, (double*)X, call->incX, call->beta.d, 503 (double*)Y, call->incY); 504 break; 505 // dspmv requires a packed 1D Allocation 506 case (RsBlas_dspmv): 507 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 508 cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, 509 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 510 break; 511 // following calls have init reordered because A is output matrix 512 case (RsBlas_dger): 513 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 514 cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X, 515 call->incX, (double*)Y, call->incY, (double*)A, lda); 516 break; 517 case (RsBlas_dsyr): 518 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 519 cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 520 (double*)A, lda); 521 break; 522 // dspr is packed 1D Allocation A only 523 case (RsBlas_dspr): 524 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 525 cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 526 (double*)A); 527 break; 528 case (RsBlas_dsyr2): 529 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 530 cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 531 (double*)Y, call->incY, (double*)A, lda); 532 break; 533 // dspr2 is packed 1D Allocation A only 534 case (RsBlas_dspr2): 535 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 536 cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 537 (double*)Y, call->incY, (double*)A); 538 break; 539 540 // C and Z only 541 case (RsBlas_chemv): 542 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 543 cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda, 544 X, call->incX, (void*)&call->beta.c, Y, call->incY); 545 break; 546 case (RsBlas_chbmv): 547 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 548 cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c, 549 A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY); 550 break; 551 case (RsBlas_chpmv): 552 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 553 cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, 554 X, call->incX, (void*)&call->beta.c, Y, call->incY); 555 break; 556 case (RsBlas_cgeru): 557 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 558 cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 559 X, call->incX, Y, call->incY, A, lda); 560 break; 561 case (RsBlas_cgerc): 562 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 563 cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 564 X, call->incX, Y, call->incY, A, lda); 565 break; 566 case (RsBlas_cher): 567 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 568 cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f, 569 X, call->incX, A, lda); 570 break; 571 // packed 1D Allocations only 572 case (RsBlas_chpr): 573 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 574 cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X, 575 call->incX, A); 576 break; 577 case (RsBlas_cher2): 578 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 579 cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, 580 X, call->incX, Y, call->incY, A, lda); 581 break; 582 // packed 1D Allocations only 583 case (RsBlas_chpr2): 584 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 585 cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X, 586 call->incX, Y, call->incY, A); 587 break; 588 case (RsBlas_zhemv): 589 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 590 cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda, 591 X, call->incX, (void*)&call->beta.z, Y, call->incY); 592 break; 593 case (RsBlas_zhbmv): 594 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 595 cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z, 596 A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY); 597 break; 598 case (RsBlas_zhpmv): 599 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 600 cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, 601 X, call->incX, (void*)&call->beta.z, Y, call->incY); 602 break; 603 case (RsBlas_zgeru): 604 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 605 cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 606 X, call->incX, Y, call->incY, A, lda); 607 break; 608 case (RsBlas_zgerc): 609 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 610 cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 611 X, call->incX, Y, call->incY, A, lda); 612 break; 613 case (RsBlas_zher): 614 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 615 cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d, 616 X, call->incX, A, lda); 617 break; 618 // packed 1D Allocations only 619 case (RsBlas_zhpr): 620 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 621 cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X, 622 call->incX, A); 623 break; 624 case (RsBlas_zher2): 625 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 626 cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, 627 X, call->incX, Y, call->incY, A, lda); 628 break; 629 // packed 1D Allocations only 630 case (RsBlas_zhpr2): 631 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 632 cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X, 633 call->incX, Y, call->incY, A); 634 break; 635 636 // Level 3 BLAS 637 case (RsBlas_sgemm): 638 setupGEMM(&mtls, ain, call, mCtx); 639 if (mtls.isThreadable) { 640 mCtx->launchThreads(walk_2d_sgemm, &mtls); 641 } else { 642 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 643 cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f, 644 (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 645 } 646 break; 647 case (RsBlas_ssymm): 648 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 649 cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A, 650 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 651 break; 652 case (RsBlas_ssyrk): 653 initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc); 654 cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 655 lda, call->beta.f, (float*)C, ldc); 656 break; 657 case (RsBlas_ssyr2k): 658 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 659 cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 660 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 661 break; 662 case (RsBlas_strmm): 663 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 664 cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 665 (float*)A, lda, (float*)B, ldb); 666 break; 667 case (RsBlas_strsm): 668 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 669 cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 670 (float*)A, lda, (float*)B, ldb); 671 break; 672 673 674 case (RsBlas_dgemm): 675 setupGEMM(&mtls, ain, call, mCtx); 676 if (mtls.isThreadable) { 677 mCtx->launchThreads(walk_2d_dgemm, &mtls); 678 } else { 679 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 680 cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d, 681 (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 682 } 683 break; 684 case (RsBlas_dsymm): 685 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 686 cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A, 687 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 688 break; 689 case (RsBlas_dsyrk): 690 initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc); 691 cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 692 lda, call->beta.d, (double*)C, ldc); 693 break; 694 case (RsBlas_dsyr2k): 695 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 696 cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 697 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 698 break; 699 case (RsBlas_dtrmm): 700 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 701 cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 702 (double*)A, lda, (double*)B, ldb); 703 break; 704 case (RsBlas_dtrsm): 705 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 706 cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 707 (double*)A, lda, (double*)B, ldb); 708 break; 709 710 case (RsBlas_cgemm): 711 setupGEMM(&mtls, ain, call, mCtx); 712 if (mtls.isThreadable) { 713 mCtx->launchThreads(walk_2d_cgemm, &mtls); 714 } else { 715 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 716 cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c, 717 A, lda, B, ldb, (void*)&call->beta.c, C, ldc); 718 } 719 break; 720 case (RsBlas_csymm): 721 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 722 cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, 723 lda, B, ldb, (void*)&call->beta.c, C, ldc); 724 break; 725 case (RsBlas_csyrk): 726 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 727 cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 728 lda, (void*)&call->beta.c, C, ldc); 729 break; 730 case (RsBlas_csyr2k): 731 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 732 cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 733 lda, B, ldb, (void*)&call->beta.c, C, ldc); 734 break; 735 case (RsBlas_ctrmm): 736 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 737 cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 738 A, lda, B, ldb); 739 break; 740 case (RsBlas_ctrsm): 741 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 742 cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 743 A, lda, B, ldb); 744 break; 745 746 case (RsBlas_zgemm): 747 setupGEMM(&mtls, ain, call, mCtx); 748 if (mtls.isThreadable) { 749 mCtx->launchThreads(walk_2d_zgemm, &mtls); 750 } else { 751 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 752 cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z, 753 A, lda, B, ldb, (void*)&call->beta.z, C, ldc); 754 } 755 break; 756 case (RsBlas_zsymm): 757 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 758 cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, 759 lda, B, ldb, (void*)&call->beta.z, C, ldc); 760 break; 761 case (RsBlas_zsyrk): 762 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 763 cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 764 lda, (void*)&call->beta.z, C, ldc); 765 break; 766 case (RsBlas_zsyr2k): 767 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 768 cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 769 lda, B, ldb, (void*)&call->beta.z, C, ldc); 770 break; 771 case (RsBlas_ztrmm): 772 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 773 cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 774 A, lda, B, ldb); 775 break; 776 case (RsBlas_ztrsm): 777 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 778 cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 779 A, lda, B, ldb); 780 break; 781 782 // Level 3 C and Z only 783 case (RsBlas_chemm): 784 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 785 cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda, 786 B, ldb, (void*)&call->beta.c, C, ldc); 787 break; 788 case (RsBlas_cherk): 789 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 790 cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda, 791 call->beta.f, C, ldc); 792 break; 793 case (RsBlas_cher2k): 794 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 795 cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda, 796 B, ldb, call->beta.f, C, ldc); 797 break; 798 799 case (RsBlas_zhemm): 800 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 801 cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda, 802 B, ldb, (void*)&call->beta.z, C, ldc); 803 break; 804 case (RsBlas_zherk): 805 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 806 cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda, 807 call->beta.d, C, ldc); 808 break; 809 case (RsBlas_zher2k): 810 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 811 cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda, 812 B, ldb, call->beta.d, C, ldc); 813 break; 814 815 816 case (RsBlas_bnnm): 817 initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc); 818 kernelBNNM(call->M, call->N, call->K, 819 (const uint8_t*)A, call->a_offset, lda, 820 (const uint8_t*)B, call->b_offset, ldb, 821 (uint8_t*)C, call->c_offset, ldc, 822 call->c_mult_int); 823 824 break; 825 826 default: 827 ALOGE("unimplemented\n"); 828 } 829 830 831 } 832 833 void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k, 834 const uint8_t* a, uint8_t a_offset, size_t lda, 835 const uint8_t* b, uint8_t b_offset, size_t ldb, 836 uint8_t* c, int32_t c_offset, size_t ldc, 837 int32_t c_mult_int) { 838 const int c_shift = 21; 839 #if defined(ARCH_ARM_HAVE_VFP) || defined(ARCH_ARM_USE_INTRINSICS) 840 // Non-optimized path for ARMv7 devices without SIMD instructions. 841 if (!gArchUseSIMD) { 842 /* 843 * Calculations are done in 1.10.21 fixed-point format for the final output, 844 * just before there's a shift down to drop the fractional parts. The output 845 * values are gated to 0 to 255 to fit in a byte, but the 10-bit format 846 * gives some headroom to avoid wrapping around on small overflows. 847 */ 848 size_t i = 0, j = 0, l = 0; 849 for (j = 0; j < n; j++) { 850 for (i = 0; i < m; i++) { 851 int32_t total = 0; 852 for (l = 0; l < k; l++) { 853 const int a_index = ((i * lda) + l); 854 const uint8_t a_as_byte = a[a_index]; 855 const int32_t a_as_int = (((int32_t)(a_as_byte)) - a_offset); 856 const int b_index = ((j * ldb) + l); 857 const uint8_t b_as_byte = b[b_index]; 858 const int32_t b_as_int = (((int32_t)(b_as_byte)) - b_offset); 859 const int32_t mult_as_int = (a_as_int * b_as_int); 860 total += mult_as_int; 861 } 862 const int c_index = ((ldc * i) + j); 863 int32_t output = 864 ((((total + c_offset) * c_mult_int) + (1 << (c_shift - 1))) 865 >> c_shift); 866 if (output > 255) { 867 output = 255; 868 } 869 if (output < 0) { 870 output = 0; 871 } 872 c[c_index] = (uint8_t)(output); 873 } 874 } 875 return; 876 } 877 #endif 878 879 // Using gemmlowp to calculate the low precision 8 bit GEMM. 880 // Set MaxNumThreads to 0. The value 0 lets the implementation query 881 // the system to determine the number of hardware threads 882 gemmlowp::eight_bit_int_gemm::SetMaxNumThreads(0); 883 884 bool transpose_a = true; 885 bool transpose_b = false; 886 bool transpose_c = true; 887 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(transpose_a, transpose_b, transpose_c, 888 m, n, k, a, -a_offset, lda, 889 b, -b_offset, ldb, c, c_offset, 890 c_mult_int, c_shift, ldc, 891 gemmlowp::eight_bit_int_gemm::BitDepthSetting::A8B8); 892 893 } 894 895 896 897 898 899 RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, 900 const Script *s) 901 : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) { 902 903 904 } 905 906 RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() { 907 } 908 909 RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx, 910 const Script *s, const Element *e) { 911 912 return new RsdCpuScriptIntrinsicBLAS(ctx, s); 913 } 914 915 } // namespace renderscript 916 } // namespace android 917