Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cassert>
     17 #include <cmath>
     18 #include <cstdio>
     19 #include <cstdlib>
     20 #include <iostream>
     21 #include <limits>
     22 
     23 #include "tensorflow/lite/c/builtin_op_data.h"
     24 #include "tensorflow/lite/c/c_api_internal.h"
     25 #include "tensorflow/lite/kernels/activation_functor.h"
     26 #include "tensorflow/lite/kernels/gemm_support.h"
     27 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
     28 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
     29 #include "tensorflow/lite/kernels/internal/tensor.h"
     30 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
     31 #include "tensorflow/lite/kernels/kernel_util.h"
     32 #include "tensorflow/lite/kernels/lstm_eval.h"
     33 #include "tensorflow/lite/kernels/op_macros.h"
     34 
     35 namespace tflite {
     36 namespace ops {
     37 namespace builtin {
     38 namespace lstm {
     39 
     40 struct OpData {
     41   // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
     42   // inputs).
     43   // Please note the 20-input full kernel is deprecated and only kept
     44   // here for backward compatibility.
     45   TfLiteLSTMKernelType kernel_type;
     46 
     47   // If the lstm is layer norm.
     48   bool is_layer_norm_lstm;
     49 
     50   // These fields are only used by full kernel.
     51   int activation_state_tensor_index;
     52   int cell_state_tensor_index;
     53   int scratch_tensor_index;
     54 };
     55 
     56 // For full inputs kernel (24-inputs).
     57 // Please note the 20-input full kernel is deprecated and only kept
     58 // here for backward compatibility.
     59 namespace full {
     60 
     61 // Input Tensors of size {n_batch, n_input}
     62 constexpr int kInputTensor = 0;
     63 
     64 // Input weight tensors of size: {n_cell, n_input}
     65 constexpr int kInputToInputWeightsTensor = 1;  // Optional
     66 constexpr int kInputToForgetWeightsTensor = 2;
     67 constexpr int kInputToCellWeightsTensor = 3;
     68 constexpr int kInputToOutputWeightsTensor = 4;
     69 
     70 // Recurrent weight tensors of size {n_cell, n_output}
     71 constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
     72 constexpr int kRecurrentToForgetWeightsTensor = 6;
     73 constexpr int kRecurrentToCellWeightsTensor = 7;
     74 constexpr int kRecurrentToOutputWeightsTensor = 8;
     75 
     76 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
     77 constexpr int kCellToInputWeightsTensor = 9;    // Optional
     78 constexpr int kCellToForgetWeightsTensor = 10;  // Optional
     79 constexpr int kCellToOutputWeightsTensor = 11;  // Optional
     80 
     81 // Gates bias tensors of size {n_cell}
     82 constexpr int kInputGateBiasTensor = 12;  // Optional
     83 constexpr int kForgetGateBiasTensor = 13;
     84 constexpr int kCellGateBiasTensor = 14;
     85 constexpr int kOutputGateBiasTensor = 15;
     86 
     87 // Projection weight tensor of size {n_output, n_cell}
     88 constexpr int kProjectionWeightsTensor = 16;  // Optional
     89 // Projection bias tensor of size {n_output}
     90 constexpr int kProjectionBiasTensor = 17;  // Optional
     91 
     92 // These state tensors are defined as variable tensors, and will be modified by
     93 // this op.
     94 constexpr int kInputActivationStateTensor = 18;
     95 constexpr int kInputCellStateTensor = 19;
     96 
     97 // Layer norm coefficient tensors of size {n_cell}, representing a diagonal
     98 // matrix.
     99 constexpr int kInputLayerNormCoefficientsTensor = 20;   // Optional
    100 constexpr int kForgetLayerNormCoefficientsTensor = 21;  // Optional
    101 constexpr int kCellLayerNormCoefficientsTensor = 22;    // Optional
    102 constexpr int kOutputLayerNormCoefficientsTensor = 23;  // Optional
    103 
    104 // Output tensors.
    105 constexpr int kOutputTensor = 0;
    106 
    107 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
    108   auto* op_data = new OpData();
    109   op_data->kernel_type = kTfLiteLSTMFullKernel;
    110   context->AddTensors(context, /*tensors_to_add=*/7,
    111                       &op_data->scratch_tensor_index);
    112   return op_data;
    113 }
    114 
    115 // Check that input tensor dimensions matches with each other.
    116 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
    117                                         TfLiteNode* node, int n_input,
    118                                         int n_output, int n_cell,
    119                                         bool is_layer_norm_lstm) {
    120   const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
    121 
    122   // Making sure clipping parameters have valid values.
    123   // == 0 means no clipping
    124   //  > 0 means clipping
    125   TF_LITE_ENSURE(context, params->cell_clip >= 0);
    126   TF_LITE_ENSURE(context, params->proj_clip >= 0);
    127 
    128   const TfLiteTensor* input_to_input_weights =
    129       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
    130   const bool use_cifg = (input_to_input_weights == nullptr);
    131   if (!use_cifg) {
    132     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
    133     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
    134     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
    135   }
    136 
    137   const TfLiteTensor* input_to_forget_weights =
    138       GetInput(context, node, kInputToForgetWeightsTensor);
    139   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
    140   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
    141   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
    142 
    143   const TfLiteTensor* input_to_cell_weights =
    144       GetInput(context, node, kInputToCellWeightsTensor);
    145   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
    146   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
    147   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
    148 
    149   const TfLiteTensor* recurrent_to_input_weights =
    150       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
    151   if (recurrent_to_input_weights != nullptr) {
    152     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
    153     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
    154                       n_cell);
    155     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
    156                       n_output);
    157   }
    158 
    159   const TfLiteTensor* recurrent_to_forget_weights =
    160       GetInput(context, node, kRecurrentToForgetWeightsTensor);
    161   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
    162   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
    163                     n_cell);
    164   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
    165                     n_output);
    166 
    167   const TfLiteTensor* recurrent_to_cell_weights =
    168       GetInput(context, node, kRecurrentToCellWeightsTensor);
    169   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
    170   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
    171   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
    172                     n_output);
    173 
    174   // We make sure the input-gate's parameters are either both present (regular
    175   // LSTM) or not at all (CIFG-LSTM).
    176   const bool cifg_weights_all_or_none =
    177       ((input_to_input_weights != nullptr) &&
    178        (recurrent_to_input_weights != nullptr)) ||
    179       ((input_to_input_weights == nullptr) &&
    180        (recurrent_to_input_weights == nullptr));
    181   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
    182 
    183   const TfLiteTensor* cell_to_input_weights =
    184       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
    185   if (cell_to_input_weights) {
    186     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
    187     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
    188   }
    189 
    190   const TfLiteTensor* cell_to_forget_weights =
    191       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
    192   if (cell_to_forget_weights) {
    193     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
    194     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
    195   }
    196 
    197   const TfLiteTensor* cell_to_output_weights =
    198       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
    199   if (cell_to_output_weights) {
    200     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
    201     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
    202   }
    203 
    204   // Making sure the peephole weights are there all or none.
    205   const bool peephole_weights_all_or_none =
    206       ((cell_to_input_weights != nullptr || use_cifg) &&
    207        (cell_to_forget_weights != nullptr) &&
    208        (cell_to_output_weights != nullptr)) ||
    209       ((cell_to_input_weights == nullptr) &&
    210        (cell_to_forget_weights == nullptr) &&
    211        (cell_to_output_weights == nullptr));
    212   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
    213 
    214   // Make sure the input gate bias is present only when not a CIFG-LSTM.
    215   const TfLiteTensor* input_gate_bias =
    216       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
    217   if (use_cifg) {
    218     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
    219   } else {
    220     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
    221     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
    222   }
    223 
    224   const TfLiteTensor* forget_gate_bias =
    225       GetInput(context, node, kForgetGateBiasTensor);
    226   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
    227   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
    228 
    229   const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
    230   TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
    231   TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
    232 
    233   const TfLiteTensor* output_gate_bias =
    234       GetInput(context, node, kOutputGateBiasTensor);
    235   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
    236   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
    237 
    238   const TfLiteTensor* projection_weights =
    239       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
    240   if (projection_weights != nullptr) {
    241     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
    242     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
    243     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
    244   }
    245 
    246   const TfLiteTensor* projection_bias =
    247       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
    248   if (projection_bias != nullptr) {
    249     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
    250     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
    251   }
    252 
    253   // Making sure the projection tensors are consistent:
    254   // 1) If projection weight is not present, then projection bias should not be
    255   // present.
    256   // 2) If projection weight is present, then projection bias is optional.
    257   // TODO(ghodrat): make sure this is correct.
    258   const bool projection_tensors_consistent =
    259       ((projection_weights != nullptr) || (projection_bias == nullptr));
    260   TF_LITE_ENSURE(context, projection_tensors_consistent == true);
    261 
    262   if (is_layer_norm_lstm) {
    263     const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
    264         context, node, kInputLayerNormCoefficientsTensor);
    265     if (use_cifg) {
    266       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
    267     } else {
    268       TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
    269       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
    270       TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
    271                         n_cell);
    272     }
    273 
    274     const TfLiteTensor* forget_layer_norm_coefficients =
    275         GetInput(context, node, kForgetLayerNormCoefficientsTensor);
    276     TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
    277     TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
    278     TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
    279                       n_cell);
    280 
    281     const TfLiteTensor* cell_layer_norm_coefficients =
    282         GetInput(context, node, kCellLayerNormCoefficientsTensor);
    283     TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
    284     TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
    285     TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
    286                       n_cell);
    287 
    288     const TfLiteTensor* output_layer_norm_coefficients =
    289         GetInput(context, node, kOutputLayerNormCoefficientsTensor);
    290     TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
    291     TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
    292     TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
    293                       n_cell);
    294   }
    295 
    296   return kTfLiteOk;
    297 }
    298 
    299 // Resize the output, state tensors based on the sizes of the input tensors.
    300 // Allocate a temporary scratch tensor. Also check that the sizes of the input
    301 // tensors match each other.
    302 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    303   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
    304 
    305   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
    306   // Logic for determining regular lstm and layer norm lstm:
    307   // input_size, forget_gate_layer_norm_tensor (20) null? is_layer_norm?
    308   // 20,         N/A,                                     No.
    309   // 24,         null,                                    No.
    310   // 24,         not null,                                Yes.
    311   // 20-inputs lstm are deprecated and is only kept here for backward
    312   // compatibility.
    313   if (node->inputs->size == 24) {
    314     const TfLiteTensor* forget_layer_norm_coefficients =
    315         GetInput(context, node, kForgetLayerNormCoefficientsTensor);
    316     if (forget_layer_norm_coefficients == nullptr) {
    317       op_data->is_layer_norm_lstm = false;
    318     } else {
    319       op_data->is_layer_norm_lstm = true;
    320     }
    321   } else if (node->inputs->size == 20) {
    322     // This is deprecated and is only kept here for backward compatibility.
    323     op_data->is_layer_norm_lstm = false;
    324   } else {
    325     context->ReportError(
    326         context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
    327         node->inputs->size);
    328     return kTfLiteError;
    329   }
    330 
    331   const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
    332   op_data->activation_state_tensor_index =
    333       node->inputs->data[kInputActivationStateTensor];
    334   op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor];
    335 
    336   // Inferring batch size, number of outputs and number of cells from the
    337   // input tensors.
    338   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    339   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
    340   TF_LITE_ENSURE(context, input->dims->size > 1);
    341   const int n_batch = input->dims->data[0];
    342   const int n_input = input->dims->data[1];
    343 
    344   const TfLiteTensor* input_to_output_weights =
    345       GetInput(context, node, kInputToOutputWeightsTensor);
    346   const int n_cell = input_to_output_weights->dims->data[0];
    347   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
    348   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
    349 
    350   const TfLiteTensor* recurrent_to_output_weights =
    351       GetInput(context, node, kRecurrentToOutputWeightsTensor);
    352   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
    353   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
    354                     n_cell);
    355   const int n_output = recurrent_to_output_weights->dims->data[1];
    356 
    357   // Check that input tensor dimensions matches with each other.
    358   TF_LITE_ENSURE_OK(context,
    359                     CheckInputTensorDimensions(context, node, n_input, n_output,
    360                                                n_cell, is_layer_norm_lstm));
    361 
    362   // Get the pointer to output, activation_state and cell_state tensors.
    363   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    364 
    365   TfLiteTensor* activation_state =
    366       &context->tensors[op_data->activation_state_tensor_index];
    367   TfLiteTensor* cell_state =
    368       &context->tensors[op_data->cell_state_tensor_index];
    369 
    370   // Check the shape of input state tensors.
    371   // These tensor may be 1D or 2D. It's fine as long as the total size is
    372   // correct.
    373   TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
    374   TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
    375 
    376   // Resize the output tensors.
    377   TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
    378   output_size->data[0] = n_batch;
    379   output_size->data[1] = n_output;
    380   TF_LITE_ENSURE_OK(context,
    381                     context->ResizeTensor(context, output, output_size));
    382 
    383   // The weights are of consistent type, so it suffices to check one.
    384   // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
    385   const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 ||
    386                               input_to_output_weights->type == kTfLiteInt8) &&
    387                              input->type == kTfLiteFloat32);
    388 
    389   TfLiteIntArrayFree(node->temporaries);
    390   if (is_hybrid_op) {
    391     node->temporaries = TfLiteIntArrayCreate(7);
    392   } else {
    393     node->temporaries = TfLiteIntArrayCreate(1);
    394   }
    395   node->temporaries->data[0] = op_data->scratch_tensor_index;
    396 
    397   // Create a scratch buffer tensor.
    398   TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
    399   scratch_buffer->type = input->type;
    400   scratch_buffer->allocation_type = kTfLiteArenaRw;
    401 
    402   const TfLiteTensor* input_to_input_weights =
    403       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
    404   const bool use_cifg = (input_to_input_weights == nullptr);
    405   TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
    406   scratch_buffer_size->data[0] = n_batch;
    407   if (use_cifg) {
    408     // Reserving space for Cell, Forget, Output gates
    409     scratch_buffer_size->data[1] = n_cell * 3;
    410   } else {
    411     // Reserving space for Input, Cell, Forget, Output gates
    412     scratch_buffer_size->data[1] = n_cell * 4;
    413   }
    414   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
    415                                                    scratch_buffer_size));
    416 
    417   if (is_hybrid_op) {
    418     // Allocate temporary tensors to store quantized values of input,
    419     // activation_state and cell_state tensors.
    420     node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
    421     TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
    422     input_quantized->type = input_to_output_weights->type;
    423     input_quantized->allocation_type = kTfLiteArenaRw;
    424     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
    425       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
    426       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
    427                                                        input_quantized_size));
    428     }
    429     node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
    430     TfLiteTensor* activation_state_quantized =
    431         GetTemporary(context, node, /*index=*/2);
    432     activation_state_quantized->type = input_to_output_weights->type;
    433     activation_state_quantized->allocation_type = kTfLiteArenaRw;
    434     if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
    435                              activation_state->dims)) {
    436       TfLiteIntArray* activation_state_quantized_size =
    437           TfLiteIntArrayCopy(activation_state->dims);
    438       TF_LITE_ENSURE_OK(
    439           context, context->ResizeTensor(context, activation_state_quantized,
    440                                          activation_state_quantized_size));
    441     }
    442     node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
    443     TfLiteTensor* cell_state_quantized =
    444         GetTemporary(context, node, /*index=*/3);
    445     cell_state_quantized->type = input_to_output_weights->type;
    446     cell_state_quantized->allocation_type = kTfLiteArenaRw;
    447     if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
    448       TfLiteIntArray* cell_state_quantized_size =
    449           TfLiteIntArrayCopy(cell_state->dims);
    450       TF_LITE_ENSURE_OK(context,
    451                         context->ResizeTensor(context, cell_state_quantized,
    452                                               cell_state_quantized_size));
    453     }
    454 
    455     // Allocate temporary tensors to store scaling factors and product scaling
    456     // factors. The latter is a convenience storage which allows to quantize
    457     // a vector once (which produces the scaling factors) and multiply it with
    458     // different matrices (which requires multiplying the scaling factors with
    459     // the scaling factor of the matrix).
    460     node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
    461     TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
    462     scaling_factors->type = kTfLiteFloat32;
    463     scaling_factors->allocation_type = kTfLiteArenaRw;
    464     int scaling_dims[1] = {n_batch};
    465     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
    466       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
    467       scaling_factors_size->data[0] = n_batch;
    468       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
    469                                                        scaling_factors_size));
    470     }
    471     node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
    472     TfLiteTensor* prod_scaling_factors =
    473         GetTemporary(context, node, /*index=*/5);
    474     prod_scaling_factors->type = kTfLiteFloat32;
    475     prod_scaling_factors->allocation_type = kTfLiteArenaRw;
    476     if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
    477                                    scaling_dims)) {
    478       TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
    479       prod_scaling_factors_size->data[0] = n_batch;
    480       TF_LITE_ENSURE_OK(context,
    481                         context->ResizeTensor(context, prod_scaling_factors,
    482                                               prod_scaling_factors_size));
    483     }
    484 
    485     // Allocate a temporary tensor to store the recovered cell weights. Since
    486     // this is used for diagonal matrices, only need to store n_cell values.
    487     node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
    488     TfLiteTensor* recovered_cell_weights =
    489         GetTemporary(context, node, /*index=*/6);
    490     recovered_cell_weights->type = kTfLiteFloat32;
    491     recovered_cell_weights->allocation_type = kTfLiteArenaRw;
    492     int recovered_cell_dims[1] = {n_cell};
    493     if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
    494                                    recovered_cell_dims)) {
    495       TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
    496       recovered_cell_weights_size->data[0] = n_cell;
    497       TF_LITE_ENSURE_OK(context,
    498                         context->ResizeTensor(context, recovered_cell_weights,
    499                                               recovered_cell_weights_size));
    500     }
    501   }
    502   return kTfLiteOk;
    503 }
    504 
    505 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    506   const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
    507   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
    508   const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
    509 
    510   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    511 
    512   const TfLiteTensor* input_to_input_weights =
    513       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
    514   const TfLiteTensor* input_to_forget_weights =
    515       GetInput(context, node, kInputToForgetWeightsTensor);
    516   const TfLiteTensor* input_to_cell_weights =
    517       GetInput(context, node, kInputToCellWeightsTensor);
    518   const TfLiteTensor* input_to_output_weights =
    519       GetInput(context, node, kInputToOutputWeightsTensor);
    520 
    521   const TfLiteTensor* recurrent_to_input_weights =
    522       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
    523   const TfLiteTensor* recurrent_to_forget_weights =
    524       GetInput(context, node, kRecurrentToForgetWeightsTensor);
    525   const TfLiteTensor* recurrent_to_cell_weights =
    526       GetInput(context, node, kRecurrentToCellWeightsTensor);
    527   const TfLiteTensor* recurrent_to_output_weights =
    528       GetInput(context, node, kRecurrentToOutputWeightsTensor);
    529 
    530   const TfLiteTensor* cell_to_input_weights =
    531       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
    532   const TfLiteTensor* cell_to_forget_weights =
    533       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
    534   const TfLiteTensor* cell_to_output_weights =
    535       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
    536 
    537   const TfLiteTensor* input_layer_norm_coefficients =
    538       is_layer_norm_lstm ? GetOptionalInputTensor(
    539                                context, node, kInputLayerNormCoefficientsTensor)
    540                          : nullptr;
    541   const TfLiteTensor* forget_layer_norm_coefficients =
    542       is_layer_norm_lstm
    543           ? GetInput(context, node, kForgetLayerNormCoefficientsTensor)
    544           : nullptr;
    545   const TfLiteTensor* cell_layer_norm_coefficients =
    546       is_layer_norm_lstm
    547           ? GetInput(context, node, kCellLayerNormCoefficientsTensor)
    548           : nullptr;
    549   const TfLiteTensor* output_layer_norm_coefficients =
    550       is_layer_norm_lstm
    551           ? GetInput(context, node, kOutputLayerNormCoefficientsTensor)
    552           : nullptr;
    553 
    554   const TfLiteTensor* input_gate_bias =
    555       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
    556   const TfLiteTensor* forget_gate_bias =
    557       GetInput(context, node, kForgetGateBiasTensor);
    558   const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
    559   const TfLiteTensor* output_gate_bias =
    560       GetInput(context, node, kOutputGateBiasTensor);
    561 
    562   const TfLiteTensor* projection_weights =
    563       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
    564   const TfLiteTensor* projection_bias =
    565       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
    566 
    567   // Index the scratch buffers pointers to the global scratch buffer.
    568   TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
    569 
    570   TfLiteTensor* activation_state =
    571       &context->tensors[op_data->activation_state_tensor_index];
    572   TfLiteTensor* cell_state =
    573       &context->tensors[op_data->cell_state_tensor_index];
    574 
    575   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    576 
    577   // TODO(mirkov): add a check that weights are all uint8s or all floats.
    578   switch (input_to_output_weights->type) {
    579     case kTfLiteFloat32: {
    580       return lstm_eval::EvalFloat(
    581           input, input_to_input_weights, input_to_forget_weights,
    582           input_to_cell_weights, input_to_output_weights,
    583           recurrent_to_input_weights, recurrent_to_forget_weights,
    584           recurrent_to_cell_weights, recurrent_to_output_weights,
    585           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
    586           input_layer_norm_coefficients, forget_layer_norm_coefficients,
    587           cell_layer_norm_coefficients, output_layer_norm_coefficients,
    588           /*aux_input=*/nullptr,
    589           /*aux_input_to_input_weights=*/nullptr,
    590           /*aux_input_to_forget_weights=*/nullptr,
    591           /*aux_input_to_cell_weights=*/nullptr,
    592           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
    593           forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
    594           projection_bias, params, /*forward_sequence=*/true,
    595           /*time_major=*/true,
    596           /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
    597           output);
    598     }
    599     case kTfLiteUInt8:
    600     case kTfLiteInt8: {
    601       TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
    602       TfLiteTensor* activation_state_quantized =
    603           GetTemporary(context, node, /*index=*/2);
    604       TfLiteTensor* cell_state_quantized =
    605           GetTemporary(context, node, /*index=*/3);
    606       TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
    607       TfLiteTensor* prod_scaling_factors =
    608           GetTemporary(context, node, /*index=*/5);
    609       TfLiteTensor* recovered_cell_weights =
    610           GetTemporary(context, node, /*index=*/6);
    611       return lstm_eval::EvalHybrid(
    612           input, input_to_input_weights, input_to_forget_weights,
    613           input_to_cell_weights, input_to_output_weights,
    614           recurrent_to_input_weights, recurrent_to_forget_weights,
    615           recurrent_to_cell_weights, recurrent_to_output_weights,
    616           cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
    617           input_layer_norm_coefficients, forget_layer_norm_coefficients,
    618           cell_layer_norm_coefficients, output_layer_norm_coefficients,
    619           /*aux_input=*/nullptr,
    620           /*aux_input_to_input_weights=*/nullptr,
    621           /*aux_input_to_forget_weights=*/nullptr,
    622           /*aux_input_to_cell_weights=*/nullptr,
    623           /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
    624           forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
    625           projection_bias, params, /*forward_sequence=*/true,
    626           /*time_major=*/true, /*output_offset=*/0, scratch_buffer,
    627           scaling_factors, prod_scaling_factors, recovered_cell_weights,
    628           input_quantized,
    629           /*aux_input_quantized=*/nullptr, activation_state_quantized,
    630           cell_state_quantized, activation_state, cell_state, output);
    631     }
    632     default:
    633       context->ReportError(context, "Type %d is not currently supported.",
    634                            input_to_output_weights->type);
    635       return kTfLiteError;
    636   }
    637   return kTfLiteOk;
    638 }
    639 
    640 }  // namespace full
    641 
    642 // For basic kernel (5-inputs).
    643 namespace basic {
    644 
    645 enum InputTensor {
    646   kInputData = 0,
    647   kInputPrevActivation = 1,
    648   kInputWeights = 2,
    649   kInputBiases = 3,
    650   kInputPrevState = 4,
    651   kInputNum = 5,
    652 };
    653 
    654 enum OutputTensor {
    655   kOutputActivation = 0,
    656   kOutputState = 1,
    657   kOutputConcatTemp = 2,
    658   kOutputActivationTemp = 3,
    659   kOutputNum = 4,
    660 };
    661 
    662 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
    663   auto* op_data = new OpData();
    664   op_data->kernel_type = kTfLiteLSTMBasicKernel;
    665   // `scratch_tensor_index` is unused in this kernel.
    666   op_data->scratch_tensor_index = -1;
    667   return op_data;
    668 }
    669 
    670 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    671   TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
    672   TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
    673 
    674   const TfLiteTensor* input = GetInput(context, node, kInputData);
    675   const TfLiteTensor* prev_activation =
    676       GetInput(context, node, kInputPrevActivation);
    677   const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
    678   const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
    679   const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
    680 
    681   TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
    682   const int num_batches = input->dims->data[0];
    683   const int input_depth = input->dims->data[1];
    684 
    685   TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2);
    686   TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches);
    687   const int activation_depth = prev_activation->dims->data[1];
    688   const int total_depth = input_depth + activation_depth;
    689 
    690   TF_LITE_ENSURE_EQ(context, weights->dims->size, 2);
    691   TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth);
    692   TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth);
    693 
    694   TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
    695   TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth);
    696 
    697   TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2);
    698   TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
    699   TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth);
    700 
    701   TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
    702   TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
    703   TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
    704   TfLiteTensor* activation_temp =
    705       GetOutput(context, node, kOutputActivationTemp);
    706 
    707   TF_LITE_ENSURE_OK(context, context->ResizeTensor(
    708                                  context, activation_out,
    709                                  TfLiteIntArrayCopy(prev_activation->dims)));
    710   TF_LITE_ENSURE_OK(
    711       context, context->ResizeTensor(context, state_out,
    712                                      TfLiteIntArrayCopy(prev_state->dims)));
    713 
    714   TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2);
    715   concat_temp_size->data[0] = num_batches;
    716   concat_temp_size->data[1] = total_depth;
    717   TF_LITE_ENSURE_OK(
    718       context, context->ResizeTensor(context, concat_temp, concat_temp_size));
    719   TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2);
    720   activation_temp_size->data[0] = num_batches;
    721   activation_temp_size->data[1] = 4 * activation_depth;
    722   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp,
    723                                                    activation_temp_size));
    724 
    725   // Set the state tensors as persistent.
    726   for (auto index : {kInputPrevActivation, kInputPrevState}) {
    727     TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
    728     tensor->allocation_type = kTfLiteArenaRwPersistent;
    729   }
    730   return kTfLiteOk;
    731 }
    732 
    733 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    734   const TfLiteTensor* input = GetInput(context, node, kInputData);
    735   const TfLiteTensor* prev_activation =
    736       GetInput(context, node, kInputPrevActivation);
    737   const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
    738   const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
    739   const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
    740 
    741   TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
    742   TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
    743   TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
    744   TfLiteTensor* activation_temp =
    745       GetOutput(context, node, kOutputActivationTemp);
    746 
    747   if (input->type == kTfLiteFloat32 &&
    748       prev_activation->type == kTfLiteFloat32 &&
    749       weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 &&
    750       prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 &&
    751       activation_out->type == kTfLiteFloat32 &&
    752       concat_temp->type == kTfLiteFloat32 &&
    753       activation_temp->type == kTfLiteFloat32) {
    754     tflite::LstmCellParams op_params;
    755     // Float LSTM cell does not need parameters to be set: leave untouched.
    756     optimized_ops::LstmCell(
    757         op_params,
    758         // Inputs.
    759         GetTensorShape(input), GetTensorData<float>(input),
    760         GetTensorShape(prev_activation), GetTensorData<float>(prev_activation),
    761         GetTensorShape(weights), GetTensorData<float>(weights),
    762         GetTensorShape(bias), GetTensorData<float>(bias),
    763         GetTensorShape(prev_state), GetTensorData<float>(prev_state),
    764         // Outputs.
    765         GetTensorShape(state_out), GetTensorData<float>(state_out),
    766         GetTensorShape(activation_out), GetTensorData<float>(activation_out),
    767         GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
    768         GetTensorShape(activation_temp), GetTensorData<float>(activation_temp));
    769   } else if (input->type == kTfLiteUInt8 &&
    770              prev_activation->type == kTfLiteUInt8 &&
    771              weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
    772              prev_state->type == kTfLiteInt16 &&
    773              state_out->type == kTfLiteInt16 &&
    774              activation_out->type == kTfLiteUInt8 &&
    775              concat_temp->type == kTfLiteUInt8 &&
    776              activation_temp->type == kTfLiteInt16) {
    777     gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
    778     int state_scale_log2_rounded;
    779     if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
    780       context->ReportError(
    781           context,
    782           "The internal state of a LSTM cell must have a power-of-two scale.");
    783       return kTfLiteError;
    784     }
    785     const int state_integer_bits = 15 + state_scale_log2_rounded;
    786     if (state_integer_bits != 4) {
    787       context->ReportError(context,
    788                            "The only case of quantized LstmCell currently "
    789                            "supported is with StateIntegerBits==4");
    790       return kTfLiteError;
    791     }
    792 
    793     double real_accum_multiplier = 4096 * bias->params.scale;
    794     int32 accum_multiplier;
    795     int accum_shift;
    796     tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier,
    797                                &accum_shift);
    798     tflite::LstmCellParams op_params;
    799     op_params.weights_zero_point = weights->params.zero_point;
    800     op_params.accum_multiplier = accum_multiplier;
    801     op_params.accum_shift = accum_shift;
    802     optimized_ops::LstmCell<4>(
    803         op_params,
    804         // Inputs.
    805         GetTensorShape(input), GetTensorData<uint8_t>(input),
    806         GetTensorShape(prev_activation),
    807         GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights),
    808         GetTensorData<uint8_t>(weights), GetTensorShape(bias),
    809         GetTensorData<int32_t>(bias), GetTensorShape(prev_state),
    810         GetTensorData<int16_t>(prev_state),
    811         // Outputs.
    812         GetTensorShape(state_out), GetTensorData<int16_t>(state_out),
    813         GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
    814         GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
    815         GetTensorShape(activation_temp),
    816         GetTensorData<int16_t>(activation_temp), gemm_context);
    817   } else {
    818     context->ReportError(context,
    819                          "Unsupported combination of data types for LstmCell");
    820     return kTfLiteError;
    821   }
    822 
    823   // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs
    824   // LSTM kernel.
    825   memcpy(prev_activation->data.raw, activation_out->data.raw,
    826          activation_out->bytes);
    827   memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes);
    828 
    829   return kTfLiteOk;
    830 }
    831 
    832 }  // namespace basic
    833 
    834 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
    835   gemm_support::IncrementUsageCounter(context);
    836 
    837   const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
    838   switch (params->kernel_type) {
    839     case kTfLiteLSTMFullKernel:
    840       return full::Init(context, buffer, length);
    841     case kTfLiteLSTMBasicKernel:
    842       return basic::Init(context, buffer, length);
    843     default:
    844       return nullptr;
    845   }
    846 }
    847 void Free(TfLiteContext* context, void* buffer) {
    848   gemm_support::DecrementUsageCounter(context);
    849 
    850   delete reinterpret_cast<OpData*>(buffer);
    851 }
    852 
    853 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    854   const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
    855   switch (op_data->kernel_type) {
    856     case kTfLiteLSTMFullKernel:
    857       return full::Prepare(context, node);
    858     case kTfLiteLSTMBasicKernel:
    859       return basic::Prepare(context, node);
    860     default:
    861       return kTfLiteError;
    862   }
    863 }
    864 
    865 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    866   const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
    867   switch (op_data->kernel_type) {
    868     case kTfLiteLSTMFullKernel:
    869       return full::Eval(context, node);
    870     case kTfLiteLSTMBasicKernel:
    871       return basic::Eval(context, node);
    872     default:
    873       return kTfLiteError;
    874   }
    875 }
    876 
    877 }  // namespace lstm
    878 
    879 TfLiteRegistration* Register_LSTM() {
    880   static TfLiteRegistration r = {lstm::Init, lstm::Free, lstm::Prepare,
    881                                  lstm::Eval};
    882   return &r;
    883 }
    884 
    885 }  // namespace builtin
    886 }  // namespace ops
    887 }  // namespace tflite
    888