1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cassert> 17 #include <cmath> 18 #include <cstdio> 19 #include <cstdlib> 20 #include <iostream> 21 #include <limits> 22 23 #include "tensorflow/lite/c/builtin_op_data.h" 24 #include "tensorflow/lite/c/c_api_internal.h" 25 #include "tensorflow/lite/kernels/activation_functor.h" 26 #include "tensorflow/lite/kernels/gemm_support.h" 27 #include "tensorflow/lite/kernels/internal/kernel_utils.h" 28 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" 29 #include "tensorflow/lite/kernels/internal/tensor.h" 30 #include "tensorflow/lite/kernels/internal/tensor_utils.h" 31 #include "tensorflow/lite/kernels/kernel_util.h" 32 #include "tensorflow/lite/kernels/lstm_eval.h" 33 #include "tensorflow/lite/kernels/op_macros.h" 34 35 namespace tflite { 36 namespace ops { 37 namespace builtin { 38 namespace lstm { 39 40 struct OpData { 41 // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5 42 // inputs). 43 // Please note the 20-input full kernel is deprecated and only kept 44 // here for backward compatibility. 45 TfLiteLSTMKernelType kernel_type; 46 47 // If the lstm is layer norm. 48 bool is_layer_norm_lstm; 49 50 // These fields are only used by full kernel. 51 int activation_state_tensor_index; 52 int cell_state_tensor_index; 53 int scratch_tensor_index; 54 }; 55 56 // For full inputs kernel (24-inputs). 57 // Please note the 20-input full kernel is deprecated and only kept 58 // here for backward compatibility. 59 namespace full { 60 61 // Input Tensors of size {n_batch, n_input} 62 constexpr int kInputTensor = 0; 63 64 // Input weight tensors of size: {n_cell, n_input} 65 constexpr int kInputToInputWeightsTensor = 1; // Optional 66 constexpr int kInputToForgetWeightsTensor = 2; 67 constexpr int kInputToCellWeightsTensor = 3; 68 constexpr int kInputToOutputWeightsTensor = 4; 69 70 // Recurrent weight tensors of size {n_cell, n_output} 71 constexpr int kRecurrentToInputWeightsTensor = 5; // Optional 72 constexpr int kRecurrentToForgetWeightsTensor = 6; 73 constexpr int kRecurrentToCellWeightsTensor = 7; 74 constexpr int kRecurrentToOutputWeightsTensor = 8; 75 76 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 77 constexpr int kCellToInputWeightsTensor = 9; // Optional 78 constexpr int kCellToForgetWeightsTensor = 10; // Optional 79 constexpr int kCellToOutputWeightsTensor = 11; // Optional 80 81 // Gates bias tensors of size {n_cell} 82 constexpr int kInputGateBiasTensor = 12; // Optional 83 constexpr int kForgetGateBiasTensor = 13; 84 constexpr int kCellGateBiasTensor = 14; 85 constexpr int kOutputGateBiasTensor = 15; 86 87 // Projection weight tensor of size {n_output, n_cell} 88 constexpr int kProjectionWeightsTensor = 16; // Optional 89 // Projection bias tensor of size {n_output} 90 constexpr int kProjectionBiasTensor = 17; // Optional 91 92 // These state tensors are defined as variable tensors, and will be modified by 93 // this op. 94 constexpr int kInputActivationStateTensor = 18; 95 constexpr int kInputCellStateTensor = 19; 96 97 // Layer norm coefficient tensors of size {n_cell}, representing a diagonal 98 // matrix. 99 constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional 100 constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional 101 constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional 102 constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional 103 104 // Output tensors. 105 constexpr int kOutputTensor = 0; 106 107 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 108 auto* op_data = new OpData(); 109 op_data->kernel_type = kTfLiteLSTMFullKernel; 110 context->AddTensors(context, /*tensors_to_add=*/7, 111 &op_data->scratch_tensor_index); 112 return op_data; 113 } 114 115 // Check that input tensor dimensions matches with each other. 116 TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, 117 TfLiteNode* node, int n_input, 118 int n_output, int n_cell, 119 bool is_layer_norm_lstm) { 120 const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); 121 122 // Making sure clipping parameters have valid values. 123 // == 0 means no clipping 124 // > 0 means clipping 125 TF_LITE_ENSURE(context, params->cell_clip >= 0); 126 TF_LITE_ENSURE(context, params->proj_clip >= 0); 127 128 const TfLiteTensor* input_to_input_weights = 129 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); 130 const bool use_cifg = (input_to_input_weights == nullptr); 131 if (!use_cifg) { 132 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); 133 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); 134 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); 135 } 136 137 const TfLiteTensor* input_to_forget_weights = 138 GetInput(context, node, kInputToForgetWeightsTensor); 139 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); 140 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); 141 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); 142 143 const TfLiteTensor* input_to_cell_weights = 144 GetInput(context, node, kInputToCellWeightsTensor); 145 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); 146 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); 147 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); 148 149 const TfLiteTensor* recurrent_to_input_weights = 150 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); 151 if (recurrent_to_input_weights != nullptr) { 152 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); 153 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], 154 n_cell); 155 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], 156 n_output); 157 } 158 159 const TfLiteTensor* recurrent_to_forget_weights = 160 GetInput(context, node, kRecurrentToForgetWeightsTensor); 161 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); 162 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], 163 n_cell); 164 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], 165 n_output); 166 167 const TfLiteTensor* recurrent_to_cell_weights = 168 GetInput(context, node, kRecurrentToCellWeightsTensor); 169 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); 170 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); 171 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], 172 n_output); 173 174 // We make sure the input-gate's parameters are either both present (regular 175 // LSTM) or not at all (CIFG-LSTM). 176 const bool cifg_weights_all_or_none = 177 ((input_to_input_weights != nullptr) && 178 (recurrent_to_input_weights != nullptr)) || 179 ((input_to_input_weights == nullptr) && 180 (recurrent_to_input_weights == nullptr)); 181 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); 182 183 const TfLiteTensor* cell_to_input_weights = 184 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); 185 if (cell_to_input_weights) { 186 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); 187 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); 188 } 189 190 const TfLiteTensor* cell_to_forget_weights = 191 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); 192 if (cell_to_forget_weights) { 193 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); 194 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); 195 } 196 197 const TfLiteTensor* cell_to_output_weights = 198 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); 199 if (cell_to_output_weights) { 200 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); 201 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); 202 } 203 204 // Making sure the peephole weights are there all or none. 205 const bool peephole_weights_all_or_none = 206 ((cell_to_input_weights != nullptr || use_cifg) && 207 (cell_to_forget_weights != nullptr) && 208 (cell_to_output_weights != nullptr)) || 209 ((cell_to_input_weights == nullptr) && 210 (cell_to_forget_weights == nullptr) && 211 (cell_to_output_weights == nullptr)); 212 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); 213 214 // Make sure the input gate bias is present only when not a CIFG-LSTM. 215 const TfLiteTensor* input_gate_bias = 216 GetOptionalInputTensor(context, node, kInputGateBiasTensor); 217 if (use_cifg) { 218 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); 219 } else { 220 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); 221 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); 222 } 223 224 const TfLiteTensor* forget_gate_bias = 225 GetInput(context, node, kForgetGateBiasTensor); 226 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); 227 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); 228 229 const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); 230 TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); 231 TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); 232 233 const TfLiteTensor* output_gate_bias = 234 GetInput(context, node, kOutputGateBiasTensor); 235 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); 236 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); 237 238 const TfLiteTensor* projection_weights = 239 GetOptionalInputTensor(context, node, kProjectionWeightsTensor); 240 if (projection_weights != nullptr) { 241 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); 242 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); 243 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); 244 } 245 246 const TfLiteTensor* projection_bias = 247 GetOptionalInputTensor(context, node, kProjectionBiasTensor); 248 if (projection_bias != nullptr) { 249 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); 250 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); 251 } 252 253 // Making sure the projection tensors are consistent: 254 // 1) If projection weight is not present, then projection bias should not be 255 // present. 256 // 2) If projection weight is present, then projection bias is optional. 257 // TODO(ghodrat): make sure this is correct. 258 const bool projection_tensors_consistent = 259 ((projection_weights != nullptr) || (projection_bias == nullptr)); 260 TF_LITE_ENSURE(context, projection_tensors_consistent == true); 261 262 if (is_layer_norm_lstm) { 263 const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( 264 context, node, kInputLayerNormCoefficientsTensor); 265 if (use_cifg) { 266 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); 267 } else { 268 TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); 269 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); 270 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0], 271 n_cell); 272 } 273 274 const TfLiteTensor* forget_layer_norm_coefficients = 275 GetInput(context, node, kForgetLayerNormCoefficientsTensor); 276 TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr); 277 TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1); 278 TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0], 279 n_cell); 280 281 const TfLiteTensor* cell_layer_norm_coefficients = 282 GetInput(context, node, kCellLayerNormCoefficientsTensor); 283 TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr); 284 TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1); 285 TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0], 286 n_cell); 287 288 const TfLiteTensor* output_layer_norm_coefficients = 289 GetInput(context, node, kOutputLayerNormCoefficientsTensor); 290 TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr); 291 TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1); 292 TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0], 293 n_cell); 294 } 295 296 return kTfLiteOk; 297 } 298 299 // Resize the output, state tensors based on the sizes of the input tensors. 300 // Allocate a temporary scratch tensor. Also check that the sizes of the input 301 // tensors match each other. 302 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 303 OpData* op_data = reinterpret_cast<OpData*>(node->user_data); 304 305 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); 306 // Logic for determining regular lstm and layer norm lstm: 307 // input_size, forget_gate_layer_norm_tensor (20) null? is_layer_norm? 308 // 20, N/A, No. 309 // 24, null, No. 310 // 24, not null, Yes. 311 // 20-inputs lstm are deprecated and is only kept here for backward 312 // compatibility. 313 if (node->inputs->size == 24) { 314 const TfLiteTensor* forget_layer_norm_coefficients = 315 GetInput(context, node, kForgetLayerNormCoefficientsTensor); 316 if (forget_layer_norm_coefficients == nullptr) { 317 op_data->is_layer_norm_lstm = false; 318 } else { 319 op_data->is_layer_norm_lstm = true; 320 } 321 } else if (node->inputs->size == 20) { 322 // This is deprecated and is only kept here for backward compatibility. 323 op_data->is_layer_norm_lstm = false; 324 } else { 325 context->ReportError( 326 context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs", 327 node->inputs->size); 328 return kTfLiteError; 329 } 330 331 const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm; 332 op_data->activation_state_tensor_index = 333 node->inputs->data[kInputActivationStateTensor]; 334 op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor]; 335 336 // Inferring batch size, number of outputs and number of cells from the 337 // input tensors. 338 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 339 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); 340 TF_LITE_ENSURE(context, input->dims->size > 1); 341 const int n_batch = input->dims->data[0]; 342 const int n_input = input->dims->data[1]; 343 344 const TfLiteTensor* input_to_output_weights = 345 GetInput(context, node, kInputToOutputWeightsTensor); 346 const int n_cell = input_to_output_weights->dims->data[0]; 347 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); 348 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); 349 350 const TfLiteTensor* recurrent_to_output_weights = 351 GetInput(context, node, kRecurrentToOutputWeightsTensor); 352 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); 353 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], 354 n_cell); 355 const int n_output = recurrent_to_output_weights->dims->data[1]; 356 357 // Check that input tensor dimensions matches with each other. 358 TF_LITE_ENSURE_OK(context, 359 CheckInputTensorDimensions(context, node, n_input, n_output, 360 n_cell, is_layer_norm_lstm)); 361 362 // Get the pointer to output, activation_state and cell_state tensors. 363 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 364 365 TfLiteTensor* activation_state = 366 &context->tensors[op_data->activation_state_tensor_index]; 367 TfLiteTensor* cell_state = 368 &context->tensors[op_data->cell_state_tensor_index]; 369 370 // Check the shape of input state tensors. 371 // These tensor may be 1D or 2D. It's fine as long as the total size is 372 // correct. 373 TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output); 374 TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); 375 376 // Resize the output tensors. 377 TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); 378 output_size->data[0] = n_batch; 379 output_size->data[1] = n_output; 380 TF_LITE_ENSURE_OK(context, 381 context->ResizeTensor(context, output, output_size)); 382 383 // The weights are of consistent type, so it suffices to check one. 384 // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. 385 const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 || 386 input_to_output_weights->type == kTfLiteInt8) && 387 input->type == kTfLiteFloat32); 388 389 TfLiteIntArrayFree(node->temporaries); 390 if (is_hybrid_op) { 391 node->temporaries = TfLiteIntArrayCreate(7); 392 } else { 393 node->temporaries = TfLiteIntArrayCreate(1); 394 } 395 node->temporaries->data[0] = op_data->scratch_tensor_index; 396 397 // Create a scratch buffer tensor. 398 TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); 399 scratch_buffer->type = input->type; 400 scratch_buffer->allocation_type = kTfLiteArenaRw; 401 402 const TfLiteTensor* input_to_input_weights = 403 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); 404 const bool use_cifg = (input_to_input_weights == nullptr); 405 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); 406 scratch_buffer_size->data[0] = n_batch; 407 if (use_cifg) { 408 // Reserving space for Cell, Forget, Output gates 409 scratch_buffer_size->data[1] = n_cell * 3; 410 } else { 411 // Reserving space for Input, Cell, Forget, Output gates 412 scratch_buffer_size->data[1] = n_cell * 4; 413 } 414 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, 415 scratch_buffer_size)); 416 417 if (is_hybrid_op) { 418 // Allocate temporary tensors to store quantized values of input, 419 // activation_state and cell_state tensors. 420 node->temporaries->data[1] = op_data->scratch_tensor_index + 1; 421 TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); 422 input_quantized->type = input_to_output_weights->type; 423 input_quantized->allocation_type = kTfLiteArenaRw; 424 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { 425 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); 426 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, 427 input_quantized_size)); 428 } 429 node->temporaries->data[2] = op_data->scratch_tensor_index + 2; 430 TfLiteTensor* activation_state_quantized = 431 GetTemporary(context, node, /*index=*/2); 432 activation_state_quantized->type = input_to_output_weights->type; 433 activation_state_quantized->allocation_type = kTfLiteArenaRw; 434 if (!TfLiteIntArrayEqual(activation_state_quantized->dims, 435 activation_state->dims)) { 436 TfLiteIntArray* activation_state_quantized_size = 437 TfLiteIntArrayCopy(activation_state->dims); 438 TF_LITE_ENSURE_OK( 439 context, context->ResizeTensor(context, activation_state_quantized, 440 activation_state_quantized_size)); 441 } 442 node->temporaries->data[3] = op_data->scratch_tensor_index + 3; 443 TfLiteTensor* cell_state_quantized = 444 GetTemporary(context, node, /*index=*/3); 445 cell_state_quantized->type = input_to_output_weights->type; 446 cell_state_quantized->allocation_type = kTfLiteArenaRw; 447 if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { 448 TfLiteIntArray* cell_state_quantized_size = 449 TfLiteIntArrayCopy(cell_state->dims); 450 TF_LITE_ENSURE_OK(context, 451 context->ResizeTensor(context, cell_state_quantized, 452 cell_state_quantized_size)); 453 } 454 455 // Allocate temporary tensors to store scaling factors and product scaling 456 // factors. The latter is a convenience storage which allows to quantize 457 // a vector once (which produces the scaling factors) and multiply it with 458 // different matrices (which requires multiplying the scaling factors with 459 // the scaling factor of the matrix). 460 node->temporaries->data[4] = op_data->scratch_tensor_index + 4; 461 TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); 462 scaling_factors->type = kTfLiteFloat32; 463 scaling_factors->allocation_type = kTfLiteArenaRw; 464 int scaling_dims[1] = {n_batch}; 465 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { 466 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); 467 scaling_factors_size->data[0] = n_batch; 468 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, 469 scaling_factors_size)); 470 } 471 node->temporaries->data[5] = op_data->scratch_tensor_index + 5; 472 TfLiteTensor* prod_scaling_factors = 473 GetTemporary(context, node, /*index=*/5); 474 prod_scaling_factors->type = kTfLiteFloat32; 475 prod_scaling_factors->allocation_type = kTfLiteArenaRw; 476 if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, 477 scaling_dims)) { 478 TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); 479 prod_scaling_factors_size->data[0] = n_batch; 480 TF_LITE_ENSURE_OK(context, 481 context->ResizeTensor(context, prod_scaling_factors, 482 prod_scaling_factors_size)); 483 } 484 485 // Allocate a temporary tensor to store the recovered cell weights. Since 486 // this is used for diagonal matrices, only need to store n_cell values. 487 node->temporaries->data[6] = op_data->scratch_tensor_index + 6; 488 TfLiteTensor* recovered_cell_weights = 489 GetTemporary(context, node, /*index=*/6); 490 recovered_cell_weights->type = kTfLiteFloat32; 491 recovered_cell_weights->allocation_type = kTfLiteArenaRw; 492 int recovered_cell_dims[1] = {n_cell}; 493 if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1, 494 recovered_cell_dims)) { 495 TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); 496 recovered_cell_weights_size->data[0] = n_cell; 497 TF_LITE_ENSURE_OK(context, 498 context->ResizeTensor(context, recovered_cell_weights, 499 recovered_cell_weights_size)); 500 } 501 } 502 return kTfLiteOk; 503 } 504 505 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 506 const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); 507 OpData* op_data = reinterpret_cast<OpData*>(node->user_data); 508 const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm; 509 510 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 511 512 const TfLiteTensor* input_to_input_weights = 513 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); 514 const TfLiteTensor* input_to_forget_weights = 515 GetInput(context, node, kInputToForgetWeightsTensor); 516 const TfLiteTensor* input_to_cell_weights = 517 GetInput(context, node, kInputToCellWeightsTensor); 518 const TfLiteTensor* input_to_output_weights = 519 GetInput(context, node, kInputToOutputWeightsTensor); 520 521 const TfLiteTensor* recurrent_to_input_weights = 522 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); 523 const TfLiteTensor* recurrent_to_forget_weights = 524 GetInput(context, node, kRecurrentToForgetWeightsTensor); 525 const TfLiteTensor* recurrent_to_cell_weights = 526 GetInput(context, node, kRecurrentToCellWeightsTensor); 527 const TfLiteTensor* recurrent_to_output_weights = 528 GetInput(context, node, kRecurrentToOutputWeightsTensor); 529 530 const TfLiteTensor* cell_to_input_weights = 531 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); 532 const TfLiteTensor* cell_to_forget_weights = 533 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); 534 const TfLiteTensor* cell_to_output_weights = 535 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); 536 537 const TfLiteTensor* input_layer_norm_coefficients = 538 is_layer_norm_lstm ? GetOptionalInputTensor( 539 context, node, kInputLayerNormCoefficientsTensor) 540 : nullptr; 541 const TfLiteTensor* forget_layer_norm_coefficients = 542 is_layer_norm_lstm 543 ? GetInput(context, node, kForgetLayerNormCoefficientsTensor) 544 : nullptr; 545 const TfLiteTensor* cell_layer_norm_coefficients = 546 is_layer_norm_lstm 547 ? GetInput(context, node, kCellLayerNormCoefficientsTensor) 548 : nullptr; 549 const TfLiteTensor* output_layer_norm_coefficients = 550 is_layer_norm_lstm 551 ? GetInput(context, node, kOutputLayerNormCoefficientsTensor) 552 : nullptr; 553 554 const TfLiteTensor* input_gate_bias = 555 GetOptionalInputTensor(context, node, kInputGateBiasTensor); 556 const TfLiteTensor* forget_gate_bias = 557 GetInput(context, node, kForgetGateBiasTensor); 558 const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); 559 const TfLiteTensor* output_gate_bias = 560 GetInput(context, node, kOutputGateBiasTensor); 561 562 const TfLiteTensor* projection_weights = 563 GetOptionalInputTensor(context, node, kProjectionWeightsTensor); 564 const TfLiteTensor* projection_bias = 565 GetOptionalInputTensor(context, node, kProjectionBiasTensor); 566 567 // Index the scratch buffers pointers to the global scratch buffer. 568 TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); 569 570 TfLiteTensor* activation_state = 571 &context->tensors[op_data->activation_state_tensor_index]; 572 TfLiteTensor* cell_state = 573 &context->tensors[op_data->cell_state_tensor_index]; 574 575 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 576 577 // TODO(mirkov): add a check that weights are all uint8s or all floats. 578 switch (input_to_output_weights->type) { 579 case kTfLiteFloat32: { 580 return lstm_eval::EvalFloat( 581 input, input_to_input_weights, input_to_forget_weights, 582 input_to_cell_weights, input_to_output_weights, 583 recurrent_to_input_weights, recurrent_to_forget_weights, 584 recurrent_to_cell_weights, recurrent_to_output_weights, 585 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, 586 input_layer_norm_coefficients, forget_layer_norm_coefficients, 587 cell_layer_norm_coefficients, output_layer_norm_coefficients, 588 /*aux_input=*/nullptr, 589 /*aux_input_to_input_weights=*/nullptr, 590 /*aux_input_to_forget_weights=*/nullptr, 591 /*aux_input_to_cell_weights=*/nullptr, 592 /*aux_input_to_output_weights=*/nullptr, input_gate_bias, 593 forget_gate_bias, cell_bias, output_gate_bias, projection_weights, 594 projection_bias, params, /*forward_sequence=*/true, 595 /*time_major=*/true, 596 /*output_offset=*/0, scratch_buffer, activation_state, cell_state, 597 output); 598 } 599 case kTfLiteUInt8: 600 case kTfLiteInt8: { 601 TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); 602 TfLiteTensor* activation_state_quantized = 603 GetTemporary(context, node, /*index=*/2); 604 TfLiteTensor* cell_state_quantized = 605 GetTemporary(context, node, /*index=*/3); 606 TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); 607 TfLiteTensor* prod_scaling_factors = 608 GetTemporary(context, node, /*index=*/5); 609 TfLiteTensor* recovered_cell_weights = 610 GetTemporary(context, node, /*index=*/6); 611 return lstm_eval::EvalHybrid( 612 input, input_to_input_weights, input_to_forget_weights, 613 input_to_cell_weights, input_to_output_weights, 614 recurrent_to_input_weights, recurrent_to_forget_weights, 615 recurrent_to_cell_weights, recurrent_to_output_weights, 616 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, 617 input_layer_norm_coefficients, forget_layer_norm_coefficients, 618 cell_layer_norm_coefficients, output_layer_norm_coefficients, 619 /*aux_input=*/nullptr, 620 /*aux_input_to_input_weights=*/nullptr, 621 /*aux_input_to_forget_weights=*/nullptr, 622 /*aux_input_to_cell_weights=*/nullptr, 623 /*aux_input_to_output_weights=*/nullptr, input_gate_bias, 624 forget_gate_bias, cell_bias, output_gate_bias, projection_weights, 625 projection_bias, params, /*forward_sequence=*/true, 626 /*time_major=*/true, /*output_offset=*/0, scratch_buffer, 627 scaling_factors, prod_scaling_factors, recovered_cell_weights, 628 input_quantized, 629 /*aux_input_quantized=*/nullptr, activation_state_quantized, 630 cell_state_quantized, activation_state, cell_state, output); 631 } 632 default: 633 context->ReportError(context, "Type %d is not currently supported.", 634 input_to_output_weights->type); 635 return kTfLiteError; 636 } 637 return kTfLiteOk; 638 } 639 640 } // namespace full 641 642 // For basic kernel (5-inputs). 643 namespace basic { 644 645 enum InputTensor { 646 kInputData = 0, 647 kInputPrevActivation = 1, 648 kInputWeights = 2, 649 kInputBiases = 3, 650 kInputPrevState = 4, 651 kInputNum = 5, 652 }; 653 654 enum OutputTensor { 655 kOutputActivation = 0, 656 kOutputState = 1, 657 kOutputConcatTemp = 2, 658 kOutputActivationTemp = 3, 659 kOutputNum = 4, 660 }; 661 662 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 663 auto* op_data = new OpData(); 664 op_data->kernel_type = kTfLiteLSTMBasicKernel; 665 // `scratch_tensor_index` is unused in this kernel. 666 op_data->scratch_tensor_index = -1; 667 return op_data; 668 } 669 670 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 671 TF_LITE_ENSURE(context, node->inputs->size == kInputNum); 672 TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); 673 674 const TfLiteTensor* input = GetInput(context, node, kInputData); 675 const TfLiteTensor* prev_activation = 676 GetInput(context, node, kInputPrevActivation); 677 const TfLiteTensor* weights = GetInput(context, node, kInputWeights); 678 const TfLiteTensor* bias = GetInput(context, node, kInputBiases); 679 const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); 680 681 TF_LITE_ENSURE_EQ(context, input->dims->size, 2); 682 const int num_batches = input->dims->data[0]; 683 const int input_depth = input->dims->data[1]; 684 685 TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2); 686 TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches); 687 const int activation_depth = prev_activation->dims->data[1]; 688 const int total_depth = input_depth + activation_depth; 689 690 TF_LITE_ENSURE_EQ(context, weights->dims->size, 2); 691 TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth); 692 TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth); 693 694 TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); 695 TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth); 696 697 TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2); 698 TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); 699 TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth); 700 701 TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); 702 TfLiteTensor* state_out = GetOutput(context, node, kOutputState); 703 TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); 704 TfLiteTensor* activation_temp = 705 GetOutput(context, node, kOutputActivationTemp); 706 707 TF_LITE_ENSURE_OK(context, context->ResizeTensor( 708 context, activation_out, 709 TfLiteIntArrayCopy(prev_activation->dims))); 710 TF_LITE_ENSURE_OK( 711 context, context->ResizeTensor(context, state_out, 712 TfLiteIntArrayCopy(prev_state->dims))); 713 714 TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2); 715 concat_temp_size->data[0] = num_batches; 716 concat_temp_size->data[1] = total_depth; 717 TF_LITE_ENSURE_OK( 718 context, context->ResizeTensor(context, concat_temp, concat_temp_size)); 719 TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2); 720 activation_temp_size->data[0] = num_batches; 721 activation_temp_size->data[1] = 4 * activation_depth; 722 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp, 723 activation_temp_size)); 724 725 // Set the state tensors as persistent. 726 for (auto index : {kInputPrevActivation, kInputPrevState}) { 727 TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; 728 tensor->allocation_type = kTfLiteArenaRwPersistent; 729 } 730 return kTfLiteOk; 731 } 732 733 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 734 const TfLiteTensor* input = GetInput(context, node, kInputData); 735 const TfLiteTensor* prev_activation = 736 GetInput(context, node, kInputPrevActivation); 737 const TfLiteTensor* weights = GetInput(context, node, kInputWeights); 738 const TfLiteTensor* bias = GetInput(context, node, kInputBiases); 739 const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); 740 741 TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); 742 TfLiteTensor* state_out = GetOutput(context, node, kOutputState); 743 TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); 744 TfLiteTensor* activation_temp = 745 GetOutput(context, node, kOutputActivationTemp); 746 747 if (input->type == kTfLiteFloat32 && 748 prev_activation->type == kTfLiteFloat32 && 749 weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 && 750 prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 && 751 activation_out->type == kTfLiteFloat32 && 752 concat_temp->type == kTfLiteFloat32 && 753 activation_temp->type == kTfLiteFloat32) { 754 tflite::LstmCellParams op_params; 755 // Float LSTM cell does not need parameters to be set: leave untouched. 756 optimized_ops::LstmCell( 757 op_params, 758 // Inputs. 759 GetTensorShape(input), GetTensorData<float>(input), 760 GetTensorShape(prev_activation), GetTensorData<float>(prev_activation), 761 GetTensorShape(weights), GetTensorData<float>(weights), 762 GetTensorShape(bias), GetTensorData<float>(bias), 763 GetTensorShape(prev_state), GetTensorData<float>(prev_state), 764 // Outputs. 765 GetTensorShape(state_out), GetTensorData<float>(state_out), 766 GetTensorShape(activation_out), GetTensorData<float>(activation_out), 767 GetTensorShape(concat_temp), GetTensorData<float>(concat_temp), 768 GetTensorShape(activation_temp), GetTensorData<float>(activation_temp)); 769 } else if (input->type == kTfLiteUInt8 && 770 prev_activation->type == kTfLiteUInt8 && 771 weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 && 772 prev_state->type == kTfLiteInt16 && 773 state_out->type == kTfLiteInt16 && 774 activation_out->type == kTfLiteUInt8 && 775 concat_temp->type == kTfLiteUInt8 && 776 activation_temp->type == kTfLiteInt16) { 777 gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); 778 int state_scale_log2_rounded; 779 if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) { 780 context->ReportError( 781 context, 782 "The internal state of a LSTM cell must have a power-of-two scale."); 783 return kTfLiteError; 784 } 785 const int state_integer_bits = 15 + state_scale_log2_rounded; 786 if (state_integer_bits != 4) { 787 context->ReportError(context, 788 "The only case of quantized LstmCell currently " 789 "supported is with StateIntegerBits==4"); 790 return kTfLiteError; 791 } 792 793 double real_accum_multiplier = 4096 * bias->params.scale; 794 int32 accum_multiplier; 795 int accum_shift; 796 tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier, 797 &accum_shift); 798 tflite::LstmCellParams op_params; 799 op_params.weights_zero_point = weights->params.zero_point; 800 op_params.accum_multiplier = accum_multiplier; 801 op_params.accum_shift = accum_shift; 802 optimized_ops::LstmCell<4>( 803 op_params, 804 // Inputs. 805 GetTensorShape(input), GetTensorData<uint8_t>(input), 806 GetTensorShape(prev_activation), 807 GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights), 808 GetTensorData<uint8_t>(weights), GetTensorShape(bias), 809 GetTensorData<int32_t>(bias), GetTensorShape(prev_state), 810 GetTensorData<int16_t>(prev_state), 811 // Outputs. 812 GetTensorShape(state_out), GetTensorData<int16_t>(state_out), 813 GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out), 814 GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp), 815 GetTensorShape(activation_temp), 816 GetTensorData<int16_t>(activation_temp), gemm_context); 817 } else { 818 context->ReportError(context, 819 "Unsupported combination of data types for LstmCell"); 820 return kTfLiteError; 821 } 822 823 // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs 824 // LSTM kernel. 825 memcpy(prev_activation->data.raw, activation_out->data.raw, 826 activation_out->bytes); 827 memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes); 828 829 return kTfLiteOk; 830 } 831 832 } // namespace basic 833 834 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 835 gemm_support::IncrementUsageCounter(context); 836 837 const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer); 838 switch (params->kernel_type) { 839 case kTfLiteLSTMFullKernel: 840 return full::Init(context, buffer, length); 841 case kTfLiteLSTMBasicKernel: 842 return basic::Init(context, buffer, length); 843 default: 844 return nullptr; 845 } 846 } 847 void Free(TfLiteContext* context, void* buffer) { 848 gemm_support::DecrementUsageCounter(context); 849 850 delete reinterpret_cast<OpData*>(buffer); 851 } 852 853 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 854 const auto* op_data = reinterpret_cast<const OpData*>(node->user_data); 855 switch (op_data->kernel_type) { 856 case kTfLiteLSTMFullKernel: 857 return full::Prepare(context, node); 858 case kTfLiteLSTMBasicKernel: 859 return basic::Prepare(context, node); 860 default: 861 return kTfLiteError; 862 } 863 } 864 865 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 866 const auto* op_data = reinterpret_cast<const OpData*>(node->user_data); 867 switch (op_data->kernel_type) { 868 case kTfLiteLSTMFullKernel: 869 return full::Eval(context, node); 870 case kTfLiteLSTMBasicKernel: 871 return basic::Eval(context, node); 872 default: 873 return kTfLiteError; 874 } 875 } 876 877 } // namespace lstm 878 879 TfLiteRegistration* Register_LSTM() { 880 static TfLiteRegistration r = {lstm::Init, lstm::Free, lstm::Prepare, 881 lstm::Eval}; 882 return &r; 883 } 884 885 } // namespace builtin 886 } // namespace ops 887 } // namespace tflite 888