Home | History | Annotate | Download | only in cpp
      1 /*
      2  * Copyright (C) 2015 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 
     18 #include "RenderScript.h"
     19 #include "rsCppInternal.h"
     20 
     21 using namespace android;
     22 using namespace RSC;
     23 
     24 // ScriptIntrinsicBLAS APIS
     25 ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e)
     26     : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) {
     27 
     28 }
     29 
     30 sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(sp<RS> rs) {
     31     return new ScriptIntrinsicBLAS(rs, Element::U32(rs));
     32 }
     33 
     34 enum RsBlasDataType {
     35     SINGLE,
     36     DOUBLE,
     37     SINGLE_COMPLEX,
     38     DOUBLE_COMPLEX
     39 };
     40 
     41 static RsBlasCall
     42 setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func,
     43               int TransA, int TransB, int Side, int Uplo, int Diag,
     44               int M, int N, int K, int incX, int incY, int KL, int KU,
     45               float alphaF, float betaF, double alphaD, double betaD,
     46               float alphaCX, float alphaCY, float betaCX, float betaCY,
     47               double alphaZX, double alphaZY, double betaZX, double betaZY
     48               ) {
     49     RsBlasCall call;
     50     memset(&call, 0, sizeof(call));
     51     call.func = func;
     52     call.transA = (RsBlasTranspose)TransA;
     53     call.transB = (RsBlasTranspose)TransB;
     54     call.side = (RsBlasSide)Side;
     55     call.uplo = (RsBlasUplo)Uplo;
     56     call.diag = (RsBlasDiag)Diag;
     57     call.M = M;
     58     call.N = N;
     59     call.K = K;
     60 
     61     switch (dataType) {
     62         case SINGLE:
     63             // For Single-precision BLAS.
     64             call.alpha.f = alphaF;
     65             call.beta.f = betaF;
     66             break;
     67         case DOUBLE:
     68             // For Double-precision BLAS.
     69             call.alpha.d = alphaD;
     70             call.beta.d = betaD;
     71             break;
     72         case SINGLE_COMPLEX:
     73             // For Single-precision complex BLAS.
     74             call.alpha.c.r = alphaCX;
     75             call.alpha.c.i = alphaCY;
     76             call.beta.c.r = betaCX;
     77             call.beta.c.i = betaCY;
     78             break;
     79         case DOUBLE_COMPLEX:
     80             // For Double-precision complex BLAS.
     81             call.alpha.z.r = alphaZX;
     82             call.alpha.z.i = alphaZY;
     83             call.beta.z.r = betaZX;
     84             call.beta.z.i = betaZY;
     85             break;
     86         default:
     87             break;
     88     }
     89 
     90     call.incX = incX;
     91     call.incY = incY;
     92     call.KL = KL;
     93     call.KU = KU;
     94 
     95     return call;
     96 }
     97 
     98 static void
     99 nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
    100                             int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
    101                             float alpha, RsAllocation A, RsAllocation B,
    102                             float beta, RsAllocation C, int incX, int incY, int KL, int KU) {
    103     RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag,
    104                                     M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0,
    105                                     0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
    106     RsAllocation in_allocs[3] = {A, B, C};
    107     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
    108                                                       &call, sizeof(call), nullptr, 0));
    109 }
    110 
    111 
    112 static void
    113 nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
    114                             int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
    115                             double alpha, RsAllocation A, RsAllocation B,
    116                             double beta, RsAllocation C, int incX, int incY, int KL, int KU) {
    117     RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag,
    118                                     M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta,
    119                                     0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
    120     RsAllocation in_allocs[3] = {A, B, C};
    121     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
    122                                                       &call, sizeof(call), nullptr, 0));
    123 }
    124 
    125 static void
    126 nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
    127                              int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
    128                              float alphaX, float alphaY, RsAllocation A, RsAllocation B,
    129                              float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
    130     RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
    131                                     M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
    132                                     alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0);
    133     RsAllocation in_allocs[3] = {A, B, C};
    134     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
    135                                                       &call, sizeof(call), nullptr, 0));
    136 }
    137 
    138 static void
    139 nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
    140                        int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
    141                        double alphaX, double alphaY, RsAllocation A, RsAllocation B,
    142                        double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
    143     RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
    144                                     M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
    145                                     0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY);
    146     RsAllocation in_allocs[3] = {A, B, C};
    147     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
    148                                                       &call, sizeof(call), nullptr, 0));
    149 }
    150 
    151 
    152 static void
    153 nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K,
    154                           RsAllocation A, int a_offset, RsAllocation B, int b_offset,
    155                           RsAllocation C, int c_offset, int c_mult_int) {
    156     RsBlasCall call;
    157     memset(&call, 0, sizeof(call));
    158     call.func = RsBlas_bnnm;
    159     call.M = M;
    160     call.N = N;
    161     call.K = K;
    162     call.a_offset = a_offset & 0xFF;
    163     call.b_offset = b_offset & 0xFF;
    164     call.c_offset = c_offset;
    165     call.c_mult_int = c_mult_int;
    166 
    167     RsAllocation in_allocs[3] = {A, B, C};
    168     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
    169                                                       &call, sizeof(call), nullptr, 0));
    170 }
    171 
    172 /**
    173  * Level 2 BLAS
    174  */
    175 static void validateGEMV(RS* mRS, sp<const Element> e, RsBlasTranspose TransA, sp<Allocation> A,
    176                          sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
    177     int M = A->getType()->getY();
    178     int N = A->getType()->getX();
    179     if (!A->getType()->getElement()->isCompatible(e) ||
    180         !X->getType()->getElement()->isCompatible(e) ||
    181         !Y->getType()->getElement()->isCompatible(e)) {
    182         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    183     }
    184     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
    185         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    186     }
    187 
    188     if (incX <= 0 || incY <= 0) {
    189         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    190     }
    191     int expectedXDim = -1, expectedYDim = -1;
    192     if (TransA == RsBlasNoTrans) {
    193         expectedXDim = 1 + (N - 1) * incX;
    194         expectedYDim = 1 + (M - 1) * incY;
    195     } else {
    196         expectedXDim = 1 + (M - 1) * incX;
    197         expectedYDim = 1 + (N - 1) * incY;
    198     }
    199     if ((int)X->getType()->getX() != expectedXDim ||
    200         (int)Y->getType()->getX() != expectedYDim) {
    201         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV");
    202     }
    203 }
    204 
    205 void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, sp<Allocation> A, sp<Allocation> X,
    206                                 int incX, float beta, sp<Allocation> Y, int incY) {
    207     validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
    208     int M = A->getType()->getY();
    209     int N = A->getType()->getX();
    210     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv,
    211                                 TransA, 0, 0, 0, 0, M, N, 0,
    212                                 alpha, A->getID(), X->getID(),
    213                                 beta, Y->getID(), incX, incY, 0, 0);
    214 }
    215 
    216 void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, sp<Allocation> A, sp<Allocation> X,
    217                                 int incX, double beta, sp<Allocation> Y, int incY) {
    218     validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
    219     int M = A->getType()->getY();
    220     int N = A->getType()->getX();
    221     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv,
    222                                 TransA, 0, 0, 0, 0, M, N, 0,
    223                                 alpha, A->getID(), X->getID(),
    224                                 beta, Y->getID(), incX, incY, 0, 0);
    225 }
    226 
    227 void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, sp<Allocation> A, sp<Allocation> X,
    228                                 int incX, Float2 beta, sp<Allocation> Y, int incY) {
    229     validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
    230     int M = A->getType()->getY();
    231     int N = A->getType()->getX();
    232     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv,
    233                                  TransA, 0, 0, 0, 0, M, N, 0,
    234                                  alpha.x, alpha.y, A->getID(), X->getID(),
    235                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
    236 }
    237 
    238 void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
    239                                 int incX, Double2 beta, sp<Allocation> Y, int incY) {
    240     validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
    241     int M = A->getType()->getY();
    242     int N = A->getType()->getX();
    243     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv,
    244                            TransA, 0, 0, 0, 0, M, N, 0,
    245                            alpha.x, alpha.y, A->getID(), X->getID(),
    246                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
    247 }
    248 
    249 void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, sp<Allocation> A,
    250                                 sp<Allocation> X, int incX, float beta, sp<Allocation> Y, int incY) {
    251     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
    252     validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
    253     if (KL < 0 || KU < 0) {
    254         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
    255     }
    256     int M = A->getType()->getY();
    257     int N = A->getType()->getX();
    258 
    259     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv,
    260                                 TransA, 0, 0, 0, 0, M, N, 0,
    261                                 alpha, A->getID(), X->getID(),
    262                                 beta, Y->getID(), incX, incY, KL, KU);
    263 }
    264 
    265 void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, sp<Allocation> A,
    266                                 sp<Allocation> X, int incX, double beta, sp<Allocation> Y, int incY) {
    267     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
    268     validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
    269     if (KL < 0 || KU < 0) {
    270         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
    271     }
    272     int M = A->getType()->getY();
    273     int N = A->getType()->getX();
    274 
    275     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv,
    276                                 TransA, 0, 0, 0, 0, M, N, 0,
    277                                 alpha, A->getID(), X->getID(),
    278                                 beta, Y->getID(), incX, incY, KL, KU);
    279 }
    280 
    281 void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, sp<Allocation> A,
    282                                 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
    283     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
    284     validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
    285     if (KL < 0 || KU < 0) {
    286         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
    287     }
    288     int M = A->getType()->getY();
    289     int N = A->getType()->getX();
    290 
    291     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv,
    292                                  TransA, 0, 0, 0, 0, M, N, 0,
    293                                  alpha.x, alpha.y, A->getID(), X->getID(),
    294                                  beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
    295 }
    296 
    297 void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, sp<Allocation> A,
    298                                 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
    299     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
    300     validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
    301     if (KL < 0 || KU < 0) {
    302         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
    303     }
    304     int M = A->getType()->getY();
    305     int N = A->getType()->getX();
    306 
    307     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv,
    308                            TransA, 0, 0, 0, 0, M, N, 0,
    309                            alpha.x, alpha.y, A->getID(), X->getID(),
    310                            beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
    311 }
    312 
    313 static void validateTRMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA,
    314                          RsBlasDiag Diag, sp<Allocation> A, sp<Allocation> X, int incX) {
    315     int N = A->getType()->getY();
    316     if ((int)A->getType()->getX() != N) {
    317         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV");
    318     }
    319     if (!A->getType()->getElement()->isCompatible(e) ||
    320         !X->getType()->getElement()->isCompatible(e)) {
    321         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    322     }
    323     if (X->getType()->getY() > 1) {
    324         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    325     }
    326 
    327     if (incX <= 0) {
    328         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    329     }
    330     int expectedXDim = 1 + (N - 1) * incX;
    331     if ((int)X->getType()->getX() != expectedXDim) {
    332         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV");
    333     }
    334 }
    335 
    336 static int validateTPMV(RS* mRS, sp<const Element> e,  RsBlasUplo Uplo, RsBlasTranspose TransA,
    337                         RsBlasDiag Diag, sp<Allocation> Ap, sp<Allocation> X, int incX) {
    338     if (!Ap->getType()->getElement()->isCompatible(e) ||
    339         !X->getType()->getElement()->isCompatible(e)) {
    340         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    341     }
    342     if (X->getType()->getY() > 1) {
    343         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    344     }
    345 
    346     if (Ap->getType()->getY() > 1) {
    347         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
    348     }
    349 
    350     int N = sqrt((double)Ap->getType()->getX() * 2);
    351     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
    352         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
    353     }
    354     if (incX <= 0) {
    355         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    356     }
    357     int expectedXDim = 1 + (N - 1) * incX;
    358     if ((int)X->getType()->getX() != expectedXDim) {
    359         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV");
    360     }
    361 
    362     return N;
    363 }
    364 
    365 
    366 void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    367                                 sp<Allocation> A, sp<Allocation> X, int incX) {
    368     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
    369     int N = A->getType()->getY();
    370     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv,
    371                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
    372                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    373 }
    374 
    375 void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    376                                 sp<Allocation> A, sp<Allocation> X, int incX) {
    377     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
    378     int N = A->getType()->getY();
    379     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv,
    380                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
    381                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    382 }
    383 
    384 void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    385                                 sp<Allocation> A, sp<Allocation> X, int incX) {
    386     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
    387     int N = A->getType()->getY();
    388     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv,
    389                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
    390                                  A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    391 }
    392 
    393 void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    394                                 sp<Allocation> A, sp<Allocation> X, int incX) {
    395     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
    396     int N = A->getType()->getY();
    397     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv,
    398                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
    399                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    400 }
    401 
    402 void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    403                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
    404     // TBMV has the same requirements as TRMV + K >= 0
    405     if (K < 0) {
    406         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
    407     }
    408     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
    409     int N = A->getType()->getY();
    410     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv,
    411                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
    412                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    413 }
    414 
    415 void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    416                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
    417     // TBMV has the same requirements as TRMV + K >= 0
    418     if (K < 0) {
    419         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
    420     }
    421     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
    422     int N = A->getType()->getY();
    423     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv,
    424                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
    425                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    426 }
    427 
    428 void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    429                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
    430     // TBMV has the same requirements as TRMV + K >= 0
    431     if (K < 0) {
    432         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
    433     }
    434     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
    435     int N = A->getType()->getY();
    436     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv,
    437                                  TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
    438                                  A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    439 }
    440 
    441 void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    442                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
    443     // TBMV has the same requirements as TRMV + K >= 0
    444     if (K < 0) {
    445         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
    446     }
    447     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
    448     int N = A->getType()->getY();
    449     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv,
    450                            TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
    451                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    452 }
    453 
    454 void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    455                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
    456     int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
    457     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv,
    458                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
    459                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    460 }
    461 
    462 void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    463                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
    464     int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
    465     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv,
    466                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
    467                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    468 }
    469 
    470 void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    471                                 sp<Allocation> Ap,  sp<Allocation> X,  int incX) {
    472     int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
    473     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv,
    474                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
    475                                  Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    476 }
    477 
    478 void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    479                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
    480     int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
    481     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv,
    482                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
    483                            Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    484 }
    485 
    486 void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    487                                 sp<Allocation> A, sp<Allocation> X, int incX) {
    488     // TRSV is the same as TRMV
    489     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
    490     int N = A->getType()->getY();
    491     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv,
    492                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
    493                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    494 }
    495 
    496 void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    497                                 sp<Allocation> A,  sp<Allocation> X,  int incX) {
    498     // TRSV is the same as TRMV
    499     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
    500     int N = A->getType()->getY();
    501     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv,
    502                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
    503                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    504 
    505 }
    506 
    507 void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    508                                 sp<Allocation> A, sp<Allocation> X, int incX) {
    509     // TRSV is the same as TRMV
    510     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
    511     int N = A->getType()->getY();
    512     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv,
    513                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
    514                                  A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    515 
    516 }
    517 
    518 void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    519                                 sp<Allocation> A, sp<Allocation> X, int incX) {
    520     // TRSV is the same as TRMV
    521     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
    522     int N = A->getType()->getY();
    523     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv,
    524                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
    525                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    526 
    527 }
    528 
    529 void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    530                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
    531     // TBSV is the same as TRMV + K >= 0
    532     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
    533     int N = A->getType()->getY();
    534     if (K < 0) {
    535         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
    536     }
    537     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv,
    538                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
    539                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    540 }
    541 
    542 void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    543                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
    544     // TBSV is the same as TRMV + K >= 0
    545     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
    546     int N = A->getType()->getY();
    547     if (K < 0) {
    548         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
    549     }
    550     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv,
    551                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
    552                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    553 }
    554 
    555 void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    556                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
    557     // TBSV is the same as TRMV + K >= 0
    558     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
    559     int N = A->getType()->getY();
    560     if (K < 0) {
    561         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
    562     }
    563     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv,
    564                                  TransA, 0, 0, Uplo, Diag, 0, N, K,
    565                                  0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    566 }
    567 
    568 void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    569                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
    570     // TBSV is the same as TRMV + K >= 0
    571     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
    572     int N = A->getType()->getY();
    573     if (K < 0) {
    574         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
    575     }
    576     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv,
    577                            TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
    578                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    579 }
    580 
    581 void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    582                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
    583     // TPSV is same as TPMV
    584     int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
    585     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv,
    586                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
    587                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    588 }
    589 
    590 void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    591                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
    592     // TPSV is same as TPMV
    593     int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
    594     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv,
    595                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
    596                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
    597 }
    598 
    599 void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    600                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
    601     // TPSV is same as TPMV
    602     int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
    603     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv,
    604                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
    605                                  Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    606 }
    607 
    608 void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
    609                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
    610     // TPSV is same as TPMV
    611     int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
    612     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv,
    613                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
    614                            Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
    615 }
    616 
    617 /**
    618  * Level 2, S and D only
    619  */
    620 static int validateSYMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> A,
    621                         sp<Allocation> X, sp<Allocation> Y, int incX, int incY) {
    622     int N = A->getType()->getY();
    623     if ((int)A->getType()->getX() != N) {
    624         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV");
    625     }
    626     if (!A->getType()->getElement()->isCompatible(e) ||
    627         !X->getType()->getElement()->isCompatible(e) ||
    628         !Y->getType()->getElement()->isCompatible(e) ) {
    629         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    630     }
    631     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
    632         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    633     }
    634 
    635     if (incX <= 0 || incY <= 0) {
    636         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    637     }
    638     int expectedXDim = 1 + (N - 1) * incX;
    639     if ((int)X->getType()->getX() != expectedXDim) {
    640         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
    641     }
    642     int expectedYDim = 1 + (N - 1) * incY;
    643     if ((int)Y->getType()->getX() != expectedYDim) {
    644         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
    645     }
    646     return N;
    647 }
    648 static int validateSPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> Ap,
    649                         sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
    650     if (!Ap->getType()->getElement()->isCompatible(e) ||
    651         !X->getType()->getElement()->isCompatible(e) ||
    652         !Y->getType()->getElement()->isCompatible(e)) {
    653         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    654     }
    655     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
    656         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    657     }
    658 
    659     if (Ap->getType()->getY() > 1) {
    660         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
    661     }
    662 
    663     int N = sqrt((double)Ap->getType()->getX() * 2);
    664     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
    665         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
    666     }
    667     if (incX <= 0 || incY <= 0) {
    668         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    669     }
    670     int expectedXDim = 1 + (N - 1) * incX;
    671     if ((int)X->getType()->getX() != expectedXDim) {
    672         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
    673     }
    674     int expectedYDim = 1 + (N - 1) * incY;
    675     if ((int)Y->getType()->getX() != expectedYDim) {
    676         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
    677     }
    678 
    679     return N;
    680 }
    681 static void validateGER(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
    682                         sp<Allocation> Y, int incY, sp<Allocation> A) {
    683     if (!A->getType()->getElement()->isCompatible(e) ||
    684         !X->getType()->getElement()->isCompatible(e) ||
    685         !Y->getType()->getElement()->isCompatible(e) ) {
    686         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    687     }
    688 
    689     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
    690         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    691     }
    692 
    693     int M = A->getType()->getY();
    694     int N = A->getType()->getX();
    695 
    696     if (N < 1 || M < 1) {
    697         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER");
    698     }
    699     if (incX <= 0 || incY <= 0) {
    700         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    701     }
    702     int expectedXDim = 1 + (M - 1) * incX;
    703     if ((int)X->getType()->getX() != expectedXDim) {
    704         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
    705     }
    706     int expectedYDim = 1 + (N - 1) * incY;
    707     if ((int)Y->getType()->getX() != expectedYDim) {
    708         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
    709     }
    710 
    711 
    712 }
    713 static int validateSYR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
    714                        sp<Allocation> X, int incX, sp<Allocation> A) {
    715     if (!A->getType()->getElement()->isCompatible(e) ||
    716         !X->getType()->getElement()->isCompatible(e)) {
    717         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    718     }
    719 
    720     int N = A->getType()->getX();
    721 
    722     if (X->getType()->getY() > 1) {
    723         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    724     }
    725     if (N != (int)A->getType()->getY()) {
    726         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
    727     }
    728     if (incX <= 0) {
    729         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    730     }
    731     int expectedXDim = 1 + (N - 1) * incX;
    732     if ((int)X->getType()->getX() != expectedXDim) {
    733         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
    734     }
    735     return N;
    736 }
    737 static int validateSPR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
    738                        sp<Allocation> X, int incX, sp<Allocation> Ap) {
    739     if (!Ap->getType()->getElement()->isCompatible(e) ||
    740         !X->getType()->getElement()->isCompatible(e)) {
    741         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    742     }
    743     if (X->getType()->getY() > 1) {
    744         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    745     }
    746 
    747     if (Ap->getType()->getY() > 1) {
    748         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
    749     }
    750 
    751     int N = sqrt((double)Ap->getType()->getX() * 2);
    752     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
    753         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
    754     }
    755     if (incX <= 0) {
    756         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    757     }
    758     int expectedXDim = 1 + (N - 1) * incX;
    759     if ((int)X->getType()->getX() != expectedXDim) {
    760         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR");
    761     }
    762 
    763     return N;
    764 }
    765 
    766 static int validateSYR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
    767                         int incX, sp<Allocation> Y, int incY, sp<Allocation> A) {
    768     if (!A->getType()->getElement()->isCompatible(e) ||
    769         !X->getType()->getElement()->isCompatible(e) ||
    770         !Y->getType()->getElement()->isCompatible(e)) {
    771         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    772     }
    773 
    774     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
    775         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    776     }
    777 
    778     int N = A->getType()->getX();
    779 
    780     if (N != (int)A->getType()->getY()) {
    781         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
    782     }
    783     if (incX <= 0 || incY <= 0) {
    784         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    785     }
    786     int expectedXDim = 1 + (N - 1) * incX;
    787     int expectedYDim = 1 + (N - 1) * incY;
    788     if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
    789         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
    790     }
    791     return N;
    792 
    793 }
    794 static int validateSPR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
    795                         int incX, sp<Allocation> Y, int incY, sp<Allocation> Ap) {
    796     if (!Ap->getType()->getElement()->isCompatible(e) ||
    797         !X->getType()->getElement()->isCompatible(e) ||
    798         !Y->getType()->getElement()->isCompatible(e)) {
    799         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    800     }
    801     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
    802         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    803     }
    804 
    805     if (Ap->getType()->getY() > 1) {
    806         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
    807     }
    808 
    809     int N = sqrt((double)Ap->getType()->getX() * 2);
    810     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
    811         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
    812     }
    813     if (incX <= 0 || incY <= 0) {
    814         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    815     }
    816     int expectedXDim = 1 + (N - 1) * incX;
    817     int expectedYDim = 1 + (N - 1) * incY;
    818     if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
    819         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2");
    820     }
    821 
    822     return N;
    823 }
    824 
    825 void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, sp<Allocation> A, sp<Allocation> X,
    826                                 int incX, float beta, sp<Allocation> Y, int incY) {
    827     int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
    828     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv,
    829                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    830                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
    831 }
    832 
    833 void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, sp<Allocation> A, sp<Allocation> X,
    834                                 int incX, float beta, sp<Allocation> Y, int incY) {
    835     // SBMV is the same as SYMV + K >= 0
    836     if (K < 0) {
    837         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
    838     }
    839     int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
    840     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv,
    841                                 0, 0, 0, Uplo, 0, 0, N, K, alpha,
    842                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
    843 }
    844 
    845 void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, sp<Allocation> Ap, sp<Allocation> X,
    846                                 int incX, float beta, sp<Allocation> Y, int incY) {
    847     int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY);
    848     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv,
    849                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    850                                 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
    851 }
    852 
    853 void ScriptIntrinsicBLAS::SGER(float alpha, sp<Allocation> X, int incX,
    854                                sp<Allocation> Y, int incY, sp<Allocation> A) {
    855     int M = A->getType()->getY();
    856     int N = A->getType()->getX();
    857     validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A);
    858     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger,
    859                                 0, 0, 0, 0, 0, M, N, 0, alpha,
    860                                 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
    861 }
    862 
    863 void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
    864                                int incX, sp<Allocation> A) {
    865     int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A);
    866     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr,
    867                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    868                                 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
    869 }
    870 
    871 void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
    872                                int incX, sp<Allocation> Ap) {
    873     int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap);
    874     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr,
    875                                 0, 0, 0, Uplo, 0, 0, N, 0,
    876                                 alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
    877 }
    878 
    879 void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
    880                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
    881     int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A);
    882     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2,
    883                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    884                                 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
    885 }
    886 
    887 void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
    888                                 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
    889     int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap);
    890     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2,
    891                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    892                                 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
    893 }
    894 
    895 void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, sp<Allocation> A, sp<Allocation> X,
    896                                 int incX, double beta, sp<Allocation> Y, int incY) {
    897     int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
    898     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv,
    899                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    900                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
    901 }
    902 
    903 void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, sp<Allocation> A, sp<Allocation> X,
    904                                 int incX, double beta, sp<Allocation> Y, int incY) {
    905     // SBMV is the same as SYMV + K >= 0
    906     if (K < 0) {
    907         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
    908     }
    909     int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
    910     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv,
    911                                 0, 0, 0, Uplo, 0, 0, N, K, alpha,
    912                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
    913 }
    914 
    915 void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, sp<Allocation> Ap, sp<Allocation> X,
    916                                 int incX, double beta, sp<Allocation> Y, int incY) {
    917     int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY);
    918     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv,
    919                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    920                                 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
    921 }
    922 
    923 void ScriptIntrinsicBLAS::DGER(double alpha, sp<Allocation> X, int incX, sp<Allocation> Y,
    924                                int incY, sp<Allocation> A) {
    925     int M = A->getType()->getY();
    926     int N = A->getType()->getX();
    927     validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A);
    928     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger,
    929                                 0, 0, 0, 0, 0, M, N, 0, alpha,
    930                                 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
    931 }
    932 
    933 void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
    934                                int incX, sp<Allocation> A) {
    935     int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A);
    936     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr,
    937                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    938                                 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
    939 }
    940 
    941 void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
    942                                int incX, sp<Allocation> Ap) {
    943     int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap);
    944     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr,
    945                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    946                                 X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
    947 }
    948 
    949 void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
    950                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
    951     int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A);
    952     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2,
    953                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    954                                 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
    955 }
    956 
    957 void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
    958                                 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
    959     int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap);
    960     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2,
    961                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
    962                                 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
    963 }
    964 
    965 
    966 /**
    967  * Level 2, C and Z only
    968  */
    969 
    970 static void validateGERU(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
    971                          sp<Allocation> Y, int incY, sp<Allocation> A) {
    972     if (!A->getType()->getElement()->isCompatible(e) ||
    973         !X->getType()->getElement()->isCompatible(e) ||
    974         !Y->getType()->getElement()->isCompatible(e)) {
    975         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
    976     }
    977     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
    978         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
    979     }
    980 
    981     int M = A->getType()->getY();
    982     int N = A->getType()->getX();
    983     if (incX <= 0 || incY <= 0) {
    984         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
    985     }
    986     int expectedXDim = 1 + (M - 1) * incX;
    987     if ((int)X->getType()->getX() != expectedXDim) {
    988         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
    989     }
    990     int expectedYDim = 1 + (N - 1) * incY;
    991     if ((int)Y->getType()->getX() != expectedYDim) {
    992         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
    993     }
    994 
    995 }
    996 
    997 void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> A,
    998                                 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
    999     // HEMV is the same as SYR2 validation-wise
   1000     int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
   1001     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv,
   1002                                  0, 0, 0, Uplo, 0, 0, N, 0,
   1003                                  alpha.x, alpha.y, A->getID(), X->getID(),
   1004                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
   1005 }
   1006 
   1007 void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, sp<Allocation> A,
   1008                                 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
   1009     // HBMV is the same as SYR2 validation-wise
   1010     int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
   1011     if (K < 0) {
   1012         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
   1013     }
   1014     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv,
   1015                                  0, 0, 0, Uplo, 0, 0, N, K,
   1016                                  alpha.x, alpha.y, A->getID(), X->getID(),
   1017                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
   1018 }
   1019 
   1020 void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> Ap,
   1021                                 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
   1022     // HPMV is the same as SPR2
   1023     int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
   1024     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv,
   1025                                  0, 0, 0, Uplo, 0, 0, N, 0,
   1026                                  alpha.x, alpha.y, Ap->getID(), X->getID(),
   1027                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
   1028 }
   1029 
   1030 void ScriptIntrinsicBLAS::CGERU(Float2 alpha, sp<Allocation> X, int incX,
   1031                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
   1032     validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
   1033     int M = A->getType()->getY();
   1034     int N = A->getType()->getX();
   1035     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru,
   1036                                  0, 0, 0, 0, 0, M, N, 0,
   1037                                  alpha.x, alpha.y, X->getID(), Y->getID(),
   1038                                  0, 0, A->getID(), incX, incY, 0, 0);
   1039 }
   1040 
   1041 void ScriptIntrinsicBLAS::CGERC(Float2 alpha, sp<Allocation> X, int incX,
   1042                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
   1043     // Same as GERU
   1044     validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
   1045     int M = A->getType()->getY();
   1046     int N = A->getType()->getX();
   1047     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc,
   1048                                  0, 0, 0, 0, 0, M, N, 0,
   1049                                  alpha.x, alpha.y, X->getID(), Y->getID(),
   1050                                  0, 0, A->getID(), incX, incY, 0, 0);
   1051 }
   1052 
   1053 void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
   1054                                int incX, sp<Allocation> A) {
   1055     // Same as SYR
   1056     int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A);
   1057     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher,
   1058                                  0, 0, 0, Uplo, 0, 0, N, 0,
   1059                                  alpha, 0, X->getID(), 0,
   1060                                  0, 0, A->getID(), incX, 0, 0, 0);
   1061 }
   1062 
   1063 void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
   1064                                int incX, sp<Allocation> Ap) {
   1065     // Equivalent to SPR for validation
   1066     int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap);
   1067     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr,
   1068                                  0, 0, 0, Uplo, 0, 0, N, 0,
   1069                                  alpha, 0, X->getID(), 0,
   1070                                  0, 0, Ap->getID(), incX, 0, 0, 0);
   1071 }
   1072 
   1073 void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
   1074                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
   1075     // Same as SYR2
   1076     int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
   1077     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2,
   1078                                  0, 0, 0, Uplo, 0, 0, N, 0,
   1079                                  alpha.x, alpha.y, X->getID(), Y->getID(),
   1080                                  0, 0, A->getID(), incX, incY, 0, 0);
   1081 }
   1082 
   1083 void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
   1084                                 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
   1085     // Same as SPR2
   1086     int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
   1087     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2,
   1088                                  0, 0, 0, Uplo, 0, 0, N, 0,
   1089                                  alpha.x, alpha.y, X->getID(), Y->getID(),
   1090                                  0, 0, Ap->getID(), incX, incY, 0, 0);
   1091 }
   1092 
   1093 void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> A,
   1094                                 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
   1095     // HEMV is the same as SYR2 validation-wise
   1096     int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
   1097     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv,
   1098                            0, 0, 0, Uplo, 0, 0, N, 0,
   1099                            alpha.x, alpha.y, A->getID(), X->getID(),
   1100                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
   1101 }
   1102 
   1103 void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
   1104                                 int incX, Double2 beta, sp<Allocation> Y, int incY) {
   1105     // HBMV is the same as SYR2 validation-wise
   1106     int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
   1107     if (K < 0) {
   1108         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
   1109     }
   1110     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv,
   1111                            0, 0, 0, Uplo, 0, 0, N, K,
   1112                            alpha.x, alpha.y, A->getID(), X->getID(),
   1113                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
   1114 }
   1115 
   1116 void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> Ap, sp<Allocation> X,
   1117                                 int incX, Double2 beta, sp<Allocation> Y, int incY) {
   1118     // HPMV is the same as SPR2
   1119     int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
   1120     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv,
   1121                            0, 0, 0, Uplo, 0, 0, N, 0,
   1122                            alpha.x, alpha.y, Ap->getID(), X->getID(),
   1123                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
   1124 }
   1125 
   1126 void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, sp<Allocation> X, int incX,
   1127                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
   1128     validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
   1129     int M = A->getType()->getY();
   1130     int N = A->getType()->getX();
   1131     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru,
   1132                            0, 0, 0, 0, 0, M, N, 0,
   1133                            alpha.x, alpha.y, X->getID(), Y->getID(),
   1134                            0, 0, A->getID(), incX, incY, 0, 0);
   1135 }
   1136 
   1137 void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, sp<Allocation> X, int incX,
   1138                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
   1139     // Same as GERU
   1140     validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
   1141     int M = A->getType()->getY();
   1142     int N = A->getType()->getX();
   1143     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc,
   1144                            0, 0, 0, 0, 0, M, N, 0,
   1145                            alpha.x, alpha.y, X->getID(), Y->getID(),
   1146                            0, 0, A->getID(), incX, incY, 0, 0);
   1147 }
   1148 
   1149 void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
   1150                                int incX, sp<Allocation> A) {
   1151     // Same as SYR
   1152     int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A);
   1153     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher,
   1154                            0, 0, 0, Uplo, 0, 0, N, 0,
   1155                            alpha, 0, X->getID(), 0,
   1156                            0, 0, A->getID(), incX, 0, 0, 0);
   1157 }
   1158 
   1159 void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
   1160                                int incX, sp<Allocation> Ap) {
   1161     // Equivalent to SPR for validation
   1162     int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap);
   1163     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr,
   1164                            0, 0, 0, Uplo, 0, 0, N, 0,
   1165                            alpha, 0, X->getID(), 0,
   1166                            0, 0, Ap->getID(), incX, 0, 0, 0);
   1167 }
   1168 
   1169 void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
   1170                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
   1171     // Same as SYR2
   1172     int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
   1173     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2,
   1174                            0, 0, 0, Uplo, 0, 0, N, 0,
   1175                            alpha.x, alpha.y, X->getID(), Y->getID(),
   1176                            0, 0, A->getID(), incX, incY, 0, 0);
   1177 }
   1178 
   1179 void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
   1180                                 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
   1181     // Same as SPR2
   1182     int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
   1183     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2,
   1184                            0, 0, 0, Uplo, 0, 0, N, 0,
   1185                            alpha.x, alpha.y, X->getID(), Y->getID(),
   1186                            0, 0, Ap->getID(), incX, incY, 0, 0);
   1187 }
   1188 
   1189 
   1190 /**
   1191  * Level 3 BLAS
   1192  */
   1193 
   1194 static void validateL3(RS* mRS, sp<const Element> e, int TransA, int TransB, int Side,
   1195                        sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
   1196     int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
   1197     if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) ||
   1198         (B != nullptr && !B->getType()->getElement()->isCompatible(e)) ||
   1199         (C != nullptr && !C->getType()->getElement()->isCompatible(e))) {
   1200         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
   1201     }
   1202     if (C == nullptr) {
   1203         // Since matrix C is used to store the result, it cannot be null.
   1204         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null");
   1205     }
   1206     cM = C->getType()->getY();
   1207     cN = C->getType()->getX();
   1208 
   1209     if (Side == RsBlasRight) {
   1210         if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) {
   1211             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa");
   1212         }
   1213         if (B != nullptr) {
   1214             bM = A->getType()->getY();
   1215             bN = A->getType()->getX();
   1216         }
   1217         if (A != nullptr) {
   1218             aM = B->getType()->getY();
   1219             aN = B->getType()->getX();
   1220         }
   1221     } else {
   1222         if (A != nullptr) {
   1223             if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) {
   1224                 aN = A->getType()->getY();
   1225                 aM = A->getType()->getX();
   1226             } else {
   1227                 aM = A->getType()->getY();
   1228                 aN = A->getType()->getX();
   1229             }
   1230         }
   1231         if (B != nullptr) {
   1232             if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) {
   1233                 bN = B->getType()->getY();
   1234                 bM = B->getType()->getX();
   1235             } else {
   1236                 bM = B->getType()->getY();
   1237                 bN = B->getType()->getX();
   1238             }
   1239         }
   1240     }
   1241     if (A != nullptr && B != nullptr && C != nullptr) {
   1242         if (aN != bM || aM != cM || bN != cN) {
   1243             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
   1244         }
   1245     } else if (A != nullptr && C != nullptr) {
   1246         // A and C only, for SYRK
   1247         if (cM != cN) {
   1248             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric");
   1249         }
   1250         if (aM != cM) {
   1251             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
   1252         }
   1253     } else if (A != nullptr && B != nullptr) {
   1254         // A and B only
   1255         if (aN != bM) {
   1256             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
   1257         }
   1258     }
   1259 
   1260 }
   1261 
   1262 void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha,
   1263                                 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
   1264     validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C);
   1265 
   1266     int M = -1, N = -1, K = -1;
   1267     if (TransA != RsBlasNoTrans) {
   1268         M = A->getType()->getX();
   1269         K = A->getType()->getY();
   1270     } else {
   1271         M = A->getType()->getY();
   1272         K = A->getType()->getX();
   1273     }
   1274     if (TransB != RsBlasNoTrans) {
   1275         N = B->getType()->getY();
   1276     } else {
   1277         N = B->getType()->getX();
   1278     }
   1279     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm,
   1280                                 TransA, TransB, 0, 0, 0, M, N, K,
   1281                                 alpha, A->getID(), B->getID(),
   1282                                 beta, C->getID(), 0, 0, 0, 0);
   1283 }
   1284 
   1285 void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha,
   1286                                 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
   1287     validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C);
   1288     int M = -1, N = -1, K = -1;
   1289     if (TransA != RsBlasNoTrans) {
   1290         M = A->getType()->getX();
   1291         K = A->getType()->getY();
   1292     } else {
   1293         M = A->getType()->getY();
   1294         K = A->getType()->getX();
   1295     }
   1296     if (TransB != RsBlasNoTrans) {
   1297         N = B->getType()->getY();
   1298     } else {
   1299         N = B->getType()->getX();
   1300     }
   1301     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm,
   1302                                 TransA, TransB, 0, 0, 0, M, N, K,
   1303                                 alpha, A->getID(), B->getID(),
   1304                                 beta, C->getID(), 0, 0, 0, 0);
   1305 }
   1306 
   1307 void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha,
   1308                                 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
   1309     validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C);
   1310     int M = -1, N = -1, K = -1;
   1311     if (TransA != RsBlasNoTrans) {
   1312         M = A->getType()->getX();
   1313         K = A->getType()->getY();
   1314     } else {
   1315         M = A->getType()->getY();
   1316         K = A->getType()->getX();
   1317     }
   1318     if (TransB != RsBlasNoTrans) {
   1319         N = B->getType()->getY();
   1320     } else {
   1321         N = B->getType()->getX();
   1322     }
   1323     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm,
   1324                                  TransA, TransB, 0, 0, 0, M, N, K,
   1325                                  alpha.x, alpha.y, A->getID(), B->getID(),
   1326                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1327 }
   1328 
   1329 void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha,
   1330                                 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
   1331     validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C);
   1332     int M = -1, N = -1, K = -1;
   1333     if (TransA != RsBlasNoTrans) {
   1334         M = A->getType()->getX();
   1335         K = A->getType()->getY();
   1336     } else {
   1337         M = A->getType()->getY();
   1338         K = A->getType()->getX();
   1339     }
   1340     if (TransB != RsBlasNoTrans) {
   1341         N = B->getType()->getY();
   1342     } else {
   1343         N = B->getType()->getX();
   1344     }
   1345     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm,
   1346                            TransA, TransB, 0, 0, 0, M, N, K,
   1347                            alpha.x, alpha.y, A->getID(), B->getID(),
   1348                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1349 }
   1350 
   1351 void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha,
   1352                                 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
   1353     //For SYMM, Matrix A should be symmetric
   1354     if (A->getType()->getX() != A->getType()->getY()) {
   1355         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
   1356     }
   1357     validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C);
   1358     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm,
   1359                                 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
   1360                                 alpha, A->getID(), B->getID(),
   1361                                 beta, C->getID(), 0, 0, 0, 0);
   1362 }
   1363 
   1364 void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha,
   1365                                 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
   1366     if (A->getType()->getX() != A->getType()->getY()) {
   1367         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
   1368     }
   1369     validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C);
   1370     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm,
   1371                                 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
   1372                                 alpha, A->getID(), B->getID(),
   1373                                 beta, C->getID(), 0, 0, 0, 0);
   1374 }
   1375 
   1376 void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
   1377                                 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
   1378     if (A->getType()->getX() != A->getType()->getY()) {
   1379         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
   1380     }
   1381     validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C);
   1382     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm,
   1383                                  0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
   1384                                  alpha.x, alpha.y, A->getID(), B->getID(),
   1385                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1386 }
   1387 
   1388 void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
   1389                                 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
   1390     if (A->getType()->getX() != A->getType()->getY()) {
   1391         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
   1392     }
   1393     validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C);
   1394     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm,
   1395                            0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
   1396                            alpha.x, alpha.y, A->getID(), B->getID(),
   1397                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1398 }
   1399 
   1400 void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
   1401                                 sp<Allocation> A, float beta, sp<Allocation> C) {
   1402     validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C);
   1403     int K = -1;
   1404     if (Trans != RsBlasNoTrans) {
   1405         K = A->getType()->getY();
   1406     } else {
   1407         K = A->getType()->getX();
   1408     }
   1409     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk,
   1410                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
   1411                                 alpha, A->getID(), 0,
   1412                                 beta, C->getID(), 0, 0, 0, 0);
   1413 }
   1414 
   1415 void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
   1416                                 sp<Allocation> A, double beta, sp<Allocation> C) {
   1417     validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C);
   1418     int K = -1;
   1419     if (Trans != RsBlasNoTrans) {
   1420         K = A->getType()->getY();
   1421     } else {
   1422         K = A->getType()->getX();
   1423     }
   1424     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk,
   1425                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
   1426                                 alpha, A->getID(), 0,
   1427                                 beta, C->getID(), 0, 0, 0, 0);
   1428 }
   1429 
   1430 void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
   1431                                 sp<Allocation> A, Float2 beta, sp<Allocation> C) {
   1432     validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C);
   1433     int K = -1;
   1434     if (Trans != RsBlasNoTrans) {
   1435         K = A->getType()->getY();
   1436     } else {
   1437         K = A->getType()->getX();
   1438     }
   1439     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk,
   1440                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
   1441                                  alpha.x, alpha.y, A->getID(), 0,
   1442                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1443 }
   1444 
   1445 void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
   1446                                 sp<Allocation> A, Double2 beta, sp<Allocation> C) {
   1447     validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C);
   1448     int K = -1;
   1449     if (Trans != RsBlasNoTrans) {
   1450         K = A->getType()->getY();
   1451     } else {
   1452         K = A->getType()->getX();
   1453     }
   1454     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk,
   1455                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
   1456                            alpha.x, alpha.y, A->getID(), 0,
   1457                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1458 }
   1459 
   1460 static void validateSYR2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
   1461                           sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
   1462     if (!A->getType()->getElement()->isCompatible(e) ||
   1463         !B->getType()->getElement()->isCompatible(e) ||
   1464         !C->getType()->getElement()->isCompatible(e)) {
   1465         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
   1466     }
   1467     int Cdim = -1;
   1468     // A is n x k if no transpose, k x n if transpose
   1469     // C is n x n
   1470     if (Trans == RsBlasTrans) {
   1471         // check columns versus C
   1472         Cdim = A->getType()->getX();
   1473     } else {
   1474         // check rows versus C
   1475         Cdim = A->getType()->getY();
   1476     }
   1477     if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) {
   1478         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K");
   1479     }
   1480     // A dims == B dims
   1481     if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
   1482         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K");
   1483     }
   1484 }
   1485 
   1486 void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
   1487                                  sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
   1488     validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C);
   1489     int K = -1;
   1490     if (Trans != RsBlasNoTrans) {
   1491         K = A->getType()->getY();
   1492     } else {
   1493         K = A->getType()->getX();
   1494     }
   1495     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k,
   1496                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
   1497                                 alpha, A->getID(), B->getID(),
   1498                                 beta, C->getID(), 0, 0, 0, 0);
   1499 }
   1500 
   1501 void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
   1502                                  sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
   1503     validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C);
   1504     int K = -1;
   1505     if (Trans != RsBlasNoTrans) {
   1506         K = A->getType()->getY();
   1507     } else {
   1508         K = A->getType()->getX();
   1509     }
   1510     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k,
   1511                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
   1512                                 alpha, A->getID(), B->getID(),
   1513                                 beta, C->getID(), 0, 0, 0, 0);
   1514 }
   1515 
   1516 void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
   1517                                  sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
   1518     validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
   1519     int K = -1;
   1520     if (Trans != RsBlasNoTrans) {
   1521         K = A->getType()->getY();
   1522     } else {
   1523         K = A->getType()->getX();
   1524     }
   1525     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k,
   1526                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
   1527                                  alpha.x, alpha.y, A->getID(), B->getID(),
   1528                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1529 }
   1530 
   1531 void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
   1532                                  sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
   1533     validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
   1534     int K = -1;
   1535     if (Trans != RsBlasNoTrans) {
   1536         K = A->getType()->getY();
   1537     } else {
   1538         K = A->getType()->getX();
   1539     }
   1540     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k,
   1541                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
   1542                            alpha.x, alpha.y, A->getID(), B->getID(),
   1543                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1544 }
   1545 
   1546 static void validateTRMM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
   1547                          sp<Allocation> A, sp<Allocation> B) {
   1548     int aM = -1, aN = -1, bM = -1, bN = -1;
   1549     if (!A->getType()->getElement()->isCompatible(e) ||
   1550         !B->getType()->getElement()->isCompatible(e)) {
   1551         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
   1552     }
   1553 
   1554     aM = A->getType()->getY();
   1555     aN = A->getType()->getX();
   1556     if (aM != aN) {
   1557         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A");
   1558     }
   1559 
   1560     bM = B->getType()->getY();
   1561     bN = B->getType()->getX();
   1562     if (Side == RsBlasLeft) {
   1563         if (aN != bM) {
   1564             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
   1565         }
   1566     } else {
   1567         if (bN != aM) {
   1568             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
   1569         }
   1570     }
   1571 }
   1572 
   1573 void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
   1574                                 float alpha, sp<Allocation> A, sp<Allocation> B) {
   1575     validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B);
   1576     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm,
   1577                                 TransA, 0, Side, Uplo, Diag,\
   1578                                 B->getType()->getY(), B->getType()->getX(), 0,
   1579                                 alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0);
   1580 }
   1581 
   1582 void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
   1583                                 double alpha, sp<Allocation> A, sp<Allocation> B) {
   1584     validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B);
   1585     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm,
   1586                                 TransA, 0, Side, Uplo, Diag,
   1587                                 B->getType()->getY(), B->getType()->getX(), 0,
   1588                                 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
   1589 }
   1590 
   1591 void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
   1592                                 Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
   1593     validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
   1594     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm,
   1595                                  TransA, 0, Side, Uplo, Diag,
   1596                                  B->getType()->getY(), B->getType()->getX(), 0,
   1597                                  alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
   1598 }
   1599 
   1600 void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
   1601                                 Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
   1602     validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
   1603     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm,
   1604                            TransA, 0, Side, Uplo, Diag,
   1605                            B->getType()->getY(), B->getType()->getX(), 0,
   1606                            alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
   1607 }
   1608 
   1609 static void validateTRSM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
   1610                          sp<Allocation> A, sp<Allocation> B) {
   1611     int adim = -1, bM = -1, bN = -1;
   1612     if (!A->getType()->getElement()->isCompatible(e) ||
   1613         !B->getType()->getElement()->isCompatible(e)) {
   1614         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
   1615     }
   1616     adim = A->getType()->getX();
   1617     if (adim != (int)A->getType()->getY()) {
   1618         // This may be unnecessary, the restriction could potentially be relaxed.
   1619         // Allocation A needs to contain at least that symmetric matrix but could theoretically
   1620         // be larger for now we assume adapters are sufficient, will reevaluate in the future.
   1621         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A");
   1622     }
   1623     bM = B->getType()->getY();
   1624     bN = B->getType()->getX();
   1625     if (Side == RsBlasLeft) {
   1626         // A is M*M
   1627         if (adim != bM) {
   1628             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
   1629         }
   1630     } else {
   1631         // A is N*N
   1632         if (adim != bN) {
   1633             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
   1634         }
   1635     }
   1636 }
   1637 
   1638 void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
   1639                                 float alpha, sp<Allocation> A, sp<Allocation> B) {
   1640     validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B);
   1641     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm,
   1642                                 TransA, 0, Side, Uplo, Diag,
   1643                                 B->getType()->getY(), B->getType()->getX(), 0,
   1644                                 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
   1645 }
   1646 
   1647 void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
   1648                                 double alpha, sp<Allocation> A, sp<Allocation> B) {
   1649     validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B);
   1650     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm,
   1651                                 TransA, 0, Side, Uplo, Diag,
   1652                                 B->getType()->getY(), B->getType()->getX(), 0,
   1653                                 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
   1654 }
   1655 
   1656 void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
   1657                                 Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
   1658     validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
   1659     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm,
   1660                                  TransA, 0, Side, Uplo, Diag,
   1661                                  B->getType()->getY(), B->getType()->getX(), 0,
   1662                                  alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
   1663 }
   1664 
   1665 void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
   1666                                 Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
   1667     validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
   1668     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm,
   1669                            TransA, 0, Side, Uplo, Diag,
   1670                            B->getType()->getY(), B->getType()->getX(), 0,
   1671                            alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
   1672 }
   1673 
   1674 static void validateHEMM(RS* mRS, sp<const Element> e, RsBlasSide Side,
   1675                          sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
   1676     if (!A->getType()->getElement()->isCompatible(e) ||
   1677         !B->getType()->getElement()->isCompatible(e) ||
   1678         !C->getType()->getElement()->isCompatible(e)) {
   1679         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
   1680     }
   1681 
   1682     // A must be square; can potentially be relaxed similar to TRSM
   1683     int adim = A->getType()->getX();
   1684     if (adim != (int)A->getType()->getY()) {
   1685         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A");
   1686     }
   1687     if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) ||
   1688         (Side == RsBlasRight && adim != (int)B->getType()->getX())) {
   1689         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B");
   1690     }
   1691     if (B->getType()->getX() != C->getType()->getX() ||
   1692         B->getType()->getY() != C->getType()->getY()) {
   1693         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C");
   1694     }
   1695 }
   1696 
   1697 void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
   1698                                 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
   1699     validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C);
   1700     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm,
   1701                                  0, 0, Side, Uplo, 0,
   1702                                  C->getType()->getY(), C->getType()->getX(), 0,
   1703                                  alpha.x, alpha.y, A->getID(), B->getID(),
   1704                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1705 }
   1706 
   1707 void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
   1708                                 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
   1709     validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C);
   1710     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm,
   1711                            0, 0, Side, Uplo, 0,
   1712                            C->getType()->getY(), C->getType()->getX(), 0,
   1713                            alpha.x, alpha.y, A->getID(), B->getID(),
   1714                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
   1715 }
   1716 
   1717 static void validateHERK(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
   1718                          sp<Allocation> A, sp<Allocation> C) {
   1719     if (!A->getType()->getElement()->isCompatible(e) ||
   1720         !C->getType()->getElement()->isCompatible(e)) {
   1721         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
   1722     }
   1723     if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
   1724         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
   1725     }
   1726     int cdim = C->getType()->getX();
   1727     if (cdim != (int)C->getType()->getY()) {
   1728         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C");
   1729     }
   1730     if (Trans == RsBlasNoTrans) {
   1731         if (cdim != (int)A->getType()->getY()) {
   1732             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
   1733         }
   1734     } else {
   1735         if (cdim != (int)A->getType()->getX()) {
   1736             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
   1737         }
   1738     }
   1739 }
   1740 
   1741 void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
   1742                                 sp<Allocation> A, float beta, sp<Allocation> C) {
   1743     validateHERK(mRS, Element::F32_2(mRS), Trans, A, C);
   1744     int k = 0;
   1745     if (Trans == RsBlasConjTrans) {
   1746         k = A->getType()->getY();
   1747     } else {
   1748         k = A->getType()->getX();
   1749     }
   1750     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk,
   1751                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
   1752                                  alpha, 0, A->getID(), 0,
   1753                                  beta, 0, C->getID(), 0, 0, 0, 0);
   1754 }
   1755 
   1756 void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
   1757                                 sp<Allocation> A, double beta, sp<Allocation> C) {
   1758     validateHERK(mRS, Element::F64_2(mRS), Trans, A, C);
   1759     int k = 0;
   1760     if (Trans == RsBlasConjTrans) {
   1761         k = A->getType()->getY();
   1762     } else {
   1763         k = A->getType()->getX();
   1764     }
   1765     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk,
   1766                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
   1767                            alpha, 0, A->getID(), 0,
   1768                            beta, 0, C->getID(), 0, 0, 0, 0);
   1769 }
   1770 
   1771 static void validateHER2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
   1772                           sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
   1773     if (!A->getType()->getElement()->isCompatible(e) ||
   1774         !B->getType()->getElement()->isCompatible(e) ||
   1775         !C->getType()->getElement()->isCompatible(e)) {
   1776         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
   1777     }
   1778     if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
   1779         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
   1780     }
   1781     int cdim = C->getType()->getX();
   1782     if (cdim != (int)C->getType()->getY()) {
   1783         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C");
   1784     }
   1785     if (Trans == RsBlasNoTrans) {
   1786         if ((int)A->getType()->getY() != cdim) {
   1787             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
   1788         }
   1789     } else {
   1790         if ((int)A->getType()->getX() != cdim) {
   1791             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
   1792         }
   1793     }
   1794     if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
   1795         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices");
   1796     }
   1797 }
   1798 
   1799 void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
   1800                                  sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
   1801     validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
   1802     int k = 0;
   1803     if (Trans == RsBlasNoTrans) {
   1804         k = A->getType()->getX();
   1805     } else {
   1806         k = A->getType()->getY();
   1807     }
   1808     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k,
   1809                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
   1810                                  alpha.x, alpha.y, A->getID(), B->getID(),
   1811                                  beta, 0, C->getID(), 0, 0, 0, 0);
   1812 }
   1813 
   1814 void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
   1815                                  sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
   1816     validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
   1817     int k = 0;
   1818     if (Trans == RsBlasNoTrans) {
   1819         k = A->getType()->getX();
   1820     } else {
   1821         k = A->getType()->getY();
   1822     }
   1823     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k,
   1824                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
   1825                            alpha.x, alpha.y, A->getID(), B->getID(),
   1826                            beta, 0, C->getID(), 0, 0, 0, 0);
   1827 }
   1828 
   1829 
   1830 
   1831 void ScriptIntrinsicBLAS::BNNM(sp<Allocation> A, int a_offset, sp<Allocation> B, int b_offset,
   1832                                sp<Allocation> C, int c_offset, int c_mult) {
   1833     validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C);
   1834 
   1835     if (a_offset < 0 || a_offset > 255) {
   1836         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM");
   1837     }
   1838     if (b_offset < 0 || b_offset > 255) {
   1839         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM");
   1840     }
   1841     int M = -1, N = -1, K = -1;
   1842     M = A->getType()->getY();
   1843     N = B->getType()->getY();
   1844     K = A->getType()->getX();
   1845 
   1846     nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset,
   1847                               B->getID(), b_offset, C->getID(), c_offset, c_mult);
   1848 }
   1849