1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cassert> 17 #include <cmath> 18 #include <cstdio> 19 #include <cstdlib> 20 #include <iostream> 21 #include <limits> 22 23 #include "tensorflow/lite/c/builtin_op_data.h" 24 #include "tensorflow/lite/c/c_api_internal.h" 25 #include "tensorflow/lite/kernels/activation_functor.h" 26 #include "tensorflow/lite/kernels/internal/kernel_utils.h" 27 #include "tensorflow/lite/kernels/internal/tensor_utils.h" 28 #include "tensorflow/lite/kernels/kernel_util.h" 29 #include "tensorflow/lite/kernels/lstm_eval.h" 30 #include "tensorflow/lite/kernels/op_macros.h" 31 32 namespace tflite { 33 namespace ops { 34 namespace builtin { 35 namespace bidirectional_sequence_lstm { 36 37 // Input Tensors of size {max_time, n_batch, n_input} 38 constexpr int kInputTensor = 0; 39 40 // Forward LSTM cell tensors. 41 // Input weight tensors of size: {n_cell, n_input} 42 constexpr int kFwInputToInputWeightsTensor = 1; // Optional 43 constexpr int kFwInputToForgetWeightsTensor = 2; 44 constexpr int kFwInputToCellWeightsTensor = 3; 45 constexpr int kFwInputToOutputWeightsTensor = 4; 46 47 // Recurrent weight tensors of size {n_cell, n_output} 48 constexpr int kFwRecurrentToInputWeightsTensor = 5; // Optional 49 constexpr int kFwRecurrentToForgetWeightsTensor = 6; 50 constexpr int kFwRecurrentToCellWeightsTensor = 7; 51 constexpr int kFwRecurrentToOutputWeightsTensor = 8; 52 53 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 54 constexpr int kFwCellToInputWeightsTensor = 9; // Optional 55 constexpr int kFwCellToForgetWeightsTensor = 10; // Optional 56 constexpr int kFwCellToOutputWeightsTensor = 11; // Optional 57 58 // Gates bias tensors of size {n_cell} 59 constexpr int kFwInputGateBiasTensor = 12; // Optional 60 constexpr int kFwForgetGateBiasTensor = 13; 61 constexpr int kFwCellGateBiasTensor = 14; 62 constexpr int kFwOutputGateBiasTensor = 15; 63 64 // Projection weight tensor of size {n_output, n_cell} 65 constexpr int kFwProjectionWeightsTensor = 16; // Optional 66 // Projection bias tensor of size {n_output} 67 constexpr int kFwProjectionBiasTensor = 17; // Optional 68 69 // Backward LSTM cell tensors. 70 // Input weight tensors of size: {n_cell, n_input} 71 constexpr int kBwInputToInputWeightsTensor = 18; // Optional 72 constexpr int kBwInputToForgetWeightsTensor = 19; 73 constexpr int kBwInputToCellWeightsTensor = 20; 74 constexpr int kBwInputToOutputWeightsTensor = 21; 75 76 // Recurrent weight tensors of size {n_cell, n_output} 77 constexpr int kBwRecurrentToInputWeightsTensor = 22; // Optional 78 constexpr int kBwRecurrentToForgetWeightsTensor = 23; 79 constexpr int kBwRecurrentToCellWeightsTensor = 24; 80 constexpr int kBwRecurrentToOutputWeightsTensor = 25; 81 82 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 83 constexpr int kBwCellToInputWeightsTensor = 26; // Optional 84 constexpr int kBwCellToForgetWeightsTensor = 27; // Optional 85 constexpr int kBwCellToOutputWeightsTensor = 28; // Optional 86 87 // Gates bias tensors of size {n_cell} 88 constexpr int kBwInputGateBiasTensor = 29; // Optional 89 constexpr int kBwForgetGateBiasTensor = 30; 90 constexpr int kBwCellGateBiasTensor = 31; 91 constexpr int kBwOutputGateBiasTensor = 32; 92 93 // Projection weight tensor of size {n_output, n_cell} 94 constexpr int kBwProjectionWeightsTensor = 33; // Optional 95 // Projection bias tensor of size {n_output} 96 constexpr int kBwProjectionBiasTensor = 34; // Optional 97 98 // Stateful input tensors that are variables and will be modified by the Op. 99 // Activation state tensors of size {n_batch, n_output} 100 constexpr int kFwInputActivationStateTensor = 35; 101 // Cell state tensors of size {n_batch, n_cell} 102 constexpr int kFwInputCellStateTensor = 36; 103 // Activation state tensors of size {n_batch, n_output} 104 constexpr int kBwInputActivationStateTensor = 37; 105 // Cell state tensors of size {n_batch, n_cell} 106 constexpr int kBwInputCellStateTensor = 38; 107 108 // Used as auxiliary input and weights when stacking for 109 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input 110 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case 111 // (without cross links). 112 constexpr int kAuxInputTensor = 39; // Optional 113 // Forward weights. 114 constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional 115 constexpr int kFwAuxInputToForgetWeightsTensor = 41; // Optional 116 constexpr int kFwAuxInputToCellWeightsTensor = 42; // Optional 117 constexpr int kFwAuxInputToOutputWeightsTensor = 43; // Optional 118 // Backward weights. 119 constexpr int kBwAuxInputToInputWeightsTensor = 44; // Optional 120 constexpr int kBwAuxInputToForgetWeightsTensor = 45; // Optional 121 constexpr int kBwAuxInputToCellWeightsTensor = 46; // Optional 122 constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional 123 124 // Output tensors. 125 constexpr int kFwOutputTensor = 0; 126 constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set. 127 128 // Temporary tensors. 129 enum TemporaryTensor { 130 // Scratch buffers for input, forget, etc. gates 131 kFwScratchBuffer = 0, 132 kBwScratchBuffer = 1, 133 // Quantized tensors needed for the hybrid kernel. 134 kInputQuantized = 2, 135 kFwActivationStateQuantized = 3, 136 kBwActivationStateQuantized = 4, 137 kFwCellStateQuantized = 5, 138 kBwCellStateQuantized = 6, 139 kScalingFactors = 7, 140 kProductScalingFactors = 8, 141 kRecoveredCellWeights = 9, 142 kAuxInputQuantized = 10, // Optional, quantized tensor for auxiliary input. 143 kNumTemporaryTensors = 11 144 }; 145 146 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 147 auto* scratch_tensor_index = new int; 148 context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); 149 return scratch_tensor_index; 150 } 151 152 void Free(TfLiteContext* context, void* buffer) { 153 delete reinterpret_cast<int*>(buffer); 154 } 155 156 // Check that input tensor dimensions matches with each other. 157 TfLiteStatus CheckLstmTensorDimensionsAndTypes( 158 TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, 159 int n_cell, int input_to_input_weights_tensor, 160 int input_to_forget_weights_tensor, int input_to_cell_weights_tensor, 161 int input_to_output_weights_tensor, int recurrent_to_input_weights_tensor, 162 int recurrent_to_forget_weights_tensor, 163 int recurrent_to_cell_weights_tensor, 164 int recurrent_to_output_weights_tensor, int cell_to_input_weights_tensor, 165 int cell_to_forget_weights_tensor, int cell_to_output_weights_tensor, 166 int input_gate_bias_tensor, int forget_gate_bias_tensor, 167 int cell_gate_bias_tensor, int output_gate_bias_tensor, 168 int projection_weights_tensor, int projection_bias_tensor) { 169 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( 170 node->builtin_data); 171 172 // Making sure clipping parameters have valid values. 173 // == 0 means no clipping 174 // > 0 means clipping 175 TF_LITE_ENSURE(context, params->cell_clip >= 0); 176 TF_LITE_ENSURE(context, params->proj_clip >= 0); 177 178 const TfLiteTensor* input_to_forget_weights = 179 GetInput(context, node, input_to_forget_weights_tensor); 180 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); 181 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); 182 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); 183 TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) || 184 (input_to_forget_weights->type == kTfLiteUInt8)); 185 186 const TfLiteTensor* input_to_input_weights = 187 GetOptionalInputTensor(context, node, input_to_input_weights_tensor); 188 if (input_to_input_weights != nullptr) { 189 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); 190 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); 191 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); 192 TF_LITE_ENSURE_EQ(context, input_to_input_weights->type, 193 input_to_forget_weights->type); 194 } 195 196 const TfLiteTensor* input_to_cell_weights = 197 GetInput(context, node, input_to_cell_weights_tensor); 198 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); 199 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); 200 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); 201 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->type, 202 input_to_forget_weights->type); 203 204 const TfLiteTensor* input_to_output_weights = 205 GetInput(context, node, input_to_output_weights_tensor); 206 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); 207 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell); 208 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); 209 TF_LITE_ENSURE_EQ(context, input_to_output_weights->type, 210 input_to_forget_weights->type); 211 212 const TfLiteTensor* recurrent_to_input_weights = 213 GetOptionalInputTensor(context, node, recurrent_to_input_weights_tensor); 214 if (recurrent_to_input_weights != nullptr) { 215 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); 216 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], 217 n_cell); 218 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], 219 n_output); 220 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->type, 221 input_to_forget_weights->type); 222 } 223 224 const TfLiteTensor* recurrent_to_forget_weights = 225 GetInput(context, node, recurrent_to_forget_weights_tensor); 226 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); 227 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], 228 n_cell); 229 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], 230 n_output); 231 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->type, 232 input_to_forget_weights->type); 233 234 const TfLiteTensor* recurrent_to_cell_weights = 235 GetInput(context, node, recurrent_to_cell_weights_tensor); 236 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); 237 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); 238 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], 239 n_output); 240 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->type, 241 input_to_forget_weights->type); 242 243 // We make sure the input-gate's parameters are either both present (regular 244 // LSTM) or not at all (CIFG-LSTM). 245 const bool cifg_weights_all_or_none = 246 ((input_to_input_weights != nullptr) && 247 (recurrent_to_input_weights != nullptr)) || 248 ((input_to_input_weights == nullptr) && 249 (recurrent_to_input_weights == nullptr)); 250 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); 251 252 const TfLiteTensor* cell_to_input_weights = 253 GetOptionalInputTensor(context, node, cell_to_input_weights_tensor); 254 if (cell_to_input_weights != nullptr) { 255 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); 256 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); 257 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->type, 258 input_to_forget_weights->type); 259 } 260 261 const TfLiteTensor* cell_to_forget_weights = 262 GetOptionalInputTensor(context, node, cell_to_forget_weights_tensor); 263 if (cell_to_forget_weights != nullptr) { 264 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); 265 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); 266 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->type, 267 input_to_forget_weights->type); 268 } 269 270 const TfLiteTensor* cell_to_output_weights = 271 GetOptionalInputTensor(context, node, cell_to_output_weights_tensor); 272 if (cell_to_output_weights != nullptr) { 273 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); 274 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); 275 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->type, 276 input_to_forget_weights->type); 277 } 278 279 // Making sure the peephole weights are there all or none. 280 const bool use_cifg = (input_to_input_weights == nullptr); 281 const bool peephole_weights_all_or_none = 282 ((cell_to_input_weights != nullptr || use_cifg) && 283 (cell_to_forget_weights != nullptr) && 284 (cell_to_output_weights != nullptr)) || 285 ((cell_to_input_weights == nullptr) && 286 (cell_to_forget_weights == nullptr) && 287 (cell_to_output_weights == nullptr)); 288 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); 289 290 // Make sure the input gate bias is present only when not a CIFG-LSTM. 291 const TfLiteTensor* input_gate_bias = 292 GetOptionalInputTensor(context, node, input_gate_bias_tensor); 293 if (use_cifg) { 294 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); 295 } else { 296 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); 297 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); 298 TF_LITE_ENSURE_EQ(context, input_gate_bias->type, kTfLiteFloat32); 299 } 300 301 const TfLiteTensor* forget_gate_bias = 302 GetInput(context, node, forget_gate_bias_tensor); 303 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); 304 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); 305 TF_LITE_ENSURE_EQ(context, forget_gate_bias->type, kTfLiteFloat32); 306 307 const TfLiteTensor* cell_bias = 308 GetInput(context, node, cell_gate_bias_tensor); 309 TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); 310 TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); 311 TF_LITE_ENSURE_EQ(context, cell_bias->type, kTfLiteFloat32); 312 313 const TfLiteTensor* output_gate_bias = 314 GetInput(context, node, output_gate_bias_tensor); 315 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); 316 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); 317 TF_LITE_ENSURE_EQ(context, output_gate_bias->type, kTfLiteFloat32); 318 319 const TfLiteTensor* projection_weights = 320 GetOptionalInputTensor(context, node, projection_weights_tensor); 321 if (projection_weights != nullptr) { 322 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); 323 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); 324 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); 325 TF_LITE_ENSURE_EQ(context, projection_weights->type, 326 input_to_forget_weights->type); 327 } 328 329 const TfLiteTensor* projection_bias = 330 GetOptionalInputTensor(context, node, projection_bias_tensor); 331 if (projection_bias != nullptr) { 332 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); 333 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); 334 TF_LITE_ENSURE_EQ(context, projection_bias->type, kTfLiteFloat32); 335 } 336 337 // Making sure the projection tensors are consistent: 338 // 1) If projection weight is not present, then projection bias should not be 339 // present. 340 // 2) If projection weight is present, then projection bias is optional. 341 // TODO(ghodrat): make sure this is correct. 342 const bool projecton_tensors_consistent = 343 ((projection_weights != nullptr) || (projection_bias == nullptr)); 344 TF_LITE_ENSURE(context, projecton_tensors_consistent == true); 345 346 return kTfLiteOk; 347 } 348 349 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, 350 TfLiteNode* node, int n_input, 351 int n_output, int n_cell) { 352 TF_LITE_ENSURE_OK( 353 context, 354 CheckLstmTensorDimensionsAndTypes( 355 context, node, n_input, n_output, n_cell, 356 kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor, 357 kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor, 358 kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor, 359 kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor, 360 kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor, 361 kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor, 362 kFwForgetGateBiasTensor, kFwCellGateBiasTensor, 363 kFwOutputGateBiasTensor, kFwProjectionWeightsTensor, 364 kFwProjectionBiasTensor)); 365 366 TF_LITE_ENSURE_OK( 367 context, 368 CheckLstmTensorDimensionsAndTypes( 369 context, node, n_input, n_output, n_cell, 370 kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor, 371 kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor, 372 kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor, 373 kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor, 374 kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor, 375 kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor, 376 kBwForgetGateBiasTensor, kBwCellGateBiasTensor, 377 kBwOutputGateBiasTensor, kBwProjectionWeightsTensor, 378 kBwProjectionBiasTensor)); 379 380 // Check if Forward and Backward tensors match along required dimensions. 381 return kTfLiteOk; 382 } 383 384 // Resize the output and scratch tensors based on the sizes of the input 385 // tensors. Also check that the size of the input tensors match each other. 386 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 387 int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); 388 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( 389 node->builtin_data); 390 391 // Check we have all the inputs and outputs we need. 392 TF_LITE_ENSURE_EQ(context, node->inputs->size, 48); 393 TF_LITE_ENSURE_EQ(context, node->outputs->size, 394 params->merge_outputs ? 1 : 2); 395 396 // Inferring batch size, number of outputs and sequence length and 397 // number of cells from the input tensors. 398 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 399 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); 400 TF_LITE_ENSURE_EQ(context, input->dims->size, 3); 401 const bool time_major = params->time_major; 402 const int max_time = time_major ? input->dims->data[0] : input->dims->data[1]; 403 const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0]; 404 const int n_input = input->dims->data[2]; 405 406 const TfLiteTensor* fw_input_to_output_weights = 407 GetInput(context, node, kFwInputToOutputWeightsTensor); 408 const int n_fw_cell = fw_input_to_output_weights->dims->data[0]; 409 TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2); 410 TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1], 411 n_input); 412 413 const TfLiteTensor* bw_input_to_output_weights = 414 GetInput(context, node, kBwInputToOutputWeightsTensor); 415 const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; 416 TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2); 417 TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1], 418 n_input); 419 TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type, 420 fw_input_to_output_weights->type); 421 422 const TfLiteTensor* fw_recurrent_to_output_weights = 423 GetInput(context, node, kFwRecurrentToOutputWeightsTensor); 424 TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2); 425 TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0], 426 n_fw_cell); 427 TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->type, 428 fw_input_to_output_weights->type); 429 const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; 430 431 const TfLiteTensor* bw_recurrent_to_output_weights = 432 GetInput(context, node, kBwRecurrentToOutputWeightsTensor); 433 TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2); 434 TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0], 435 n_bw_cell); 436 TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->type, 437 fw_input_to_output_weights->type); 438 const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; 439 440 // Check that input tensor dimensions matches with each other. 441 TF_LITE_ENSURE_OK( 442 context, CheckInputTensorDimensions(context, node, n_input, n_fw_output, 443 n_fw_cell)); 444 445 // Get (optional) auxiliary inputs and weights. 446 const TfLiteTensor* aux_input = 447 GetOptionalInputTensor(context, node, kAuxInputTensor); 448 const TfLiteTensor* fw_aux_input_to_input_weights = 449 GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor); 450 const TfLiteTensor* fw_aux_input_to_forget_weights = 451 GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor); 452 const TfLiteTensor* fw_aux_input_to_cell_weights = 453 GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor); 454 const TfLiteTensor* fw_aux_input_to_output_weights = 455 GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor); 456 const TfLiteTensor* bw_aux_input_to_input_weights = 457 GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor); 458 const TfLiteTensor* bw_aux_input_to_forget_weights = 459 GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor); 460 const TfLiteTensor* bw_aux_input_to_cell_weights = 461 GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor); 462 const TfLiteTensor* bw_aux_input_to_output_weights = 463 GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor); 464 465 const bool aux_inputs_weights_all_or_none = 466 ((fw_aux_input_to_cell_weights != nullptr) && 467 (fw_aux_input_to_forget_weights != nullptr) && 468 (fw_aux_input_to_output_weights != nullptr) && 469 (bw_aux_input_to_cell_weights != nullptr) && 470 (bw_aux_input_to_forget_weights != nullptr) && 471 (bw_aux_input_to_output_weights != nullptr)) || 472 ((fw_aux_input_to_cell_weights == nullptr) && 473 (fw_aux_input_to_forget_weights == nullptr) && 474 (fw_aux_input_to_output_weights == nullptr) && 475 (bw_aux_input_to_cell_weights == nullptr) && 476 (bw_aux_input_to_forget_weights == nullptr) && 477 (bw_aux_input_to_output_weights == nullptr)); 478 TF_LITE_ENSURE(context, aux_inputs_weights_all_or_none); 479 480 const bool has_aux_input = (fw_aux_input_to_forget_weights != nullptr); 481 482 if (has_aux_input) { 483 // Check that aux_input has the same dimensions (except last) as the input. 484 TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]); 485 TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]); 486 } 487 488 // Get the pointer to output, activation_state and cell_state buffer tensors. 489 TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); 490 TfLiteTensor* fw_activation_state = 491 GetVariableInput(context, node, kFwInputActivationStateTensor); 492 TfLiteTensor* fw_cell_state = 493 GetVariableInput(context, node, kFwInputCellStateTensor); 494 495 // Check the shape of input state tensors. 496 // These tensor may be 1D or 2D. It's fine as long as the total size is 497 // correct. 498 TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state), 499 n_batch * n_fw_output); 500 TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell); 501 502 // Resize the output tensors. 503 TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3); 504 fw_output_size->data[0] = time_major ? max_time : n_batch; 505 fw_output_size->data[1] = time_major ? n_batch : max_time; 506 fw_output_size->data[2] = 507 params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output; 508 TF_LITE_ENSURE_OK(context, 509 context->ResizeTensor(context, fw_output, fw_output_size)); 510 511 // The weights are of consistent type, so it suffices to check one. 512 const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8 || 513 fw_input_to_output_weights->type == kTfLiteInt8); 514 515 TfLiteIntArrayFree(node->temporaries); 516 if (is_hybrid_op) { 517 node->temporaries = TfLiteIntArrayCreate( 518 has_aux_input ? kNumTemporaryTensors : kNumTemporaryTensors - 1); 519 } else { 520 node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers. 521 } 522 // Create a scratch buffer tensor. 523 node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index; 524 TfLiteTensor* fw_scratch_buffer = 525 GetTemporary(context, node, kFwScratchBuffer); 526 fw_scratch_buffer->type = input->type; 527 fw_scratch_buffer->allocation_type = kTfLiteArenaRw; 528 529 const TfLiteTensor* fw_input_to_input_weights = 530 GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); 531 const bool fw_use_cifg = (fw_input_to_input_weights == nullptr); 532 if (has_aux_input && !fw_use_cifg) { 533 TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0], 534 fw_input_to_input_weights->dims->data[0]); 535 } 536 TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2); 537 fw_scratch_buffer_size->data[0] = n_batch; 538 if (fw_use_cifg) { 539 // Reserving space for Cell, Forget, Output gates 540 fw_scratch_buffer_size->data[1] = n_fw_cell * 3; 541 } else { 542 // Reserving space for Input, Cell, Forget, Output gates 543 fw_scratch_buffer_size->data[1] = n_fw_cell * 4; 544 } 545 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer, 546 fw_scratch_buffer_size)); 547 // Same for the backward cell. 548 549 // Check that input tensor dimensions matches with each other. 550 TF_LITE_ENSURE_OK( 551 context, CheckInputTensorDimensions(context, node, n_input, n_bw_output, 552 n_bw_cell)); 553 554 // Get the pointer to activation_state and cell_state buffer tensors. 555 TfLiteTensor* bw_activation_state = 556 GetVariableInput(context, node, kBwInputActivationStateTensor); 557 TfLiteTensor* bw_cell_state = 558 GetVariableInput(context, node, kBwInputCellStateTensor); 559 560 // Resize the output tensors. 561 if (!params->merge_outputs) { 562 TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); 563 TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3); 564 bw_output_size->data[0] = time_major ? max_time : n_batch; 565 bw_output_size->data[1] = time_major ? n_batch : max_time; 566 bw_output_size->data[2] = n_bw_output; 567 TF_LITE_ENSURE_OK( 568 context, context->ResizeTensor(context, bw_output, bw_output_size)); 569 } 570 571 // Check the shape of input state tensors. 572 // These tensor may be 1D or 2D. It's fine as long as the total size is 573 // correct. 574 TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state), 575 n_batch * n_bw_output); 576 TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell); 577 578 // Create a scratch buffer tensor. 579 node->temporaries->data[kBwScratchBuffer] = 580 *(scratch_tensor_index) + kBwScratchBuffer; 581 TfLiteTensor* bw_scratch_buffer = 582 GetTemporary(context, node, kBwScratchBuffer); 583 bw_scratch_buffer->type = input->type; 584 bw_scratch_buffer->allocation_type = kTfLiteArenaRw; 585 586 const TfLiteTensor* bw_input_to_input_weights = 587 GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); 588 const bool bw_use_cifg = (bw_input_to_input_weights == nullptr); 589 if (has_aux_input && !bw_use_cifg) { 590 TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0], 591 bw_input_to_input_weights->dims->data[0]); 592 } 593 TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2); 594 bw_scratch_buffer_size->data[0] = n_batch; 595 if (bw_use_cifg) { 596 // Reserving space for Cell, Forget, Output gates 597 bw_scratch_buffer_size->data[1] = n_bw_cell * 3; 598 } else { 599 // Reserving space for Input, Cell, Forget, Output gates 600 bw_scratch_buffer_size->data[1] = n_bw_cell * 4; 601 } 602 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, 603 bw_scratch_buffer_size)); 604 if (is_hybrid_op) { 605 // Allocate temporary tensors to store quantized values of input, aux_input 606 // (if present), activation_state and cell_state tensors. 607 node->temporaries->data[kInputQuantized] = 608 *scratch_tensor_index + kInputQuantized; 609 TfLiteTensor* input_quantized = 610 GetTemporary(context, node, kInputQuantized); 611 input_quantized->type = fw_input_to_output_weights->type; 612 input_quantized->allocation_type = kTfLiteArenaRw; 613 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { 614 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); 615 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, 616 input_quantized_size)); 617 } 618 619 node->temporaries->data[kFwActivationStateQuantized] = 620 *scratch_tensor_index + kFwActivationStateQuantized; 621 TfLiteTensor* fw_activation_state_quantized = 622 GetTemporary(context, node, kFwActivationStateQuantized); 623 fw_activation_state_quantized->type = fw_input_to_output_weights->type; 624 fw_activation_state_quantized->allocation_type = kTfLiteArenaRw; 625 if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims, 626 fw_activation_state->dims)) { 627 TfLiteIntArray* fw_activation_state_quantized_size = 628 TfLiteIntArrayCopy(fw_activation_state->dims); 629 TF_LITE_ENSURE_OK( 630 context, context->ResizeTensor(context, fw_activation_state_quantized, 631 fw_activation_state_quantized_size)); 632 } 633 node->temporaries->data[kBwActivationStateQuantized] = 634 *scratch_tensor_index + kBwActivationStateQuantized; 635 TfLiteTensor* bw_activation_state_quantized = 636 GetTemporary(context, node, kBwActivationStateQuantized); 637 bw_activation_state_quantized->type = fw_input_to_output_weights->type; 638 bw_activation_state_quantized->allocation_type = kTfLiteArenaRw; 639 if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims, 640 bw_activation_state->dims)) { 641 TfLiteIntArray* bw_activation_state_quantized_size = 642 TfLiteIntArrayCopy(bw_activation_state->dims); 643 TF_LITE_ENSURE_OK( 644 context, context->ResizeTensor(context, bw_activation_state_quantized, 645 bw_activation_state_quantized_size)); 646 } 647 node->temporaries->data[kFwCellStateQuantized] = 648 *scratch_tensor_index + kFwCellStateQuantized; 649 TfLiteTensor* fw_cell_state_quantized = 650 GetTemporary(context, node, kFwCellStateQuantized); 651 fw_cell_state_quantized->type = fw_input_to_output_weights->type; 652 fw_cell_state_quantized->allocation_type = kTfLiteArenaRw; 653 if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims, 654 fw_cell_state->dims)) { 655 TfLiteIntArray* fw_cell_state_quantized_size = 656 TfLiteIntArrayCopy(fw_cell_state->dims); 657 TF_LITE_ENSURE_OK(context, 658 context->ResizeTensor(context, fw_cell_state_quantized, 659 fw_cell_state_quantized_size)); 660 } 661 node->temporaries->data[kBwCellStateQuantized] = 662 *scratch_tensor_index + kBwCellStateQuantized; 663 TfLiteTensor* bw_cell_state_quantized = 664 GetTemporary(context, node, kBwCellStateQuantized); 665 bw_cell_state_quantized->type = fw_input_to_output_weights->type; 666 bw_cell_state_quantized->allocation_type = kTfLiteArenaRw; 667 if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims, 668 bw_cell_state->dims)) { 669 TfLiteIntArray* bw_cell_state_quantized_size = 670 TfLiteIntArrayCopy(bw_cell_state->dims); 671 TF_LITE_ENSURE_OK(context, 672 context->ResizeTensor(context, bw_cell_state_quantized, 673 bw_cell_state_quantized_size)); 674 } 675 676 // Allocate temporary tensors to store scaling factors and product scaling 677 // factors. The latter is a convenience storage which allows to quantize 678 // a vector once (which produces the scaling factors) and multiply it with 679 // different matrices (which requires multiplying the scaling factors with 680 // the scaling factor of the matrix). 681 node->temporaries->data[kScalingFactors] = 682 *scratch_tensor_index + kScalingFactors; 683 TfLiteTensor* scaling_factors = 684 GetTemporary(context, node, kScalingFactors); 685 scaling_factors->type = kTfLiteFloat32; 686 scaling_factors->allocation_type = kTfLiteArenaRw; 687 int scaling_dims[1] = {n_batch}; 688 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { 689 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); 690 scaling_factors_size->data[0] = n_batch; 691 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, 692 scaling_factors_size)); 693 } 694 node->temporaries->data[kProductScalingFactors] = 695 *scratch_tensor_index + kProductScalingFactors; 696 TfLiteTensor* prod_scaling_factors = 697 GetTemporary(context, node, kProductScalingFactors); 698 prod_scaling_factors->type = kTfLiteFloat32; 699 prod_scaling_factors->allocation_type = kTfLiteArenaRw; 700 if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, 701 scaling_dims)) { 702 TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); 703 prod_scaling_factors_size->data[0] = n_batch; 704 TF_LITE_ENSURE_OK(context, 705 context->ResizeTensor(context, prod_scaling_factors, 706 prod_scaling_factors_size)); 707 } 708 709 // Allocate a temporary tensor to store the recovered cell weights. Since 710 // this is used for diagonal matrices, only need to store n_cell values. 711 node->temporaries->data[kRecoveredCellWeights] = 712 *scratch_tensor_index + kRecoveredCellWeights; 713 TfLiteTensor* recovered_cell_weights = 714 GetTemporary(context, node, kRecoveredCellWeights); 715 recovered_cell_weights->type = kTfLiteFloat32; 716 recovered_cell_weights->allocation_type = kTfLiteArenaRw; 717 int recovered_cell_dims[1] = {n_fw_cell}; 718 if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1, 719 recovered_cell_dims)) { 720 TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); 721 recovered_cell_weights_size->data[0] = n_fw_cell; 722 TF_LITE_ENSURE_OK(context, 723 context->ResizeTensor(context, recovered_cell_weights, 724 recovered_cell_weights_size)); 725 } 726 727 // Only allocate a temporary tensor for quantized auxiliary input if we are 728 // actually going to use it. 729 if (has_aux_input) { 730 node->temporaries->data[kAuxInputQuantized] = 731 *scratch_tensor_index + kAuxInputQuantized; 732 TfLiteTensor* aux_input_quantized = 733 GetTemporary(context, node, kAuxInputQuantized); 734 aux_input_quantized->type = fw_input_to_output_weights->type; 735 aux_input_quantized->allocation_type = kTfLiteArenaRw; 736 if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) { 737 TfLiteIntArray* aux_input_quantized_size = 738 TfLiteIntArrayCopy(aux_input->dims); 739 TF_LITE_ENSURE_OK(context, 740 context->ResizeTensor(context, aux_input_quantized, 741 aux_input_quantized_size)); 742 } 743 } 744 } 745 return kTfLiteOk; 746 } 747 748 // The LSTM Op engine. 749 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 750 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( 751 node->builtin_data); 752 753 // Input tensor. 754 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 755 756 // Tensors for the forward cell. 757 const TfLiteTensor* fw_input_to_input_weights = 758 GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); 759 const TfLiteTensor* fw_input_to_forget_weights = 760 GetInput(context, node, kFwInputToForgetWeightsTensor); 761 const TfLiteTensor* fw_input_to_cell_weights = 762 GetInput(context, node, kFwInputToCellWeightsTensor); 763 const TfLiteTensor* fw_input_to_output_weights = 764 GetInput(context, node, kFwInputToOutputWeightsTensor); 765 766 const TfLiteTensor* fw_recurrent_to_input_weights = 767 GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor); 768 const TfLiteTensor* fw_recurrent_to_forget_weights = 769 GetInput(context, node, kFwRecurrentToForgetWeightsTensor); 770 const TfLiteTensor* fw_recurrent_to_cell_weights = 771 GetInput(context, node, kFwRecurrentToCellWeightsTensor); 772 const TfLiteTensor* fw_recurrent_to_output_weights = 773 GetInput(context, node, kFwRecurrentToOutputWeightsTensor); 774 775 const TfLiteTensor* fw_cell_to_input_weights = 776 GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor); 777 const TfLiteTensor* fw_cell_to_forget_weights = 778 GetOptionalInputTensor(context, node, kFwCellToForgetWeightsTensor); 779 const TfLiteTensor* fw_cell_to_output_weights = 780 GetOptionalInputTensor(context, node, kFwCellToOutputWeightsTensor); 781 782 const TfLiteTensor* fw_input_gate_bias = 783 GetOptionalInputTensor(context, node, kFwInputGateBiasTensor); 784 const TfLiteTensor* fw_forget_gate_bias = 785 GetInput(context, node, kFwForgetGateBiasTensor); 786 const TfLiteTensor* fw_cell_bias = 787 GetInput(context, node, kFwCellGateBiasTensor); 788 const TfLiteTensor* fw_output_gate_bias = 789 GetInput(context, node, kFwOutputGateBiasTensor); 790 791 const TfLiteTensor* fw_projection_weights = 792 GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor); 793 const TfLiteTensor* fw_projection_bias = 794 GetOptionalInputTensor(context, node, kFwProjectionBiasTensor); 795 796 TfLiteTensor* fw_activation_state = 797 GetVariableInput(context, node, kFwInputActivationStateTensor); 798 TfLiteTensor* fw_cell_state = 799 GetVariableInput(context, node, kFwInputCellStateTensor); 800 TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); 801 802 // Tensors for the backward cell. 803 const TfLiteTensor* bw_input_to_input_weights = 804 GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); 805 const TfLiteTensor* bw_input_to_forget_weights = 806 GetInput(context, node, kBwInputToForgetWeightsTensor); 807 const TfLiteTensor* bw_input_to_cell_weights = 808 GetInput(context, node, kBwInputToCellWeightsTensor); 809 const TfLiteTensor* bw_input_to_output_weights = 810 GetInput(context, node, kBwInputToOutputWeightsTensor); 811 812 const TfLiteTensor* bw_recurrent_to_input_weights = 813 GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor); 814 const TfLiteTensor* bw_recurrent_to_forget_weights = 815 GetInput(context, node, kBwRecurrentToForgetWeightsTensor); 816 const TfLiteTensor* bw_recurrent_to_cell_weights = 817 GetInput(context, node, kBwRecurrentToCellWeightsTensor); 818 const TfLiteTensor* bw_recurrent_to_output_weights = 819 GetInput(context, node, kBwRecurrentToOutputWeightsTensor); 820 821 const TfLiteTensor* bw_cell_to_input_weights = 822 GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor); 823 const TfLiteTensor* bw_cell_to_forget_weights = 824 GetOptionalInputTensor(context, node, kBwCellToForgetWeightsTensor); 825 const TfLiteTensor* bw_cell_to_output_weights = 826 GetOptionalInputTensor(context, node, kBwCellToOutputWeightsTensor); 827 828 const TfLiteTensor* bw_input_gate_bias = 829 GetOptionalInputTensor(context, node, kBwInputGateBiasTensor); 830 const TfLiteTensor* bw_forget_gate_bias = 831 GetInput(context, node, kBwForgetGateBiasTensor); 832 const TfLiteTensor* bw_cell_bias = 833 GetInput(context, node, kBwCellGateBiasTensor); 834 const TfLiteTensor* bw_output_gate_bias = 835 GetInput(context, node, kBwOutputGateBiasTensor); 836 837 const TfLiteTensor* bw_projection_weights = 838 GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor); 839 const TfLiteTensor* bw_projection_bias = 840 GetOptionalInputTensor(context, node, kBwProjectionBiasTensor); 841 842 // State tensors. 843 TfLiteTensor* bw_activation_state = 844 GetVariableInput(context, node, kBwInputActivationStateTensor); 845 TfLiteTensor* bw_cell_state = 846 GetVariableInput(context, node, kBwInputCellStateTensor); 847 TfLiteTensor* bw_output = params->merge_outputs 848 ? nullptr 849 : GetOutput(context, node, kBwOutputTensor); 850 851 // Temporary tensors. 852 TfLiteTensor* fw_scratch_buffer = 853 GetTemporary(context, node, kFwScratchBuffer); 854 TfLiteTensor* bw_scratch_buffer = 855 GetTemporary(context, node, kBwScratchBuffer); 856 857 // (Optional) auxiliary inputs. 858 const TfLiteTensor* aux_input = 859 GetOptionalInputTensor(context, node, kAuxInputTensor); 860 const TfLiteTensor* fw_aux_input_to_input_weights = 861 GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor); 862 const TfLiteTensor* fw_aux_input_to_forget_weights = 863 GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor); 864 const TfLiteTensor* fw_aux_input_to_cell_weights = 865 GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor); 866 const TfLiteTensor* fw_aux_input_to_output_weights = 867 GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor); 868 const TfLiteTensor* bw_aux_input_to_input_weights = 869 GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor); 870 const TfLiteTensor* bw_aux_input_to_forget_weights = 871 GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor); 872 const TfLiteTensor* bw_aux_input_to_cell_weights = 873 GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor); 874 const TfLiteTensor* bw_aux_input_to_output_weights = 875 GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor); 876 877 const bool has_previous_bw_output = (aux_input != nullptr); 878 const bool use_aux_input = (fw_aux_input_to_forget_weights != nullptr); 879 880 // Populate a TfLiteLSTMParams struct for the evaluation functions. 881 TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip, 882 params->proj_clip, kTfLiteLSTMFullKernel}; 883 884 const int bw_output_offset = 885 params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; 886 const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output; 887 888 const bool time_major = params->time_major; 889 890 // We want to cover the following cases: 891 // 892 // If not stacking (not connected after other bidi lstms): 893 // both fw & bw will just use `input`; aux_input will be null. 894 // 895 // If stacking with cross_links, TensorFlow equivalent 896 // (tf.contrib.rnn.stack_bidirectional_rnn): 897 // both fw & bw will use `input`, but aux_input will be none null. 898 // Note, this time, whether connected after other bidi lstms both works. 899 // 900 // If stacking without cross_links, but connected after other bidi lstms, 901 // TensorFlow equivalent (tf.nn.static_bidirectional_rnn): 902 // fw will use `input`, bw will use aux_input, and the `real aux_input` 903 // will be null. 904 905 const bool non_stacking_mode = !use_aux_input && has_previous_bw_output; 906 const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input; 907 const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input; 908 909 switch (fw_input_to_output_weights->type) { 910 case kTfLiteFloat32: { 911 TfLiteStatus fw_pass_status = lstm_eval::EvalFloat( 912 input, fw_input_to_input_weights, fw_input_to_forget_weights, 913 fw_input_to_cell_weights, fw_input_to_output_weights, 914 fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, 915 fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, 916 fw_cell_to_input_weights, fw_cell_to_forget_weights, 917 fw_cell_to_output_weights, 918 /*input_layer_norm_coefficients=*/nullptr, 919 /*forget_layer_norm_coefficients=*/nullptr, 920 /*cell_layer_norm_coefficients=*/nullptr, 921 /*output_layer_norm_coefficients=*/nullptr, real_aux_input, 922 fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights, 923 fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, 924 fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, 925 fw_output_gate_bias, fw_projection_weights, fw_projection_bias, 926 &lstm_params, 927 /*forward_sequence=*/true, time_major, /*output_offset=*/0, 928 fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output); 929 TF_LITE_ENSURE_OK(context, fw_pass_status); 930 931 TfLiteStatus bw_pass_status = lstm_eval::EvalFloat( 932 bw_input, bw_input_to_input_weights, bw_input_to_forget_weights, 933 bw_input_to_cell_weights, bw_input_to_output_weights, 934 bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, 935 bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, 936 bw_cell_to_input_weights, bw_cell_to_forget_weights, 937 bw_cell_to_output_weights, 938 /*input_layer_norm_coefficients=*/nullptr, 939 /*forget_layer_norm_coefficients=*/nullptr, 940 /*cell_layer_norm_coefficients=*/nullptr, 941 /*output_layer_norm_coefficients=*/nullptr, real_aux_input, 942 bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights, 943 bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, 944 bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, 945 bw_output_gate_bias, bw_projection_weights, bw_projection_bias, 946 &lstm_params, 947 /*forward_sequence=*/false, time_major, bw_output_offset, 948 bw_scratch_buffer, bw_activation_state, bw_cell_state, 949 actual_bw_output); 950 TF_LITE_ENSURE_OK(context, bw_pass_status); 951 return kTfLiteOk; 952 } 953 case kTfLiteUInt8: 954 case kTfLiteInt8: { 955 TfLiteTensor* input_quantized = 956 GetTemporary(context, node, kInputQuantized); 957 TfLiteTensor* fw_activation_state_quantized = 958 GetTemporary(context, node, kFwActivationStateQuantized); 959 TfLiteTensor* bw_activation_state_quantized = 960 GetTemporary(context, node, kBwActivationStateQuantized); 961 TfLiteTensor* fw_cell_state_quantized = 962 GetTemporary(context, node, kFwCellStateQuantized); 963 TfLiteTensor* bw_cell_state_quantized = 964 GetTemporary(context, node, kBwCellStateQuantized); 965 TfLiteTensor* scaling_factors = 966 GetTemporary(context, node, kScalingFactors); 967 TfLiteTensor* prod_scaling_factors = 968 GetTemporary(context, node, kProductScalingFactors); 969 TfLiteTensor* recovered_cell_weights = 970 GetTemporary(context, node, kRecoveredCellWeights); 971 TfLiteTensor* aux_input_quantized = 972 use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) 973 : nullptr; 974 975 TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( 976 input, fw_input_to_input_weights, fw_input_to_forget_weights, 977 fw_input_to_cell_weights, fw_input_to_output_weights, 978 fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, 979 fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, 980 fw_cell_to_input_weights, fw_cell_to_forget_weights, 981 fw_cell_to_output_weights, 982 /*input_layer_norm_coefficients=*/nullptr, 983 /*forget_layer_norm_coefficients=*/nullptr, 984 /*cell_layer_norm_coefficients=*/nullptr, 985 /*output_layer_norm_coefficients=*/nullptr, real_aux_input, 986 fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights, 987 fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, 988 fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, 989 fw_output_gate_bias, fw_projection_weights, fw_projection_bias, 990 &lstm_params, 991 /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, 992 fw_scratch_buffer, scaling_factors, prod_scaling_factors, 993 recovered_cell_weights, input_quantized, aux_input_quantized, 994 fw_activation_state_quantized, fw_cell_state_quantized, 995 fw_activation_state, fw_cell_state, fw_output); 996 TF_LITE_ENSURE_OK(context, fw_pass_status); 997 998 TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid( 999 bw_input, bw_input_to_input_weights, bw_input_to_forget_weights, 1000 bw_input_to_cell_weights, bw_input_to_output_weights, 1001 bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, 1002 bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, 1003 bw_cell_to_input_weights, bw_cell_to_forget_weights, 1004 bw_cell_to_output_weights, 1005 /*input_layer_norm_coefficients=*/nullptr, 1006 /*forget_layer_norm_coefficients=*/nullptr, 1007 /*cell_layer_norm_coefficients=*/nullptr, 1008 /*output_layer_norm_coefficients=*/nullptr, real_aux_input, 1009 bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights, 1010 bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, 1011 bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, 1012 bw_output_gate_bias, bw_projection_weights, bw_projection_bias, 1013 &lstm_params, 1014 /*forward_sequence=*/false, /*time_major=*/true, bw_output_offset, 1015 bw_scratch_buffer, scaling_factors, prod_scaling_factors, 1016 recovered_cell_weights, input_quantized, aux_input_quantized, 1017 bw_activation_state_quantized, bw_cell_state_quantized, 1018 bw_activation_state, bw_cell_state, actual_bw_output); 1019 TF_LITE_ENSURE_OK(context, bw_pass_status); 1020 return kTfLiteOk; 1021 } 1022 default: 1023 context->ReportError(context, "Type %d is not currently supported.", 1024 fw_input_to_output_weights->type); 1025 return kTfLiteError; 1026 } 1027 return kTfLiteOk; 1028 } 1029 1030 } // namespace bidirectional_sequence_lstm 1031 1032 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() { 1033 static TfLiteRegistration r = { 1034 bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free, 1035 bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval}; 1036 return &r; 1037 } 1038 1039 } // namespace builtin 1040 } // namespace ops 1041 } // namespace tflite 1042