Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <cassert>
     16 #include <cmath>
     17 #include <cstdio>
     18 #include <cstdlib>
     19 #include <iostream>
     20 #include <limits>
     21 
     22 #include "tensorflow/lite/c/builtin_op_data.h"
     23 #include "tensorflow/lite/c/c_api_internal.h"
     24 #include "tensorflow/lite/kernels/activation_functor.h"
     25 #include "tensorflow/lite/kernels/internal/kernel_utils.h"
     26 #include "tensorflow/lite/kernels/kernel_util.h"
     27 #include "tensorflow/lite/kernels/op_macros.h"
     28 
     29 namespace tflite {
     30 namespace ops {
     31 namespace builtin {
     32 namespace bidirectional_sequence_rnn {
     33 
     34 namespace {
     35 
     36 int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
     37   if (is_uint8) {
     38     return reinterpret_cast<int8_t*>(tensor->data.uint8);
     39   } else {
     40     return tensor->data.int8;
     41   }
     42 }
     43 
     44 }  // namespace
     45 
     46 constexpr int kInputTensor = 0;
     47 // Forward and backward cell tensors.
     48 constexpr int kFwWeightsTensor = 1;
     49 constexpr int kFwRecurrentWeightsTensor = 2;
     50 constexpr int kFwBiasTensor = 3;
     51 constexpr int kFwHiddenStateTensor = 4;
     52 constexpr int kBwWeightsTensor = 5;
     53 constexpr int kBwRecurrentWeightsTensor = 6;
     54 constexpr int kBwBiasTensor = 7;
     55 constexpr int kBwHiddenStateTensor = 8;
     56 // Used as auxiliary input and weights when stacking for
     57 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
     58 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
     59 // (without cross links).
     60 constexpr int kAuxInputTensor = 9;       // Optional.
     61 constexpr int kFwAuxWeightsTensor = 10;  // Optional.
     62 constexpr int kBwAuxWeightsTensor = 11;  // Optional.
     63 // Output tensors.
     64 constexpr int kFwOutputTensor = 0;
     65 constexpr int kBwOutputTensor = 1;  // Only if merge_outputs is false.
     66 
     67 // Temporary tensors.
     68 enum TemporaryTensor {
     69   kInputQuantized = 0,
     70   kFwHiddenStateQuantized = 1,
     71   kBwHiddenStateQuantized = 2,
     72   kScalingFactors = 3,
     73   kAuxInputQuantized = 4,
     74   kNumTemporaryTensors = 5
     75 };
     76 
     77 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
     78   auto* scratch_tensor_index = new int;
     79   context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
     80   return scratch_tensor_index;
     81 }
     82 
     83 void Free(TfLiteContext* context, void* buffer) {
     84   delete reinterpret_cast<int*>(buffer);
     85 }
     86 
     87 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     88   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
     89       node->builtin_data);
     90 
     91   // Check we have all the inputs and outputs we need.
     92   TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
     93   TF_LITE_ENSURE_EQ(context, node->outputs->size,
     94                     params->merge_outputs ? 1 : 2);
     95 
     96   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
     97   const TfLiteTensor* fw_input_weights =
     98       GetInput(context, node, kFwWeightsTensor);
     99   const TfLiteTensor* fw_recurrent_weights =
    100       GetInput(context, node, kFwRecurrentWeightsTensor);
    101   const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
    102   const TfLiteTensor* fw_hidden_state =
    103       GetInput(context, node, kFwHiddenStateTensor);
    104   const TfLiteTensor* bw_input_weights =
    105       GetInput(context, node, kBwWeightsTensor);
    106   const TfLiteTensor* bw_recurrent_weights =
    107       GetInput(context, node, kBwRecurrentWeightsTensor);
    108   const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
    109   const TfLiteTensor* bw_hidden_state =
    110       GetInput(context, node, kBwHiddenStateTensor);
    111 
    112   const TfLiteTensor* aux_input =
    113       GetOptionalInputTensor(context, node, kAuxInputTensor);
    114   const TfLiteTensor* fw_aux_input_weights =
    115       GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
    116   const TfLiteTensor* bw_aux_input_weights =
    117       GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
    118 
    119   const bool aux_inputs_weights_or_none =
    120       ((fw_aux_input_weights != nullptr) &&
    121        (bw_aux_input_weights != nullptr)) ||
    122       ((fw_aux_input_weights == nullptr) && (bw_aux_input_weights == nullptr));
    123   TF_LITE_ENSURE(context, aux_inputs_weights_or_none);
    124   const bool has_aux_input = (fw_aux_input_weights != nullptr);
    125 
    126   // Check all the parameters of tensor match within themselves and match the
    127   // input configuration.
    128   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
    129 
    130   TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
    131   const bool time_major = params->time_major;
    132   const int batch_size =
    133       (time_major) ? input->dims->data[1] : input->dims->data[0];
    134   const int max_time =
    135       (time_major) ? input->dims->data[0] : input->dims->data[1];
    136   const int fw_num_units = fw_input_weights->dims->data[0];
    137   const int bw_num_units = bw_input_weights->dims->data[0];
    138   TF_LITE_ENSURE_EQ(context, input->dims->data[2],
    139                     fw_input_weights->dims->data[1]);
    140   TF_LITE_ENSURE_EQ(context, input->dims->data[2],
    141                     bw_input_weights->dims->data[1]);
    142   TF_LITE_ENSURE_EQ(context, fw_input_weights->dims->data[0],
    143                     fw_bias->dims->data[0]);
    144   TF_LITE_ENSURE_EQ(context, bw_input_weights->dims->data[0],
    145                     bw_bias->dims->data[0]);
    146   TF_LITE_ENSURE_EQ(context, fw_recurrent_weights->dims->data[0],
    147                     fw_bias->dims->data[0]);
    148   TF_LITE_ENSURE_EQ(context, bw_recurrent_weights->dims->data[1],
    149                     bw_bias->dims->data[0]);
    150   TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
    151   TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
    152   TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
    153   TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
    154   TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
    155   TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
    156 
    157   if (has_aux_input) {
    158     // Check that aux_input has the same dimensions (except last) as the input.
    159     TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
    160     TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
    161     // Check that aux_input_weights has the same dimensions (except last) as
    162     // the input_weights.
    163     TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units);
    164     TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units);
    165     TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
    166                       fw_aux_input_weights->dims->data[1]);
    167     TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
    168                       bw_aux_input_weights->dims->data[1]);
    169   }
    170 
    171   const bool is_hybrid_op = ((fw_input_weights->type == kTfLiteUInt8 ||
    172                               fw_input_weights->type == kTfLiteInt8) &&
    173                              input->type == kTfLiteFloat32);
    174 
    175   if (is_hybrid_op) {
    176     int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
    177 
    178     TfLiteIntArrayFree(node->temporaries);
    179     if (has_aux_input) {
    180       node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
    181     } else {
    182       // No need to create a temporary tensor for the non-existent aux_input.
    183       node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1);
    184     }
    185 
    186     node->temporaries->data[kInputQuantized] =
    187         *scratch_tensor_index + kInputQuantized;
    188     TfLiteTensor* input_quantized =
    189         GetTemporary(context, node, kInputQuantized);
    190     input_quantized->type = fw_input_weights->type;
    191     input_quantized->allocation_type = kTfLiteArenaRw;
    192     if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
    193       TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
    194       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
    195                                                        input_quantized_size));
    196     }
    197 
    198     node->temporaries->data[kFwHiddenStateQuantized] =
    199         *scratch_tensor_index + kFwHiddenStateQuantized;
    200     TfLiteTensor* fw_hidden_state_quantized =
    201         GetTemporary(context, node, kFwHiddenStateQuantized);
    202     fw_hidden_state_quantized->type = fw_input_weights->type;
    203     fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
    204     if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
    205                              fw_hidden_state->dims)) {
    206       TfLiteIntArray* fw_hidden_state_quantized_size =
    207           TfLiteIntArrayCopy(fw_hidden_state->dims);
    208       TF_LITE_ENSURE_OK(
    209           context, context->ResizeTensor(context, fw_hidden_state_quantized,
    210                                          fw_hidden_state_quantized_size));
    211     }
    212 
    213     node->temporaries->data[kBwHiddenStateQuantized] =
    214         *scratch_tensor_index + kBwHiddenStateQuantized;
    215     TfLiteTensor* bw_hidden_state_quantized =
    216         GetTemporary(context, node, kBwHiddenStateQuantized);
    217     bw_hidden_state_quantized->type = fw_input_weights->type;
    218     bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
    219     if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
    220                              bw_hidden_state->dims)) {
    221       TfLiteIntArray* bw_hidden_state_quantized_size =
    222           TfLiteIntArrayCopy(bw_hidden_state->dims);
    223       TF_LITE_ENSURE_OK(
    224           context, context->ResizeTensor(context, bw_hidden_state_quantized,
    225                                          bw_hidden_state_quantized_size));
    226     }
    227 
    228     // Allocate temporary tensors to store scaling factors of quantization.
    229     node->temporaries->data[kScalingFactors] =
    230         *scratch_tensor_index + kScalingFactors;
    231     TfLiteTensor* scaling_factors =
    232         GetTemporary(context, node, kScalingFactors);
    233     scaling_factors->type = kTfLiteFloat32;
    234     scaling_factors->allocation_type = kTfLiteArenaRw;
    235     int scaling_dims[1] = {batch_size};
    236     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
    237       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
    238       scaling_factors_size->data[0] = batch_size;
    239       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
    240                                                        scaling_factors_size));
    241     }
    242 
    243     if (has_aux_input) {
    244       node->temporaries->data[kAuxInputQuantized] =
    245           *scratch_tensor_index + kAuxInputQuantized;
    246       TfLiteTensor* aux_input_quantized =
    247           GetTemporary(context, node, kAuxInputQuantized);
    248       aux_input_quantized->type = fw_input_weights->type;
    249       aux_input_quantized->allocation_type = kTfLiteArenaRw;
    250       if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
    251         TfLiteIntArray* aux_input_quantized_size =
    252             TfLiteIntArrayCopy(aux_input->dims);
    253         TF_LITE_ENSURE_OK(context,
    254                           context->ResizeTensor(context, aux_input_quantized,
    255                                                 aux_input_quantized_size));
    256       }
    257     }
    258   }
    259 
    260   // Resize outputs.
    261   TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
    262   TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
    263   fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
    264   fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
    265   fw_output_size_array->data[2] =
    266       params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
    267   TF_LITE_ENSURE_OK(
    268       context, context->ResizeTensor(context, fw_output, fw_output_size_array));
    269   if (!params->merge_outputs) {
    270     TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
    271     TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
    272     bw_output_size_array->data[0] = batch_size;
    273     bw_output_size_array->data[1] = max_time;
    274     bw_output_size_array->data[2] = bw_num_units;
    275     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
    276                                                      bw_output_size_array));
    277   }
    278 
    279   return kTfLiteOk;
    280 }
    281 
    282 TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* bw_input,
    283                        const TfLiteTensor* fw_input_weights,
    284                        const TfLiteTensor* fw_recurrent_weights,
    285                        const TfLiteTensor* fw_bias,
    286                        const TfLiteTensor* bw_input_weights,
    287                        const TfLiteTensor* bw_recurrent_weights,
    288                        const TfLiteTensor* bw_bias,
    289                        const TfLiteTensor* aux_input,
    290                        const TfLiteTensor* fw_aux_input_weights,
    291                        const TfLiteTensor* bw_aux_input_weights,
    292                        const TfLiteBidirectionalSequenceRNNParams* params,
    293                        TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
    294                        TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
    295   const bool time_major = params->time_major;
    296   const int batch_size =
    297       (time_major) ? input->dims->data[1] : input->dims->data[0];
    298   const int max_time =
    299       (time_major) ? input->dims->data[0] : input->dims->data[1];
    300   const int input_size = input->dims->data[2];
    301   const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
    302 
    303   const int fw_num_units = fw_input_weights->dims->data[0];
    304   const float* fw_bias_ptr = fw_bias->data.f;
    305   const float* fw_input_weights_ptr = fw_input_weights->data.f;
    306   const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f;
    307 
    308   const int bw_num_units = bw_input_weights->dims->data[0];
    309   const float* bw_bias_ptr = bw_bias->data.f;
    310   const float* bw_input_weights_ptr = bw_input_weights->data.f;
    311   const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f;
    312 
    313   const float* fw_aux_input_weights_ptr = (fw_aux_input_weights != nullptr)
    314                                               ? fw_aux_input_weights->data.f
    315                                               : nullptr;
    316   const float* bw_aux_input_weights_ptr = (bw_aux_input_weights != nullptr)
    317                                               ? bw_aux_input_weights->data.f
    318                                               : nullptr;
    319 
    320   const int fw_output_step =
    321       params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
    322   const int bw_output_step =
    323       params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
    324   if (time_major) {
    325     // Forward cell.
    326     float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f;
    327     for (int s = 0; s < max_time; s++) {
    328       const float* input_ptr_batch =
    329           input->data.f + s * input_size * batch_size;
    330       const float* aux_input_ptr_batch =
    331           (aux_input != nullptr)
    332               ? aux_input->data.f + s * input_size * batch_size
    333               : nullptr;
    334       float* output_ptr_batch =
    335           fw_output->data.f + s * fw_output_step * batch_size;
    336 
    337       kernel_utils::RnnBatchStep(
    338           input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
    339           fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
    340           input_size, aux_input_size, fw_num_units, batch_size, fw_output_step,
    341           params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
    342     }
    343     // Backward cell.
    344     float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f;
    345     for (int s = max_time - 1; s >= 0; s--) {
    346       const float* input_ptr_batch =
    347           bw_input->data.f + s * input_size * batch_size;
    348       const float* aux_input_ptr_batch =
    349           (aux_input != nullptr)
    350               ? aux_input->data.f + s * input_size * batch_size
    351               : nullptr;
    352       float* output_ptr_batch =
    353           (params->merge_outputs ? fw_output->data.f + fw_num_units
    354                                  : bw_output->data.f) +
    355           s * bw_output_step * batch_size;
    356 
    357       kernel_utils::RnnBatchStep(
    358           input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
    359           bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
    360           input_size, aux_input_size, bw_num_units, batch_size, bw_output_step,
    361           params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
    362     }
    363   } else {
    364     for (int b = 0; b < batch_size; b++) {
    365       // Forward cell.
    366       float* fw_hidden_state_ptr_batch =
    367           fw_hidden_state->data.f + b * fw_num_units;
    368       float* fw_output_offset =
    369           fw_output->data.f + b * fw_output_step * max_time;
    370       for (int s = 0; s < max_time; s++) {
    371         const float* input_ptr_batch =
    372             input->data.f + b * input_size * max_time + s * input_size;
    373         const float* aux_input_ptr_batch =
    374             (aux_input != nullptr)
    375                 ? aux_input->data.f + b * input_size * max_time + s * input_size
    376                 : nullptr;
    377         float* output_ptr_batch = fw_output_offset + s * fw_output_step;
    378 
    379         kernel_utils::RnnBatchStep(
    380             input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
    381             fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
    382             input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
    383             fw_output_step, params->activation, fw_hidden_state_ptr_batch,
    384             output_ptr_batch);
    385       }
    386       // Backward cell.
    387       float* bw_hidden_state_ptr_batch =
    388           bw_hidden_state->data.f + b * bw_num_units;
    389       float* bw_output_offset =
    390           params->merge_outputs
    391               ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
    392               : bw_output->data.f + b * bw_output_step * max_time;
    393       for (int s = max_time - 1; s >= 0; s--) {
    394         const float* input_ptr_batch =
    395             input->data.f + b * input_size * max_time + s * input_size;
    396         const float* aux_input_ptr_batch =
    397             (aux_input != nullptr)
    398                 ? aux_input->data.f + b * input_size * max_time + s * input_size
    399                 : nullptr;
    400         float* output_ptr_batch = bw_output_offset + s * bw_output_step;
    401 
    402         kernel_utils::RnnBatchStep(
    403             input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
    404             bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
    405             input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
    406             bw_output_step, params->activation, bw_hidden_state_ptr_batch,
    407             output_ptr_batch);
    408       }
    409     }
    410   }
    411   return kTfLiteOk;
    412 }
    413 
    414 TfLiteStatus EvalHybrid(
    415     const TfLiteTensor* input, const TfLiteTensor* bw_input,
    416     const TfLiteTensor* fw_input_weights,
    417     const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
    418     const TfLiteTensor* bw_input_weights,
    419     const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
    420     const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
    421     const TfLiteTensor* aux_bw_input_weights,
    422     const TfLiteBidirectionalSequenceRNNParams* params,
    423     TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
    424     TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
    425     TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
    426     TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
    427     TfLiteTensor* bw_output) {
    428   const bool is_uint8_hybrid = fw_input_weights->type == kTfLiteUInt8;
    429   const bool time_major = params->time_major;
    430   const int batch_size =
    431       (time_major) ? input->dims->data[1] : input->dims->data[0];
    432   const int max_time =
    433       (time_major) ? input->dims->data[0] : input->dims->data[1];
    434   const int input_size = input->dims->data[2];
    435   const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
    436 
    437   const int fw_num_units = fw_input_weights->dims->data[0];
    438   const float* fw_bias_ptr = fw_bias->data.f;
    439   const int8_t* fw_input_weights_ptr =
    440       GetInt8DataPtr(fw_input_weights, is_uint8_hybrid);
    441   float fw_input_weights_scale = fw_input_weights->params.scale;
    442   const int8_t* fw_recurrent_weights_ptr =
    443       GetInt8DataPtr(fw_recurrent_weights, is_uint8_hybrid);
    444   float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale;
    445 
    446   const int bw_num_units = bw_input_weights->dims->data[0];
    447   const float* bw_bias_ptr = bw_bias->data.f;
    448   const int8_t* bw_input_weights_ptr =
    449       GetInt8DataPtr(bw_input_weights, is_uint8_hybrid);
    450   float bw_input_weights_scale = bw_input_weights->params.scale;
    451   const int8_t* bw_recurrent_weights_ptr =
    452       GetInt8DataPtr(bw_recurrent_weights, is_uint8_hybrid);
    453   float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
    454 
    455   // Set the auxiliary pointers and scales if needed.
    456   int8_t* aux_fw_input_weights_ptr = nullptr;
    457   float aux_fw_input_weights_scale = 0.0f;
    458   int8_t* aux_bw_input_weights_ptr = nullptr;
    459   float aux_bw_input_weights_scale = 0.0f;
    460   int8_t* aux_quantized_input_ptr = nullptr;
    461   if (aux_input_size > 0) {
    462     aux_fw_input_weights_ptr =
    463         GetInt8DataPtr(aux_fw_input_weights, is_uint8_hybrid);
    464     aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
    465     aux_bw_input_weights_ptr =
    466         GetInt8DataPtr(aux_bw_input_weights, is_uint8_hybrid);
    467     aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
    468     aux_quantized_input_ptr =
    469         GetInt8DataPtr(aux_input_quantized, is_uint8_hybrid);
    470   }
    471 
    472   // Initialize temporary storage for quantized values.
    473   int8_t* quantized_input_ptr =
    474       GetInt8DataPtr(input_quantized, is_uint8_hybrid);
    475   int8_t* fw_quantized_hidden_state_ptr =
    476       GetInt8DataPtr(fw_hidden_state_quantized, is_uint8_hybrid);
    477   int8_t* bw_quantized_hidden_state_ptr =
    478       GetInt8DataPtr(bw_hidden_state_quantized, is_uint8_hybrid);
    479   float* scaling_factors_ptr = scaling_factors->data.f;
    480 
    481   const int fw_output_step =
    482       params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
    483   const int bw_output_step =
    484       params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
    485   if (time_major) {
    486     for (int t = 0; t < max_time; t++) {
    487       // Forward cell.
    488       float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f;
    489       for (int s = 0; s < max_time; s++) {
    490         const float* input_ptr_batch =
    491             input->data.f + s * input_size * batch_size;
    492         const float* aux_input_ptr_batch =
    493             (aux_input != nullptr)
    494                 ? aux_input->data.f + s * input_size * batch_size
    495                 : nullptr;
    496         float* output_ptr_batch =
    497             fw_output->data.f + s * fw_output_step * batch_size;
    498 
    499         kernel_utils::RnnBatchStep(
    500             input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
    501             aux_input_ptr_batch, aux_fw_input_weights_ptr,
    502             aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
    503             fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
    504             fw_num_units, batch_size, fw_output_step, params->activation,
    505             quantized_input_ptr, aux_quantized_input_ptr,
    506             fw_quantized_hidden_state_ptr, scaling_factors_ptr,
    507             fw_hidden_state_ptr_batch, output_ptr_batch);
    508       }
    509       // Backward cell.
    510       float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f;
    511       for (int s = max_time - 1; s >= 0; s--) {
    512         const float* input_ptr_batch =
    513             bw_input->data.f + s * input_size * batch_size;
    514         const float* aux_input_ptr_batch =
    515             (aux_input != nullptr)
    516                 ? aux_input->data.f + s * input_size * batch_size
    517                 : nullptr;
    518         float* output_ptr_batch =
    519             (params->merge_outputs ? fw_output->data.f + fw_num_units
    520                                    : bw_output->data.f) +
    521             s * bw_output_step * batch_size;
    522 
    523         kernel_utils::RnnBatchStep(
    524             input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
    525             aux_input_ptr_batch, aux_bw_input_weights_ptr,
    526             aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
    527             bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
    528             bw_num_units, batch_size, bw_output_step, params->activation,
    529             quantized_input_ptr, aux_quantized_input_ptr,
    530             bw_quantized_hidden_state_ptr, scaling_factors_ptr,
    531             bw_hidden_state_ptr_batch, output_ptr_batch);
    532       }
    533     }
    534   } else {
    535     for (int b = 0; b < batch_size; b++) {
    536       // Forward cell.
    537       float* fw_hidden_state_ptr_batch =
    538           fw_hidden_state->data.f + b * fw_num_units;
    539       float* fw_output_offset =
    540           fw_output->data.f + b * fw_output_step * max_time;
    541       for (int s = 0; s < max_time; s++) {
    542         const float* input_ptr_batch =
    543             input->data.f + b * input_size * max_time + s * input_size;
    544         const float* aux_input_ptr_batch =
    545             (aux_input != nullptr)
    546                 ? aux_input->data.f + b * input_size * max_time + s * input_size
    547                 : nullptr;
    548         float* output_ptr_batch = fw_output_offset + s * fw_output_step;
    549 
    550         kernel_utils::RnnBatchStep(
    551             input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
    552             aux_input_ptr_batch, aux_fw_input_weights_ptr,
    553             aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
    554             fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
    555             fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
    556             quantized_input_ptr, aux_quantized_input_ptr,
    557             fw_quantized_hidden_state_ptr, scaling_factors_ptr,
    558             fw_hidden_state_ptr_batch, output_ptr_batch);
    559       }
    560       // Backward cell.
    561       float* bw_hidden_state_ptr_batch =
    562           bw_hidden_state->data.f + b * bw_num_units;
    563       float* bw_output_offset =
    564           params->merge_outputs
    565               ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
    566               : bw_output->data.f + b * bw_output_step * max_time;
    567       for (int s = max_time - 1; s >= 0; s--) {
    568         const float* input_ptr_batch =
    569             input->data.f + b * input_size * max_time + s * input_size;
    570         const float* aux_input_ptr_batch =
    571             (aux_input != nullptr)
    572                 ? aux_input->data.f + b * input_size * max_time + s * input_size
    573                 : nullptr;
    574         float* output_ptr_batch = bw_output_offset + s * bw_output_step;
    575 
    576         kernel_utils::RnnBatchStep(
    577             input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
    578             aux_input_ptr_batch, aux_bw_input_weights_ptr,
    579             aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
    580             bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
    581             bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
    582             quantized_input_ptr, aux_quantized_input_ptr,
    583             bw_quantized_hidden_state_ptr, scaling_factors_ptr,
    584             bw_hidden_state_ptr_batch, output_ptr_batch);
    585       }
    586     }
    587   }
    588   return kTfLiteOk;
    589 }
    590 
    591 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    592   const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
    593       node->builtin_data);
    594 
    595   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
    596   const TfLiteTensor* fw_input_weights =
    597       GetInput(context, node, kFwWeightsTensor);
    598   const TfLiteTensor* fw_recurrent_weights =
    599       GetInput(context, node, kFwRecurrentWeightsTensor);
    600   const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
    601   const TfLiteTensor* bw_input_weights =
    602       GetInput(context, node, kBwWeightsTensor);
    603   const TfLiteTensor* bw_recurrent_weights =
    604       GetInput(context, node, kBwRecurrentWeightsTensor);
    605   const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
    606 
    607   // Get auxiliary inputs.
    608   const TfLiteTensor* aux_input =
    609       GetOptionalInputTensor(context, node, kAuxInputTensor);
    610   const TfLiteTensor* fw_aux_input_weights =
    611       GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
    612   const TfLiteTensor* bw_aux_input_weights =
    613       GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
    614 
    615   TfLiteTensor* fw_hidden_state =
    616       GetVariableInput(context, node, kFwHiddenStateTensor);
    617   TfLiteTensor* bw_hidden_state =
    618       GetVariableInput(context, node, kBwHiddenStateTensor);
    619 
    620   TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
    621   TfLiteTensor* bw_output = params->merge_outputs
    622                                 ? nullptr
    623                                 : GetOutput(context, node, kBwOutputTensor);
    624 
    625   const bool has_previous_bw_output = (aux_input != nullptr);
    626   const bool use_aux_input = (fw_aux_input_weights != nullptr);
    627 
    628   // We want to cover the following cases:
    629   //
    630   // If not stacking (not connected after other bidi lstms):
    631   //   both fw & bw will just use `input`; aux_input will be null.
    632   //
    633   // If stacking with cross_links, TensorFlow equivalent
    634   // (tf.contrib.rnn.stack_bidirectional_rnn):
    635   //   both fw & bw will use `input`, but aux_input will be none null.
    636   //   Note, this time, whether connected after other bidi lstms both works.
    637   //
    638   // If stacking without cross_links, but connected after other bidi lstms,
    639   // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
    640   //   fw will use `input`, bw will use aux_input, and the `real aux_input`
    641   //   will be null.
    642 
    643   const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
    644   const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
    645   const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
    646 
    647   switch (fw_input_weights->type) {
    648     case kTfLiteFloat32:
    649       return EvalFloat(input, bw_input, fw_input_weights, fw_recurrent_weights,
    650                        fw_bias, bw_input_weights, bw_recurrent_weights, bw_bias,
    651                        real_aux_input, fw_aux_input_weights,
    652                        bw_aux_input_weights, params, fw_hidden_state, fw_output,
    653                        bw_hidden_state, bw_output);
    654     case kTfLiteUInt8:
    655     case kTfLiteInt8: {
    656       TfLiteTensor* input_quantized =
    657           GetTemporary(context, node, kInputQuantized);
    658       TfLiteTensor* fw_hidden_state_quantized =
    659           GetTemporary(context, node, kFwHiddenStateQuantized);
    660       TfLiteTensor* bw_hidden_state_quantized =
    661           GetTemporary(context, node, kBwHiddenStateQuantized);
    662       TfLiteTensor* scaling_factors =
    663           GetTemporary(context, node, kScalingFactors);
    664       TfLiteTensor* aux_input_quantized =
    665           use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
    666                         : nullptr;
    667 
    668       return EvalHybrid(input, bw_input, fw_input_weights, fw_recurrent_weights,
    669                         fw_bias, bw_input_weights, bw_recurrent_weights,
    670                         bw_bias, real_aux_input, fw_aux_input_weights,
    671                         bw_aux_input_weights, params, scaling_factors,
    672                         input_quantized, aux_input_quantized,
    673                         fw_hidden_state_quantized, fw_hidden_state, fw_output,
    674                         bw_hidden_state_quantized, bw_hidden_state, bw_output);
    675     }
    676     default:
    677       context->ReportError(context, "Type not currently supported.");
    678       return kTfLiteError;
    679   }
    680   return kTfLiteOk;
    681 }
    682 
    683 }  // namespace bidirectional_sequence_rnn
    684 
    685 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
    686   static TfLiteRegistration r = {
    687       bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free,
    688       bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval};
    689   return &r;
    690 }
    691 
    692 }  // namespace builtin
    693 }  // namespace ops
    694 }  // namespace tflite
    695