1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18 #include "rsCpuIntrinsic.h" 19 #include "rsCpuIntrinsicInlines.h" 20 #include "rsCpuBLASDispatch.h" 21 22 using namespace android; 23 using namespace android::renderscript; 24 25 namespace android { 26 namespace renderscript { 27 28 29 class RsdCpuScriptIntrinsicBLAS : public RsdCpuScriptIntrinsic { 30 public: 31 void invokeForEach(uint32_t slot, 32 const Allocation ** ain, 33 uint32_t inLen, 34 Allocation * aout, 35 const void * usr, 36 uint32_t usrLen, 37 const RsScriptCall *sc) override; 38 39 void populateScript(Script *) override; 40 ~RsdCpuScriptIntrinsicBLAS() override; 41 RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s); 42 43 protected: 44 45 uint8_t a_offset = 0; 46 uint8_t b_offset = 0; 47 uint8_t c_offset = 0; 48 49 #ifdef RS_COMPATIBILITY_LIB 50 bool isBlasLibInitialized = false; 51 #endif 52 static void kernelBNNM(size_t m, size_t n, size_t k, 53 const uint8_t* a, uint8_t a_offset, size_t lda, 54 const uint8_t* b, uint8_t b_offset, size_t ldb, 55 uint8_t* c, int32_t c_offset, size_t ldc, 56 int32_t c_mult_int); 57 58 59 60 }; 61 62 } 63 } 64 65 void RsdCpuScriptIntrinsicBLAS::populateScript(Script *s) { 66 s->mHal.info.exportedVariableCount = 0; 67 } 68 69 static void initABC(const Allocation ** ain, 70 size_t size, 71 void** A, 72 void** B, 73 void** C, 74 int* lda, 75 int* ldb, 76 int* ldc) 77 { 78 if (ain[0]) { 79 *A = ain[0]->mHal.drvState.lod[0].mallocPtr; 80 *lda = (int)(ain[0]->mHal.drvState.lod[0].stride/size); 81 } 82 if (ain[1]) { 83 *B = ain[1]->mHal.drvState.lod[0].mallocPtr; 84 *ldb = (int)(ain[1]->mHal.drvState.lod[0].stride/size); 85 } 86 if (ain[2]) { 87 *C = ain[2]->mHal.drvState.lod[0].mallocPtr; 88 *ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size); 89 } 90 91 92 } 93 94 void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot, 95 const Allocation ** ain, 96 uint32_t inLen, 97 Allocation * aout, 98 const void * usr, 99 uint32_t usrLen, 100 const RsScriptCall *sc) { 101 RsBlasCall* call = (RsBlasCall*) usr; 102 // setup BLAS enum args 103 enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA; 104 enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB; 105 enum CBLAS_UPLO Uplo = (enum CBLAS_UPLO)call->uplo; 106 enum CBLAS_DIAG Diag = (enum CBLAS_DIAG)call->diag; 107 enum CBLAS_SIDE Side = (enum CBLAS_SIDE)call->side; 108 109 void *A = nullptr; 110 void *B = nullptr; 111 void *C = nullptr; 112 void *X = nullptr; 113 void *Y = nullptr; 114 115 int lda = 0, ldb = 0, ldc = 0; 116 117 #ifdef RS_COMPATIBILITY_LIB 118 // Allow BNNM even without libblas 119 if (call->func != RsBlas_bnnm && !isBlasLibInitialized) { 120 if (!loadBLASLib()) { 121 ALOGE("Failed to load the BLAS lib, IntrinsicBLAS NOT supported!\n"); 122 return; 123 } 124 isBlasLibInitialized = true; 125 } 126 #endif 127 128 switch (call->func) { 129 130 // Level 1 BLAS: returns into a 1D Allocation 131 132 133 // Level 2 BLAS 134 case (RsBlas_sgemv): 135 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 136 cblas_sgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.f, (float*)A, 137 lda, (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 138 break; 139 case (RsBlas_sgbmv): 140 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 141 cblas_sgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 142 call->alpha.f, (float*)A, lda, (float*)X, call->incX, 143 call->beta.f, (float*)Y, call->incY); 144 break; 145 case (RsBlas_strmv): 146 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 147 cblas_strmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 148 lda, (float*)X, call->incX); 149 break; 150 case (RsBlas_stbmv): 151 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 152 cblas_stbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 153 lda, (float*)X, call->incX); 154 break; 155 // stpmv takes a packed 1D Allocation only 156 case (RsBlas_stpmv): 157 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 158 cblas_stpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 159 (float*)X, call->incX); 160 break; 161 case (RsBlas_strsv): 162 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 163 cblas_strsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, lda, 164 (float*)X, call->incX); 165 break; 166 case (RsBlas_stbsv): 167 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 168 cblas_stbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (float*)A, 169 lda, (float*)X, call->incX); 170 break; 171 case (RsBlas_stpsv): 172 initABC(ain, sizeof(float), &A, &X, nullptr, &lda, &ldb, nullptr); 173 cblas_stpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (float*)A, 174 (float*)X, call->incX); 175 break; 176 case (RsBlas_dgemv): 177 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 178 cblas_dgemv(CblasRowMajor, TransA, call->M, call->N, call->alpha.d, (double*)A, 179 lda, (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 180 break; 181 case (RsBlas_dgbmv): 182 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 183 cblas_dgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 184 call->alpha.d, (double*)A, lda, (double*)X, call->incX, 185 call->beta.d, (double*)Y, call->incY); 186 break; 187 case (RsBlas_dtrmv): 188 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 189 cblas_dtrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 190 lda, (double*)X, call->incX); 191 break; 192 case (RsBlas_dtbmv): 193 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 194 cblas_dtbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 195 lda, (double*)X, call->incX); 196 break; 197 // stpmv takes a packed 1D Allocation only 198 case (RsBlas_dtpmv): 199 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 200 cblas_dtpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 201 (double*)X, call->incX); 202 break; 203 case (RsBlas_dtrsv): 204 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 205 cblas_dtrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, lda, 206 (double*)X, call->incX); 207 break; 208 case (RsBlas_dtbsv): 209 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 210 cblas_dtbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (double*)A, 211 lda, (double*)X, call->incX); 212 break; 213 case (RsBlas_dtpsv): 214 initABC(ain, sizeof(double), &A, &X, nullptr, &lda, &ldb, nullptr); 215 cblas_dtpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (double*)A, 216 (double*)X, call->incX); 217 break; 218 case (RsBlas_cgemv): 219 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 220 cblas_cgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.c, (void*)A, 221 lda, (void*)X, call->incX, (void*)&call->beta.c, (void*)Y, call->incY); 222 break; 223 case (RsBlas_cgbmv): 224 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 225 cblas_cgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 226 (void*)&call->alpha.c, (void*)A, lda, (void*)X, call->incX, 227 (void*)&call->beta.c, (void*)Y, call->incY); 228 break; 229 case (RsBlas_ctrmv): 230 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 231 cblas_ctrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 232 lda, (void*)X, call->incX); 233 break; 234 case (RsBlas_ctbmv): 235 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 236 cblas_ctbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 237 lda, (void*)X, call->incX); 238 break; 239 // stpmv takes a packed 1D Allocation only 240 case (RsBlas_ctpmv): 241 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 242 cblas_ctpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 243 (void*)X, call->incX); 244 break; 245 case (RsBlas_ctrsv): 246 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 247 cblas_ctrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 248 (void*)X, call->incX); 249 break; 250 case (RsBlas_ctbsv): 251 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 252 cblas_ctbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 253 lda, (void*)X, call->incX); 254 break; 255 case (RsBlas_ctpsv): 256 initABC(ain, sizeof(float)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 257 cblas_ctpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 258 (void*)X, call->incX); 259 break; 260 case (RsBlas_zgemv): 261 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 262 cblas_zgemv(CblasRowMajor, TransA, call->M, call->N, (void*)&call->alpha.z, (void*)A, 263 lda, (void*)X, call->incX, (void*)&call->beta.z, (void*)Y, call->incY); 264 break; 265 case (RsBlas_zgbmv): 266 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 267 cblas_zgbmv(CblasRowMajor, TransA, call->M, call->N, call->KL, call->KU, 268 (void*)&call->alpha.z, (void*)A, lda, (void*)X, call->incX, 269 (void*)&call->beta.z, (void*)Y, call->incY); 270 break; 271 case (RsBlas_ztrmv): 272 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 273 cblas_ztrmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 274 lda, (void*)X, call->incX); 275 break; 276 case (RsBlas_ztbmv): 277 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 278 cblas_ztbmv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 279 lda, (void*)X, call->incX); 280 break; 281 // stpmv takes a packed 1D Allocation only 282 case (RsBlas_ztpmv): 283 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 284 cblas_ztpmv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 285 (void*)X, call->incX); 286 break; 287 case (RsBlas_ztrsv): 288 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 289 cblas_ztrsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, lda, 290 (void*)X, call->incX); 291 break; 292 case (RsBlas_ztbsv): 293 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 294 cblas_ztbsv(CblasRowMajor, Uplo, TransA, Diag, call->N, call->K, (void*)A, 295 lda, (void*)X, call->incX); 296 break; 297 case (RsBlas_ztpsv): 298 initABC(ain, sizeof(double)*2, &A, &X, nullptr, &lda, &ldb, nullptr); 299 cblas_ztpsv(CblasRowMajor, Uplo, TransA, Diag, call->N, (void*)A, 300 (void*)X, call->incX); 301 break; 302 303 304 // S and D only 305 case (RsBlas_ssymv): 306 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 307 cblas_ssymv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, lda, 308 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 309 break; 310 case (RsBlas_ssbmv): 311 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 312 cblas_ssbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.f, 313 (float*)A, lda, (float*)X, call->incX, call->beta.f, 314 (float*)Y, call->incY); 315 break; 316 //sspmv requires a packed 1D Allocation 317 case (RsBlas_sspmv): 318 initABC(ain, sizeof(float), &A, &X, &Y, &lda, &ldb, &ldc); 319 cblas_sspmv(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)A, 320 (float*)X, call->incX, call->beta.f, (float*)Y, call->incY); 321 break; 322 // following calls have init reordered because A is output matrix 323 case (RsBlas_sger): 324 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 325 cblas_sger(CblasRowMajor, call->M, call->N, call->alpha.f, (float*)X, 326 call->incX, (float*)Y, call->incY, (float*)A, lda); 327 break; 328 case (RsBlas_ssyr): 329 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 330 cblas_ssyr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 331 (float*)A, lda); 332 break; 333 // sspr is packed 1D Allocation A only 334 case (RsBlas_sspr): 335 initABC(ain, sizeof(float), &X, &A, nullptr, &ldb, &lda, nullptr); 336 cblas_sspr(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 337 (float*)A); 338 break; 339 case (RsBlas_ssyr2): 340 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 341 cblas_ssyr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 342 (float*)Y, call->incY, (float*)A, lda); 343 break; 344 // sspr2 is packed 1D Allocation A only 345 case (RsBlas_sspr2): 346 initABC(ain, sizeof(float), &X, &Y, &A, &ldb, &ldc, &lda); 347 cblas_sspr2(CblasRowMajor, Uplo, call->N, call->alpha.f, (float*)X, call->incX, 348 (float*)Y, call->incY, (float*)A); 349 break; 350 case (RsBlas_dsymv): 351 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 352 cblas_dsymv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, lda, 353 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 354 break; 355 case (RsBlas_dsbmv): 356 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 357 cblas_dsbmv(CblasRowMajor, Uplo, call->N, call->K, call->alpha.d, 358 (double*)A, lda, (double*)X, call->incX, call->beta.d, 359 (double*)Y, call->incY); 360 break; 361 // dspmv requires a packed 1D Allocation 362 case (RsBlas_dspmv): 363 initABC(ain, sizeof(double), &A, &X, &Y, &lda, &ldb, &ldc); 364 cblas_dspmv(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)A, 365 (double*)X, call->incX, call->beta.d, (double*)Y, call->incY); 366 break; 367 // following calls have init reordered because A is output matrix 368 case (RsBlas_dger): 369 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 370 cblas_dger(CblasRowMajor, call->M, call->N, call->alpha.d, (double*)X, 371 call->incX, (double*)Y, call->incY, (double*)A, lda); 372 break; 373 case (RsBlas_dsyr): 374 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 375 cblas_dsyr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 376 (double*)A, lda); 377 break; 378 // dspr is packed 1D Allocation A only 379 case (RsBlas_dspr): 380 initABC(ain, sizeof(double), &X, &A, nullptr, &ldb, &lda, nullptr); 381 cblas_dspr(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 382 (double*)A); 383 break; 384 case (RsBlas_dsyr2): 385 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 386 cblas_dsyr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 387 (double*)Y, call->incY, (double*)A, lda); 388 break; 389 // dspr2 is packed 1D Allocation A only 390 case (RsBlas_dspr2): 391 initABC(ain, sizeof(double), &X, &Y, &A, &ldb, &ldc, &lda); 392 cblas_dspr2(CblasRowMajor, Uplo, call->N, call->alpha.d, (double*)X, call->incX, 393 (double*)Y, call->incY, (double*)A); 394 break; 395 396 // C and Z only 397 case (RsBlas_chemv): 398 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 399 cblas_chemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, lda, 400 X, call->incX, (void*)&call->beta.c, Y, call->incY); 401 break; 402 case (RsBlas_chbmv): 403 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 404 cblas_chbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.c, 405 A, lda, X, call->incX, (void*)&call->beta.c, Y, call->incY); 406 break; 407 case (RsBlas_chpmv): 408 initABC(ain, sizeof(float)*2, &A, &X, &Y, &lda, &ldb, &ldc); 409 cblas_chpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, A, 410 X, call->incX, (void*)&call->beta.c, Y, call->incY); 411 break; 412 case (RsBlas_cgeru): 413 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 414 cblas_cgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 415 X, call->incX, Y, call->incY, A, lda); 416 break; 417 case (RsBlas_cgerc): 418 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 419 cblas_cgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.c, 420 X, call->incX, Y, call->incY, A, lda); 421 break; 422 case (RsBlas_cher): 423 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 424 cblas_cher(CblasRowMajor, Uplo, call->N, call->alpha.f, 425 X, call->incX, A, lda); 426 break; 427 // packed 1D Allocations only 428 case (RsBlas_chpr): 429 initABC(ain, sizeof(float)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 430 cblas_chpr(CblasRowMajor, Uplo, call->N, call->alpha.f, X, 431 call->incX, A); 432 break; 433 case (RsBlas_cher2): 434 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 435 cblas_cher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, 436 X, call->incX, Y, call->incY, A, lda); 437 break; 438 // packed 1D Allocations only 439 case (RsBlas_chpr2): 440 initABC(ain, sizeof(float)*2, &X, &Y, &A, &ldb, &ldc, &lda); 441 cblas_chpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.c, X, 442 call->incX, Y, call->incY, A); 443 break; 444 case (RsBlas_zhemv): 445 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 446 cblas_zhemv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, lda, 447 X, call->incX, (void*)&call->beta.z, Y, call->incY); 448 break; 449 case (RsBlas_zhbmv): 450 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 451 cblas_zhbmv(CblasRowMajor, Uplo, call->N, call->K, (void*)&call->alpha.z, 452 A, lda, X, call->incX, (void*)&call->beta.z, Y, call->incY); 453 break; 454 case (RsBlas_zhpmv): 455 initABC(ain, sizeof(double)*2, &A, &X, &Y, &lda, &ldb, &ldc); 456 cblas_zhpmv(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, A, 457 X, call->incX, (void*)&call->beta.z, Y, call->incY); 458 break; 459 case (RsBlas_zgeru): 460 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 461 cblas_zgeru(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 462 X, call->incX, Y, call->incY, A, lda); 463 break; 464 case (RsBlas_zgerc): 465 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 466 cblas_zgerc(CblasRowMajor, call->M, call->N, (void*)&call->alpha.z, 467 X, call->incX, Y, call->incY, A, lda); 468 break; 469 case (RsBlas_zher): 470 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 471 cblas_zher(CblasRowMajor, Uplo, call->N, call->alpha.d, 472 X, call->incX, A, lda); 473 break; 474 // packed 1D Allocations only 475 case (RsBlas_zhpr): 476 initABC(ain, sizeof(double)*2, &X, nullptr, &A, &ldb, nullptr, &lda); 477 cblas_zhpr(CblasRowMajor, Uplo, call->N, call->alpha.d, X, 478 call->incX, A); 479 break; 480 case (RsBlas_zher2): 481 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 482 cblas_zher2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, 483 X, call->incX, Y, call->incY, A, lda); 484 break; 485 // packed 1D Allocations only 486 case (RsBlas_zhpr2): 487 initABC(ain, sizeof(double)*2, &X, &Y, &A, &ldb, &ldc, &lda); 488 cblas_zhpr2(CblasRowMajor, Uplo, call->N, (void*)&call->alpha.z, X, 489 call->incX, Y, call->incY, A); 490 break; 491 492 // Level 3 BLAS 493 case (RsBlas_sgemm): 494 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 495 cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f, 496 (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 497 break; 498 case (RsBlas_ssymm): 499 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 500 cblas_ssymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.f, (float*)A, 501 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 502 break; 503 case (RsBlas_ssyrk): 504 initABC(ain, sizeof(float), &A, nullptr, &C, &lda, nullptr, &ldc); 505 cblas_ssyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 506 lda, call->beta.f, (float*)C, ldc); 507 break; 508 case (RsBlas_ssyr2k): 509 initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc); 510 cblas_ssyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, (float*)A, 511 lda, (float*)B, ldb, call->beta.f, (float*)C, ldc); 512 break; 513 case (RsBlas_strmm): 514 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 515 cblas_strmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 516 (float*)A, lda, (float*)B, ldb); 517 break; 518 case (RsBlas_strsm): 519 initABC(ain, sizeof(float), &A, &B, nullptr, &lda, &ldb, nullptr); 520 cblas_strsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.f, 521 (float*)A, lda, (float*)B, ldb); 522 break; 523 524 525 case (RsBlas_dgemm): 526 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 527 cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d, 528 (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 529 break; 530 case (RsBlas_dsymm): 531 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 532 cblas_dsymm(CblasRowMajor, Side, Uplo, call->M, call->N, call->alpha.d, (double*)A, 533 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 534 break; 535 case (RsBlas_dsyrk): 536 initABC(ain, sizeof(double), &A, nullptr, &C, &lda, nullptr, &ldc); 537 cblas_dsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 538 lda, call->beta.d, (double*)C, ldc); 539 break; 540 case (RsBlas_dsyr2k): 541 initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc); 542 cblas_dsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, (double*)A, 543 lda, (double*)B, ldb, call->beta.d, (double*)C, ldc); 544 break; 545 case (RsBlas_dtrmm): 546 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 547 cblas_dtrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 548 (double*)A, lda, (double*)B, ldb); 549 break; 550 case (RsBlas_dtrsm): 551 initABC(ain, sizeof(double), &A, &B, nullptr, &lda, &ldb, nullptr); 552 cblas_dtrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, call->alpha.d, 553 (double*)A, lda, (double*)B, ldb); 554 break; 555 556 case (RsBlas_cgemm): 557 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 558 cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c, 559 A, lda, B, ldb, (void*)&call->beta.c, C, ldc); 560 break; 561 case (RsBlas_csymm): 562 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 563 cblas_csymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, 564 lda, B, ldb, (void*)&call->beta.c, C, ldc); 565 break; 566 case (RsBlas_csyrk): 567 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 568 cblas_csyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 569 lda, (void*)&call->beta.c, C, ldc); 570 break; 571 case (RsBlas_csyr2k): 572 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 573 cblas_csyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, 574 lda, B, ldb, (void*)&call->beta.c, C, ldc); 575 break; 576 case (RsBlas_ctrmm): 577 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 578 cblas_ctrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 579 A, lda, B, ldb); 580 break; 581 case (RsBlas_ctrsm): 582 initABC(ain, sizeof(float)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 583 cblas_ctrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.c, 584 A, lda, B, ldb); 585 break; 586 587 case (RsBlas_zgemm): 588 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 589 cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z, 590 A, lda, B, ldb, (void*)&call->beta.z, C, ldc); 591 break; 592 case (RsBlas_zsymm): 593 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 594 cblas_zsymm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, 595 lda, B, ldb, (void*)&call->beta.z, C, ldc); 596 break; 597 case (RsBlas_zsyrk): 598 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 599 cblas_zsyrk(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 600 lda, (void*)&call->beta.z, C, ldc); 601 break; 602 case (RsBlas_zsyr2k): 603 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 604 cblas_zsyr2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, 605 lda, B, ldb, (void*)&call->beta.z, C, ldc); 606 break; 607 case (RsBlas_ztrmm): 608 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 609 cblas_ztrmm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 610 A, lda, B, ldb); 611 break; 612 case (RsBlas_ztrsm): 613 initABC(ain, sizeof(double)*2, &A, &B, nullptr, &lda, &ldb, nullptr); 614 cblas_ztrsm(CblasRowMajor, Side, Uplo, TransA, Diag, call->M, call->N, (void*)&call->alpha.z, 615 A, lda, B, ldb); 616 break; 617 618 // Level 3 C and Z only 619 case (RsBlas_chemm): 620 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 621 cblas_chemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.c, A, lda, 622 B, ldb, (void*)&call->beta.c, C, ldc); 623 break; 624 case (RsBlas_cherk): 625 initABC(ain, sizeof(float)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 626 cblas_cherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.f, A, lda, 627 call->beta.f, C, ldc); 628 break; 629 case (RsBlas_cher2k): 630 initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc); 631 cblas_cher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.c, A, lda, 632 B, ldb, call->beta.f, C, ldc); 633 break; 634 635 case (RsBlas_zhemm): 636 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 637 cblas_zhemm(CblasRowMajor, Side, Uplo, call->M, call->N, (void*)&call->alpha.z, A, lda, 638 B, ldb, (void*)&call->beta.z, C, ldc); 639 break; 640 case (RsBlas_zherk): 641 initABC(ain, sizeof(double)*2, &A, nullptr, &C, &lda, nullptr, &ldc); 642 cblas_zherk(CblasRowMajor, Uplo, TransA, call->N, call->K, call->alpha.d, A, lda, 643 call->beta.d, C, ldc); 644 break; 645 case (RsBlas_zher2k): 646 initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc); 647 cblas_zher2k(CblasRowMajor, Uplo, TransA, call->N, call->K, (void*)&call->alpha.z, A, lda, 648 B, ldb, call->beta.d, C, ldc); 649 break; 650 651 652 case (RsBlas_bnnm): 653 initABC(ain, sizeof(uint8_t), &A, &B, &C, &lda, &ldb, &ldc); 654 kernelBNNM(call->M, call->N, call->K, 655 (const uint8_t*)A, call->a_offset, lda, 656 (const uint8_t*)B, call->b_offset, ldb, 657 (uint8_t*)C, call->c_offset, ldc, 658 call->c_mult_int); 659 660 break; 661 662 default: 663 ALOGE("unimplemented\n"); 664 } 665 666 667 } 668 669 void RsdCpuScriptIntrinsicBLAS::kernelBNNM(size_t m, size_t n, size_t k, 670 const uint8_t* a, uint8_t a_offset, size_t lda, 671 const uint8_t* b, uint8_t b_offset, size_t ldb, 672 uint8_t* c, int32_t c_offset, size_t ldc, 673 int32_t c_mult_int) { 674 // Calculations are done in 1.10.21 fixed-point format for the final output, 675 // just before there's a shift down to drop the fractional parts. The output 676 // values are gated to 0 to 255 to fit in a byte, but the 10-bit format 677 // gives some headroom to avoid wrapping around on small overflows. 678 const int c_shift = 21; 679 size_t i = 0, j = 0, l = 0; 680 for (j = 0; j < n; j++) { 681 for (i = 0; i < m; i++) { 682 int32_t total = 0; 683 for (l = 0; l < k; l++) { 684 const int a_index = ((i * lda) + l); 685 const uint8_t a_as_byte = a[a_index]; 686 const int32_t a_as_int = (((int32_t)(a_as_byte)) - a_offset); 687 const int b_index = ((j * ldb) + l); 688 const uint8_t b_as_byte = b[b_index]; 689 const int32_t b_as_int = (((int32_t)(b_as_byte)) - b_offset); 690 const int32_t mult_as_int = (a_as_int * b_as_int); 691 total += mult_as_int; 692 } 693 const int c_index = ((ldc * i) + j); 694 int32_t output = 695 ((((total + c_offset) * c_mult_int) + (1 << (c_shift - 1))) 696 >> c_shift); 697 if (output > 255) { 698 output = 255; 699 } 700 if (output < 0) { 701 output = 0; 702 } 703 c[c_index] = (uint8_t)(output); 704 } 705 } 706 } 707 708 709 710 711 712 RsdCpuScriptIntrinsicBLAS::RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, 713 const Script *s) 714 : RsdCpuScriptIntrinsic(ctx, s, nullptr, RS_SCRIPT_INTRINSIC_ID_BLAS) { 715 716 717 } 718 719 RsdCpuScriptIntrinsicBLAS::~RsdCpuScriptIntrinsicBLAS() { 720 } 721 722 723 724 725 726 RsdCpuScriptImpl * rsdIntrinsic_BLAS(RsdCpuReferenceImpl *ctx, 727 const Script *s, const Element *e) { 728 729 return new RsdCpuScriptIntrinsicBLAS(ctx, s); 730 } 731