1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cassert> 17 #include <cmath> 18 #include <cstdio> 19 #include <cstdlib> 20 #include <iostream> 21 #include <limits> 22 23 #include "tensorflow/lite/c/builtin_op_data.h" 24 #include "tensorflow/lite/c/c_api_internal.h" 25 #include "tensorflow/lite/kernels/activation_functor.h" 26 #include "tensorflow/lite/kernels/internal/kernel_utils.h" 27 #include "tensorflow/lite/kernels/internal/tensor_utils.h" 28 #include "tensorflow/lite/kernels/kernel_util.h" 29 #include "tensorflow/lite/kernels/lstm_eval.h" 30 #include "tensorflow/lite/kernels/op_macros.h" 31 32 namespace tflite { 33 namespace ops { 34 namespace builtin { 35 namespace unidirectional_sequence_lstm { 36 37 // Input Tensors of size {max_time, n_batch, n_input} 38 constexpr int kInputTensor = 0; 39 40 // Input weight tensors of size: {n_cell, n_input} 41 constexpr int kInputToInputWeightsTensor = 1; // Optional 42 constexpr int kInputToForgetWeightsTensor = 2; 43 constexpr int kInputToCellWeightsTensor = 3; 44 constexpr int kInputToOutputWeightsTensor = 4; 45 46 // Recurrent weight tensors of size {n_cell, n_output} 47 constexpr int kRecurrentToInputWeightsTensor = 5; // Optional 48 constexpr int kRecurrentToForgetWeightsTensor = 6; 49 constexpr int kRecurrentToCellWeightsTensor = 7; 50 constexpr int kRecurrentToOutputWeightsTensor = 8; 51 52 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 53 constexpr int kCellToInputWeightsTensor = 9; // Optional 54 constexpr int kCellToForgetWeightsTensor = 10; // Optional 55 constexpr int kCellToOutputWeightsTensor = 11; // Optional 56 57 // Gates bias tensors of size {n_cell} 58 constexpr int kInputGateBiasTensor = 12; // Optional 59 constexpr int kForgetGateBiasTensor = 13; 60 constexpr int kCellGateBiasTensor = 14; 61 constexpr int kOutputGateBiasTensor = 15; 62 63 // Projection weight tensor of size {n_output, n_cell} 64 constexpr int kProjectionWeightsTensor = 16; // Optional 65 // Projection bias tensor of size {n_output} 66 constexpr int kProjectionBiasTensor = 17; // Optional 67 68 // Stateful input tensors that are variables and will be modified by the Op. 69 // Activation state tensor of size {n_batch, n_output} 70 constexpr int kInputActivationStateTensor = 18; 71 // Cell state tensor of size {n_batch, n_cell} 72 constexpr int kInputCellStateTensor = 19; 73 74 // Output tensors. 75 constexpr int kOutputTensor = 0; 76 77 // Temporary tensors 78 enum TemporaryTensor { 79 kScratchBuffer = 0, 80 kInputQuantized = 1, 81 kOutputStateQuantized = 2, 82 kCellStateQuantized = 3, 83 kScalingFactors = 4, 84 kProductScalingFactors = 5, 85 kRecoveredCellWeights = 6, 86 kNumTemporaryTensors = 7 87 }; 88 89 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 90 auto* scratch_tensor_index = new int(); 91 context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); 92 return scratch_tensor_index; 93 } 94 95 void Free(TfLiteContext* context, void* buffer) { 96 delete reinterpret_cast<int*>(buffer); 97 } 98 99 // Check that input tensor dimensions matches with each other. 100 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, 101 TfLiteNode* node, int n_input, 102 int n_output, int n_cell) { 103 const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); 104 105 // Making sure clipping parameters have valid values. 106 // == 0 means no clipping 107 // > 0 means clipping 108 TF_LITE_ENSURE(context, params->cell_clip >= 0); 109 TF_LITE_ENSURE(context, params->proj_clip >= 0); 110 111 const TfLiteTensor* input_to_input_weights = 112 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); 113 if (input_to_input_weights != nullptr) { 114 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); 115 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); 116 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); 117 } 118 119 const TfLiteTensor* input_to_forget_weights = 120 GetInput(context, node, kInputToForgetWeightsTensor); 121 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); 122 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); 123 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); 124 125 const TfLiteTensor* input_to_cell_weights = 126 GetInput(context, node, kInputToCellWeightsTensor); 127 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); 128 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); 129 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); 130 131 const TfLiteTensor* recurrent_to_input_weights = 132 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); 133 if (recurrent_to_input_weights != nullptr) { 134 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); 135 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], 136 n_cell); 137 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], 138 n_output); 139 } 140 141 const TfLiteTensor* recurrent_to_forget_weights = 142 GetInput(context, node, kRecurrentToForgetWeightsTensor); 143 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); 144 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], 145 n_cell); 146 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], 147 n_output); 148 149 const TfLiteTensor* recurrent_to_cell_weights = 150 GetInput(context, node, kRecurrentToCellWeightsTensor); 151 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); 152 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); 153 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], 154 n_output); 155 156 // We make sure the input-gate's parameters are either both present (regular 157 // LSTM) or not at all (CIFG-LSTM). 158 const bool cifg_weights_all_or_none = 159 ((input_to_input_weights != nullptr) && 160 (recurrent_to_input_weights != nullptr)) || 161 ((input_to_input_weights == nullptr) && 162 (recurrent_to_input_weights == nullptr)); 163 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); 164 165 const TfLiteTensor* cell_to_input_weights = 166 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); 167 if (cell_to_input_weights != nullptr) { 168 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); 169 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); 170 } 171 172 const TfLiteTensor* cell_to_forget_weights = 173 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); 174 if (cell_to_forget_weights != nullptr) { 175 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); 176 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); 177 } 178 179 const TfLiteTensor* cell_to_output_weights = 180 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); 181 if (cell_to_output_weights != nullptr) { 182 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); 183 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); 184 } 185 186 // Making sure the peephole weights are there all or none. 187 const bool use_cifg = (input_to_input_weights == nullptr); 188 const bool peephole_weights_all_or_none = 189 ((cell_to_input_weights != nullptr || use_cifg) && 190 (cell_to_forget_weights != nullptr) && 191 (cell_to_output_weights != nullptr)) || 192 ((cell_to_input_weights == nullptr) && 193 (cell_to_forget_weights == nullptr) && 194 (cell_to_output_weights == nullptr)); 195 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); 196 197 // Make sure the input gate bias is present only when not a CIFG-LSTM. 198 const TfLiteTensor* input_gate_bias = 199 GetOptionalInputTensor(context, node, kInputGateBiasTensor); 200 if (use_cifg) { 201 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); 202 } else { 203 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); 204 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); 205 } 206 207 const TfLiteTensor* forget_gate_bias = 208 GetInput(context, node, kForgetGateBiasTensor); 209 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); 210 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); 211 212 const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); 213 TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); 214 TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); 215 216 const TfLiteTensor* output_gate_bias = 217 GetInput(context, node, kOutputGateBiasTensor); 218 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); 219 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); 220 221 const TfLiteTensor* projection_weights = 222 GetOptionalInputTensor(context, node, kProjectionWeightsTensor); 223 if (projection_weights != nullptr) { 224 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); 225 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); 226 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); 227 } 228 229 const TfLiteTensor* projection_bias = 230 GetOptionalInputTensor(context, node, kProjectionBiasTensor); 231 if (projection_bias != nullptr) { 232 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); 233 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); 234 } 235 236 // Making sure the projection tensors are consistent: 237 // 1) If projection weight is not present, then projection bias should not be 238 // present. 239 // 2) If projection weight is present, then projection bias is optional. 240 // TODO(ghodrat): make sure this is correct. 241 const bool projecton_tensors_consistent = 242 ((projection_weights != nullptr) || (projection_bias == nullptr)); 243 TF_LITE_ENSURE(context, projecton_tensors_consistent == true); 244 245 return kTfLiteOk; 246 } 247 248 // Resize the output and state tensors based on the sizes of the input tensors. 249 // Allocate a temporary scratch tensor. Also check that the sizes of the input 250 // tensors match each other. 251 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 252 int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); 253 254 // Check we have all the inputs and outputs we need. 255 TF_LITE_ENSURE_EQ(context, node->inputs->size, 20); 256 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); 257 258 // Inferring batch size, number of outputs and sequence length and 259 // number of cells from the input tensors. 260 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 261 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); 262 TF_LITE_ENSURE(context, input->dims->size > 1); 263 const auto* params = 264 reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( 265 node->builtin_data); 266 const bool time_major = params->time_major; 267 const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0]; 268 const int n_input = input->dims->data[2]; 269 270 const TfLiteTensor* input_to_output_weights = 271 GetInput(context, node, kInputToOutputWeightsTensor); 272 const int n_cell = input_to_output_weights->dims->data[0]; 273 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); 274 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); 275 276 const TfLiteTensor* recurrent_to_output_weights = 277 GetInput(context, node, kRecurrentToOutputWeightsTensor); 278 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); 279 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], 280 n_cell); 281 const int n_output = recurrent_to_output_weights->dims->data[1]; 282 283 // Check that input tensor dimensions matches with each other. 284 TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, 285 n_output, n_cell)); 286 287 // Get the pointer to output, activation_state and cell_state buffer tensors. 288 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 289 290 TfLiteTensor* activation_state = 291 GetVariableInput(context, node, kInputActivationStateTensor); 292 TfLiteTensor* cell_state = 293 GetVariableInput(context, node, kInputCellStateTensor); 294 295 // Check the shape of input state tensors. 296 // These tensor may be 1D or 2D. It's fine as long as the total size is 297 // correct. 298 TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output); 299 TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); 300 301 // Resize the output tensors. 302 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims); 303 output_size->data[input->dims->size - 1] = n_output; 304 TF_LITE_ENSURE_OK(context, 305 context->ResizeTensor(context, output, output_size)); 306 307 // The weights are of consistent type, so it suffices to check one. 308 // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. 309 const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 || 310 input_to_output_weights->type == kTfLiteInt8) && 311 input->type == kTfLiteFloat32); 312 313 TfLiteIntArrayFree(node->temporaries); 314 if (is_hybrid_op) { 315 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); 316 } else { 317 node->temporaries = TfLiteIntArrayCreate(1); 318 } 319 node->temporaries->data[0] = *scratch_tensor_index; 320 321 // Create a scratch buffer tensor. 322 TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer); 323 scratch_buffer->type = input->type; 324 scratch_buffer->allocation_type = kTfLiteArenaRw; 325 326 const TfLiteTensor* input_to_input_weights = 327 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); 328 const bool use_cifg = (input_to_input_weights == nullptr); 329 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); 330 scratch_buffer_size->data[0] = n_batch; 331 if (use_cifg) { 332 // Reserving space for Cell, Forget, Output gates 333 scratch_buffer_size->data[1] = n_cell * 3; 334 } else { 335 // Reserving space for Input, Cell, Forget, Output gates 336 scratch_buffer_size->data[1] = n_cell * 4; 337 } 338 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, 339 scratch_buffer_size)); 340 341 if (is_hybrid_op) { 342 // Allocate temporary tensors to store quantized values of input, 343 // activation_state and cell_state tensors. 344 node->temporaries->data[kInputQuantized] = 345 *scratch_tensor_index + kInputQuantized; 346 TfLiteTensor* input_quantized = 347 GetTemporary(context, node, kInputQuantized); 348 input_quantized->type = input_to_output_weights->type; 349 input_quantized->allocation_type = kTfLiteArenaRw; 350 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { 351 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); 352 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, 353 input_quantized_size)); 354 } 355 node->temporaries->data[kOutputStateQuantized] = 356 *scratch_tensor_index + kOutputStateQuantized; 357 TfLiteTensor* activation_state_quantized = 358 GetTemporary(context, node, kOutputStateQuantized); 359 activation_state_quantized->type = input_to_output_weights->type; 360 activation_state_quantized->allocation_type = kTfLiteArenaRw; 361 if (!TfLiteIntArrayEqual(activation_state_quantized->dims, 362 activation_state->dims)) { 363 TfLiteIntArray* activation_state_quantized_size = 364 TfLiteIntArrayCopy(activation_state->dims); 365 TF_LITE_ENSURE_OK( 366 context, context->ResizeTensor(context, activation_state_quantized, 367 activation_state_quantized_size)); 368 } 369 node->temporaries->data[kCellStateQuantized] = 370 *scratch_tensor_index + kCellStateQuantized; 371 TfLiteTensor* cell_state_quantized = 372 GetTemporary(context, node, kCellStateQuantized); 373 cell_state_quantized->type = input_to_output_weights->type; 374 cell_state_quantized->allocation_type = kTfLiteArenaRw; 375 if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { 376 TfLiteIntArray* cell_state_quantized_size = 377 TfLiteIntArrayCopy(cell_state->dims); 378 TF_LITE_ENSURE_OK(context, 379 context->ResizeTensor(context, cell_state_quantized, 380 cell_state_quantized_size)); 381 } 382 383 // Allocate temporary tensors to store scaling factors and product scaling 384 // factors. The latter is a convenience storage which allows to quantize 385 // a vector once (which produces the scaling factors) and multiply it with 386 // different matrices (which requires multiplying the scaling factors with 387 // the scaling factor of the matrix). 388 node->temporaries->data[kScalingFactors] = 389 *scratch_tensor_index + kScalingFactors; 390 TfLiteTensor* scaling_factors = 391 GetTemporary(context, node, kScalingFactors); 392 scaling_factors->type = kTfLiteFloat32; 393 scaling_factors->allocation_type = kTfLiteArenaRw; 394 int scaling_dims[1] = {n_batch}; 395 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { 396 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); 397 scaling_factors_size->data[0] = n_batch; 398 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, 399 scaling_factors_size)); 400 } 401 node->temporaries->data[kProductScalingFactors] = 402 *scratch_tensor_index + kProductScalingFactors; 403 TfLiteTensor* prod_scaling_factors = 404 GetTemporary(context, node, kProductScalingFactors); 405 prod_scaling_factors->type = kTfLiteFloat32; 406 prod_scaling_factors->allocation_type = kTfLiteArenaRw; 407 if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, 408 scaling_dims)) { 409 TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); 410 prod_scaling_factors_size->data[0] = n_batch; 411 TF_LITE_ENSURE_OK(context, 412 context->ResizeTensor(context, prod_scaling_factors, 413 prod_scaling_factors_size)); 414 } 415 416 // Allocate a temporary tensor to store the recovered cell weights. Since 417 // this is used for diagonal matrices, only need to store n_cell values. 418 node->temporaries->data[kRecoveredCellWeights] = 419 *scratch_tensor_index + kRecoveredCellWeights; 420 TfLiteTensor* recovered_cell_weights = 421 GetTemporary(context, node, kRecoveredCellWeights); 422 recovered_cell_weights->type = kTfLiteFloat32; 423 recovered_cell_weights->allocation_type = kTfLiteArenaRw; 424 int recovered_cell_dims[1] = {n_cell}; 425 if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1, 426 recovered_cell_dims)) { 427 TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); 428 recovered_cell_weights_size->data[0] = n_cell; 429 TF_LITE_ENSURE_OK(context, 430 context->ResizeTensor(context, recovered_cell_weights, 431 recovered_cell_weights_size)); 432 } 433 } 434 return kTfLiteOk; 435 } 436 437 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 438 const auto* params = 439 reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( 440 node->builtin_data); 441 const bool time_major = params->time_major; 442 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 443 444 const TfLiteTensor* input_to_input_weights = 445 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); 446 const TfLiteTensor* input_to_forget_weights = 447 GetInput(context, node, kInputToForgetWeightsTensor); 448 const TfLiteTensor* input_to_cell_weights = 449 GetInput(context, node, kInputToCellWeightsTensor); 450 const TfLiteTensor* input_to_output_weights = 451 GetInput(context, node, kInputToOutputWeightsTensor); 452 453 const TfLiteTensor* recurrent_to_input_weights = 454 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); 455 const TfLiteTensor* recurrent_to_forget_weights = 456 GetInput(context, node, kRecurrentToForgetWeightsTensor); 457 const TfLiteTensor* recurrent_to_cell_weights = 458 GetInput(context, node, kRecurrentToCellWeightsTensor); 459 const TfLiteTensor* recurrent_to_output_weights = 460 GetInput(context, node, kRecurrentToOutputWeightsTensor); 461 462 const TfLiteTensor* cell_to_input_weights = 463 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); 464 const TfLiteTensor* cell_to_forget_weights = 465 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); 466 const TfLiteTensor* cell_to_output_weights = 467 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); 468 469 const TfLiteTensor* input_gate_bias = 470 GetOptionalInputTensor(context, node, kInputGateBiasTensor); 471 const TfLiteTensor* forget_gate_bias = 472 GetInput(context, node, kForgetGateBiasTensor); 473 const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); 474 const TfLiteTensor* output_gate_bias = 475 GetInput(context, node, kOutputGateBiasTensor); 476 477 const TfLiteTensor* projection_weights = 478 GetOptionalInputTensor(context, node, kProjectionWeightsTensor); 479 const TfLiteTensor* projection_bias = 480 GetOptionalInputTensor(context, node, kProjectionBiasTensor); 481 482 // Index the scratch buffers pointers to the global scratch buffer. 483 TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); 484 485 TfLiteTensor* activation_state = 486 GetVariableInput(context, node, kInputActivationStateTensor); 487 TfLiteTensor* cell_state = 488 GetVariableInput(context, node, kInputCellStateTensor); 489 490 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 491 492 // Copy out the LSTM specific params so they can be passed in the function. 493 TfLiteLSTMParams lstm_params; 494 lstm_params.activation = params->activation; 495 lstm_params.cell_clip = params->cell_clip; 496 lstm_params.proj_clip = params->proj_clip; 497 498 switch (input_to_output_weights->type) { 499 case kTfLiteFloat32: { 500 return lstm_eval::EvalFloat( 501 input, input_to_input_weights, input_to_forget_weights, 502 input_to_cell_weights, input_to_output_weights, 503 recurrent_to_input_weights, recurrent_to_forget_weights, 504 recurrent_to_cell_weights, recurrent_to_output_weights, 505 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, 506 /*input_layer_norm_coefficients=*/nullptr, 507 /*forget_layer_norm_coefficients=*/nullptr, 508 /*cell_layer_norm_coefficients=*/nullptr, 509 /*output_layer_norm_coefficients=*/nullptr, 510 /*aux_input=*/nullptr, 511 /*aux_input_to_input_weights=*/nullptr, 512 /*aux_input_to_forget_weights=*/nullptr, 513 /*aux_input_to_cell_weights=*/nullptr, 514 /*aux_input_to_output_weights=*/nullptr, input_gate_bias, 515 forget_gate_bias, cell_bias, output_gate_bias, projection_weights, 516 projection_bias, &lstm_params, /*forward_sequence=*/true, time_major, 517 /*output_offset=*/0, scratch_buffer, activation_state, cell_state, 518 output); 519 } 520 case kTfLiteUInt8: 521 case kTfLiteInt8: { 522 TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); 523 TfLiteTensor* activation_state_quantized = 524 GetTemporary(context, node, /*index=*/2); 525 TfLiteTensor* cell_state_quantized = 526 GetTemporary(context, node, /*index=*/3); 527 TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); 528 TfLiteTensor* prod_scaling_factors = 529 GetTemporary(context, node, /*index=*/5); 530 TfLiteTensor* recovered_cell_weights = 531 GetTemporary(context, node, /*index=*/6); 532 return lstm_eval::EvalHybrid( 533 input, input_to_input_weights, input_to_forget_weights, 534 input_to_cell_weights, input_to_output_weights, 535 recurrent_to_input_weights, recurrent_to_forget_weights, 536 recurrent_to_cell_weights, recurrent_to_output_weights, 537 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, 538 /*input_layer_norm_coefficients=*/nullptr, 539 /*forget_layer_norm_coefficients=*/nullptr, 540 /*cell_layer_norm_coefficients=*/nullptr, 541 /*output_layer_norm_coefficients=*/nullptr, 542 /*aux_input=*/nullptr, 543 /*aux_input_to_input_weights=*/nullptr, 544 /*aux_input_to_forget_weights=*/nullptr, 545 /*aux_input_to_cell_weights=*/nullptr, 546 /*aux_input_to_output_weights=*/nullptr, input_gate_bias, 547 forget_gate_bias, cell_bias, output_gate_bias, projection_weights, 548 projection_bias, &lstm_params, /*forward_sequence=*/true, time_major, 549 /*output_offset=*/0, scratch_buffer, scaling_factors, 550 prod_scaling_factors, recovered_cell_weights, input_quantized, 551 /*aux_input_quantized=*/nullptr, activation_state_quantized, 552 cell_state_quantized, activation_state, cell_state, output); 553 } 554 default: 555 context->ReportError(context, "Type %d is not currently supported.", 556 input_to_output_weights->type); 557 return kTfLiteError; 558 } 559 return kTfLiteOk; 560 } 561 } // namespace unidirectional_sequence_lstm 562 563 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { 564 static TfLiteRegistration r = {unidirectional_sequence_lstm::Init, 565 unidirectional_sequence_lstm::Free, 566 unidirectional_sequence_lstm::Prepare, 567 unidirectional_sequence_lstm::Eval}; 568 return &r; 569 } 570 571 } // namespace builtin 572 } // namespace ops 573 } // namespace tflite 574