Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cassert>
     17 #include <cmath>
     18 #include <cstdio>
     19 #include <cstdlib>
     20 #include <iostream>
     21 #include <limits>
     22 
     23 #include "tensorflow/lite/c/builtin_op_data.h"
     24 #include "tensorflow/lite/c/c_api_internal.h"
     25 #include "tensorflow/lite/kernels/activation_functor.h"
     26 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
     27 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
     28 #include "tensorflow/lite/kernels/kernel_util.h"
     29 #include "tensorflow/lite/kernels/lstm_eval.h"
     30 #include "tensorflow/lite/kernels/op_macros.h"
     31 
     32 namespace tflite {
     33 namespace ops {
     34 namespace builtin {
     35 namespace unidirectional_sequence_lstm {
     36 
     37 // Input Tensors of size {max_time, n_batch, n_input}
     38 constexpr int kInputTensor = 0;
     39 
     40 // Input weight tensors of size: {n_cell, n_input}
     41 constexpr int kInputToInputWeightsTensor = 1;  // Optional
     42 constexpr int kInputToForgetWeightsTensor = 2;
     43 constexpr int kInputToCellWeightsTensor = 3;
     44 constexpr int kInputToOutputWeightsTensor = 4;
     45 
     46 // Recurrent weight tensors of size {n_cell, n_output}
     47 constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
     48 constexpr int kRecurrentToForgetWeightsTensor = 6;
     49 constexpr int kRecurrentToCellWeightsTensor = 7;
     50 constexpr int kRecurrentToOutputWeightsTensor = 8;
     51 
     52 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
     53 constexpr int kCellToInputWeightsTensor = 9;    // Optional
     54 constexpr int kCellToForgetWeightsTensor = 10;  // Optional
     55 constexpr int kCellToOutputWeightsTensor = 11;  // Optional
     56 
     57 // Gates bias tensors of size {n_cell}
     58 constexpr int kInputGateBiasTensor = 12;  // Optional
     59 constexpr int kForgetGateBiasTensor = 13;
     60 constexpr int kCellGateBiasTensor = 14;
     61 constexpr int kOutputGateBiasTensor = 15;
     62 
     63 // Projection weight tensor of size {n_output, n_cell}
     64 constexpr int kProjectionWeightsTensor = 16;  // Optional
     65 // Projection bias tensor of size {n_output}
     66 constexpr int kProjectionBiasTensor = 17;  // Optional
     67 
     68 // Stateful input tensors that are variables and will be modified by the Op.
     69 // Activation state tensor of size {n_batch, n_output}
     70 constexpr int kInputActivationStateTensor = 18;
     71 // Cell state tensor of size {n_batch, n_cell}
     72 constexpr int kInputCellStateTensor = 19;
     73 
     74 // Output tensors.
     75 constexpr int kOutputTensor = 0;
     76 
     77 // Temporary tensors
     78 enum TemporaryTensor {
     79   kScratchBuffer = 0,
     80   kInputQuantized = 1,
     81   kOutputStateQuantized = 2,
     82   kCellStateQuantized = 3,
     83   kScalingFactors = 4,
     84   kProductScalingFactors = 5,
     85   kRecoveredCellWeights = 6,
     86   kNumTemporaryTensors = 7
     87 };
     88 
     89 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
     90   auto* scratch_tensor_index = new int();
     91   context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
     92   return scratch_tensor_index;
     93 }
     94 
     95 void Free(TfLiteContext* context, void* buffer) {
     96   delete reinterpret_cast<int*>(buffer);
     97 }
     98 
     99 // Check that input tensor dimensions matches with each other.
    100 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
    101                                         TfLiteNode* node, int n_input,
    102                                         int n_output, int n_cell) {
    103   const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
    104 
    105   // Making sure clipping parameters have valid values.
    106   // == 0 means no clipping
    107   //  > 0 means clipping
    108   TF_LITE_ENSURE(context, params->cell_clip >= 0);
    109   TF_LITE_ENSURE(context, params->proj_clip >= 0);
    110 
    111   const TfLiteTensor* input_to_input_weights =
    112       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
    113   if (input_to_input_weights != nullptr) {
    114     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
    115     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
    116     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
    117   }
    118 
    119   const TfLiteTensor* input_to_forget_weights =
    120       GetInput(context, node, kInputToForgetWeightsTensor);
    121   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
    122   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
    123   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
    124 
    125   const TfLiteTensor* input_to_cell_weights =
    126       GetInput(context, node, kInputToCellWeightsTensor);
    127   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
    128   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
    129   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
    130 
    131   const TfLiteTensor* recurrent_to_input_weights =
    132       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
    133   if (recurrent_to_input_weights != nullptr) {
    134     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
    135     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
    136                       n_cell);
    137     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
    138                       n_output);
    139   }
    140 
    141   const TfLiteTensor* recurrent_to_forget_weights =
    142       GetInput(context, node, kRecurrentToForgetWeightsTensor);
    143   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
    144   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
    145                     n_cell);
    146   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
    147                     n_output);
    148 
    149   const TfLiteTensor* recurrent_to_cell_weights =
    150       GetInput(context, node, kRecurrentToCellWeightsTensor);
    151   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
    152   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
    153   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
    154                     n_output);
    155 
    156   // We make sure the input-gate's parameters are either both present (regular
    157   // LSTM) or not at all (CIFG-LSTM).
    158   const bool cifg_weights_all_or_none =
    159       ((input_to_input_weights != nullptr) &&
    160        (recurrent_to_input_weights != nullptr)) ||
    161       ((input_to_input_weights == nullptr) &&
    162        (recurrent_to_input_weights == nullptr));
    163   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
    164 
    165   const TfLiteTensor* cell_to_input_weights =
    166       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
    167   if (cell_to_input_weights != nullptr) {
    168     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
    169     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
    170   }
    171 
    172   const TfLiteTensor* cell_to_forget_weights =
    173       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
    174   if (cell_to_forget_weights != nullptr) {
    175     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
    176     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
    177   }
    178 
    179   const TfLiteTensor* cell_to_output_weights =
    180       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
    181   if (cell_to_output_weights != nullptr) {
    182     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
    183     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
    184   }
    185 
    186   // Making sure the peephole weights are there all or none.
    187   const bool use_cifg = (input_to_input_weights == nullptr);
    188   const bool peephole_weights_all_or_none =
    189       ((cell_to_input_weights != nullptr || use_cifg) &&
    190        (cell_to_forget_weights != nullptr) &&
    191        (cell_to_output_weights != nullptr)) ||
    192       ((cell_to_input_weights == nullptr) &&
    193        (cell_to_forget_weights == nullptr) &&
    194        (cell_to_output_weights == nullptr));
    195   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
    196 
    197   // Make sure the input gate bias is present only when not a CIFG-LSTM.
    198   const TfLiteTensor* input_gate_bias =
    199       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
    200   if (use_cifg) {
    201     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
    202   } else {
    203     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
    204     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
    205   }
    206 
    207   const TfLiteTensor* forget_gate_bias =
    208       GetInput(context, node, kForgetGateBiasTensor);
    209   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
    210   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
    211 
    212   const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
    213   TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
    214   TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
    215 
    216   const TfLiteTensor* output_gate_bias =
    217       GetInput(context, node, kOutputGateBiasTensor);
    218   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
    219   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
    220 
    221   const TfLiteTensor* projection_weights =
    222       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
    223   if (projection_weights != nullptr) {
    224     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
    225     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
    226     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
    227   }
    228 
    229   const TfLiteTensor* projection_bias =
    230       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
    231   if (projection_bias != nullptr) {
    232     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
    233     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
    234   }
    235 
    236   // Making sure the projection tensors are consistent:
    237   // 1) If projection weight is not present, then projection bias should not be
    238   // present.
    239   // 2) If projection weight is present, then projection bias is optional.
    240   // TODO(ghodrat): make sure this is correct.
    241   const bool projecton_tensors_consistent =
    242       ((projection_weights != nullptr) || (projection_bias == nullptr));
    243   TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
    244 
    245   return kTfLiteOk;
    246 }
    247 
    248 // Resize the output and  state tensors based on the sizes of the input tensors.
    249 // Allocate a temporary scratch tensor. Also check that the sizes of the input
    250 // tensors match each other.
    251 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    252   int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
    253 
    254   // Check we have all the inputs and outputs we need.
    255   TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
    256   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
    257 
    258   // Inferring batch size, number of outputs and sequence length and
    259   // number of cells from the input tensors.
    260   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    261   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
    262   TF_LITE_ENSURE(context, input->dims->size > 1);
    263   const auto* params =
    264       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
    265           node->builtin_data);
    266   const bool time_major = params->time_major;
    267   const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
    268   const int n_input = input->dims->data[2];
    269 
    270   const TfLiteTensor* input_to_output_weights =
    271       GetInput(context, node, kInputToOutputWeightsTensor);
    272   const int n_cell = input_to_output_weights->dims->data[0];
    273   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
    274   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
    275 
    276   const TfLiteTensor* recurrent_to_output_weights =
    277       GetInput(context, node, kRecurrentToOutputWeightsTensor);
    278   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
    279   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
    280                     n_cell);
    281   const int n_output = recurrent_to_output_weights->dims->data[1];
    282 
    283   // Check that input tensor dimensions matches with each other.
    284   TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
    285                                                         n_output, n_cell));
    286 
    287   // Get the pointer to output, activation_state and cell_state buffer tensors.
    288   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    289 
    290   TfLiteTensor* activation_state =
    291       GetVariableInput(context, node, kInputActivationStateTensor);
    292   TfLiteTensor* cell_state =
    293       GetVariableInput(context, node, kInputCellStateTensor);
    294 
    295   // Check the shape of input state tensors.
    296   // These tensor may be 1D or 2D. It's fine as long as the total size is
    297   // correct.
    298   TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
    299   TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
    300 
    301   // Resize the output tensors.
    302   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
    303   output_size->data[input->dims->size - 1] = n_output;
    304   TF_LITE_ENSURE_OK(context,
    305                     context->ResizeTensor(context, output, output_size));
    306 
    307   // The weights are of consistent type, so it suffices to check one.
    308   // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
    309   const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 ||
    310                               input_to_output_weights->type == kTfLiteInt8) &&
    311                              input->type == kTfLiteFloat32);
    312 
    313   TfLiteIntArrayFree(node->temporaries);
    314   if (is_hybrid_op) {
    315     node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
    316   } else {
    317     node->temporaries = TfLiteIntArrayCreate(1);
    318   }
    319   node->temporaries->data[0] = *scratch_tensor_index;
    320 
    321   // Create a scratch buffer tensor.
    322   TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
    323   scratch_buffer->type = input->type;
    324   scratch_buffer->allocation_type = kTfLiteArenaRw;
    325 
    326   const TfLiteTensor* input_to_input_weights =
    327       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
    328   const bool use_cifg = (input_to_input_weights == nullptr);
    329   TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
    330   scratch_buffer_size->data[0] = n_batch;
    331   if (use_cifg) {
    332     // Reserving space for Cell, Forget, Output gates
    333     scratch_buffer_size->data[1] = n_cell * 3;
    334   } else {
    335     // Reserving space for Input, Cell, Forget, Output gates
    336     scratch_buffer_size->data[1] = n_cell * 4;
    337   }
    338   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
    339                                                    scratch_buffer_size));
    340 
    341   if (is_hybrid_op) {
    342     // Allocate temporary tensors to store quantized values of input,
    343     // activation_state and cell_state tensors.
    344     node->temporaries->data[kInputQuantized] =
    345         *scratch_tensor_index + kInputQuantized;
    346     TfLiteTensor* input_quantized =
    347         GetTemporary(context, node, kInputQuantized);
    348     input_quantized->type = input_to_output_weights->type;
    349     input_quantized->allocation_type = kTfLiteArenaRw;
    350     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
    351       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
    352       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
    353                                                        input_quantized_size));
    354     }
    355     node->temporaries->data[kOutputStateQuantized] =
    356         *scratch_tensor_index + kOutputStateQuantized;
    357     TfLiteTensor* activation_state_quantized =
    358         GetTemporary(context, node, kOutputStateQuantized);
    359     activation_state_quantized->type = input_to_output_weights->type;
    360     activation_state_quantized->allocation_type = kTfLiteArenaRw;
    361     if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
    362                              activation_state->dims)) {
    363       TfLiteIntArray* activation_state_quantized_size =
    364           TfLiteIntArrayCopy(activation_state->dims);
    365       TF_LITE_ENSURE_OK(
    366           context, context->ResizeTensor(context, activation_state_quantized,
    367                                          activation_state_quantized_size));
    368     }
    369     node->temporaries->data[kCellStateQuantized] =
    370         *scratch_tensor_index + kCellStateQuantized;
    371     TfLiteTensor* cell_state_quantized =
    372         GetTemporary(context, node, kCellStateQuantized);
    373     cell_state_quantized->type = input_to_output_weights->type;
    374     cell_state_quantized->allocation_type = kTfLiteArenaRw;
    375     if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
    376       TfLiteIntArray* cell_state_quantized_size =
    377           TfLiteIntArrayCopy(cell_state->dims);
    378       TF_LITE_ENSURE_OK(context,
    379                         context->ResizeTensor(context, cell_state_quantized,
    380                                               cell_state_quantized_size));
    381     }
    382 
    383     // Allocate temporary tensors to store scaling factors and product scaling
    384     // factors. The latter is a convenience storage which allows to quantize
    385     // a vector once (which produces the scaling factors) and multiply it with
    386     // different matrices (which requires multiplying the scaling factors with
    387     // the scaling factor of the matrix).
    388     node->temporaries->data[kScalingFactors] =
    389         *scratch_tensor_index + kScalingFactors;
    390     TfLiteTensor* scaling_factors =
    391         GetTemporary(context, node, kScalingFactors);
    392     scaling_factors->type = kTfLiteFloat32;
    393     scaling_factors->allocation_type = kTfLiteArenaRw;
    394     int scaling_dims[1] = {n_batch};
    395     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
    396       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
    397       scaling_factors_size->data[0] = n_batch;
    398       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
    399                                                        scaling_factors_size));
    400     }
    401     node->temporaries->data[kProductScalingFactors] =
    402         *scratch_tensor_index + kProductScalingFactors;
    403     TfLiteTensor* prod_scaling_factors =
    404         GetTemporary(context, node, kProductScalingFactors);
    405     prod_scaling_factors->type = kTfLiteFloat32;
    406     prod_scaling_factors->allocation_type = kTfLiteArenaRw;
    407     if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
    408                                    scaling_dims)) {
    409       TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
    410       prod_scaling_factors_size->data[0] = n_batch;
    411       TF_LITE_ENSURE_OK(context,
    412                         context->ResizeTensor(context, prod_scaling_factors,
    413                                               prod_scaling_factors_size));
    414     }
    415 
    416     // Allocate a temporary tensor to store the recovered cell weights. Since
    417     // this is used for diagonal matrices, only need to store n_cell values.
    418     node->temporaries->data[kRecoveredCellWeights] =
    419         *scratch_tensor_index + kRecoveredCellWeights;
    420     TfLiteTensor* recovered_cell_weights =
    421         GetTemporary(context, node, kRecoveredCellWeights);
    422     recovered_cell_weights->type = kTfLiteFloat32;
    423     recovered_cell_weights->allocation_type = kTfLiteArenaRw;
    424     int recovered_cell_dims[1] = {n_cell};
    425     if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
    426                                    recovered_cell_dims)) {
    427       TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
    428       recovered_cell_weights_size->data[0] = n_cell;
    429       TF_LITE_ENSURE_OK(context,
    430                         context->ResizeTensor(context, recovered_cell_weights,
    431                                               recovered_cell_weights_size));
    432     }
    433   }
    434   return kTfLiteOk;
    435 }
    436 
    437 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    438   const auto* params =
    439       reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
    440           node->builtin_data);
    441   const bool time_major = params->time_major;
    442   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    443 
    444   const TfLiteTensor* input_to_input_weights =
    445       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
    446   const TfLiteTensor* input_to_forget_weights =
    447       GetInput(context, node, kInputToForgetWeightsTensor);
    448   const TfLiteTensor* input_to_cell_weights =
    449       GetInput(context, node, kInputToCellWeightsTensor);
    450   const TfLiteTensor* input_to_output_weights =
    451       GetInput(context, node, kInputToOutputWeightsTensor);
    452 
    453   const TfLiteTensor* recurrent_to_input_weights =
    454       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
    455   const TfLiteTensor* recurrent_to_forget_weights =
    456       GetInput(context, node, kRecurrentToForgetWeightsTensor);
    457   const TfLiteTensor* recurrent_to_cell_weights =
    458       GetInput(context, node, kRecurrentToCellWeightsTensor);
    459   const TfLiteTensor* recurrent_to_output_weights =
    460       GetInput(context, node, kRecurrentToOutputWeightsTensor);
    461 
    462   const TfLiteTensor* cell_to_input_weights =
    463       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
    464   const TfLiteTensor* cell_to_forget_weights =
    465       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
    466   const TfLiteTensor* cell_to_output_weights =
    467       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
    468 
    469   const TfLiteTensor* input_gate_bias =
    470       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
    471   const TfLiteTensor* forget_gate_bias =
    472       GetInput(context, node, kForgetGateBiasTensor);
    473   const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
    474   const TfLiteTensor* output_gate_bias =
    475       GetInput(context, node, kOutputGateBiasTensor);
    476 
    477   const TfLiteTensor* projection_weights =
    478       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
    479   const TfLiteTensor* projection_bias =
    480       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
    481 
    482   // Index the scratch buffers pointers to the global scratch buffer.
    483   TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
    484 
    485   TfLiteTensor* activation_state =
    486       GetVariableInput(context, node, kInputActivationStateTensor);
    487   TfLiteTensor* cell_state =
    488       GetVariableInput(context, node, kInputCellStateTensor);
    489 
    490   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    491 
    492   // Copy out the LSTM specific params so they can be passed in the function.
    493   TfLiteLSTMParams lstm_params;
    494   lstm_params.activation = params->activation;
    495   lstm_params.cell_clip = params->cell_clip;
    496   lstm_params.proj_clip = params->proj_clip;
    497 
    498   switch (input_to_output_weights->type) {
    499     case kTfLiteFloat32: {
    500       return lstm_eval::EvalFloat(
    501           input, input_to_input_weights, input_to_forget_weights,
    502           input_to_cell_weights, input_to_output_weights,
    503           recurrent_to_input_weights, recurrent_to_forget_weights,
    504           recurrent_to_cell_weights, recurrent_to_output_weights,
    505           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
    506           /*input_layer_norm_coefficients=*/nullptr,
    507           /*forget_layer_norm_coefficients=*/nullptr,
    508           /*cell_layer_norm_coefficients=*/nullptr,
    509           /*output_layer_norm_coefficients=*/nullptr,
    510           /*aux_input=*/nullptr,
    511           /*aux_input_to_input_weights=*/nullptr,
    512           /*aux_input_to_forget_weights=*/nullptr,
    513           /*aux_input_to_cell_weights=*/nullptr,
    514           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
    515           forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
    516           projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
    517           /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
    518           output);
    519     }
    520     case kTfLiteUInt8:
    521     case kTfLiteInt8: {
    522       TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
    523       TfLiteTensor* activation_state_quantized =
    524           GetTemporary(context, node, /*index=*/2);
    525       TfLiteTensor* cell_state_quantized =
    526           GetTemporary(context, node, /*index=*/3);
    527       TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
    528       TfLiteTensor* prod_scaling_factors =
    529           GetTemporary(context, node, /*index=*/5);
    530       TfLiteTensor* recovered_cell_weights =
    531           GetTemporary(context, node, /*index=*/6);
    532       return lstm_eval::EvalHybrid(
    533           input, input_to_input_weights, input_to_forget_weights,
    534           input_to_cell_weights, input_to_output_weights,
    535           recurrent_to_input_weights, recurrent_to_forget_weights,
    536           recurrent_to_cell_weights, recurrent_to_output_weights,
    537           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
    538           /*input_layer_norm_coefficients=*/nullptr,
    539           /*forget_layer_norm_coefficients=*/nullptr,
    540           /*cell_layer_norm_coefficients=*/nullptr,
    541           /*output_layer_norm_coefficients=*/nullptr,
    542           /*aux_input=*/nullptr,
    543           /*aux_input_to_input_weights=*/nullptr,
    544           /*aux_input_to_forget_weights=*/nullptr,
    545           /*aux_input_to_cell_weights=*/nullptr,
    546           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
    547           forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
    548           projection_bias, &lstm_params, /*forward_sequence=*/true, time_major,
    549           /*output_offset=*/0, scratch_buffer, scaling_factors,
    550           prod_scaling_factors, recovered_cell_weights, input_quantized,
    551           /*aux_input_quantized=*/nullptr, activation_state_quantized,
    552           cell_state_quantized, activation_state, cell_state, output);
    553     }
    554     default:
    555       context->ReportError(context, "Type %d is not currently supported.",
    556                            input_to_output_weights->type);
    557       return kTfLiteError;
    558   }
    559   return kTfLiteOk;
    560 }
    561 }  // namespace unidirectional_sequence_lstm
    562 
    563 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
    564   static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
    565                                  unidirectional_sequence_lstm::Free,
    566                                  unidirectional_sequence_lstm::Prepare,
    567                                  unidirectional_sequence_lstm::Eval};
    568   return &r;
    569 }
    570 
    571 }  // namespace builtin
    572 }  // namespace ops
    573 }  // namespace tflite
    574