Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "HalInterfaces.h"
     20 #include "IndexedShapeWrapper.h"
     21 #include "LSTM.h"
     22 #include "OperationResolver.h"
     23 #include "OperationsUtils.h"
     24 
     25 namespace android {
     26 namespace nn {
     27 namespace unidirectional_sequence_lstm {
     28 
     29 // Inputs
     30 constexpr uint32_t kNumInputs = 28;
     31 
     32 // Input tensor of size {max_time, n_batch, n_input}
     33 constexpr uint32_t kInputTensor = 0;
     34 
     35 // Input weight tensors of size: {n_cell, n_input}
     36 constexpr uint32_t kInputToInputWeightsTensor = 1;  // Optional
     37 constexpr uint32_t kInputToForgetWeightsTensor = 2;
     38 constexpr uint32_t kInputToCellWeightsTensor = 3;
     39 constexpr uint32_t kInputToOutputWeightsTensor = 4;
     40 
     41 // Recurrent weight tensors of size {n_cell, n_output}
     42 constexpr uint32_t kRecurrentToInputWeightsTensor = 5;  // Optional
     43 constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
     44 constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
     45 constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
     46 
     47 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
     48 constexpr uint32_t kCellToInputWeightsTensor = 9;    // Optional
     49 constexpr uint32_t kCellToForgetWeightsTensor = 10;  // Optional
     50 constexpr uint32_t kCellToOutputWeightsTensor = 11;  // Optional
     51 
     52 // Gates bias tensors of size {n_cell}
     53 constexpr uint32_t kInputGateBiasTensor = 12;  // Optional
     54 constexpr uint32_t kForgetGateBiasTensor = 13;
     55 constexpr uint32_t kCellGateBiasTensor = 14;
     56 constexpr uint32_t kOutputGateBiasTensor = 15;
     57 
     58 // Projection weight tensor of size {n_output, n_cell}
     59 constexpr uint32_t kProjectionWeightsTensor = 16;  // Optional
     60 // Projection bias tensor of size {n_output}
     61 constexpr uint32_t kProjectionBiasTensor = 17;  // Optional
     62 
     63 // Input from the output of the previous step, tensor of size {batch_size, n_output}
     64 constexpr uint32_t kOutputStateInTensor = 18;
     65 // Input from the cell state of the previous step, tensor of size {batch_size, n_cell}
     66 constexpr uint32_t kCellStateInTensor = 19;
     67 
     68 constexpr uint32_t kActivationParam = 20;
     69 constexpr uint32_t kCellClipParam = 21;
     70 constexpr uint32_t kProjClipParam = 22;
     71 constexpr uint32_t kTimeMajorParam = 23;
     72 
     73 // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
     74 constexpr uint32_t kInputLayerNormWeightsTensor = 24;   // Optional
     75 constexpr uint32_t kForgetLayerNormWeightsTensor = 25;  // Optional
     76 constexpr uint32_t kCellLayerNormWeightsTensor = 26;    // Optional
     77 constexpr uint32_t kOutputLayerNormWeightsTensor = 27;  // Optional
     78 
     79 // Output tensors.
     80 constexpr uint32_t kNumOutputs = 1;
     81 
     82 constexpr uint32_t kOutputTensor = 0;
     83 
     84 namespace {
     85 
     86 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
     87     return context->getInputBuffer(tensor) != nullptr;
     88 }
     89 
     90 inline bool isTimeMajor(IOperationExecutionContext* context) {
     91     return context->getInputValue<bool>(kTimeMajorParam);
     92 }
     93 
     94 template <typename T>
     95 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
     96     LSTMParams params;
     97     params.activation =
     98             static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam));
     99     params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam));
    100     params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam));
    101     params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
    102     params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor);
    103     params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor);
    104     params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor);
    105     params.use_projection_bias = hasTensor(context, kProjectionBiasTensor);
    106     return params;
    107 }
    108 
    109 }  // namespace
    110 
    111 bool validate(const IOperationValidationContext* context) {
    112     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    113     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    114     const OperandType inputType = context->getInputType(kInputTensor);
    115     std::vector<OperandType> inExpectedTypes;
    116     std::vector<OperandType> outExpectedTypes;
    117     if (inputType == OperandType::TENSOR_FLOAT32) {
    118         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    119                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    120                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    121                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    122                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    123                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    124                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    125                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    126                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    127                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    128                            OperandType::INT32,          OperandType::FLOAT32,
    129                            OperandType::FLOAT32,        OperandType::BOOL,
    130                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    131                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32};
    132         outExpectedTypes = {OperandType::TENSOR_FLOAT32};
    133     } else if (inputType == OperandType::TENSOR_FLOAT16) {
    134         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    135                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    136                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    137                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    138                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    139                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    140                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    141                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    142                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    143                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    144                            OperandType::INT32,          OperandType::FLOAT16,
    145                            OperandType::FLOAT16,        OperandType::BOOL,
    146                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    147                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16};
    148         outExpectedTypes = {OperandType::TENSOR_FLOAT16};
    149     } else {
    150         NN_RET_CHECK_FAIL()
    151                 << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_LSTM op: "
    152                 << toString(inputType);
    153     }
    154     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
    155     NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
    156     return validateHalVersion(context, HalVersion::V1_2);
    157 }
    158 
    159 bool prepare(IOperationExecutionContext* context) {
    160     // Check that none of the required inputs are omitted
    161     const std::vector<int> requiredInputs = {
    162             kInputTensor,
    163             kInputToForgetWeightsTensor,
    164             kInputToCellWeightsTensor,
    165             kInputToOutputWeightsTensor,
    166             kRecurrentToForgetWeightsTensor,
    167             kRecurrentToCellWeightsTensor,
    168             kRecurrentToOutputWeightsTensor,
    169             kForgetGateBiasTensor,
    170             kCellGateBiasTensor,
    171             kOutputGateBiasTensor,
    172             kOutputStateInTensor,
    173             kCellStateInTensor,
    174             kActivationParam,
    175             kCellClipParam,
    176             kProjClipParam,
    177             kTimeMajorParam,
    178     };
    179     for (const int requiredInput : requiredInputs) {
    180         NN_RET_CHECK(!context->isOmittedInput(requiredInput))
    181                 << "required input " << requiredInput << " is omitted";
    182     }
    183 
    184     const Shape inputShape = context->getInputShape(kInputTensor);
    185     const uint32_t inputRank = getNumberOfDimensions(inputShape);
    186     NN_RET_CHECK_EQ(inputRank, 3) << "Invalid input tensor rank: " << inputRank;
    187 
    188     const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1);
    189     const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0);
    190     const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1);
    191 
    192     const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
    193     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
    194     NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
    195     const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0);
    196 
    197     const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
    198     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
    199     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells);
    200     const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
    201 
    202     if (hasTensor(context, kInputToInputWeightsTensor)) {
    203         const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
    204         NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
    205         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells);
    206         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
    207     }
    208 
    209     const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
    210     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
    211     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells);
    212     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
    213     const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
    214     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
    215     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells);
    216     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
    217 
    218     if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
    219         const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
    220         NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
    221         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells);
    222         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
    223     }
    224 
    225     const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
    226     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
    227     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells);
    228     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
    229     const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
    230     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
    231     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells);
    232     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
    233 
    234     // We make sure the input-gate's parameters are either both present (regular
    235     // LSTM) or not at all (CIFG-LSTM).
    236     const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
    237                                        hasTensor(context, kRecurrentToInputWeightsTensor)) ||
    238                                       (!hasTensor(context, kInputToInputWeightsTensor) &&
    239                                        !hasTensor(context, kRecurrentToInputWeightsTensor));
    240     NN_RET_CHECK(cifgWeightsAllOrNone);
    241 
    242     if (hasTensor(context, kCellToInputWeightsTensor)) {
    243         const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
    244         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
    245         NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells);
    246     }
    247 
    248     if (hasTensor(context, kCellToForgetWeightsTensor)) {
    249         const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
    250         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
    251         NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells);
    252     }
    253 
    254     if (hasTensor(context, kCellToOutputWeightsTensor)) {
    255         const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
    256         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
    257         NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells);
    258     }
    259 
    260     // Making sure the peephole weights are there all or none.
    261     const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
    262     const bool peepholeWeightsAllOrNone =
    263             ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
    264              hasTensor(context, kCellToForgetWeightsTensor) &&
    265              hasTensor(context, kCellToOutputWeightsTensor)) ||
    266             (!hasTensor(context, kCellToInputWeightsTensor) &&
    267              !hasTensor(context, kCellToForgetWeightsTensor) &&
    268              !hasTensor(context, kCellToOutputWeightsTensor));
    269     NN_RET_CHECK(peepholeWeightsAllOrNone);
    270 
    271     if (!cifgUsed) {
    272         NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
    273         const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
    274         NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
    275         NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells);
    276     } else {
    277         NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
    278                 << "Input gate bias tensor is present when CIFG is used";
    279     }
    280 
    281     const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
    282     NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
    283     NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells);
    284     const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
    285     NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
    286     NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells);
    287     const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
    288     NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
    289     NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells);
    290 
    291     if (hasTensor(context, kProjectionWeightsTensor)) {
    292         const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
    293         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
    294         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
    295         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells);
    296     }
    297 
    298     if (hasTensor(context, kProjectionBiasTensor)) {
    299         const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
    300         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
    301         NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
    302     }
    303 
    304     const Shape outputStateShape = context->getInputShape(kOutputStateInTensor);
    305     NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
    306     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
    307     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
    308     const Shape cellStateShape = context->getInputShape(kCellStateInTensor);
    309     NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
    310     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
    311     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells);
    312 
    313     if (hasTensor(context, kInputLayerNormWeightsTensor)) {
    314         const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor);
    315         NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
    316         NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells);
    317     }
    318 
    319     if (hasTensor(context, kForgetLayerNormWeightsTensor)) {
    320         const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor);
    321         NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
    322         NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells);
    323     }
    324 
    325     if (hasTensor(context, kCellLayerNormWeightsTensor)) {
    326         const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor);
    327         NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
    328         NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells);
    329     }
    330 
    331     if (hasTensor(context, kOutputLayerNormWeightsTensor)) {
    332         const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor);
    333         NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
    334         NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells);
    335     }
    336 
    337     if (cifgUsed) {
    338         NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor))
    339                 << "Input layer norm weights tensor is present when CIFG is used";
    340         const bool layerNormWeightsAllOrNoneCifg =
    341                 (hasTensor(context, kForgetLayerNormWeightsTensor) &&
    342                  hasTensor(context, kCellLayerNormWeightsTensor) &&
    343                  hasTensor(context, kOutputLayerNormWeightsTensor)) ||
    344                 (!hasTensor(context, kForgetLayerNormWeightsTensor) &&
    345                  !hasTensor(context, kCellLayerNormWeightsTensor) &&
    346                  !hasTensor(context, kOutputLayerNormWeightsTensor));
    347         NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
    348     } else {
    349         const bool layerNormWeightsAllOrNone =
    350                 (hasTensor(context, kInputLayerNormWeightsTensor) &&
    351                  hasTensor(context, kForgetLayerNormWeightsTensor) &&
    352                  hasTensor(context, kCellLayerNormWeightsTensor) &&
    353                  hasTensor(context, kOutputLayerNormWeightsTensor)) ||
    354                 (!hasTensor(context, kInputLayerNormWeightsTensor) &&
    355                  !hasTensor(context, kForgetLayerNormWeightsTensor) &&
    356                  !hasTensor(context, kCellLayerNormWeightsTensor) &&
    357                  !hasTensor(context, kOutputLayerNormWeightsTensor));
    358         NN_RET_CHECK(layerNormWeightsAllOrNone);
    359     }
    360 
    361     Shape outputShape = context->getInputShape(kInputTensor);
    362     outputShape.dimensions[2] = outputSize;
    363 
    364     return context->setOutputShape(kOutputTensor, outputShape);
    365 }
    366 
    367 bool execute(IOperationExecutionContext* context) {
    368     const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor));
    369     const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor));
    370     const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
    371     const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize;
    372 
    373     const OperandType inputType = context->getInputType(kInputTensor);
    374     switch (inputType) {
    375         case OperandType::TENSOR_FLOAT32: {
    376             std::vector<float> outputStateOut(outputStateSize);
    377             std::vector<float> cellStateOut(cellStateSize);
    378             std::vector<float> scratchBuffer(scratchSize);
    379             LSTMCell::LSTMEvalFloat32(
    380                     getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor),
    381                     context->getInputShape(kInputTensor),
    382                     context->getInputBuffer<float>(kInputToInputWeightsTensor),
    383                     context->getInputBuffer<float>(kInputToForgetWeightsTensor),
    384                     context->getInputBuffer<float>(kInputToCellWeightsTensor),
    385                     context->getInputBuffer<float>(kInputToOutputWeightsTensor),
    386                     context->getInputShape(kInputToOutputWeightsTensor),
    387                     context->getInputBuffer<float>(kRecurrentToInputWeightsTensor),
    388                     context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor),
    389                     context->getInputBuffer<float>(kRecurrentToCellWeightsTensor),
    390                     context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor),
    391                     context->getInputShape(kRecurrentToOutputWeightsTensor),
    392                     context->getInputBuffer<float>(kCellToInputWeightsTensor),
    393                     context->getInputBuffer<float>(kCellToForgetWeightsTensor),
    394                     context->getInputBuffer<float>(kCellToOutputWeightsTensor),
    395                     /*aux_input_buffer=*/nullptr,
    396                     /*aux_input_to_input_weights_buffer=*/nullptr,
    397                     /*aux_input_to_forget_weights_buffer=*/nullptr,
    398                     /*aux_input_to_cell_weights_buffer=*/nullptr,
    399                     /*aux_input_to_output_weights_buffer=*/nullptr,
    400                     context->getInputBuffer<float>(kInputGateBiasTensor),
    401                     context->getInputBuffer<float>(kForgetGateBiasTensor),
    402                     context->getInputBuffer<float>(kCellGateBiasTensor),
    403                     context->getInputBuffer<float>(kOutputGateBiasTensor),
    404                     context->getInputBuffer<float>(kProjectionWeightsTensor),
    405                     context->getInputBuffer<float>(kProjectionBiasTensor),
    406                     context->getInputBuffer<float>(kOutputStateInTensor),
    407                     context->getInputBuffer<float>(kCellStateInTensor),
    408                     context->getInputBuffer<float>(kInputLayerNormWeightsTensor),
    409                     context->getInputBuffer<float>(kForgetLayerNormWeightsTensor),
    410                     context->getInputBuffer<float>(kCellLayerNormWeightsTensor),
    411                     context->getInputBuffer<float>(kOutputLayerNormWeightsTensor),
    412                     outputStateOut.data(), cellStateOut.data(),
    413                     context->getOutputBuffer<float>(kOutputTensor), scratchBuffer.data(),
    414                     isTimeMajor(context));
    415         } break;
    416         case OperandType::TENSOR_FLOAT16: {
    417             std::vector<_Float16> outputStateOut(outputStateSize);
    418             std::vector<_Float16> cellStateOut(cellStateSize);
    419             std::vector<_Float16> scratchBuffer(scratchSize);
    420             LSTMCell::LSTMEvalFloat16(
    421                     getLSTMParams<_Float16>(context),
    422                     context->getInputBuffer<_Float16>(kInputTensor),
    423                     context->getInputShape(kInputTensor),
    424                     context->getInputBuffer<_Float16>(kInputToInputWeightsTensor),
    425                     context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor),
    426                     context->getInputBuffer<_Float16>(kInputToCellWeightsTensor),
    427                     context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor),
    428                     context->getInputShape(kInputToOutputWeightsTensor),
    429                     context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor),
    430                     context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor),
    431                     context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor),
    432                     context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor),
    433                     context->getInputShape(kRecurrentToOutputWeightsTensor),
    434                     context->getInputBuffer<_Float16>(kCellToInputWeightsTensor),
    435                     context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor),
    436                     context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor),
    437                     /*aux_input_buffer=*/nullptr,
    438                     /*aux_input_to_input_weights_buffer=*/nullptr,
    439                     /*aux_input_to_forget_weights_buffer=*/nullptr,
    440                     /*aux_input_to_cell_weights_buffer=*/nullptr,
    441                     /*aux_input_to_output_weights_buffer=*/nullptr,
    442                     context->getInputBuffer<_Float16>(kInputGateBiasTensor),
    443                     context->getInputBuffer<_Float16>(kForgetGateBiasTensor),
    444                     context->getInputBuffer<_Float16>(kCellGateBiasTensor),
    445                     context->getInputBuffer<_Float16>(kOutputGateBiasTensor),
    446                     context->getInputBuffer<_Float16>(kProjectionWeightsTensor),
    447                     context->getInputBuffer<_Float16>(kProjectionBiasTensor),
    448                     context->getInputBuffer<_Float16>(kOutputStateInTensor),
    449                     context->getInputBuffer<_Float16>(kCellStateInTensor),
    450                     context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor),
    451                     context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor),
    452                     context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor),
    453                     context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor),
    454                     outputStateOut.data(), cellStateOut.data(),
    455                     context->getOutputBuffer<_Float16>(kOutputTensor), scratchBuffer.data(),
    456                     isTimeMajor(context));
    457         } break;
    458         default: {
    459             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType);
    460             return false;
    461         }
    462     }
    463     return true;
    464 }
    465 
    466 }  // namespace unidirectional_sequence_lstm
    467 
    468 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_LSTM, "UNIDIRECTIONAL_SEQUENCE_LSTM",
    469                       unidirectional_sequence_lstm::validate, unidirectional_sequence_lstm::prepare,
    470                       unidirectional_sequence_lstm::execute, .allowOmittedOperand = true);
    471 
    472 }  // namespace nn
    473 }  // namespace android
    474