Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <unistd.h>
     17 #include <cassert>
     18 #include <cmath>
     19 #include <cstdio>
     20 #include <cstdlib>
     21 #include <iostream>
     22 #include <limits>
     23 
     24 #include "tensorflow/contrib/lite/builtin_op_data.h"
     25 #include "tensorflow/contrib/lite/context.h"
     26 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
     27 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
     28 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
     29 #include "tensorflow/contrib/lite/kernels/op_macros.h"
     30 
     31 namespace tflite {
     32 namespace ops {
     33 namespace builtin {
     34 namespace unidirectional_sequence_lstm {
     35 
     36 // Input Tensors of size {max_time, n_batch, n_input}
     37 constexpr int kInputTensor = 0;
     38 
     39 // Input weight tensors of size: {n_cell, n_input}
     40 constexpr int kInputToInputWeightsTensor = 1;  // Optional
     41 constexpr int kInputToForgetWeightsTensor = 2;
     42 constexpr int kInputToCellWeightsTensor = 3;
     43 constexpr int kInputToOutputWeightsTensor = 4;
     44 
     45 // Recurrent weight tensors of size {n_cell, n_output}
     46 constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
     47 constexpr int kRecurrentToForgetWeightsTensor = 6;
     48 constexpr int kRecurrentToCellWeightsTensor = 7;
     49 constexpr int kRecurrentToOutputWeightsTensor = 8;
     50 
     51 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
     52 constexpr int kCellToInputWeightsTensor = 9;    // Optional
     53 constexpr int kCellToForgetWeightsTensor = 10;  // Optional
     54 constexpr int kCellToOutputWeightsTensor = 11;  // Optional
     55 
     56 // Gates bias tensors of size {n_cell}
     57 constexpr int kInputGateBiasTensor = 12;  // Optional
     58 constexpr int kForgetGateBiasTensor = 13;
     59 constexpr int kCellGateBiasTensor = 14;
     60 constexpr int kOutputGateBiasTensor = 15;
     61 
     62 // Projection weight tensor of size {n_output, n_cell}
     63 constexpr int kProjectionWeightsTensor = 16;  // Optional
     64 // Projection bias tensor of size {n_output}
     65 constexpr int kProjectionBiasTensor = 17;  // Optional
     66 
     67 // Output tensors.
     68 constexpr int kScratchBufferTensor = 0;
     69 constexpr int kOutputStateTensor = 1;
     70 constexpr int kCellStateTensor = 2;
     71 constexpr int kOutputTensor = 3;
     72 
     73 // Check that input tensor dimensions matches with each other.
     74 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
     75                                         TfLiteNode* node, int n_input,
     76                                         int n_output, int n_cell) {
     77   auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
     78 
     79   // Making sure clipping parameters have valid values.
     80   // == 0 means no clipping
     81   //  > 0 means clipping
     82   TF_LITE_ENSURE(context, params->cell_clip >= 0);
     83   TF_LITE_ENSURE(context, params->proj_clip >= 0);
     84 
     85   TfLiteTensor* input_to_input_weights =
     86       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
     87   if (input_to_input_weights) {
     88     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
     89     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
     90     TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
     91   }
     92 
     93   TfLiteTensor* input_to_forget_weights =
     94       GetInput(context, node, kInputToForgetWeightsTensor);
     95   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
     96   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
     97   TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
     98 
     99   TfLiteTensor* input_to_cell_weights =
    100       GetInput(context, node, kInputToCellWeightsTensor);
    101   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
    102   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
    103   TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
    104 
    105   TfLiteTensor* recurrent_to_input_weights =
    106       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
    107   if (recurrent_to_input_weights) {
    108     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
    109     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
    110                       n_cell);
    111     TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
    112                       n_output);
    113   }
    114 
    115   TfLiteTensor* recurrent_to_forget_weights =
    116       GetInput(context, node, kRecurrentToForgetWeightsTensor);
    117   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
    118   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
    119                     n_cell);
    120   TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
    121                     n_output);
    122 
    123   TfLiteTensor* recurrent_to_cell_weights =
    124       GetInput(context, node, kRecurrentToCellWeightsTensor);
    125   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
    126   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
    127   TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
    128                     n_output);
    129 
    130   // We make sure the input-gate's parameters are either both present (regular
    131   // LSTM) or not at all (CIFG-LSTM).
    132   const bool cifg_weights_all_or_none =
    133       ((input_to_input_weights != nullptr) &&
    134        (recurrent_to_input_weights != nullptr)) ||
    135       ((input_to_input_weights == nullptr) &&
    136        (recurrent_to_input_weights == nullptr));
    137   TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
    138 
    139   TfLiteTensor* cell_to_input_weights =
    140       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
    141   if (cell_to_input_weights) {
    142     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
    143     TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
    144   }
    145 
    146   TfLiteTensor* cell_to_forget_weights =
    147       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
    148   if (cell_to_forget_weights) {
    149     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
    150     TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
    151   }
    152 
    153   TfLiteTensor* cell_to_output_weights =
    154       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
    155   if (cell_to_output_weights) {
    156     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
    157     TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
    158   }
    159 
    160   // Making sure the peephole weights are there all or none.
    161   const bool use_cifg = (input_to_input_weights == nullptr);
    162   const bool peephole_weights_all_or_none =
    163       ((cell_to_input_weights != nullptr || use_cifg) &&
    164        (cell_to_forget_weights != nullptr) &&
    165        (cell_to_output_weights != nullptr)) ||
    166       ((cell_to_input_weights == nullptr) &&
    167        (cell_to_forget_weights == nullptr) &&
    168        (cell_to_output_weights == nullptr));
    169   TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
    170 
    171   // Make sure the input gate bias is present only when not a CIFG-LSTM.
    172   TfLiteTensor* input_gate_bias =
    173       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
    174   if (use_cifg) {
    175     TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
    176   } else {
    177     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
    178     TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
    179   }
    180 
    181   TfLiteTensor* forget_gate_bias =
    182       GetInput(context, node, kForgetGateBiasTensor);
    183   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
    184   TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
    185 
    186   TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
    187   TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
    188   TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
    189 
    190   TfLiteTensor* output_gate_bias =
    191       GetInput(context, node, kOutputGateBiasTensor);
    192   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
    193   TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
    194 
    195   TfLiteTensor* projection_weights =
    196       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
    197   if (projection_weights) {
    198     TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
    199     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
    200     TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
    201   }
    202 
    203   TfLiteTensor* projection_bias =
    204       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
    205   if (projection_bias) {
    206     TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
    207     TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
    208   }
    209 
    210   // Making sure the projection tensors are consistent:
    211   // 1) If projection weight is not present, then projection bias should not be
    212   // present.
    213   // 2) If projection weight is present, then projection bias is optional.
    214   // TODO(ghodrat): make sure this is correct.
    215   const bool projecton_tensors_consistent =
    216       ((projection_weights != nullptr) || (projection_bias == nullptr));
    217   TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
    218 
    219   return kTfLiteOk;
    220 }
    221 
    222 // Resize the output, state and scratch tensors based on the sizes of the input
    223 // tensors. Also check that the size of the input tensors match each other.
    224 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    225   // Check we have all the inputs and outputs we need.
    226   TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
    227   TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
    228 
    229   // Inferring batch size, number of outputs and sequence length and
    230   // number of cells from the input tensors.
    231   TfLiteTensor* input = GetInput(context, node, kInputTensor);
    232   TF_LITE_ENSURE(context, input->dims->size > 1);
    233   const int max_time = input->dims->data[0];
    234   const int n_batch = input->dims->data[1];
    235   const int n_input = input->dims->data[2];
    236 
    237   TfLiteTensor* input_to_output_weights =
    238       GetInput(context, node, kInputToOutputWeightsTensor);
    239   const int n_cell = input_to_output_weights->dims->data[0];
    240   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
    241   TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
    242 
    243   TfLiteTensor* recurrent_to_output_weights =
    244       GetInput(context, node, kRecurrentToOutputWeightsTensor);
    245   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
    246   TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
    247                     n_cell);
    248   const int n_output = recurrent_to_output_weights->dims->data[1];
    249 
    250   // Check that input tensor dimensions matches with each other.
    251   CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
    252 
    253   // Get the pointer to output, state and scratch buffer tensors.
    254   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    255   TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
    256   TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
    257   // TODO(ghodrat): Modify this as soon as we have a finalized method for
    258   // scratch buffers.
    259   TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
    260 
    261   // Resize the output and output_state tensors.
    262   TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
    263   output_size->data[0] = max_time;
    264   output_size->data[1] = n_batch;
    265   output_size->data[2] = n_output;
    266   TF_LITE_ENSURE_OK(context,
    267                     context->ResizeTensor(context, output, output_size));
    268 
    269   TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
    270   output_state_size->data[0] = n_batch;
    271   output_state_size->data[1] = n_output;
    272   TF_LITE_ENSURE_OK(
    273       context, context->ResizeTensor(context, output_state, output_state_size));
    274 
    275   // Resize the scratch buffer tensor.
    276   TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
    277   cell_size->data[0] = n_batch;
    278   cell_size->data[1] = n_cell;
    279   TF_LITE_ENSURE_OK(context,
    280                     context->ResizeTensor(context, cell_state, cell_size));
    281 
    282   // Mark state tensors as persistent tensors.
    283   output_state->allocation_type = kTfLiteArenaRwPersistent;
    284   cell_state->allocation_type = kTfLiteArenaRwPersistent;
    285 
    286   TfLiteTensor* input_to_input_weights =
    287       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
    288   const bool use_cifg = (input_to_input_weights == nullptr);
    289   if (use_cifg) {
    290     TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
    291     scratch_buffer_size->data[0] = n_batch;
    292     // Reserving space for Cell, Forget, Output gates
    293     scratch_buffer_size->data[1] = n_cell * 3;
    294     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
    295                                                      scratch_buffer_size));
    296   } else {
    297     TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
    298     scratch_buffer_size->data[0] = n_batch;
    299     // Reserving space for Input, Cell, Forget, Output gates
    300     scratch_buffer_size->data[1] = n_cell * 4;
    301     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
    302                                                      scratch_buffer_size));
    303   }
    304   return kTfLiteOk;
    305 }
    306 
    307 // The LSTM Op engine.
    308 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    309   auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
    310   TfLiteTensor* input = GetInput(context, node, kInputTensor);
    311 
    312   TfLiteTensor* input_to_input_weights =
    313       GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
    314   TfLiteTensor* input_to_forget_weights =
    315       GetInput(context, node, kInputToForgetWeightsTensor);
    316   TfLiteTensor* input_to_cell_weights =
    317       GetInput(context, node, kInputToCellWeightsTensor);
    318   TfLiteTensor* input_to_output_weights =
    319       GetInput(context, node, kInputToOutputWeightsTensor);
    320 
    321   TfLiteTensor* recurrent_to_input_weights =
    322       GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
    323   TfLiteTensor* recurrent_to_forget_weights =
    324       GetInput(context, node, kRecurrentToForgetWeightsTensor);
    325   TfLiteTensor* recurrent_to_cell_weights =
    326       GetInput(context, node, kRecurrentToCellWeightsTensor);
    327   TfLiteTensor* recurrent_to_output_weights =
    328       GetInput(context, node, kRecurrentToOutputWeightsTensor);
    329 
    330   TfLiteTensor* cell_to_input_weights =
    331       GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
    332   TfLiteTensor* cell_to_forget_weights =
    333       GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
    334   TfLiteTensor* cell_to_output_weights =
    335       GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
    336 
    337   TfLiteTensor* input_gate_bias =
    338       GetOptionalInputTensor(context, node, kInputGateBiasTensor);
    339   TfLiteTensor* forget_gate_bias =
    340       GetInput(context, node, kForgetGateBiasTensor);
    341   TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
    342   TfLiteTensor* output_gate_bias =
    343       GetInput(context, node, kOutputGateBiasTensor);
    344 
    345   TfLiteTensor* projection_weights =
    346       GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
    347   TfLiteTensor* projection_bias =
    348       GetOptionalInputTensor(context, node, kProjectionBiasTensor);
    349 
    350   TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
    351   TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
    352   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
    353 
    354   const int max_time = input->dims->data[0];
    355   const int n_batch = input->dims->data[1];
    356   const int n_input = input->dims->data[2];
    357   // n_cell and n_output will be the same size when there is no projection.
    358   const int n_cell = input_to_output_weights->dims->data[0];
    359   const int n_output = recurrent_to_output_weights->dims->data[1];
    360 
    361   // Since we have already checked that weights are all there or none, we can
    362   // check the existense of only one to the get the condition.
    363   const bool use_cifg = (input_to_input_weights == nullptr);
    364   const bool use_peephole = (cell_to_output_weights != nullptr);
    365 
    366   // Index the scratch buffers pointers to the global scratch buffer.
    367   TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
    368   float* input_gate_scratch = nullptr;
    369   float* cell_scratch = nullptr;
    370   float* forget_gate_scratch = nullptr;
    371   float* output_gate_scratch = nullptr;
    372   if (use_cifg) {
    373     cell_scratch = scratch_buffer->data.f;
    374     forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
    375     output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
    376   } else {
    377     input_gate_scratch = scratch_buffer->data.f;
    378     cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
    379     forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
    380     output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
    381   }
    382 
    383   for (int t = 0; t < max_time; t++) {
    384     const float* input_ptr_time = input->data.f + t * n_batch * n_input;
    385     // Initialize scratch buffers with bias.
    386     if (!use_cifg) {
    387       tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell,
    388                                             n_batch, input_gate_scratch);
    389     }
    390     tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell,
    391                                           n_batch, forget_gate_scratch);
    392     tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch,
    393                                           cell_scratch);
    394     tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell,
    395                                           n_batch, output_gate_scratch);
    396 
    397     // For each batch and cell: compute input_weight * input.
    398     if (!use_cifg) {
    399       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    400           input_to_input_weights->data.f, n_cell, n_input, input_ptr_time,
    401           n_batch, input_gate_scratch, /*result_stride=*/1);
    402     }
    403     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    404         input_to_forget_weights->data.f, n_cell, n_input, input_ptr_time,
    405         n_batch, forget_gate_scratch, /*result_stride=*/1);
    406     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    407         input_to_cell_weights->data.f, n_cell, n_input, input_ptr_time, n_batch,
    408         cell_scratch, /*result_stride=*/1);
    409     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    410         input_to_output_weights->data.f, n_cell, n_input, input_ptr_time,
    411         n_batch, output_gate_scratch, /*result_stride=*/1);
    412 
    413     // For each batch and cell: compute recurrent_weight * output_state.
    414     if (!use_cifg) {
    415       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    416           recurrent_to_input_weights->data.f, n_cell, n_output,
    417           output_state->data.f, n_batch, input_gate_scratch,
    418           /*result_stride=*/1);
    419     }
    420     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    421         recurrent_to_forget_weights->data.f, n_cell, n_output,
    422         output_state->data.f, n_batch, forget_gate_scratch,
    423         /*result_stride=*/1);
    424     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    425         recurrent_to_cell_weights->data.f, n_cell, n_output,
    426         output_state->data.f, n_batch, cell_scratch, /*result_stride=*/1);
    427     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    428         recurrent_to_output_weights->data.f, n_cell, n_output,
    429         output_state->data.f, n_batch, output_gate_scratch,
    430         /*result_stride=*/1);
    431 
    432     // For each batch and cell: update input gate.
    433     if (!use_cifg) {
    434       if (use_peephole) {
    435         tensor_utils::VectorBatchVectorCwiseProductAccumulate(
    436             cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch,
    437             input_gate_scratch);
    438       }
    439       tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
    440                                          input_gate_scratch);
    441     }
    442 
    443     // For each batch and cell: update forget gate.
    444     if (use_peephole) {
    445       tensor_utils::VectorBatchVectorCwiseProductAccumulate(
    446           cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch,
    447           forget_gate_scratch);
    448     }
    449     tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
    450                                        forget_gate_scratch);
    451 
    452     // For each batch and cell: update the cell.
    453     tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
    454                                            cell_state->data.f, n_batch * n_cell,
    455                                            cell_state->data.f);
    456     tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
    457                                           params->activation, cell_scratch);
    458     if (use_cifg) {
    459       tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
    460                                forget_gate_scratch);
    461       tensor_utils::VectorVectorCwiseProductAccumulate(
    462           cell_scratch, forget_gate_scratch, n_batch * n_cell,
    463           cell_state->data.f);
    464     } else {
    465       tensor_utils::VectorVectorCwiseProductAccumulate(
    466           cell_scratch, input_gate_scratch, n_batch * n_cell,
    467           cell_state->data.f);
    468     }
    469     if (params->cell_clip > 0.0) {
    470       tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell,
    471                                params->cell_clip, cell_state->data.f);
    472     }
    473 
    474     // For each batch and cell: update the output gate.
    475     if (use_peephole) {
    476       tensor_utils::VectorBatchVectorCwiseProductAccumulate(
    477           cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch,
    478           output_gate_scratch);
    479     }
    480     tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
    481                                        output_gate_scratch);
    482     tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell,
    483                                           params->activation, cell_scratch);
    484     tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
    485                                            n_batch * n_cell,
    486                                            output_gate_scratch);
    487 
    488     // For each batch: update the projection and output_state.
    489     const bool use_projection_weight = (projection_weights != nullptr);
    490     const bool use_projection_bias = (projection_bias != nullptr);
    491     float* output_ptr_time = output->data.f + t * n_batch * n_output;
    492     if (use_projection_weight) {
    493       if (use_projection_bias) {
    494         tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output,
    495                                               n_batch, output_ptr_time);
    496       } else {
    497         tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output);
    498       }
    499       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
    500           projection_weights->data.f, n_output, n_cell, output_gate_scratch,
    501           n_batch, output_ptr_time, /*result_stride=*/1);
    502       if (params->proj_clip > 0.0) {
    503         tensor_utils::ClipVector(output_ptr_time, n_batch * n_output,
    504                                  params->proj_clip, output_ptr_time);
    505       }
    506     } else {
    507       tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
    508                                output_ptr_time);
    509     }
    510     tensor_utils::CopyVector(output_ptr_time, n_batch * n_output,
    511                              output_state->data.f);
    512   }
    513   return kTfLiteOk;
    514 }
    515 
    516 }  // namespace unidirectional_sequence_lstm
    517 
    518 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
    519   static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
    520                                  unidirectional_sequence_lstm::Prepare,
    521                                  unidirectional_sequence_lstm::Eval};
    522   return &r;
    523 }
    524 
    525 }  // namespace builtin
    526 }  // namespace ops
    527 }  // namespace tflite
    528