Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cassert>
     17 #include <cmath>
     18 #include <cstdio>
     19 #include <cstdlib>
     20 #include <iostream>
     21 #include <limits>
     22 
     23 #include "tensorflow/lite/c/builtin_op_data.h"
     24 #include "tensorflow/lite/c/c_api_internal.h"
     25 #include "tensorflow/lite/kernels/activation_functor.h"
     26 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
     27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
     28 #include "tensorflow/lite/kernels/kernel_util.h"
     29 #include "tensorflow/lite/kernels/lstm_eval.h"
     30 #include "tensorflow/lite/kernels/op_macros.h"
     31 
     32 namespace tflite {
     33 namespace ops {
     34 namespace builtin {
     35 namespace bidirectional_sequence_lstm {
     36 
     37 // Input Tensors of size {max_time, n_batch, n_input}
     38 constexpr int kInputTensor = 0;
     39 
     40 // Forward LSTM cell tensors.
     41 // Input weight tensors of size: {n_cell, n_input}
     42 constexpr int kFwInputToInputWeightsTensor = 1;  // Optional
     43 constexpr int kFwInputToForgetWeightsTensor = 2;
     44 constexpr int kFwInputToCellWeightsTensor = 3;
     45 constexpr int kFwInputToOutputWeightsTensor = 4;
     46 
     47 // Recurrent weight tensors of size {n_cell, n_output}
     48 constexpr int kFwRecurrentToInputWeightsTensor = 5;  // Optional
     49 constexpr int kFwRecurrentToForgetWeightsTensor = 6;
     50 constexpr int kFwRecurrentToCellWeightsTensor = 7;
     51 constexpr int kFwRecurrentToOutputWeightsTensor = 8;
     52 
     53 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
     54 constexpr int kFwCellToInputWeightsTensor = 9;    // Optional
     55 constexpr int kFwCellToForgetWeightsTensor = 10;  // Optional
     56 constexpr int kFwCellToOutputWeightsTensor = 11;  // Optional
     57 
     58 // Gates bias tensors of size {n_cell}
     59 constexpr int kFwInputGateBiasTensor = 12;  // Optional
     60 constexpr int kFwForgetGateBiasTensor = 13;
     61 constexpr int kFwCellGateBiasTensor = 14;
     62 constexpr int kFwOutputGateBiasTensor = 15;
     63 
     64 // Projection weight tensor of size {n_output, n_cell}
     65 constexpr int kFwProjectionWeightsTensor = 16;  // Optional
     66 // Projection bias tensor of size {n_output}
     67 constexpr int kFwProjectionBiasTensor = 17;  // Optional
     68 
     69 // Backward LSTM cell tensors.
     70 // Input weight tensors of size: {n_cell, n_input}
     71 constexpr int kBwInputToInputWeightsTensor = 18;  // Optional
     72 constexpr int kBwInputToForgetWeightsTensor = 19;
     73 constexpr int kBwInputToCellWeightsTensor = 20;
     74 constexpr int kBwInputToOutputWeightsTensor = 21;
     75 
     76 // Recurrent weight tensors of size {n_cell, n_output}
     77 constexpr int kBwRecurrentToInputWeightsTensor = 22;  // Optional
     78 constexpr int kBwRecurrentToForgetWeightsTensor = 23;
     79 constexpr int kBwRecurrentToCellWeightsTensor = 24;
     80 constexpr int kBwRecurrentToOutputWeightsTensor = 25;
     81 
     82 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
     83 constexpr int kBwCellToInputWeightsTensor = 26;   // Optional
     84 constexpr int kBwCellToForgetWeightsTensor = 27;  // Optional
     85 constexpr int kBwCellToOutputWeightsTensor = 28;  // Optional
     86 
     87 // Gates bias tensors of size {n_cell}
     88 constexpr int kBwInputGateBiasTensor = 29;  // Optional
     89 constexpr int kBwForgetGateBiasTensor = 30;
     90 constexpr int kBwCellGateBiasTensor = 31;
     91 constexpr int kBwOutputGateBiasTensor = 32;
     92 
     93 // Projection weight tensor of size {n_output, n_cell}
     94 constexpr int kBwProjectionWeightsTensor = 33;  // Optional
     95 // Projection bias tensor of size {n_output}
     96 constexpr int kBwProjectionBiasTensor = 34;  // Optional
     97 
     98 // Stateful input tensors that are variables and will be modified by the Op.
     99 // Activation state tensors of size {n_batch, n_output}
    100 constexpr int kFwInputActivationStateTensor = 35;
    101 // Cell state tensors of size {n_batch, n_cell}
    102 constexpr int kFwInputCellStateTensor = 36;
    103 // Activation state tensors of size {n_batch, n_output}
    104 constexpr int kBwInputActivationStateTensor = 37;
    105 // Cell state tensors of size {n_batch, n_cell}
    106 constexpr int kBwInputCellStateTensor = 38;
    107 
    108 // Used as auxiliary input and weights when stacking for
    109 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
    110 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
    111 // (without cross links).
    112 constexpr int kAuxInputTensor = 39;  // Optional
    113 // Forward weights.
    114 constexpr int kFwAuxInputToInputWeightsTensor = 40;   // Optional
    115 constexpr int kFwAuxInputToForgetWeightsTensor = 41;  // Optional
    116 constexpr int kFwAuxInputToCellWeightsTensor = 42;    // Optional
    117 constexpr int kFwAuxInputToOutputWeightsTensor = 43;  // Optional
    118 // Backward weights.
    119 constexpr int kBwAuxInputToInputWeightsTensor = 44;   // Optional
    120 constexpr int kBwAuxInputToForgetWeightsTensor = 45;  // Optional
    121 constexpr int kBwAuxInputToCellWeightsTensor = 46;    // Optional
    122 constexpr int kBwAuxInputToOutputWeightsTensor = 47;  // Optional
    123 
    124 // Output tensors.
    125 constexpr int kFwOutputTensor = 0;
    126 constexpr int kBwOutputTensor = 1;  // Ignored if merge_outputs is set.
    127 
    128 // Temporary tensors.
    129 enum TemporaryTensor {
    130   // Scratch buffers for input, forget, etc. gates
    131   kFwScratchBuffer = 0,
    132   kBwScratchBuffer = 1,
    133   // Quantized tensors needed for the hybrid kernel.
    134   kInputQuantized = 2,
    135   kFwActivationStateQuantized = 3,
    136   kBwActivationStateQuantized = 4,
    137   kFwCellStateQuantized = 5,
    138   kBwCellStateQuantized = 6,
    139   kScalingFactors = 7,
    140   kProductScalingFactors = 8,
    141   kRecoveredCellWeights = 9,
    142   kAuxInputQuantized = 10,  // Optional, quantized tensor for auxiliary input.
    143   kNumTemporaryTensors = 11
    144 };
    145 
    146 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
    147   auto* scratch_tensor_index = new int;
    148   context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
    149   return scratch_tensor_index;
    150 }
    151 
    152 void Free(TfLiteContext* context, void* buffer) {
    153   delete reinterpret_cast<int*>(buffer);
    154 }
    155 
    156 // Check that input tensor dimensions matches with each other.
    157 TfLiteStatus CheckLstmTensorDimensionsAndTypes(
    158     TfLiteContext* context, TfLiteNode* node, int n_input, int n_output,
    159     int n_cell, int input_to_input_weights_tensor,
    160     int input_to_forget_weights_tensor, int input_to_cell_weights_tensor,
    161     int input_to_output_weights_tensor, int recurrent_to_input_weights_tensor,
    162     int recurrent_to_forget_weights_tensor,
    163     int recurrent_to_cell_weights_tensor,
    164     int recurrent_to_output_weights_tensor, int cell_to_input_weights_tensor,
    165     int cell_to_forget_weights_tensor, int cell_to_output_weights_tensor,
    166     int input_gate_bias_tensor, int forget_gate_bias_tensor,
    167     int cell_gate_bias_tensor, int output_gate_bias_tensor,
    168     int projection_weights_tensor, int projection_bias_tensor) {
    169   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
    170       node->builtin_data);
    171 
    172   // Making sure clipping parameters have valid values.
    173   // == 0 means no clipping
    174   //  > 0 means clipping
    175   TF_LITE_ENSURE(context, params->cell_clip >= 0);
    176   TF_LITE_ENSURE(context, params->proj_clip >= 0);
    177 
    178   const TfLiteTensor* input_to_forget_weights =
    179       GetInput(context, node, input_to_forget_weights_tensor);
    180   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
    181   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
    182   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
    183   TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) ||
    184                               (input_to_forget_weights->type == kTfLiteUInt8));
    185 
    186   const TfLiteTensor* input_to_input_weights =
    187       GetOptionalInputTensor(context, node, input_to_input_weights_tensor);
    188   if (input_to_input_weights != nullptr) {
    189     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
    190     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
    191     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
    192     TF_LITE_ENSURE_EQ(context, input_to_input_weights->type,
    193                       input_to_forget_weights->type);
    194   }
    195 
    196   const TfLiteTensor* input_to_cell_weights =
    197       GetInput(context, node, input_to_cell_weights_tensor);
    198   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
    199   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
    200   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
    201   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->type,
    202                     input_to_forget_weights->type);
    203 
    204   const TfLiteTensor* input_to_output_weights =
    205       GetInput(context, node, input_to_output_weights_tensor);
    206   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
    207   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell);
    208   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
    209   TF_LITE_ENSURE_EQ(context, input_to_output_weights->type,
    210                     input_to_forget_weights->type);
    211 
    212   const TfLiteTensor* recurrent_to_input_weights =
    213       GetOptionalInputTensor(context, node, recurrent_to_input_weights_tensor);
    214   if (recurrent_to_input_weights != nullptr) {
    215     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
    216     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
    217                       n_cell);
    218     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
    219                       n_output);
    220     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->type,
    221                       input_to_forget_weights->type);
    222   }
    223 
    224   const TfLiteTensor* recurrent_to_forget_weights =
    225       GetInput(context, node, recurrent_to_forget_weights_tensor);
    226   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
    227   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
    228                     n_cell);
    229   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
    230                     n_output);
    231   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->type,
    232                     input_to_forget_weights->type);
    233 
    234   const TfLiteTensor* recurrent_to_cell_weights =
    235       GetInput(context, node, recurrent_to_cell_weights_tensor);
    236   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
    237   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
    238   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
    239                     n_output);
    240   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->type,
    241                     input_to_forget_weights->type);
    242 
    243   // We make sure the input-gate's parameters are either both present (regular
    244   // LSTM) or not at all (CIFG-LSTM).
    245   const bool cifg_weights_all_or_none =
    246       ((input_to_input_weights != nullptr) &&
    247        (recurrent_to_input_weights != nullptr)) ||
    248       ((input_to_input_weights == nullptr) &&
    249        (recurrent_to_input_weights == nullptr));
    250   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
    251 
    252   const TfLiteTensor* cell_to_input_weights =
    253       GetOptionalInputTensor(context, node, cell_to_input_weights_tensor);
    254   if (cell_to_input_weights != nullptr) {
    255     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
    256     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
    257     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->type,
    258                       input_to_forget_weights->type);
    259   }
    260 
    261   const TfLiteTensor* cell_to_forget_weights =
    262       GetOptionalInputTensor(context, node, cell_to_forget_weights_tensor);
    263   if (cell_to_forget_weights != nullptr) {
    264     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
    265     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
    266     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->type,
    267                       input_to_forget_weights->type);
    268   }
    269 
    270   const TfLiteTensor* cell_to_output_weights =
    271       GetOptionalInputTensor(context, node, cell_to_output_weights_tensor);
    272   if (cell_to_output_weights != nullptr) {
    273     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
    274     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
    275     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->type,
    276                       input_to_forget_weights->type);
    277   }
    278 
    279   // Making sure the peephole weights are there all or none.
    280   const bool use_cifg = (input_to_input_weights == nullptr);
    281   const bool peephole_weights_all_or_none =
    282       ((cell_to_input_weights != nullptr || use_cifg) &&
    283        (cell_to_forget_weights != nullptr) &&
    284        (cell_to_output_weights != nullptr)) ||
    285       ((cell_to_input_weights == nullptr) &&
    286        (cell_to_forget_weights == nullptr) &&
    287        (cell_to_output_weights == nullptr));
    288   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
    289 
    290   // Make sure the input gate bias is present only when not a CIFG-LSTM.
    291   const TfLiteTensor* input_gate_bias =
    292       GetOptionalInputTensor(context, node, input_gate_bias_tensor);
    293   if (use_cifg) {
    294     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
    295   } else {
    296     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
    297     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
    298     TF_LITE_ENSURE_EQ(context, input_gate_bias->type, kTfLiteFloat32);
    299   }
    300 
    301   const TfLiteTensor* forget_gate_bias =
    302       GetInput(context, node, forget_gate_bias_tensor);
    303   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
    304   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
    305   TF_LITE_ENSURE_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
    306 
    307   const TfLiteTensor* cell_bias =
    308       GetInput(context, node, cell_gate_bias_tensor);
    309   TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
    310   TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
    311   TF_LITE_ENSURE_EQ(context, cell_bias->type, kTfLiteFloat32);
    312 
    313   const TfLiteTensor* output_gate_bias =
    314       GetInput(context, node, output_gate_bias_tensor);
    315   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
    316   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
    317   TF_LITE_ENSURE_EQ(context, output_gate_bias->type, kTfLiteFloat32);
    318 
    319   const TfLiteTensor* projection_weights =
    320       GetOptionalInputTensor(context, node, projection_weights_tensor);
    321   if (projection_weights != nullptr) {
    322     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
    323     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
    324     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
    325     TF_LITE_ENSURE_EQ(context, projection_weights->type,
    326                       input_to_forget_weights->type);
    327   }
    328 
    329   const TfLiteTensor* projection_bias =
    330       GetOptionalInputTensor(context, node, projection_bias_tensor);
    331   if (projection_bias != nullptr) {
    332     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
    333     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
    334     TF_LITE_ENSURE_EQ(context, projection_bias->type, kTfLiteFloat32);
    335   }
    336 
    337   // Making sure the projection tensors are consistent:
    338   // 1) If projection weight is not present, then projection bias should not be
    339   // present.
    340   // 2) If projection weight is present, then projection bias is optional.
    341   // TODO(ghodrat): make sure this is correct.
    342   const bool projecton_tensors_consistent =
    343       ((projection_weights != nullptr) || (projection_bias == nullptr));
    344   TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
    345 
    346   return kTfLiteOk;
    347 }
    348 
    349 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
    350                                         TfLiteNode* node, int n_input,
    351                                         int n_output, int n_cell) {
    352   TF_LITE_ENSURE_OK(
    353       context,
    354       CheckLstmTensorDimensionsAndTypes(
    355           context, node, n_input, n_output, n_cell,
    356           kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor,
    357           kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor,
    358           kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor,
    359           kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor,
    360           kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor,
    361           kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor,
    362           kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
    363           kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
    364           kFwProjectionBiasTensor));
    365 
    366   TF_LITE_ENSURE_OK(
    367       context,
    368       CheckLstmTensorDimensionsAndTypes(
    369           context, node, n_input, n_output, n_cell,
    370           kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor,
    371           kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor,
    372           kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor,
    373           kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor,
    374           kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor,
    375           kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor,
    376           kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
    377           kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
    378           kBwProjectionBiasTensor));
    379 
    380   // Check if Forward and Backward tensors match along required dimensions.
    381   return kTfLiteOk;
    382 }
    383 
    384 // Resize the output and scratch tensors based on the sizes of the input
    385 // tensors. Also check that the size of the input tensors match each other.
    386 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    387   int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
    388   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
    389       node->builtin_data);
    390 
    391   // Check we have all the inputs and outputs we need.
    392   TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
    393   TF_LITE_ENSURE_EQ(context, node->outputs->size,
    394                     params->merge_outputs ? 1 : 2);
    395 
    396   // Inferring batch size, number of outputs and sequence length and
    397   // number of cells from the input tensors.
    398   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    399   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
    400   TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
    401   const bool time_major = params->time_major;
    402   const int max_time = time_major ? input->dims->data[0] : input->dims->data[1];
    403   const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
    404   const int n_input = input->dims->data[2];
    405 
    406   const TfLiteTensor* fw_input_to_output_weights =
    407       GetInput(context, node, kFwInputToOutputWeightsTensor);
    408   const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
    409   TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2);
    410   TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
    411                     n_input);
    412 
    413   const TfLiteTensor* bw_input_to_output_weights =
    414       GetInput(context, node, kBwInputToOutputWeightsTensor);
    415   const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
    416   TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
    417   TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
    418                     n_input);
    419   TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type,
    420                     fw_input_to_output_weights->type);
    421 
    422   const TfLiteTensor* fw_recurrent_to_output_weights =
    423       GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
    424   TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
    425   TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0],
    426                     n_fw_cell);
    427   TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->type,
    428                     fw_input_to_output_weights->type);
    429   const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
    430 
    431   const TfLiteTensor* bw_recurrent_to_output_weights =
    432       GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
    433   TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
    434   TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
    435                     n_bw_cell);
    436   TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->type,
    437                     fw_input_to_output_weights->type);
    438   const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
    439 
    440   // Check that input tensor dimensions matches with each other.
    441   TF_LITE_ENSURE_OK(
    442       context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
    443                                           n_fw_cell));
    444 
    445   // Get (optional) auxiliary inputs and weights.
    446   const TfLiteTensor* aux_input =
    447       GetOptionalInputTensor(context, node, kAuxInputTensor);
    448   const TfLiteTensor* fw_aux_input_to_input_weights =
    449       GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
    450   const TfLiteTensor* fw_aux_input_to_forget_weights =
    451       GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
    452   const TfLiteTensor* fw_aux_input_to_cell_weights =
    453       GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
    454   const TfLiteTensor* fw_aux_input_to_output_weights =
    455       GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
    456   const TfLiteTensor* bw_aux_input_to_input_weights =
    457       GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
    458   const TfLiteTensor* bw_aux_input_to_forget_weights =
    459       GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
    460   const TfLiteTensor* bw_aux_input_to_cell_weights =
    461       GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
    462   const TfLiteTensor* bw_aux_input_to_output_weights =
    463       GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
    464 
    465   const bool aux_inputs_weights_all_or_none =
    466       ((fw_aux_input_to_cell_weights != nullptr) &&
    467        (fw_aux_input_to_forget_weights != nullptr) &&
    468        (fw_aux_input_to_output_weights != nullptr) &&
    469        (bw_aux_input_to_cell_weights != nullptr) &&
    470        (bw_aux_input_to_forget_weights != nullptr) &&
    471        (bw_aux_input_to_output_weights != nullptr)) ||
    472       ((fw_aux_input_to_cell_weights == nullptr) &&
    473        (fw_aux_input_to_forget_weights == nullptr) &&
    474        (fw_aux_input_to_output_weights == nullptr) &&
    475        (bw_aux_input_to_cell_weights == nullptr) &&
    476        (bw_aux_input_to_forget_weights == nullptr) &&
    477        (bw_aux_input_to_output_weights == nullptr));
    478   TF_LITE_ENSURE(context, aux_inputs_weights_all_or_none);
    479 
    480   const bool has_aux_input = (fw_aux_input_to_forget_weights != nullptr);
    481 
    482   if (has_aux_input) {
    483     // Check that aux_input has the same dimensions (except last) as the input.
    484     TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
    485     TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
    486   }
    487 
    488   // Get the pointer to output, activation_state and cell_state buffer tensors.
    489   TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
    490   TfLiteTensor* fw_activation_state =
    491       GetVariableInput(context, node, kFwInputActivationStateTensor);
    492   TfLiteTensor* fw_cell_state =
    493       GetVariableInput(context, node, kFwInputCellStateTensor);
    494 
    495   // Check the shape of input state tensors.
    496   // These tensor may be 1D or 2D. It's fine as long as the total size is
    497   // correct.
    498   TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state),
    499                     n_batch * n_fw_output);
    500   TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell);
    501 
    502   // Resize the output tensors.
    503   TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
    504   fw_output_size->data[0] = time_major ? max_time : n_batch;
    505   fw_output_size->data[1] = time_major ? n_batch : max_time;
    506   fw_output_size->data[2] =
    507       params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output;
    508   TF_LITE_ENSURE_OK(context,
    509                     context->ResizeTensor(context, fw_output, fw_output_size));
    510 
    511   // The weights are of consistent type, so it suffices to check one.
    512   const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8 ||
    513                              fw_input_to_output_weights->type == kTfLiteInt8);
    514 
    515   TfLiteIntArrayFree(node->temporaries);
    516   if (is_hybrid_op) {
    517     node->temporaries = TfLiteIntArrayCreate(
    518         has_aux_input ? kNumTemporaryTensors : kNumTemporaryTensors - 1);
    519   } else {
    520     node->temporaries = TfLiteIntArrayCreate(2);  // the two scratch buffers.
    521   }
    522   // Create a scratch buffer tensor.
    523   node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index;
    524   TfLiteTensor* fw_scratch_buffer =
    525       GetTemporary(context, node, kFwScratchBuffer);
    526   fw_scratch_buffer->type = input->type;
    527   fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
    528 
    529   const TfLiteTensor* fw_input_to_input_weights =
    530       GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
    531   const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
    532   if (has_aux_input && !fw_use_cifg) {
    533     TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0],
    534                       fw_input_to_input_weights->dims->data[0]);
    535   }
    536   TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2);
    537   fw_scratch_buffer_size->data[0] = n_batch;
    538   if (fw_use_cifg) {
    539     // Reserving space for Cell, Forget, Output gates
    540     fw_scratch_buffer_size->data[1] = n_fw_cell * 3;
    541   } else {
    542     // Reserving space for Input, Cell, Forget, Output gates
    543     fw_scratch_buffer_size->data[1] = n_fw_cell * 4;
    544   }
    545   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
    546                                                    fw_scratch_buffer_size));
    547   // Same for the backward cell.
    548 
    549   // Check that input tensor dimensions matches with each other.
    550   TF_LITE_ENSURE_OK(
    551       context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
    552                                           n_bw_cell));
    553 
    554   // Get the pointer to activation_state and cell_state buffer tensors.
    555   TfLiteTensor* bw_activation_state =
    556       GetVariableInput(context, node, kBwInputActivationStateTensor);
    557   TfLiteTensor* bw_cell_state =
    558       GetVariableInput(context, node, kBwInputCellStateTensor);
    559 
    560   // Resize the output tensors.
    561   if (!params->merge_outputs) {
    562     TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
    563     TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
    564     bw_output_size->data[0] = time_major ? max_time : n_batch;
    565     bw_output_size->data[1] = time_major ? n_batch : max_time;
    566     bw_output_size->data[2] = n_bw_output;
    567     TF_LITE_ENSURE_OK(
    568         context, context->ResizeTensor(context, bw_output, bw_output_size));
    569   }
    570 
    571   // Check the shape of input state tensors.
    572   // These tensor may be 1D or 2D. It's fine as long as the total size is
    573   // correct.
    574   TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state),
    575                     n_batch * n_bw_output);
    576   TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
    577 
    578   // Create a scratch buffer tensor.
    579   node->temporaries->data[kBwScratchBuffer] =
    580       *(scratch_tensor_index) + kBwScratchBuffer;
    581   TfLiteTensor* bw_scratch_buffer =
    582       GetTemporary(context, node, kBwScratchBuffer);
    583   bw_scratch_buffer->type = input->type;
    584   bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
    585 
    586   const TfLiteTensor* bw_input_to_input_weights =
    587       GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
    588   const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
    589   if (has_aux_input && !bw_use_cifg) {
    590     TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0],
    591                       bw_input_to_input_weights->dims->data[0]);
    592   }
    593   TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2);
    594   bw_scratch_buffer_size->data[0] = n_batch;
    595   if (bw_use_cifg) {
    596     // Reserving space for Cell, Forget, Output gates
    597     bw_scratch_buffer_size->data[1] = n_bw_cell * 3;
    598   } else {
    599     // Reserving space for Input, Cell, Forget, Output gates
    600     bw_scratch_buffer_size->data[1] = n_bw_cell * 4;
    601   }
    602   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
    603                                                    bw_scratch_buffer_size));
    604   if (is_hybrid_op) {
    605     // Allocate temporary tensors to store quantized values of input, aux_input
    606     // (if present), activation_state and cell_state tensors.
    607     node->temporaries->data[kInputQuantized] =
    608         *scratch_tensor_index + kInputQuantized;
    609     TfLiteTensor* input_quantized =
    610         GetTemporary(context, node, kInputQuantized);
    611     input_quantized->type = fw_input_to_output_weights->type;
    612     input_quantized->allocation_type = kTfLiteArenaRw;
    613     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
    614       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
    615       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
    616                                                        input_quantized_size));
    617     }
    618 
    619     node->temporaries->data[kFwActivationStateQuantized] =
    620         *scratch_tensor_index + kFwActivationStateQuantized;
    621     TfLiteTensor* fw_activation_state_quantized =
    622         GetTemporary(context, node, kFwActivationStateQuantized);
    623     fw_activation_state_quantized->type = fw_input_to_output_weights->type;
    624     fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
    625     if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
    626                              fw_activation_state->dims)) {
    627       TfLiteIntArray* fw_activation_state_quantized_size =
    628           TfLiteIntArrayCopy(fw_activation_state->dims);
    629       TF_LITE_ENSURE_OK(
    630           context, context->ResizeTensor(context, fw_activation_state_quantized,
    631                                          fw_activation_state_quantized_size));
    632     }
    633     node->temporaries->data[kBwActivationStateQuantized] =
    634         *scratch_tensor_index + kBwActivationStateQuantized;
    635     TfLiteTensor* bw_activation_state_quantized =
    636         GetTemporary(context, node, kBwActivationStateQuantized);
    637     bw_activation_state_quantized->type = fw_input_to_output_weights->type;
    638     bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
    639     if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
    640                              bw_activation_state->dims)) {
    641       TfLiteIntArray* bw_activation_state_quantized_size =
    642           TfLiteIntArrayCopy(bw_activation_state->dims);
    643       TF_LITE_ENSURE_OK(
    644           context, context->ResizeTensor(context, bw_activation_state_quantized,
    645                                          bw_activation_state_quantized_size));
    646     }
    647     node->temporaries->data[kFwCellStateQuantized] =
    648         *scratch_tensor_index + kFwCellStateQuantized;
    649     TfLiteTensor* fw_cell_state_quantized =
    650         GetTemporary(context, node, kFwCellStateQuantized);
    651     fw_cell_state_quantized->type = fw_input_to_output_weights->type;
    652     fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
    653     if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
    654                              fw_cell_state->dims)) {
    655       TfLiteIntArray* fw_cell_state_quantized_size =
    656           TfLiteIntArrayCopy(fw_cell_state->dims);
    657       TF_LITE_ENSURE_OK(context,
    658                         context->ResizeTensor(context, fw_cell_state_quantized,
    659                                               fw_cell_state_quantized_size));
    660     }
    661     node->temporaries->data[kBwCellStateQuantized] =
    662         *scratch_tensor_index + kBwCellStateQuantized;
    663     TfLiteTensor* bw_cell_state_quantized =
    664         GetTemporary(context, node, kBwCellStateQuantized);
    665     bw_cell_state_quantized->type = fw_input_to_output_weights->type;
    666     bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
    667     if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
    668                              bw_cell_state->dims)) {
    669       TfLiteIntArray* bw_cell_state_quantized_size =
    670           TfLiteIntArrayCopy(bw_cell_state->dims);
    671       TF_LITE_ENSURE_OK(context,
    672                         context->ResizeTensor(context, bw_cell_state_quantized,
    673                                               bw_cell_state_quantized_size));
    674     }
    675 
    676     // Allocate temporary tensors to store scaling factors and product scaling
    677     // factors. The latter is a convenience storage which allows to quantize
    678     // a vector once (which produces the scaling factors) and multiply it with
    679     // different matrices (which requires multiplying the scaling factors with
    680     // the scaling factor of the matrix).
    681     node->temporaries->data[kScalingFactors] =
    682         *scratch_tensor_index + kScalingFactors;
    683     TfLiteTensor* scaling_factors =
    684         GetTemporary(context, node, kScalingFactors);
    685     scaling_factors->type = kTfLiteFloat32;
    686     scaling_factors->allocation_type = kTfLiteArenaRw;
    687     int scaling_dims[1] = {n_batch};
    688     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
    689       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
    690       scaling_factors_size->data[0] = n_batch;
    691       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
    692                                                        scaling_factors_size));
    693     }
    694     node->temporaries->data[kProductScalingFactors] =
    695         *scratch_tensor_index + kProductScalingFactors;
    696     TfLiteTensor* prod_scaling_factors =
    697         GetTemporary(context, node, kProductScalingFactors);
    698     prod_scaling_factors->type = kTfLiteFloat32;
    699     prod_scaling_factors->allocation_type = kTfLiteArenaRw;
    700     if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
    701                                    scaling_dims)) {
    702       TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
    703       prod_scaling_factors_size->data[0] = n_batch;
    704       TF_LITE_ENSURE_OK(context,
    705                         context->ResizeTensor(context, prod_scaling_factors,
    706                                               prod_scaling_factors_size));
    707     }
    708 
    709     // Allocate a temporary tensor to store the recovered cell weights. Since
    710     // this is used for diagonal matrices, only need to store n_cell values.
    711     node->temporaries->data[kRecoveredCellWeights] =
    712         *scratch_tensor_index + kRecoveredCellWeights;
    713     TfLiteTensor* recovered_cell_weights =
    714         GetTemporary(context, node, kRecoveredCellWeights);
    715     recovered_cell_weights->type = kTfLiteFloat32;
    716     recovered_cell_weights->allocation_type = kTfLiteArenaRw;
    717     int recovered_cell_dims[1] = {n_fw_cell};
    718     if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
    719                                    recovered_cell_dims)) {
    720       TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
    721       recovered_cell_weights_size->data[0] = n_fw_cell;
    722       TF_LITE_ENSURE_OK(context,
    723                         context->ResizeTensor(context, recovered_cell_weights,
    724                                               recovered_cell_weights_size));
    725     }
    726 
    727     // Only allocate a temporary tensor for quantized auxiliary input if we are
    728     // actually going to use it.
    729     if (has_aux_input) {
    730       node->temporaries->data[kAuxInputQuantized] =
    731           *scratch_tensor_index + kAuxInputQuantized;
    732       TfLiteTensor* aux_input_quantized =
    733           GetTemporary(context, node, kAuxInputQuantized);
    734       aux_input_quantized->type = fw_input_to_output_weights->type;
    735       aux_input_quantized->allocation_type = kTfLiteArenaRw;
    736       if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
    737         TfLiteIntArray* aux_input_quantized_size =
    738             TfLiteIntArrayCopy(aux_input->dims);
    739         TF_LITE_ENSURE_OK(context,
    740                           context->ResizeTensor(context, aux_input_quantized,
    741                                                 aux_input_quantized_size));
    742       }
    743     }
    744   }
    745   return kTfLiteOk;
    746 }
    747 
    748 // The LSTM Op engine.
    749 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    750   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
    751       node->builtin_data);
    752 
    753   // Input tensor.
    754   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    755 
    756   // Tensors for the forward cell.
    757   const TfLiteTensor* fw_input_to_input_weights =
    758       GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
    759   const TfLiteTensor* fw_input_to_forget_weights =
    760       GetInput(context, node, kFwInputToForgetWeightsTensor);
    761   const TfLiteTensor* fw_input_to_cell_weights =
    762       GetInput(context, node, kFwInputToCellWeightsTensor);
    763   const TfLiteTensor* fw_input_to_output_weights =
    764       GetInput(context, node, kFwInputToOutputWeightsTensor);
    765 
    766   const TfLiteTensor* fw_recurrent_to_input_weights =
    767       GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor);
    768   const TfLiteTensor* fw_recurrent_to_forget_weights =
    769       GetInput(context, node, kFwRecurrentToForgetWeightsTensor);
    770   const TfLiteTensor* fw_recurrent_to_cell_weights =
    771       GetInput(context, node, kFwRecurrentToCellWeightsTensor);
    772   const TfLiteTensor* fw_recurrent_to_output_weights =
    773       GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
    774 
    775   const TfLiteTensor* fw_cell_to_input_weights =
    776       GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor);
    777   const TfLiteTensor* fw_cell_to_forget_weights =
    778       GetOptionalInputTensor(context, node, kFwCellToForgetWeightsTensor);
    779   const TfLiteTensor* fw_cell_to_output_weights =
    780       GetOptionalInputTensor(context, node, kFwCellToOutputWeightsTensor);
    781 
    782   const TfLiteTensor* fw_input_gate_bias =
    783       GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
    784   const TfLiteTensor* fw_forget_gate_bias =
    785       GetInput(context, node, kFwForgetGateBiasTensor);
    786   const TfLiteTensor* fw_cell_bias =
    787       GetInput(context, node, kFwCellGateBiasTensor);
    788   const TfLiteTensor* fw_output_gate_bias =
    789       GetInput(context, node, kFwOutputGateBiasTensor);
    790 
    791   const TfLiteTensor* fw_projection_weights =
    792       GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
    793   const TfLiteTensor* fw_projection_bias =
    794       GetOptionalInputTensor(context, node, kFwProjectionBiasTensor);
    795 
    796   TfLiteTensor* fw_activation_state =
    797       GetVariableInput(context, node, kFwInputActivationStateTensor);
    798   TfLiteTensor* fw_cell_state =
    799       GetVariableInput(context, node, kFwInputCellStateTensor);
    800   TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
    801 
    802   // Tensors for the backward cell.
    803   const TfLiteTensor* bw_input_to_input_weights =
    804       GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
    805   const TfLiteTensor* bw_input_to_forget_weights =
    806       GetInput(context, node, kBwInputToForgetWeightsTensor);
    807   const TfLiteTensor* bw_input_to_cell_weights =
    808       GetInput(context, node, kBwInputToCellWeightsTensor);
    809   const TfLiteTensor* bw_input_to_output_weights =
    810       GetInput(context, node, kBwInputToOutputWeightsTensor);
    811 
    812   const TfLiteTensor* bw_recurrent_to_input_weights =
    813       GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor);
    814   const TfLiteTensor* bw_recurrent_to_forget_weights =
    815       GetInput(context, node, kBwRecurrentToForgetWeightsTensor);
    816   const TfLiteTensor* bw_recurrent_to_cell_weights =
    817       GetInput(context, node, kBwRecurrentToCellWeightsTensor);
    818   const TfLiteTensor* bw_recurrent_to_output_weights =
    819       GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
    820 
    821   const TfLiteTensor* bw_cell_to_input_weights =
    822       GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor);
    823   const TfLiteTensor* bw_cell_to_forget_weights =
    824       GetOptionalInputTensor(context, node, kBwCellToForgetWeightsTensor);
    825   const TfLiteTensor* bw_cell_to_output_weights =
    826       GetOptionalInputTensor(context, node, kBwCellToOutputWeightsTensor);
    827 
    828   const TfLiteTensor* bw_input_gate_bias =
    829       GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
    830   const TfLiteTensor* bw_forget_gate_bias =
    831       GetInput(context, node, kBwForgetGateBiasTensor);
    832   const TfLiteTensor* bw_cell_bias =
    833       GetInput(context, node, kBwCellGateBiasTensor);
    834   const TfLiteTensor* bw_output_gate_bias =
    835       GetInput(context, node, kBwOutputGateBiasTensor);
    836 
    837   const TfLiteTensor* bw_projection_weights =
    838       GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
    839   const TfLiteTensor* bw_projection_bias =
    840       GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
    841 
    842   // State tensors.
    843   TfLiteTensor* bw_activation_state =
    844       GetVariableInput(context, node, kBwInputActivationStateTensor);
    845   TfLiteTensor* bw_cell_state =
    846       GetVariableInput(context, node, kBwInputCellStateTensor);
    847   TfLiteTensor* bw_output = params->merge_outputs
    848                                 ? nullptr
    849                                 : GetOutput(context, node, kBwOutputTensor);
    850 
    851   // Temporary tensors.
    852   TfLiteTensor* fw_scratch_buffer =
    853       GetTemporary(context, node, kFwScratchBuffer);
    854   TfLiteTensor* bw_scratch_buffer =
    855       GetTemporary(context, node, kBwScratchBuffer);
    856 
    857   // (Optional) auxiliary inputs.
    858   const TfLiteTensor* aux_input =
    859       GetOptionalInputTensor(context, node, kAuxInputTensor);
    860   const TfLiteTensor* fw_aux_input_to_input_weights =
    861       GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
    862   const TfLiteTensor* fw_aux_input_to_forget_weights =
    863       GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
    864   const TfLiteTensor* fw_aux_input_to_cell_weights =
    865       GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
    866   const TfLiteTensor* fw_aux_input_to_output_weights =
    867       GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
    868   const TfLiteTensor* bw_aux_input_to_input_weights =
    869       GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
    870   const TfLiteTensor* bw_aux_input_to_forget_weights =
    871       GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
    872   const TfLiteTensor* bw_aux_input_to_cell_weights =
    873       GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
    874   const TfLiteTensor* bw_aux_input_to_output_weights =
    875       GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
    876 
    877   const bool has_previous_bw_output = (aux_input != nullptr);
    878   const bool use_aux_input = (fw_aux_input_to_forget_weights != nullptr);
    879 
    880   // Populate a TfLiteLSTMParams struct for the evaluation functions.
    881   TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
    882                                   params->proj_clip, kTfLiteLSTMFullKernel};
    883 
    884   const int bw_output_offset =
    885       params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
    886   const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
    887 
    888   const bool time_major = params->time_major;
    889 
    890   // We want to cover the following cases:
    891   //
    892   // If not stacking (not connected after other bidi lstms):
    893   //   both fw & bw will just use `input`; aux_input will be null.
    894   //
    895   // If stacking with cross_links, TensorFlow equivalent
    896   // (tf.contrib.rnn.stack_bidirectional_rnn):
    897   //   both fw & bw will use `input`, but aux_input will be none null.
    898   //   Note, this time, whether connected after other bidi lstms both works.
    899   //
    900   // If stacking without cross_links, but connected after other bidi lstms,
    901   // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
    902   //   fw will use `input`, bw will use aux_input, and the `real aux_input`
    903   //   will be null.
    904 
    905   const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
    906   const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
    907   const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
    908 
    909   switch (fw_input_to_output_weights->type) {
    910     case kTfLiteFloat32: {
    911       TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
    912           input, fw_input_to_input_weights, fw_input_to_forget_weights,
    913           fw_input_to_cell_weights, fw_input_to_output_weights,
    914           fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
    915           fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
    916           fw_cell_to_input_weights, fw_cell_to_forget_weights,
    917           fw_cell_to_output_weights,
    918           /*input_layer_norm_coefficients=*/nullptr,
    919           /*forget_layer_norm_coefficients=*/nullptr,
    920           /*cell_layer_norm_coefficients=*/nullptr,
    921           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
    922           fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
    923           fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
    924           fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
    925           fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
    926           &lstm_params,
    927           /*forward_sequence=*/true, time_major, /*output_offset=*/0,
    928           fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output);
    929       TF_LITE_ENSURE_OK(context, fw_pass_status);
    930 
    931       TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
    932           bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
    933           bw_input_to_cell_weights, bw_input_to_output_weights,
    934           bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
    935           bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
    936           bw_cell_to_input_weights, bw_cell_to_forget_weights,
    937           bw_cell_to_output_weights,
    938           /*input_layer_norm_coefficients=*/nullptr,
    939           /*forget_layer_norm_coefficients=*/nullptr,
    940           /*cell_layer_norm_coefficients=*/nullptr,
    941           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
    942           bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
    943           bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
    944           bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
    945           bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
    946           &lstm_params,
    947           /*forward_sequence=*/false, time_major, bw_output_offset,
    948           bw_scratch_buffer, bw_activation_state, bw_cell_state,
    949           actual_bw_output);
    950       TF_LITE_ENSURE_OK(context, bw_pass_status);
    951       return kTfLiteOk;
    952     }
    953     case kTfLiteUInt8:
    954     case kTfLiteInt8: {
    955       TfLiteTensor* input_quantized =
    956           GetTemporary(context, node, kInputQuantized);
    957       TfLiteTensor* fw_activation_state_quantized =
    958           GetTemporary(context, node, kFwActivationStateQuantized);
    959       TfLiteTensor* bw_activation_state_quantized =
    960           GetTemporary(context, node, kBwActivationStateQuantized);
    961       TfLiteTensor* fw_cell_state_quantized =
    962           GetTemporary(context, node, kFwCellStateQuantized);
    963       TfLiteTensor* bw_cell_state_quantized =
    964           GetTemporary(context, node, kBwCellStateQuantized);
    965       TfLiteTensor* scaling_factors =
    966           GetTemporary(context, node, kScalingFactors);
    967       TfLiteTensor* prod_scaling_factors =
    968           GetTemporary(context, node, kProductScalingFactors);
    969       TfLiteTensor* recovered_cell_weights =
    970           GetTemporary(context, node, kRecoveredCellWeights);
    971       TfLiteTensor* aux_input_quantized =
    972           use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
    973                         : nullptr;
    974 
    975       TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
    976           input, fw_input_to_input_weights, fw_input_to_forget_weights,
    977           fw_input_to_cell_weights, fw_input_to_output_weights,
    978           fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
    979           fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
    980           fw_cell_to_input_weights, fw_cell_to_forget_weights,
    981           fw_cell_to_output_weights,
    982           /*input_layer_norm_coefficients=*/nullptr,
    983           /*forget_layer_norm_coefficients=*/nullptr,
    984           /*cell_layer_norm_coefficients=*/nullptr,
    985           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
    986           fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
    987           fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
    988           fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
    989           fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
    990           &lstm_params,
    991           /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
    992           fw_scratch_buffer, scaling_factors, prod_scaling_factors,
    993           recovered_cell_weights, input_quantized, aux_input_quantized,
    994           fw_activation_state_quantized, fw_cell_state_quantized,
    995           fw_activation_state, fw_cell_state, fw_output);
    996       TF_LITE_ENSURE_OK(context, fw_pass_status);
    997 
    998       TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
    999           bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
   1000           bw_input_to_cell_weights, bw_input_to_output_weights,
   1001           bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
   1002           bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
   1003           bw_cell_to_input_weights, bw_cell_to_forget_weights,
   1004           bw_cell_to_output_weights,
   1005           /*input_layer_norm_coefficients=*/nullptr,
   1006           /*forget_layer_norm_coefficients=*/nullptr,
   1007           /*cell_layer_norm_coefficients=*/nullptr,
   1008           /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
   1009           bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
   1010           bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
   1011           bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
   1012           bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
   1013           &lstm_params,
   1014           /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
   1015           bw_scratch_buffer, scaling_factors, prod_scaling_factors,
   1016           recovered_cell_weights, input_quantized, aux_input_quantized,
   1017           bw_activation_state_quantized, bw_cell_state_quantized,
   1018           bw_activation_state, bw_cell_state, actual_bw_output);
   1019       TF_LITE_ENSURE_OK(context, bw_pass_status);
   1020       return kTfLiteOk;
   1021     }
   1022     default:
   1023       context->ReportError(context, "Type %d is not currently supported.",
   1024                            fw_input_to_output_weights->type);
   1025       return kTfLiteError;
   1026   }
   1027   return kTfLiteOk;
   1028 }
   1029 
   1030 }  // namespace bidirectional_sequence_lstm
   1031 
   1032 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() {
   1033   static TfLiteRegistration r = {
   1034       bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free,
   1035       bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval};
   1036   return &r;
   1037 }
   1038 
   1039 }  // namespace builtin
   1040 }  // namespace ops
   1041 }  // namespace tflite
   1042