Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "LSTM.h"
     18 
     19 #include "CpuExecutor.h"
     20 #include "CpuOperationUtils.h"
     21 #include "HalInterfaces.h"
     22 #include "OperationsUtils.h"
     23 
     24 #include "Tracing.h"
     25 #include "Utils.h"
     26 
     27 namespace android {
     28 namespace nn {
     29 
     30 namespace {
     31 
     32 template <typename T>
     33 inline T* GetBuffer(RunTimeOperandInfo* operand) {
     34     return reinterpret_cast<T*>(operand->buffer);
     35 }
     36 
     37 template <typename T>
     38 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
     39     return reinterpret_cast<const T*>(operand->buffer);
     40 }
     41 
     42 template <typename T>
     43 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
     44     return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
     45 }
     46 
     47 }  // anonymous namespace
     48 
     49 LSTMCell::LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) {
     50     input_ = GetInput(operation, operands, kInputTensor);
     51 
     52     input_to_input_weights_ =
     53             GetInput(operation, operands, kInputToInputWeightsTensor);  // optional
     54     input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
     55     input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
     56     input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
     57 
     58     recurrent_to_input_weights_ =
     59             GetInput(operation, operands, kRecurrentToInputWeightsTensor);  // optional
     60     recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
     61     recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
     62     recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
     63 
     64     cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor);  // optional
     65     cell_to_forget_weights_ =
     66             GetInput(operation, operands, kCellToForgetWeightsTensor);  // optional
     67     cell_to_output_weights_ =
     68             GetInput(operation, operands, kCellToOutputWeightsTensor);  // optional
     69 
     70     input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
     71     forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
     72     cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
     73     output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
     74 
     75     projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor);  // optional
     76     projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor);        // optional
     77 
     78     output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
     79     cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
     80 
     81     params_.activation = static_cast<TfLiteFusedActivation>(
     82             getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
     83     if (input_->type == OperandType::TENSOR_FLOAT32) {
     84         params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
     85         params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
     86     } else {
     87         params_.cell_clip = static_cast<float>(
     88                 getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
     89         params_.proj_clip = static_cast<float>(
     90                 getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
     91     }
     92 
     93     // We check the version of LSTM by checking the number of the inputs to the
     94     // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
     95     if (operation.inputs.size() == 27) {
     96         input_layer_norm_weights_ =
     97                 GetInput(operation, operands, kInputLayerNormWeightsTensor);  // optional
     98         forget_layer_norm_weights_ =
     99                 GetInput(operation, operands, kForgetLayerNormWeightsTensor);  // optional
    100         cell_layer_norm_weights_ =
    101                 GetInput(operation, operands, kCellLayerNormWeightsTensor);  // optional
    102         output_layer_norm_weights_ =
    103                 GetInput(operation, operands, kOutputLayerNormWeightsTensor);  // optional
    104     } else {
    105         // For LSTM from HAL v1.0 assign operands with no values
    106         static RunTimeOperandInfo no_value;
    107         no_value.lifetime = OperandLifeTime::NO_VALUE;
    108 
    109         input_layer_norm_weights_ = &no_value;
    110         forget_layer_norm_weights_ = &no_value;
    111         cell_layer_norm_weights_ = &no_value;
    112         output_layer_norm_weights_ = &no_value;
    113     }
    114 
    115     output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
    116     cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
    117     output_ = GetOutput(operation, operands, kOutputTensor);
    118 
    119     scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
    120 }
    121 
    122 // static
    123 bool LSTMCell::CheckInputTensorDimensions(
    124         const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
    125         const RunTimeOperandInfo* input_to_forget_weights,
    126         const RunTimeOperandInfo* input_to_cell_weights,
    127         const RunTimeOperandInfo* input_to_output_weights,
    128         const RunTimeOperandInfo* recurrent_to_input_weights,
    129         const RunTimeOperandInfo* recurrent_to_forget_weights,
    130         const RunTimeOperandInfo* recurrent_to_cell_weights,
    131         const RunTimeOperandInfo* recurrent_to_output_weights,
    132         const RunTimeOperandInfo* cell_to_input_weights,
    133         const RunTimeOperandInfo* cell_to_forget_weights,
    134         const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
    135         const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
    136         const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
    137         const RunTimeOperandInfo* projection_bias,
    138         const RunTimeOperandInfo* input_layer_norm_weights,
    139         const RunTimeOperandInfo* forget_layer_norm_weights,
    140         const RunTimeOperandInfo* cell_layer_norm_weights,
    141         const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
    142         uint32_t n_cell, LSTMParams* params) {
    143     // Making sure clipping parameters have valid values.
    144     // == 0 means no clipping
    145     //  > 0 means clipping
    146     NN_CHECK(params->cell_clip >= 0);
    147     NN_CHECK(params->proj_clip >= 0);
    148 
    149     if (!IsNullInput(input_to_input_weights)) {
    150         NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
    151         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
    152         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
    153     }
    154 
    155     NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
    156     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
    157     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
    158 
    159     NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
    160     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
    161     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
    162 
    163     if (!IsNullInput(recurrent_to_input_weights)) {
    164         NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
    165         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
    166         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
    167     }
    168 
    169     NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
    170     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
    171     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
    172 
    173     NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
    174     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
    175     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
    176 
    177     // We make sure the input-gate's parameters are either both present (regular
    178     // LSTM) or not at all (CIFG-LSTM).
    179     const bool cifg_weights_all_or_none =
    180             (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
    181             (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
    182     NN_CHECK(cifg_weights_all_or_none);
    183 
    184     if (!IsNullInput(cell_to_input_weights)) {
    185         NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
    186         NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
    187     }
    188 
    189     if (!IsNullInput(cell_to_forget_weights)) {
    190         NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
    191         NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
    192     }
    193 
    194     if (!IsNullInput(cell_to_output_weights)) {
    195         NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
    196         NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
    197     }
    198 
    199     // Making sure the peephole weights are there all or none.
    200     params->use_cifg = IsNullInput(input_to_input_weights);
    201     const bool peephole_weights_all_or_none =
    202             ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
    203              !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
    204             (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
    205              IsNullInput(cell_to_output_weights));
    206     NN_CHECK(peephole_weights_all_or_none);
    207 
    208     // Since we have already checked that weights are all there or none, we can
    209     // check the existence of only one to the get the condition.
    210     params->use_peephole = !IsNullInput(cell_to_output_weights);
    211     // Checking output instead of input layer norm weights because input can be
    212     // omitted ones can be omited in case CIFG LSTM is used.
    213     params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
    214 
    215     params->use_projection_weight = (projection_weights->lifetime != OperandLifeTime::NO_VALUE);
    216     params->use_projection_bias = (projection_bias->lifetime != OperandLifeTime::NO_VALUE);
    217 
    218     // Make sure the input gate bias is present only when not a CIFG-LSTM.
    219     if (params->use_cifg) {
    220         NN_CHECK(IsNullInput(input_gate_bias));
    221     } else {
    222         NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
    223         NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
    224     }
    225 
    226     NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
    227     NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
    228 
    229     NN_CHECK_EQ(NumDimensions(cell_bias), 1);
    230     NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
    231 
    232     NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
    233     NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
    234 
    235     if (!IsNullInput(projection_weights)) {
    236         NN_CHECK_EQ(NumDimensions(projection_weights), 2);
    237         NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
    238         NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
    239     }
    240 
    241     if (!IsNullInput(projection_bias)) {
    242         NN_CHECK_EQ(NumDimensions(projection_bias), 1);
    243         NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
    244     }
    245 
    246     // Making sure the projection tensors are consistent:
    247     // 1) If projection weight is not present, then projection bias should not be
    248     // present.
    249     // 2) If projection weight is present, then projection bias is optional.
    250     // TODO: make sure this is correct.
    251     const bool projecton_tensors_consistent =
    252             (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
    253     NN_CHECK(projecton_tensors_consistent == true);
    254 
    255     if (!IsNullInput(input_layer_norm_weights)) {
    256         NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1);
    257         NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
    258     }
    259     if (!IsNullInput(forget_layer_norm_weights)) {
    260         NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1);
    261         NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
    262     }
    263     if (!IsNullInput(cell_layer_norm_weights)) {
    264         NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1);
    265         NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
    266     }
    267     if (!IsNullInput(output_layer_norm_weights)) {
    268         NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1);
    269         NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
    270     }
    271 
    272     if (params->use_cifg) {
    273         NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
    274                 << "input_layer_norm_weights are provided while CIFG is used";
    275         const bool layer_norm_weights_all_or_none_cifg =
    276                 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
    277                  IsNullInput(output_layer_norm_weights)) ||
    278                 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
    279                  !IsNullInput(output_layer_norm_weights));
    280         NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
    281     } else {
    282         const bool layer_norm_weights_all_or_none =
    283                 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
    284                  IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
    285                 (!IsNullInput(input_layer_norm_weights) &&
    286                  !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
    287                  !IsNullInput(output_layer_norm_weights));
    288         NN_RET_CHECK(layer_norm_weights_all_or_none);
    289     }
    290 
    291     return true;
    292 }
    293 
    294 bool LSTMCell::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
    295                        Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
    296                        Shape* outputShape) {
    297     // Check we have all the inputs and outputs we need.
    298     NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
    299              NumInputsWithValues(operation, operands) <= 27);
    300     NN_CHECK_EQ(NumOutputs(operation), 4);
    301 
    302     // Inferring batch size, number of outputs and number of cells from the
    303     // input tensors.
    304     NN_CHECK(NumDimensions(input_) > 1);
    305     const uint32_t n_batch = SizeOfDimension(input_, 0);
    306     const uint32_t n_input = SizeOfDimension(input_, 1);
    307 
    308     const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
    309     NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2);
    310     NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
    311 
    312     NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2);
    313     NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
    314     const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
    315 
    316     // Check that input tensor dimensions matches with each other.
    317     if (!CheckInputTensorDimensions(
    318                 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
    319                 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
    320                 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
    321                 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
    322                 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
    323                 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
    324                 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
    325                 &params_)) {
    326         return false;
    327     }
    328 
    329     // Resize the output and output_state tensors.
    330     const Shape& inputShape = input_->shape();
    331 
    332     outputShape->type = inputShape.type;
    333     outputShape->dimensions = {n_batch, n_output};
    334     outputShape->offset = inputShape.offset;
    335     outputShape->scale = inputShape.scale;
    336 
    337     outputStateShape->type = inputShape.type;
    338     outputStateShape->dimensions = {n_batch, n_output};
    339     outputStateShape->offset = inputShape.offset;
    340     outputStateShape->scale = inputShape.scale;
    341 
    342     cellStateShape->type = inputShape.type;
    343     cellStateShape->dimensions = {n_batch, n_cell};
    344     cellStateShape->offset = inputShape.offset;
    345     cellStateShape->scale = inputShape.scale;
    346 
    347     if (params_.use_cifg) {
    348         // Reserving space for Cell, Forget, Output gates
    349         scratchShape->dimensions = {n_batch, n_cell * 3};
    350     } else {
    351         // Reserving space for Input, Cell, Forget, Output gates
    352         scratchShape->dimensions = {n_batch, n_cell * 4};
    353     }
    354     scratchShape->type = inputShape.type;
    355     scratchShape->offset = inputShape.offset;
    356     scratchShape->scale = inputShape.scale;
    357 
    358     return true;
    359 }
    360 
    361 // static
    362 bool LSTMCell::LSTMEvalFloat32(
    363         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
    364         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
    365         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
    366         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
    367         const float* recurrent_to_forget_weights_buffer,
    368         const float* recurrent_to_cell_weights_buffer,
    369         const float* recurrent_to_output_weights_buffer,
    370         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
    371         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
    372         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
    373         const float* aux_input_to_forget_weights_buffer,
    374         const float* aux_input_to_cell_weights_buffer,
    375         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
    376         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
    377         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
    378         const float* projection_bias_buffer, const float* output_state_in_buffer,
    379         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
    380         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
    381         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
    382         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
    383         bool timeMajor, bool forwardSequence) {
    384     NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
    385 
    386     const uint32_t inputRank = getNumberOfDimensions(input_shape);
    387     NN_CHECK(inputRank == 2 || inputRank == 3);
    388 
    389     const uint32_t maxTime =
    390             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
    391     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
    392                                                 : getSizeOfDimension(input_shape, 0);
    393     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
    394     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
    395     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
    396 
    397     Shape batchInputShape = input_shape;
    398     batchInputShape.dimensions = {batchSize, inputSize};
    399     const uint32_t batchInputSize = batchSize * inputSize;
    400     const uint32_t batchOutputSize = batchSize * outputSize;
    401 
    402     std::vector<float> transposedInput;
    403     const bool hasAuxInput = (aux_input_buffer != nullptr);
    404     std::vector<float> transposedAuxInput;
    405     std::vector<float> transposedOutput;
    406     Shape transposedInputShape;
    407     Shape transposedOutputShape;
    408     if (!timeMajor) {
    409         transposedInput.resize(maxTime * batchInputSize);
    410         transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
    411         if (hasAuxInput) {
    412             transposedAuxInput.resize(maxTime * batchInputSize);
    413             transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
    414                                                transposedAuxInput.data());
    415         }
    416         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
    417         transposedOutput.resize(maxTime * batchOutputSize);
    418         transposedOutputShape = transposedInputShape;
    419         transposedOutputShape.dimensions[2] = outputSize;
    420     }
    421     const float* inputData = timeMajor ? input_buffer : transposedInput.data();
    422     const float* auxInputData =
    423             hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
    424     float* outputData = timeMajor ? output_buffer : transposedOutput.data();
    425 
    426     std::vector<float> outputStateInCurrentTimeStep(
    427             output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
    428     std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
    429                                                   cell_state_in_buffer + batchSize * numCells);
    430     const float* inputCurrentTimeStep =
    431             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
    432     const float* auxInputCurrentTimeStep =
    433             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
    434                         : nullptr;
    435     float* outputCurrentTimeStep =
    436             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
    437     const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
    438     const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize;
    439 
    440     for (int t = 0; t < maxTime; ++t) {
    441         LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
    442                  input_to_forget_weights_buffer, input_to_cell_weights_buffer,
    443                  input_to_output_weights_buffer, input_to_output_weights_shape,
    444                  recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
    445                  recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
    446                  recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
    447                  cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
    448                  auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
    449                  aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
    450                  aux_input_to_output_weights_buffer, input_gate_bias_buffer,
    451                  forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
    452                  projection_weights_buffer, projection_bias_buffer,
    453                  outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
    454                  input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
    455                  cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
    456                  output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
    457                  scratch_buffer_buffer);
    458         inputCurrentTimeStep += batchInputDelta;
    459         if (hasAuxInput) {
    460             auxInputCurrentTimeStep += batchInputDelta;
    461         }
    462         outputCurrentTimeStep += batchOutputDelta;
    463         outputStateInCurrentTimeStep.assign(output_state_out_buffer,
    464                                             output_state_out_buffer + batchSize * outputSize);
    465         cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
    466                                           cell_state_out_buffer + batchSize * numCells);
    467     }
    468 
    469     if (!timeMajor) {
    470         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
    471                                            output_buffer);
    472     }
    473 
    474     return true;
    475 }
    476 
    477 // static
    478 bool LSTMCell::LSTMEvalFloat16(
    479         const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
    480         const _Float16* input_to_input_weights_buffer,
    481         const _Float16* input_to_forget_weights_buffer,
    482         const _Float16* input_to_cell_weights_buffer,
    483         const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
    484         const _Float16* recurrent_to_input_weights_buffer,
    485         const _Float16* recurrent_to_forget_weights_buffer,
    486         const _Float16* recurrent_to_cell_weights_buffer,
    487         const _Float16* recurrent_to_output_weights_buffer,
    488         const Shape& recurrent_to_output_weights_shape,
    489         const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
    490         const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
    491         const _Float16* aux_input_to_input_weights_buffer,
    492         const _Float16* aux_input_to_forget_weights_buffer,
    493         const _Float16* aux_input_to_cell_weights_buffer,
    494         const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
    495         const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
    496         const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
    497         const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
    498         const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
    499         const _Float16* forget_layer_norm_weights_buffer,
    500         const _Float16* cell_layer_norm_weights_buffer,
    501         const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
    502         _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
    503         bool timeMajor, bool forwardSequence) {
    504     NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
    505 
    506     const uint32_t inputRank = getNumberOfDimensions(input_shape);
    507     NN_CHECK(inputRank == 2 || inputRank == 3);
    508 
    509     const uint32_t maxTime =
    510             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
    511     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
    512                                                 : getSizeOfDimension(input_shape, 0);
    513     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
    514     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
    515     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
    516 
    517     Shape batchInputShape = input_shape;
    518     batchInputShape.dimensions = {batchSize, inputSize};
    519     const uint32_t batchInputSize = batchSize * inputSize;
    520     const uint32_t batchOutputSize = batchSize * outputSize;
    521 
    522     std::vector<float> input_float32(maxTime * batchInputSize);
    523     convertFloat16ToFloat32(input_buffer, &input_float32);
    524     std::vector<float> input_to_input_weights_float32(numCells * inputSize);
    525     if (input_to_input_weights_buffer != nullptr) {
    526         convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
    527     }
    528     std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
    529     convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
    530     std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
    531     convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
    532     std::vector<float> input_to_output_weights_float32(numCells * inputSize);
    533     convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
    534 
    535     std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
    536     if (recurrent_to_input_weights_buffer != nullptr) {
    537         convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
    538                                 &recurrent_to_input_weights_float32);
    539     }
    540     std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
    541     convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
    542                             &recurrent_to_forget_weights_float32);
    543     std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
    544     convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
    545     std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
    546     convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
    547                             &recurrent_to_output_weights_float32);
    548 
    549     std::vector<float> cell_to_input_weights_float32(numCells);
    550     if (cell_to_input_weights_buffer != nullptr) {
    551         convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
    552     }
    553     std::vector<float> cell_to_forget_weights_float32(numCells);
    554     if (cell_to_forget_weights_buffer != nullptr) {
    555         convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
    556     }
    557     std::vector<float> cell_to_output_weights_float32(numCells);
    558     if (cell_to_output_weights_buffer != nullptr) {
    559         convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
    560     }
    561 
    562     std::vector<float> aux_input_float32(maxTime * batchInputSize);
    563     if (aux_input_buffer != nullptr) {
    564         convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
    565     }
    566     std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
    567     if (aux_input_to_input_weights_buffer != nullptr) {
    568         convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
    569                                 &aux_input_to_input_weights_float32);
    570     }
    571     std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
    572     if (aux_input_to_forget_weights_buffer != nullptr) {
    573         convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
    574                                 &aux_input_to_forget_weights_float32);
    575     }
    576     std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
    577     if (aux_input_to_cell_weights_buffer != nullptr) {
    578         convertFloat16ToFloat32(aux_input_to_cell_weights_buffer,
    579                                 &aux_input_to_cell_weights_float32);
    580     }
    581     std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
    582     if (aux_input_to_output_weights_buffer != nullptr) {
    583         convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
    584                                 &aux_input_to_output_weights_float32);
    585     }
    586 
    587     std::vector<float> input_gate_bias_float32(numCells);
    588     if (input_gate_bias_buffer != nullptr) {
    589         convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
    590     }
    591     std::vector<float> forget_gate_bias_float32(numCells);
    592     convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
    593     std::vector<float> cell_bias_float32(numCells);
    594     convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
    595     std::vector<float> output_gate_bias_float32(numCells);
    596     convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
    597 
    598     std::vector<float> projection_weights_float32(numCells * outputSize);
    599     if (projection_weights_buffer != nullptr) {
    600         convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
    601     }
    602     std::vector<float> projection_bias_float32(outputSize);
    603     if (projection_bias_buffer != nullptr) {
    604         convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
    605     }
    606 
    607     std::vector<float> input_layer_norm_weights_float32(numCells);
    608     if (input_layer_norm_weights_buffer != nullptr) {
    609         convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
    610     }
    611     std::vector<float> forget_layer_norm_weights_float32(numCells);
    612     if (forget_layer_norm_weights_buffer != nullptr) {
    613         convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
    614                                 &forget_layer_norm_weights_float32);
    615     }
    616     std::vector<float> cell_layer_norm_weights_float32(numCells);
    617     if (cell_layer_norm_weights_buffer != nullptr) {
    618         convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
    619     }
    620     std::vector<float> output_layer_norm_weights_float32(numCells);
    621     if (output_layer_norm_weights_buffer != nullptr) {
    622         convertFloat16ToFloat32(output_layer_norm_weights_buffer,
    623                                 &output_layer_norm_weights_float32);
    624     }
    625 
    626     std::vector<float> output_state_out_float32(batchOutputSize);
    627     convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
    628     std::vector<float> cell_state_out_float32(batchSize * numCells);
    629     convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
    630 
    631     std::vector<float> output_float32(maxTime * batchOutputSize);
    632     convertFloat16ToFloat32(output_buffer, &output_float32);
    633     std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
    634                                                               : 4 * batchSize * numCells);
    635     convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
    636 
    637     std::vector<float> transposedInput;
    638     const bool hasAuxInput = (aux_input_buffer != nullptr);
    639     std::vector<float> transposedAuxInput;
    640     std::vector<float> transposedOutput;
    641     Shape transposedInputShape;
    642     Shape transposedOutputShape;
    643     if (!timeMajor) {
    644         transposedInput.resize(maxTime * batchInputSize);
    645         transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
    646                                            transposedInput.data());
    647         if (hasAuxInput) {
    648             transposedAuxInput.resize(maxTime * batchInputSize);
    649             transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
    650                                                transposedAuxInput.data());
    651         }
    652         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
    653         transposedOutput.resize(maxTime * batchOutputSize);
    654         transposedOutputShape = transposedInputShape;
    655         transposedOutputShape.dimensions[2] = outputSize;
    656     }
    657     const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
    658     const float* auxInputData =
    659             hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
    660                         : nullptr;
    661     float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
    662 
    663     std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
    664     convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
    665     std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
    666     convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
    667 
    668     const float* inputCurrentTimeStep =
    669             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
    670     const float* auxInputCurrentTimeStep =
    671             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
    672                         : nullptr;
    673     float* outputCurrentTimeStep =
    674             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
    675     const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
    676     const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize;
    677 
    678     for (int t = 0; t < maxTime; ++t) {
    679         LSTMStep(params, inputCurrentTimeStep, batchInputShape,
    680                  input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
    681                  input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
    682                  input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
    683                  recurrent_to_forget_weights_float32.data(),
    684                  recurrent_to_cell_weights_float32.data(),
    685                  recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
    686                  cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
    687                  cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
    688                  aux_input_to_input_weights_float32.data(),
    689                  aux_input_to_forget_weights_float32.data(),
    690                  aux_input_to_cell_weights_float32.data(),
    691                  aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
    692                  forget_gate_bias_float32.data(), cell_bias_float32.data(),
    693                  output_gate_bias_float32.data(), projection_weights_float32.data(),
    694                  projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
    695                  cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(),
    696                  forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
    697                  output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
    698                  cell_state_out_float32.data(), outputCurrentTimeStep,
    699                  scratch_buffer_float32.data());
    700         inputCurrentTimeStep += batchInputDelta;
    701         if (hasAuxInput) {
    702             auxInputCurrentTimeStep += batchInputDelta;
    703         }
    704         outputCurrentTimeStep += batchOutputDelta;
    705         outputStateInCurrentTimeStep = output_state_out_float32;
    706         cellStateInCurrentTimeStep = cell_state_out_float32;
    707     }
    708 
    709     if (!timeMajor) {
    710         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
    711                                            output_float32.data());
    712     }
    713 
    714     convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
    715     convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
    716     convertFloat32ToFloat16(output_float32, output_buffer);
    717     convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
    718     return true;
    719 }
    720 
    721 // static
    722 bool LSTMCell::LSTMStep(
    723         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
    724         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
    725         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
    726         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
    727         const float* recurrent_to_forget_weights_buffer,
    728         const float* recurrent_to_cell_weights_buffer,
    729         const float* recurrent_to_output_weights_buffer,
    730         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
    731         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
    732         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
    733         const float* aux_input_to_forget_weights_buffer,
    734         const float* aux_input_to_cell_weights_buffer,
    735         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
    736         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
    737         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
    738         const float* projection_bias_buffer, const float* output_state_in_buffer,
    739         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
    740         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
    741         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
    742         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
    743     NNTRACE_COMP("LSTMCell::LSTMStep");
    744 
    745     const uint32_t n_batch = input_shape.dimensions[0];
    746     const uint32_t n_input = input_shape.dimensions[1];
    747     // n_cell and n_output will be the same size when there is no projection.
    748     const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
    749     const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
    750     const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
    751 
    752     // Index the scratch buffers pointers to the global scratch buffer.
    753     float* input_gate_scratch = nullptr;
    754     float* cell_scratch = nullptr;
    755     float* forget_gate_scratch = nullptr;
    756     float* output_gate_scratch = nullptr;
    757     if (params.use_cifg) {
    758         cell_scratch = scratch_buffer_buffer;
    759         forget_gate_scratch = cell_scratch + n_cell * n_batch;
    760         output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
    761     } else {
    762         input_gate_scratch = scratch_buffer_buffer;
    763         cell_scratch = input_gate_scratch + n_cell * n_batch;
    764         forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
    765         output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
    766     }
    767 
    768     if (!params.use_layer_norm) {
    769         // Initialize scratch buffers with bias.
    770         if (!params.use_cifg) {
    771             tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
    772                                                           input_gate_scratch);
    773         }
    774         tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
    775                                                       forget_gate_scratch);
    776         tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
    777                                                       cell_scratch);
    778         tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
    779                                                       output_gate_scratch);
    780     } else {
    781         // Initialize scratch buffers with zeroes.
    782         if (!params.use_cifg) {
    783             tflite::tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
    784         }
    785         tflite::tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
    786         tflite::tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
    787         tflite::tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
    788     }
    789 
    790     // For each batch and cell: compute input_weight * input.
    791     if (!params.use_cifg) {
    792         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    793                 input_to_input_weights_buffer, n_cell, n_input, input_buffer, n_batch,
    794                 input_gate_scratch, /*result_stride*/ 1);
    795     }
    796     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    797             input_to_forget_weights_buffer, n_cell, n_input, input_buffer, n_batch,
    798             forget_gate_scratch, /*result_stride*/ 1);
    799     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_buffer, n_cell,
    800                                                               n_input, input_buffer, n_batch,
    801                                                               cell_scratch, /*result_stride*/ 1);
    802     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    803             input_to_output_weights_buffer, n_cell, n_input, input_buffer, n_batch,
    804             output_gate_scratch, /*result_stride*/ 1);
    805 
    806     // If auxiliary input is available then compute aux_input_weight * aux_input
    807     if (aux_input_buffer != nullptr) {
    808         if (!params.use_cifg) {
    809             tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    810                     aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
    811                     n_batch, input_gate_scratch,
    812                     /*result_stride=*/1);
    813         }
    814 
    815         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    816                 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
    817                 forget_gate_scratch, /*result_stride=*/1);
    818         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    819                 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
    820                 cell_scratch, /*result_stride=*/1);
    821         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    822                 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
    823                 output_gate_scratch, /*result_stride=*/1);
    824     }
    825 
    826     // For each batch and cell: compute recurrent_weight * output_state.
    827     if (!params.use_cifg) {
    828         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    829                 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
    830                 n_batch, input_gate_scratch,
    831                 /*result_stride*/ 1);
    832     }
    833     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    834             recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
    835             forget_gate_scratch, /*result_stride*/ 1);
    836     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    837             recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
    838             cell_scratch, /*result_stride*/ 1);
    839     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    840             recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
    841             output_gate_scratch, /*result_stride*/ 1);
    842 
    843     // For each batch and cell: update input gate.
    844     if (!params.use_cifg) {
    845         if (params.use_peephole) {
    846             tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
    847                     cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
    848                     input_gate_scratch);
    849         }
    850         if (params.use_layer_norm) {
    851             tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
    852                                                           n_cell, n_batch, kLayerNormEpsilon);
    853             tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
    854                                                                 n_cell, input_gate_scratch, n_batch,
    855                                                                 input_gate_scratch);
    856             tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
    857                                                        input_gate_scratch);
    858         }
    859         tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
    860                                                    input_gate_scratch);
    861     }
    862 
    863     // For each batch and cell: update forget gate.
    864     if (params.use_peephole) {
    865         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
    866                                                                       n_cell, cell_state_in_buffer,
    867                                                                       n_batch, forget_gate_scratch);
    868     }
    869     if (params.use_layer_norm) {
    870         tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
    871                                                       n_cell, n_batch, kLayerNormEpsilon);
    872         tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
    873                                                             n_cell, forget_gate_scratch, n_batch,
    874                                                             forget_gate_scratch);
    875         tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
    876                                                    forget_gate_scratch);
    877     }
    878     tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
    879                                                forget_gate_scratch);
    880 
    881     // For each batch and cell: update the cell.
    882     if (params.use_layer_norm) {
    883         tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch,
    884                                                       kLayerNormEpsilon);
    885         tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
    886                                                             cell_scratch, n_batch, cell_scratch);
    887         tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
    888     }
    889     tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
    890                                                    n_batch * n_cell, cell_state_out_buffer);
    891     tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params.activation,
    892                                                   cell_scratch);
    893     if (params.use_cifg) {
    894         tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
    895                                          forget_gate_scratch);
    896         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
    897                 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
    898     } else {
    899         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
    900                 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
    901     }
    902     if (params.cell_clip > 0.0) {
    903         tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell, params.cell_clip,
    904                                          cell_state_out_buffer);
    905     }
    906 
    907     // For each batch and cell: update the output gate.
    908     if (params.use_peephole) {
    909         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
    910                                                                       n_cell, cell_state_out_buffer,
    911                                                                       n_batch, output_gate_scratch);
    912     }
    913     if (params.use_layer_norm) {
    914         tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
    915                                                       n_cell, n_batch, kLayerNormEpsilon);
    916         tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
    917                                                             n_cell, output_gate_scratch, n_batch,
    918                                                             output_gate_scratch);
    919         tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
    920                                                    output_gate_scratch);
    921     }
    922     tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
    923                                                output_gate_scratch);
    924     tflite::tensor_utils::ApplyActivationToVector(cell_state_out_buffer, n_batch * n_cell,
    925                                                   params.activation, cell_scratch);
    926     tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
    927                                                    n_batch * n_cell, output_gate_scratch);
    928 
    929     // For each batch: update the projection and output_state.
    930     if (params.use_projection_weight) {
    931         if (params.use_projection_bias) {
    932             tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
    933                                                           output_buffer);
    934         } else {
    935             tflite::tensor_utils::ZeroVector(output_buffer, n_batch * n_output);
    936         }
    937         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    938                 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
    939                 output_buffer,
    940                 /*result_stride*/ 1);
    941         if (params.proj_clip > 0.0) {
    942             tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params.proj_clip,
    943                                              output_buffer);
    944         }
    945     } else {
    946         tflite::tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, output_buffer);
    947     }
    948     tflite::tensor_utils::CopyVector(output_buffer, n_batch * n_output, output_state_out_buffer);
    949     return true;
    950 }
    951 
    952 bool LSTMCell::Eval() {
    953     switch (input_->type) {
    954         case OperandType::TENSOR_FLOAT32: {
    955             LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
    956                             GetBuffer<const float>(input_to_input_weights_),
    957                             GetBuffer<const float>(input_to_forget_weights_),
    958                             GetBuffer<const float>(input_to_cell_weights_),
    959                             GetBuffer<const float>(input_to_output_weights_),
    960                             input_to_output_weights_->shape(),
    961                             GetBuffer<const float>(recurrent_to_input_weights_),
    962                             GetBuffer<const float>(recurrent_to_forget_weights_),
    963                             GetBuffer<const float>(recurrent_to_cell_weights_),
    964                             GetBuffer<const float>(recurrent_to_output_weights_),
    965                             recurrent_to_output_weights_->shape(),
    966                             GetBuffer<const float>(cell_to_input_weights_),
    967                             GetBuffer<const float>(cell_to_forget_weights_),
    968                             GetBuffer<const float>(cell_to_output_weights_),
    969                             /*aux_input_buffer=*/nullptr,
    970                             /*aux_input_to_input_weights_buffer=*/nullptr,
    971                             /*aux_input_to_forget_weights_buffer=*/nullptr,
    972                             /*aux_input_to_cell_weights_buffer=*/nullptr,
    973                             /*aux_input_to_output_weights_buffer=*/nullptr,
    974                             GetBuffer<const float>(input_gate_bias_),
    975                             GetBuffer<const float>(forget_gate_bias_),
    976                             GetBuffer<const float>(cell_bias_),
    977                             GetBuffer<const float>(output_gate_bias_),
    978                             GetBuffer<const float>(projection_weights_),
    979                             GetBuffer<const float>(projection_bias_),
    980                             GetBuffer<const float>(output_state_in_),
    981                             GetBuffer<const float>(cell_state_in_),
    982                             GetBuffer<const float>(input_layer_norm_weights_),
    983                             GetBuffer<const float>(forget_layer_norm_weights_),
    984                             GetBuffer<const float>(cell_layer_norm_weights_),
    985                             GetBuffer<const float>(output_layer_norm_weights_),
    986                             GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
    987                             GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
    988         } break;
    989         case OperandType::TENSOR_FLOAT16: {
    990             LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
    991                             GetOptionalBuffer<const _Float16>(input_to_input_weights_),
    992                             GetBuffer<const _Float16>(input_to_forget_weights_),
    993                             GetBuffer<const _Float16>(input_to_cell_weights_),
    994                             GetBuffer<const _Float16>(input_to_output_weights_),
    995                             input_to_output_weights_->shape(),
    996                             GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
    997                             GetBuffer<const _Float16>(recurrent_to_forget_weights_),
    998                             GetBuffer<const _Float16>(recurrent_to_cell_weights_),
    999                             GetBuffer<const _Float16>(recurrent_to_output_weights_),
   1000                             recurrent_to_output_weights_->shape(),
   1001                             GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
   1002                             GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
   1003                             GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
   1004                             /*aux_input_buffer=*/nullptr,
   1005                             /*aux_input_to_input_weights_buffer=*/nullptr,
   1006                             /*aux_input_to_forget_weights_buffer=*/nullptr,
   1007                             /*aux_input_to_cell_weights_buffer=*/nullptr,
   1008                             /*aux_input_to_output_weights_buffer=*/nullptr,
   1009                             GetOptionalBuffer<const _Float16>(input_gate_bias_),
   1010                             GetBuffer<const _Float16>(forget_gate_bias_),
   1011                             GetBuffer<const _Float16>(cell_bias_),
   1012                             GetBuffer<const _Float16>(output_gate_bias_),
   1013                             GetOptionalBuffer<const _Float16>(projection_weights_),
   1014                             GetOptionalBuffer<const _Float16>(projection_bias_),
   1015                             GetBuffer<const _Float16>(output_state_in_),
   1016                             GetBuffer<const _Float16>(cell_state_in_),
   1017                             GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
   1018                             GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
   1019                             GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
   1020                             GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
   1021                             GetBuffer<_Float16>(output_state_out_),
   1022                             GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
   1023                             GetBuffer<_Float16>(scratch_buffer_));
   1024         } break;
   1025         default: {
   1026             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
   1027             return false;
   1028         }
   1029     }
   1030     return true;
   1031 }
   1032 
   1033 }  // namespace nn
   1034 }  // namespace android
   1035