1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <cassert> 16 #include <cmath> 17 #include <cstdio> 18 #include <cstdlib> 19 #include <iostream> 20 #include <limits> 21 22 #include "tensorflow/lite/c/builtin_op_data.h" 23 #include "tensorflow/lite/c/c_api_internal.h" 24 #include "tensorflow/lite/kernels/activation_functor.h" 25 #include "tensorflow/lite/kernels/internal/kernel_utils.h" 26 #include "tensorflow/lite/kernels/kernel_util.h" 27 #include "tensorflow/lite/kernels/op_macros.h" 28 29 namespace tflite { 30 namespace ops { 31 namespace builtin { 32 namespace bidirectional_sequence_rnn { 33 34 namespace { 35 36 int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) { 37 if (is_uint8) { 38 return reinterpret_cast<int8_t*>(tensor->data.uint8); 39 } else { 40 return tensor->data.int8; 41 } 42 } 43 44 } // namespace 45 46 constexpr int kInputTensor = 0; 47 // Forward and backward cell tensors. 48 constexpr int kFwWeightsTensor = 1; 49 constexpr int kFwRecurrentWeightsTensor = 2; 50 constexpr int kFwBiasTensor = 3; 51 constexpr int kFwHiddenStateTensor = 4; 52 constexpr int kBwWeightsTensor = 5; 53 constexpr int kBwRecurrentWeightsTensor = 6; 54 constexpr int kBwBiasTensor = 7; 55 constexpr int kBwHiddenStateTensor = 8; 56 // Used as auxiliary input and weights when stacking for 57 // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input 58 // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case 59 // (without cross links). 60 constexpr int kAuxInputTensor = 9; // Optional. 61 constexpr int kFwAuxWeightsTensor = 10; // Optional. 62 constexpr int kBwAuxWeightsTensor = 11; // Optional. 63 // Output tensors. 64 constexpr int kFwOutputTensor = 0; 65 constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false. 66 67 // Temporary tensors. 68 enum TemporaryTensor { 69 kInputQuantized = 0, 70 kFwHiddenStateQuantized = 1, 71 kBwHiddenStateQuantized = 2, 72 kScalingFactors = 3, 73 kAuxInputQuantized = 4, 74 kNumTemporaryTensors = 5 75 }; 76 77 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 78 auto* scratch_tensor_index = new int; 79 context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); 80 return scratch_tensor_index; 81 } 82 83 void Free(TfLiteContext* context, void* buffer) { 84 delete reinterpret_cast<int*>(buffer); 85 } 86 87 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 88 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>( 89 node->builtin_data); 90 91 // Check we have all the inputs and outputs we need. 92 TF_LITE_ENSURE_EQ(context, node->inputs->size, 12); 93 TF_LITE_ENSURE_EQ(context, node->outputs->size, 94 params->merge_outputs ? 1 : 2); 95 96 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 97 const TfLiteTensor* fw_input_weights = 98 GetInput(context, node, kFwWeightsTensor); 99 const TfLiteTensor* fw_recurrent_weights = 100 GetInput(context, node, kFwRecurrentWeightsTensor); 101 const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor); 102 const TfLiteTensor* fw_hidden_state = 103 GetInput(context, node, kFwHiddenStateTensor); 104 const TfLiteTensor* bw_input_weights = 105 GetInput(context, node, kBwWeightsTensor); 106 const TfLiteTensor* bw_recurrent_weights = 107 GetInput(context, node, kBwRecurrentWeightsTensor); 108 const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor); 109 const TfLiteTensor* bw_hidden_state = 110 GetInput(context, node, kBwHiddenStateTensor); 111 112 const TfLiteTensor* aux_input = 113 GetOptionalInputTensor(context, node, kAuxInputTensor); 114 const TfLiteTensor* fw_aux_input_weights = 115 GetOptionalInputTensor(context, node, kFwAuxWeightsTensor); 116 const TfLiteTensor* bw_aux_input_weights = 117 GetOptionalInputTensor(context, node, kBwAuxWeightsTensor); 118 119 const bool aux_inputs_weights_or_none = 120 ((fw_aux_input_weights != nullptr) && 121 (bw_aux_input_weights != nullptr)) || 122 ((fw_aux_input_weights == nullptr) && (bw_aux_input_weights == nullptr)); 123 TF_LITE_ENSURE(context, aux_inputs_weights_or_none); 124 const bool has_aux_input = (fw_aux_input_weights != nullptr); 125 126 // Check all the parameters of tensor match within themselves and match the 127 // input configuration. 128 TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); 129 130 TF_LITE_ENSURE_EQ(context, input->dims->size, 3); 131 const bool time_major = params->time_major; 132 const int batch_size = 133 (time_major) ? input->dims->data[1] : input->dims->data[0]; 134 const int max_time = 135 (time_major) ? input->dims->data[0] : input->dims->data[1]; 136 const int fw_num_units = fw_input_weights->dims->data[0]; 137 const int bw_num_units = bw_input_weights->dims->data[0]; 138 TF_LITE_ENSURE_EQ(context, input->dims->data[2], 139 fw_input_weights->dims->data[1]); 140 TF_LITE_ENSURE_EQ(context, input->dims->data[2], 141 bw_input_weights->dims->data[1]); 142 TF_LITE_ENSURE_EQ(context, fw_input_weights->dims->data[0], 143 fw_bias->dims->data[0]); 144 TF_LITE_ENSURE_EQ(context, bw_input_weights->dims->data[0], 145 bw_bias->dims->data[0]); 146 TF_LITE_ENSURE_EQ(context, fw_recurrent_weights->dims->data[0], 147 fw_bias->dims->data[0]); 148 TF_LITE_ENSURE_EQ(context, bw_recurrent_weights->dims->data[1], 149 bw_bias->dims->data[0]); 150 TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2); 151 TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size); 152 TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units); 153 TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2); 154 TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size); 155 TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units); 156 157 if (has_aux_input) { 158 // Check that aux_input has the same dimensions (except last) as the input. 159 TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]); 160 TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]); 161 // Check that aux_input_weights has the same dimensions (except last) as 162 // the input_weights. 163 TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units); 164 TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units); 165 TF_LITE_ASSERT_EQ(aux_input->dims->data[2], 166 fw_aux_input_weights->dims->data[1]); 167 TF_LITE_ASSERT_EQ(aux_input->dims->data[2], 168 bw_aux_input_weights->dims->data[1]); 169 } 170 171 const bool is_hybrid_op = ((fw_input_weights->type == kTfLiteUInt8 || 172 fw_input_weights->type == kTfLiteInt8) && 173 input->type == kTfLiteFloat32); 174 175 if (is_hybrid_op) { 176 int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); 177 178 TfLiteIntArrayFree(node->temporaries); 179 if (has_aux_input) { 180 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); 181 } else { 182 // No need to create a temporary tensor for the non-existent aux_input. 183 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1); 184 } 185 186 node->temporaries->data[kInputQuantized] = 187 *scratch_tensor_index + kInputQuantized; 188 TfLiteTensor* input_quantized = 189 GetTemporary(context, node, kInputQuantized); 190 input_quantized->type = fw_input_weights->type; 191 input_quantized->allocation_type = kTfLiteArenaRw; 192 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { 193 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); 194 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, 195 input_quantized_size)); 196 } 197 198 node->temporaries->data[kFwHiddenStateQuantized] = 199 *scratch_tensor_index + kFwHiddenStateQuantized; 200 TfLiteTensor* fw_hidden_state_quantized = 201 GetTemporary(context, node, kFwHiddenStateQuantized); 202 fw_hidden_state_quantized->type = fw_input_weights->type; 203 fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw; 204 if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims, 205 fw_hidden_state->dims)) { 206 TfLiteIntArray* fw_hidden_state_quantized_size = 207 TfLiteIntArrayCopy(fw_hidden_state->dims); 208 TF_LITE_ENSURE_OK( 209 context, context->ResizeTensor(context, fw_hidden_state_quantized, 210 fw_hidden_state_quantized_size)); 211 } 212 213 node->temporaries->data[kBwHiddenStateQuantized] = 214 *scratch_tensor_index + kBwHiddenStateQuantized; 215 TfLiteTensor* bw_hidden_state_quantized = 216 GetTemporary(context, node, kBwHiddenStateQuantized); 217 bw_hidden_state_quantized->type = fw_input_weights->type; 218 bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw; 219 if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims, 220 bw_hidden_state->dims)) { 221 TfLiteIntArray* bw_hidden_state_quantized_size = 222 TfLiteIntArrayCopy(bw_hidden_state->dims); 223 TF_LITE_ENSURE_OK( 224 context, context->ResizeTensor(context, bw_hidden_state_quantized, 225 bw_hidden_state_quantized_size)); 226 } 227 228 // Allocate temporary tensors to store scaling factors of quantization. 229 node->temporaries->data[kScalingFactors] = 230 *scratch_tensor_index + kScalingFactors; 231 TfLiteTensor* scaling_factors = 232 GetTemporary(context, node, kScalingFactors); 233 scaling_factors->type = kTfLiteFloat32; 234 scaling_factors->allocation_type = kTfLiteArenaRw; 235 int scaling_dims[1] = {batch_size}; 236 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { 237 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); 238 scaling_factors_size->data[0] = batch_size; 239 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, 240 scaling_factors_size)); 241 } 242 243 if (has_aux_input) { 244 node->temporaries->data[kAuxInputQuantized] = 245 *scratch_tensor_index + kAuxInputQuantized; 246 TfLiteTensor* aux_input_quantized = 247 GetTemporary(context, node, kAuxInputQuantized); 248 aux_input_quantized->type = fw_input_weights->type; 249 aux_input_quantized->allocation_type = kTfLiteArenaRw; 250 if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) { 251 TfLiteIntArray* aux_input_quantized_size = 252 TfLiteIntArrayCopy(aux_input->dims); 253 TF_LITE_ENSURE_OK(context, 254 context->ResizeTensor(context, aux_input_quantized, 255 aux_input_quantized_size)); 256 } 257 } 258 } 259 260 // Resize outputs. 261 TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); 262 TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); 263 fw_output_size_array->data[0] = (time_major) ? max_time : batch_size; 264 fw_output_size_array->data[1] = (time_major) ? batch_size : max_time; 265 fw_output_size_array->data[2] = 266 params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; 267 TF_LITE_ENSURE_OK( 268 context, context->ResizeTensor(context, fw_output, fw_output_size_array)); 269 if (!params->merge_outputs) { 270 TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); 271 TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); 272 bw_output_size_array->data[0] = batch_size; 273 bw_output_size_array->data[1] = max_time; 274 bw_output_size_array->data[2] = bw_num_units; 275 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output, 276 bw_output_size_array)); 277 } 278 279 return kTfLiteOk; 280 } 281 282 TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* bw_input, 283 const TfLiteTensor* fw_input_weights, 284 const TfLiteTensor* fw_recurrent_weights, 285 const TfLiteTensor* fw_bias, 286 const TfLiteTensor* bw_input_weights, 287 const TfLiteTensor* bw_recurrent_weights, 288 const TfLiteTensor* bw_bias, 289 const TfLiteTensor* aux_input, 290 const TfLiteTensor* fw_aux_input_weights, 291 const TfLiteTensor* bw_aux_input_weights, 292 const TfLiteBidirectionalSequenceRNNParams* params, 293 TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, 294 TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { 295 const bool time_major = params->time_major; 296 const int batch_size = 297 (time_major) ? input->dims->data[1] : input->dims->data[0]; 298 const int max_time = 299 (time_major) ? input->dims->data[0] : input->dims->data[1]; 300 const int input_size = input->dims->data[2]; 301 const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; 302 303 const int fw_num_units = fw_input_weights->dims->data[0]; 304 const float* fw_bias_ptr = fw_bias->data.f; 305 const float* fw_input_weights_ptr = fw_input_weights->data.f; 306 const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f; 307 308 const int bw_num_units = bw_input_weights->dims->data[0]; 309 const float* bw_bias_ptr = bw_bias->data.f; 310 const float* bw_input_weights_ptr = bw_input_weights->data.f; 311 const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f; 312 313 const float* fw_aux_input_weights_ptr = (fw_aux_input_weights != nullptr) 314 ? fw_aux_input_weights->data.f 315 : nullptr; 316 const float* bw_aux_input_weights_ptr = (bw_aux_input_weights != nullptr) 317 ? bw_aux_input_weights->data.f 318 : nullptr; 319 320 const int fw_output_step = 321 params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; 322 const int bw_output_step = 323 params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; 324 if (time_major) { 325 // Forward cell. 326 float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f; 327 for (int s = 0; s < max_time; s++) { 328 const float* input_ptr_batch = 329 input->data.f + s * input_size * batch_size; 330 const float* aux_input_ptr_batch = 331 (aux_input != nullptr) 332 ? aux_input->data.f + s * input_size * batch_size 333 : nullptr; 334 float* output_ptr_batch = 335 fw_output->data.f + s * fw_output_step * batch_size; 336 337 kernel_utils::RnnBatchStep( 338 input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, 339 fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr, 340 input_size, aux_input_size, fw_num_units, batch_size, fw_output_step, 341 params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); 342 } 343 // Backward cell. 344 float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f; 345 for (int s = max_time - 1; s >= 0; s--) { 346 const float* input_ptr_batch = 347 bw_input->data.f + s * input_size * batch_size; 348 const float* aux_input_ptr_batch = 349 (aux_input != nullptr) 350 ? aux_input->data.f + s * input_size * batch_size 351 : nullptr; 352 float* output_ptr_batch = 353 (params->merge_outputs ? fw_output->data.f + fw_num_units 354 : bw_output->data.f) + 355 s * bw_output_step * batch_size; 356 357 kernel_utils::RnnBatchStep( 358 input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, 359 bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr, 360 input_size, aux_input_size, bw_num_units, batch_size, bw_output_step, 361 params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); 362 } 363 } else { 364 for (int b = 0; b < batch_size; b++) { 365 // Forward cell. 366 float* fw_hidden_state_ptr_batch = 367 fw_hidden_state->data.f + b * fw_num_units; 368 float* fw_output_offset = 369 fw_output->data.f + b * fw_output_step * max_time; 370 for (int s = 0; s < max_time; s++) { 371 const float* input_ptr_batch = 372 input->data.f + b * input_size * max_time + s * input_size; 373 const float* aux_input_ptr_batch = 374 (aux_input != nullptr) 375 ? aux_input->data.f + b * input_size * max_time + s * input_size 376 : nullptr; 377 float* output_ptr_batch = fw_output_offset + s * fw_output_step; 378 379 kernel_utils::RnnBatchStep( 380 input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, 381 fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr, 382 input_size, aux_input_size, fw_num_units, /*batch_size=*/1, 383 fw_output_step, params->activation, fw_hidden_state_ptr_batch, 384 output_ptr_batch); 385 } 386 // Backward cell. 387 float* bw_hidden_state_ptr_batch = 388 bw_hidden_state->data.f + b * bw_num_units; 389 float* bw_output_offset = 390 params->merge_outputs 391 ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units 392 : bw_output->data.f + b * bw_output_step * max_time; 393 for (int s = max_time - 1; s >= 0; s--) { 394 const float* input_ptr_batch = 395 input->data.f + b * input_size * max_time + s * input_size; 396 const float* aux_input_ptr_batch = 397 (aux_input != nullptr) 398 ? aux_input->data.f + b * input_size * max_time + s * input_size 399 : nullptr; 400 float* output_ptr_batch = bw_output_offset + s * bw_output_step; 401 402 kernel_utils::RnnBatchStep( 403 input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, 404 bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr, 405 input_size, aux_input_size, bw_num_units, /*batch_size=*/1, 406 bw_output_step, params->activation, bw_hidden_state_ptr_batch, 407 output_ptr_batch); 408 } 409 } 410 } 411 return kTfLiteOk; 412 } 413 414 TfLiteStatus EvalHybrid( 415 const TfLiteTensor* input, const TfLiteTensor* bw_input, 416 const TfLiteTensor* fw_input_weights, 417 const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias, 418 const TfLiteTensor* bw_input_weights, 419 const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias, 420 const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights, 421 const TfLiteTensor* aux_bw_input_weights, 422 const TfLiteBidirectionalSequenceRNNParams* params, 423 TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized, 424 TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized, 425 TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, 426 TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, 427 TfLiteTensor* bw_output) { 428 const bool is_uint8_hybrid = fw_input_weights->type == kTfLiteUInt8; 429 const bool time_major = params->time_major; 430 const int batch_size = 431 (time_major) ? input->dims->data[1] : input->dims->data[0]; 432 const int max_time = 433 (time_major) ? input->dims->data[0] : input->dims->data[1]; 434 const int input_size = input->dims->data[2]; 435 const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; 436 437 const int fw_num_units = fw_input_weights->dims->data[0]; 438 const float* fw_bias_ptr = fw_bias->data.f; 439 const int8_t* fw_input_weights_ptr = 440 GetInt8DataPtr(fw_input_weights, is_uint8_hybrid); 441 float fw_input_weights_scale = fw_input_weights->params.scale; 442 const int8_t* fw_recurrent_weights_ptr = 443 GetInt8DataPtr(fw_recurrent_weights, is_uint8_hybrid); 444 float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale; 445 446 const int bw_num_units = bw_input_weights->dims->data[0]; 447 const float* bw_bias_ptr = bw_bias->data.f; 448 const int8_t* bw_input_weights_ptr = 449 GetInt8DataPtr(bw_input_weights, is_uint8_hybrid); 450 float bw_input_weights_scale = bw_input_weights->params.scale; 451 const int8_t* bw_recurrent_weights_ptr = 452 GetInt8DataPtr(bw_recurrent_weights, is_uint8_hybrid); 453 float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale; 454 455 // Set the auxiliary pointers and scales if needed. 456 int8_t* aux_fw_input_weights_ptr = nullptr; 457 float aux_fw_input_weights_scale = 0.0f; 458 int8_t* aux_bw_input_weights_ptr = nullptr; 459 float aux_bw_input_weights_scale = 0.0f; 460 int8_t* aux_quantized_input_ptr = nullptr; 461 if (aux_input_size > 0) { 462 aux_fw_input_weights_ptr = 463 GetInt8DataPtr(aux_fw_input_weights, is_uint8_hybrid); 464 aux_fw_input_weights_scale = aux_fw_input_weights->params.scale; 465 aux_bw_input_weights_ptr = 466 GetInt8DataPtr(aux_bw_input_weights, is_uint8_hybrid); 467 aux_bw_input_weights_scale = aux_bw_input_weights->params.scale; 468 aux_quantized_input_ptr = 469 GetInt8DataPtr(aux_input_quantized, is_uint8_hybrid); 470 } 471 472 // Initialize temporary storage for quantized values. 473 int8_t* quantized_input_ptr = 474 GetInt8DataPtr(input_quantized, is_uint8_hybrid); 475 int8_t* fw_quantized_hidden_state_ptr = 476 GetInt8DataPtr(fw_hidden_state_quantized, is_uint8_hybrid); 477 int8_t* bw_quantized_hidden_state_ptr = 478 GetInt8DataPtr(bw_hidden_state_quantized, is_uint8_hybrid); 479 float* scaling_factors_ptr = scaling_factors->data.f; 480 481 const int fw_output_step = 482 params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; 483 const int bw_output_step = 484 params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; 485 if (time_major) { 486 for (int t = 0; t < max_time; t++) { 487 // Forward cell. 488 float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f; 489 for (int s = 0; s < max_time; s++) { 490 const float* input_ptr_batch = 491 input->data.f + s * input_size * batch_size; 492 const float* aux_input_ptr_batch = 493 (aux_input != nullptr) 494 ? aux_input->data.f + s * input_size * batch_size 495 : nullptr; 496 float* output_ptr_batch = 497 fw_output->data.f + s * fw_output_step * batch_size; 498 499 kernel_utils::RnnBatchStep( 500 input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, 501 aux_input_ptr_batch, aux_fw_input_weights_ptr, 502 aux_fw_input_weights_scale, fw_recurrent_weights_ptr, 503 fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, 504 fw_num_units, batch_size, fw_output_step, params->activation, 505 quantized_input_ptr, aux_quantized_input_ptr, 506 fw_quantized_hidden_state_ptr, scaling_factors_ptr, 507 fw_hidden_state_ptr_batch, output_ptr_batch); 508 } 509 // Backward cell. 510 float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f; 511 for (int s = max_time - 1; s >= 0; s--) { 512 const float* input_ptr_batch = 513 bw_input->data.f + s * input_size * batch_size; 514 const float* aux_input_ptr_batch = 515 (aux_input != nullptr) 516 ? aux_input->data.f + s * input_size * batch_size 517 : nullptr; 518 float* output_ptr_batch = 519 (params->merge_outputs ? fw_output->data.f + fw_num_units 520 : bw_output->data.f) + 521 s * bw_output_step * batch_size; 522 523 kernel_utils::RnnBatchStep( 524 input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, 525 aux_input_ptr_batch, aux_bw_input_weights_ptr, 526 aux_bw_input_weights_scale, bw_recurrent_weights_ptr, 527 bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, 528 bw_num_units, batch_size, bw_output_step, params->activation, 529 quantized_input_ptr, aux_quantized_input_ptr, 530 bw_quantized_hidden_state_ptr, scaling_factors_ptr, 531 bw_hidden_state_ptr_batch, output_ptr_batch); 532 } 533 } 534 } else { 535 for (int b = 0; b < batch_size; b++) { 536 // Forward cell. 537 float* fw_hidden_state_ptr_batch = 538 fw_hidden_state->data.f + b * fw_num_units; 539 float* fw_output_offset = 540 fw_output->data.f + b * fw_output_step * max_time; 541 for (int s = 0; s < max_time; s++) { 542 const float* input_ptr_batch = 543 input->data.f + b * input_size * max_time + s * input_size; 544 const float* aux_input_ptr_batch = 545 (aux_input != nullptr) 546 ? aux_input->data.f + b * input_size * max_time + s * input_size 547 : nullptr; 548 float* output_ptr_batch = fw_output_offset + s * fw_output_step; 549 550 kernel_utils::RnnBatchStep( 551 input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, 552 aux_input_ptr_batch, aux_fw_input_weights_ptr, 553 aux_fw_input_weights_scale, fw_recurrent_weights_ptr, 554 fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, 555 fw_num_units, /*batch_size=*/1, fw_output_step, params->activation, 556 quantized_input_ptr, aux_quantized_input_ptr, 557 fw_quantized_hidden_state_ptr, scaling_factors_ptr, 558 fw_hidden_state_ptr_batch, output_ptr_batch); 559 } 560 // Backward cell. 561 float* bw_hidden_state_ptr_batch = 562 bw_hidden_state->data.f + b * bw_num_units; 563 float* bw_output_offset = 564 params->merge_outputs 565 ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units 566 : bw_output->data.f + b * bw_output_step * max_time; 567 for (int s = max_time - 1; s >= 0; s--) { 568 const float* input_ptr_batch = 569 input->data.f + b * input_size * max_time + s * input_size; 570 const float* aux_input_ptr_batch = 571 (aux_input != nullptr) 572 ? aux_input->data.f + b * input_size * max_time + s * input_size 573 : nullptr; 574 float* output_ptr_batch = bw_output_offset + s * bw_output_step; 575 576 kernel_utils::RnnBatchStep( 577 input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, 578 aux_input_ptr_batch, aux_bw_input_weights_ptr, 579 aux_bw_input_weights_scale, bw_recurrent_weights_ptr, 580 bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, 581 bw_num_units, /*batch_size=*/1, bw_output_step, params->activation, 582 quantized_input_ptr, aux_quantized_input_ptr, 583 bw_quantized_hidden_state_ptr, scaling_factors_ptr, 584 bw_hidden_state_ptr_batch, output_ptr_batch); 585 } 586 } 587 } 588 return kTfLiteOk; 589 } 590 591 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 592 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>( 593 node->builtin_data); 594 595 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 596 const TfLiteTensor* fw_input_weights = 597 GetInput(context, node, kFwWeightsTensor); 598 const TfLiteTensor* fw_recurrent_weights = 599 GetInput(context, node, kFwRecurrentWeightsTensor); 600 const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor); 601 const TfLiteTensor* bw_input_weights = 602 GetInput(context, node, kBwWeightsTensor); 603 const TfLiteTensor* bw_recurrent_weights = 604 GetInput(context, node, kBwRecurrentWeightsTensor); 605 const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor); 606 607 // Get auxiliary inputs. 608 const TfLiteTensor* aux_input = 609 GetOptionalInputTensor(context, node, kAuxInputTensor); 610 const TfLiteTensor* fw_aux_input_weights = 611 GetOptionalInputTensor(context, node, kFwAuxWeightsTensor); 612 const TfLiteTensor* bw_aux_input_weights = 613 GetOptionalInputTensor(context, node, kBwAuxWeightsTensor); 614 615 TfLiteTensor* fw_hidden_state = 616 GetVariableInput(context, node, kFwHiddenStateTensor); 617 TfLiteTensor* bw_hidden_state = 618 GetVariableInput(context, node, kBwHiddenStateTensor); 619 620 TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); 621 TfLiteTensor* bw_output = params->merge_outputs 622 ? nullptr 623 : GetOutput(context, node, kBwOutputTensor); 624 625 const bool has_previous_bw_output = (aux_input != nullptr); 626 const bool use_aux_input = (fw_aux_input_weights != nullptr); 627 628 // We want to cover the following cases: 629 // 630 // If not stacking (not connected after other bidi lstms): 631 // both fw & bw will just use `input`; aux_input will be null. 632 // 633 // If stacking with cross_links, TensorFlow equivalent 634 // (tf.contrib.rnn.stack_bidirectional_rnn): 635 // both fw & bw will use `input`, but aux_input will be none null. 636 // Note, this time, whether connected after other bidi lstms both works. 637 // 638 // If stacking without cross_links, but connected after other bidi lstms, 639 // TensorFlow equivalent (tf.nn.static_bidirectional_rnn): 640 // fw will use `input`, bw will use aux_input, and the `real aux_input` 641 // will be null. 642 643 const bool non_stacking_mode = !use_aux_input && has_previous_bw_output; 644 const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input; 645 const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input; 646 647 switch (fw_input_weights->type) { 648 case kTfLiteFloat32: 649 return EvalFloat(input, bw_input, fw_input_weights, fw_recurrent_weights, 650 fw_bias, bw_input_weights, bw_recurrent_weights, bw_bias, 651 real_aux_input, fw_aux_input_weights, 652 bw_aux_input_weights, params, fw_hidden_state, fw_output, 653 bw_hidden_state, bw_output); 654 case kTfLiteUInt8: 655 case kTfLiteInt8: { 656 TfLiteTensor* input_quantized = 657 GetTemporary(context, node, kInputQuantized); 658 TfLiteTensor* fw_hidden_state_quantized = 659 GetTemporary(context, node, kFwHiddenStateQuantized); 660 TfLiteTensor* bw_hidden_state_quantized = 661 GetTemporary(context, node, kBwHiddenStateQuantized); 662 TfLiteTensor* scaling_factors = 663 GetTemporary(context, node, kScalingFactors); 664 TfLiteTensor* aux_input_quantized = 665 use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) 666 : nullptr; 667 668 return EvalHybrid(input, bw_input, fw_input_weights, fw_recurrent_weights, 669 fw_bias, bw_input_weights, bw_recurrent_weights, 670 bw_bias, real_aux_input, fw_aux_input_weights, 671 bw_aux_input_weights, params, scaling_factors, 672 input_quantized, aux_input_quantized, 673 fw_hidden_state_quantized, fw_hidden_state, fw_output, 674 bw_hidden_state_quantized, bw_hidden_state, bw_output); 675 } 676 default: 677 context->ReportError(context, "Type not currently supported."); 678 return kTfLiteError; 679 } 680 return kTfLiteOk; 681 } 682 683 } // namespace bidirectional_sequence_rnn 684 685 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() { 686 static TfLiteRegistration r = { 687 bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free, 688 bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval}; 689 return &r; 690 } 691 692 } // namespace builtin 693 } // namespace ops 694 } // namespace tflite 695