1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Unit test for TFLite Bidirectional LSTM op. 16 17 #include <initializer_list> 18 #include <iomanip> 19 #include <memory> 20 #include <vector> 21 22 #include <gmock/gmock.h> 23 #include <gtest/gtest.h> 24 #include "tensorflow/lite/interpreter.h" 25 #include "tensorflow/lite/kernels/register.h" 26 #include "tensorflow/lite/kernels/test_util.h" 27 #include "tensorflow/lite/model.h" 28 #include "tensorflow/lite/schema/schema_generated.h" 29 30 namespace tflite { 31 namespace { 32 33 using ::testing::ElementsAreArray; 34 35 class BidirectionalLSTMOpModel : public SingleOpModel { 36 public: 37 BidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, 38 int sequence_length, bool use_cifg, 39 bool use_peephole, bool use_projection_weights, 40 bool use_projection_bias, bool merge_outputs, 41 bool use_aux_input, float cell_clip, float proj_clip, 42 bool quantize_weights, bool time_major, 43 const std::vector<std::vector<int>>& input_shapes) 44 : n_batch_(n_batch), 45 n_input_(n_input), 46 n_fw_cell_(n_cell), 47 n_bw_cell_(n_cell), 48 n_fw_output_(n_output), 49 n_bw_output_(n_output), 50 sequence_length_(sequence_length), 51 quantize_weights_(quantize_weights) { 52 input_ = AddInput(TensorType_FLOAT32); 53 const auto weight_type = 54 quantize_weights_ ? TensorType_UINT8 : TensorType_FLOAT32; 55 56 if (use_cifg) { 57 fw_input_to_input_weights_ = AddNullInput(); 58 } else { 59 fw_input_to_input_weights_ = AddInput(weight_type); 60 } 61 62 fw_input_to_forget_weights_ = AddInput(weight_type); 63 fw_input_to_cell_weights_ = AddInput(weight_type); 64 fw_input_to_output_weights_ = AddInput(weight_type); 65 66 if (use_cifg) { 67 fw_recurrent_to_input_weights_ = AddNullInput(); 68 } else { 69 fw_recurrent_to_input_weights_ = AddInput(weight_type); 70 } 71 72 fw_recurrent_to_forget_weights_ = AddInput(weight_type); 73 fw_recurrent_to_cell_weights_ = AddInput(weight_type); 74 fw_recurrent_to_output_weights_ = AddInput(weight_type); 75 76 if (use_peephole) { 77 if (use_cifg) { 78 fw_cell_to_input_weights_ = AddNullInput(); 79 } else { 80 fw_cell_to_input_weights_ = AddInput(weight_type); 81 } 82 fw_cell_to_forget_weights_ = AddInput(weight_type); 83 fw_cell_to_output_weights_ = AddInput(weight_type); 84 } else { 85 fw_cell_to_input_weights_ = AddNullInput(); 86 fw_cell_to_forget_weights_ = AddNullInput(); 87 fw_cell_to_output_weights_ = AddNullInput(); 88 } 89 90 if (use_cifg) { 91 fw_input_gate_bias_ = AddNullInput(); 92 } else { 93 fw_input_gate_bias_ = AddInput(TensorType_FLOAT32); 94 } 95 fw_forget_gate_bias_ = AddInput(TensorType_FLOAT32); 96 fw_cell_bias_ = AddInput(TensorType_FLOAT32); 97 fw_output_gate_bias_ = AddInput(TensorType_FLOAT32); 98 99 if (use_projection_weights) { 100 fw_projection_weights_ = AddInput(TensorType_FLOAT32); 101 if (use_projection_bias) { 102 fw_projection_bias_ = AddInput(TensorType_FLOAT32); 103 } else { 104 fw_projection_bias_ = AddNullInput(); 105 } 106 } else { 107 fw_projection_weights_ = AddNullInput(); 108 fw_projection_bias_ = AddNullInput(); 109 } 110 111 if (use_cifg) { 112 bw_input_to_input_weights_ = AddNullInput(); 113 } else { 114 bw_input_to_input_weights_ = AddInput(weight_type); 115 } 116 117 bw_input_to_forget_weights_ = AddInput(weight_type); 118 bw_input_to_cell_weights_ = AddInput(weight_type); 119 bw_input_to_output_weights_ = AddInput(weight_type); 120 121 if (use_cifg) { 122 bw_recurrent_to_input_weights_ = AddNullInput(); 123 } else { 124 bw_recurrent_to_input_weights_ = AddInput(weight_type); 125 } 126 127 bw_recurrent_to_forget_weights_ = AddInput(weight_type); 128 bw_recurrent_to_cell_weights_ = AddInput(weight_type); 129 bw_recurrent_to_output_weights_ = AddInput(weight_type); 130 131 if (use_peephole) { 132 if (use_cifg) { 133 bw_cell_to_input_weights_ = AddNullInput(); 134 } else { 135 bw_cell_to_input_weights_ = AddInput(weight_type); 136 } 137 bw_cell_to_forget_weights_ = AddInput(weight_type); 138 bw_cell_to_output_weights_ = AddInput(weight_type); 139 } else { 140 bw_cell_to_input_weights_ = AddNullInput(); 141 bw_cell_to_forget_weights_ = AddNullInput(); 142 bw_cell_to_output_weights_ = AddNullInput(); 143 } 144 145 if (use_cifg) { 146 bw_input_gate_bias_ = AddNullInput(); 147 } else { 148 bw_input_gate_bias_ = AddInput(TensorType_FLOAT32); 149 } 150 bw_forget_gate_bias_ = AddInput(TensorType_FLOAT32); 151 bw_cell_bias_ = AddInput(TensorType_FLOAT32); 152 bw_output_gate_bias_ = AddInput(TensorType_FLOAT32); 153 154 if (use_projection_weights) { 155 bw_projection_weights_ = AddInput(weight_type); 156 if (use_projection_bias) { 157 bw_projection_bias_ = AddInput(TensorType_FLOAT32); 158 } else { 159 bw_projection_bias_ = AddNullInput(); 160 } 161 } else { 162 bw_projection_weights_ = AddNullInput(); 163 bw_projection_bias_ = AddNullInput(); 164 } 165 166 // Adding the 2 input state tensors. 167 fw_input_activation_state_ = 168 AddInput(TensorData{TensorType_FLOAT32, {n_fw_output_ * n_batch_}}, 169 /*is_variable=*/true); 170 fw_input_cell_state_ = 171 AddInput(TensorData{TensorType_FLOAT32, {n_fw_cell_ * n_batch_}}, 172 /*is_variable=*/true); 173 174 // Adding the 2 input state tensors. 175 bw_input_activation_state_ = 176 AddInput(TensorData{TensorType_FLOAT32, {n_bw_output_ * n_batch_}}, 177 /*is_variable=*/true); 178 bw_input_cell_state_ = 179 AddInput(TensorData{TensorType_FLOAT32, {n_bw_cell_ * n_batch_}}, 180 /*is_variable=*/true); 181 182 fw_output_ = AddOutput(TensorType_FLOAT32); 183 184 if (!merge_outputs) { 185 bw_output_ = AddOutput(TensorType_FLOAT32); 186 } 187 188 if (use_aux_input) { 189 aux_input_ = AddInput(TensorType_FLOAT32); 190 } else { 191 aux_input_ = AddNullInput(); 192 } 193 fw_aux_input_to_input_weights_ = AddNullInput(); 194 fw_aux_input_to_forget_weights_ = AddNullInput(); 195 fw_aux_input_to_cell_weights_ = AddNullInput(); 196 fw_aux_input_to_output_weights_ = AddNullInput(); 197 bw_aux_input_to_input_weights_ = AddNullInput(); 198 bw_aux_input_to_forget_weights_ = AddNullInput(); 199 bw_aux_input_to_cell_weights_ = AddNullInput(); 200 bw_aux_input_to_output_weights_ = AddNullInput(); 201 202 SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, 203 BuiltinOptions_BidirectionalSequenceLSTMOptions, 204 CreateBidirectionalSequenceLSTMOptions( 205 builder_, ActivationFunctionType_TANH, cell_clip, 206 proj_clip, merge_outputs, time_major) 207 .Union()); 208 BuildInterpreter(input_shapes); 209 } 210 211 void PopulateWeightTensor(int tensor_id, const std::vector<float>& f) { 212 if (quantize_weights_) { 213 SymmetricQuantizeAndPopulate(tensor_id, f); 214 } else { 215 PopulateTensor(tensor_id, f); 216 } 217 } 218 219 // Set weights in forward and backward cells to be the same. 220 void SetInputToInputWeights(const std::vector<float>& f) { 221 PopulateWeightTensor(fw_input_to_input_weights_, f); 222 PopulateWeightTensor(bw_input_to_input_weights_, f); 223 } 224 225 void SetInputToForgetWeights(const std::vector<float>& f) { 226 PopulateWeightTensor(fw_input_to_forget_weights_, f); 227 PopulateWeightTensor(bw_input_to_forget_weights_, f); 228 } 229 230 void SetInputToCellWeights(const std::vector<float>& f) { 231 PopulateWeightTensor(fw_input_to_cell_weights_, f); 232 PopulateWeightTensor(bw_input_to_cell_weights_, f); 233 } 234 235 void SetInputToOutputWeights(const std::vector<float>& f) { 236 PopulateWeightTensor(fw_input_to_output_weights_, f); 237 PopulateWeightTensor(bw_input_to_output_weights_, f); 238 } 239 240 void SetRecurrentToInputWeights(const std::vector<float>& f) { 241 PopulateWeightTensor(fw_recurrent_to_input_weights_, f); 242 PopulateWeightTensor(bw_recurrent_to_input_weights_, f); 243 } 244 245 void SetRecurrentToForgetWeights(const std::vector<float>& f) { 246 PopulateWeightTensor(fw_recurrent_to_forget_weights_, f); 247 PopulateWeightTensor(bw_recurrent_to_forget_weights_, f); 248 } 249 250 void SetRecurrentToCellWeights(const std::vector<float>& f) { 251 PopulateWeightTensor(fw_recurrent_to_cell_weights_, f); 252 PopulateWeightTensor(bw_recurrent_to_cell_weights_, f); 253 } 254 255 void SetRecurrentToOutputWeights(const std::vector<float>& f) { 256 PopulateWeightTensor(fw_recurrent_to_output_weights_, f); 257 PopulateWeightTensor(bw_recurrent_to_output_weights_, f); 258 } 259 260 void SetCellToInputWeights(const std::vector<float>& f) { 261 PopulateWeightTensor(fw_cell_to_input_weights_, f); 262 PopulateWeightTensor(bw_cell_to_input_weights_, f); 263 } 264 265 void SetCellToForgetWeights(const std::vector<float>& f) { 266 PopulateWeightTensor(fw_cell_to_forget_weights_, f); 267 PopulateWeightTensor(bw_cell_to_forget_weights_, f); 268 } 269 270 void SetCellToOutputWeights(const std::vector<float>& f) { 271 PopulateWeightTensor(fw_cell_to_output_weights_, f); 272 PopulateWeightTensor(bw_cell_to_output_weights_, f); 273 } 274 275 void SetInputGateBias(const std::vector<float>& f) { 276 PopulateTensor(fw_input_gate_bias_, f); 277 PopulateTensor(bw_input_gate_bias_, f); 278 } 279 280 void SetForgetGateBias(const std::vector<float>& f) { 281 PopulateTensor(fw_forget_gate_bias_, f); 282 PopulateTensor(bw_forget_gate_bias_, f); 283 } 284 285 void SetCellBias(const std::vector<float>& f) { 286 PopulateTensor(fw_cell_bias_, f); 287 PopulateTensor(bw_cell_bias_, f); 288 } 289 290 void SetOutputGateBias(const std::vector<float>& f) { 291 PopulateTensor(fw_output_gate_bias_, f); 292 PopulateTensor(bw_output_gate_bias_, f); 293 } 294 295 void SetProjectionWeights(const std::vector<float>& f) { 296 PopulateWeightTensor(fw_projection_weights_, f); 297 PopulateWeightTensor(bw_projection_weights_, f); 298 } 299 300 void SetProjectionBias(const std::vector<float>& f) { 301 PopulateTensor(fw_projection_bias_, f); 302 PopulateTensor(bw_projection_bias_, f); 303 } 304 305 void SetInput(int offset, float* begin, float* end) { 306 PopulateTensor(input_, offset, begin, end); 307 } 308 309 void SetAuxInput(int offset, float* begin, float* end) { 310 PopulateTensor(aux_input_, offset, begin, end); 311 } 312 313 std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); } 314 std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); } 315 316 int num_inputs() { return n_input_; } 317 int num_fw_outputs() { return n_fw_output_; } 318 int num_bw_outputs() { return n_bw_output_; } 319 int num_fw_cells() { return n_fw_cell_; } 320 int num_bw_cells() { return n_bw_cell_; } 321 int num_batches() { return n_batch_; } 322 int sequence_length() { return sequence_length_; } 323 324 private: 325 int input_; 326 int fw_input_to_input_weights_; 327 int fw_input_to_forget_weights_; 328 int fw_input_to_cell_weights_; 329 int fw_input_to_output_weights_; 330 331 int fw_recurrent_to_input_weights_; 332 int fw_recurrent_to_forget_weights_; 333 int fw_recurrent_to_cell_weights_; 334 int fw_recurrent_to_output_weights_; 335 336 int fw_cell_to_input_weights_; 337 int fw_cell_to_forget_weights_; 338 int fw_cell_to_output_weights_; 339 340 int fw_input_gate_bias_; 341 int fw_forget_gate_bias_; 342 int fw_cell_bias_; 343 int fw_output_gate_bias_; 344 345 int fw_projection_weights_; 346 int fw_projection_bias_; 347 348 int bw_input_to_input_weights_; 349 int bw_input_to_forget_weights_; 350 int bw_input_to_cell_weights_; 351 int bw_input_to_output_weights_; 352 353 int bw_recurrent_to_input_weights_; 354 int bw_recurrent_to_forget_weights_; 355 int bw_recurrent_to_cell_weights_; 356 int bw_recurrent_to_output_weights_; 357 358 int bw_cell_to_input_weights_; 359 int bw_cell_to_forget_weights_; 360 int bw_cell_to_output_weights_; 361 362 int bw_input_gate_bias_; 363 int bw_forget_gate_bias_; 364 int bw_cell_bias_; 365 int bw_output_gate_bias_; 366 367 int bw_projection_weights_; 368 int bw_projection_bias_; 369 370 int fw_input_activation_state_; 371 int fw_input_cell_state_; 372 int bw_input_activation_state_; 373 int bw_input_cell_state_; 374 375 int fw_output_; 376 int bw_output_; 377 378 int aux_input_; 379 int fw_aux_input_to_input_weights_; 380 int fw_aux_input_to_forget_weights_; 381 int fw_aux_input_to_cell_weights_; 382 int fw_aux_input_to_output_weights_; 383 int bw_aux_input_to_input_weights_; 384 int bw_aux_input_to_forget_weights_; 385 int bw_aux_input_to_cell_weights_; 386 int bw_aux_input_to_output_weights_; 387 388 int n_batch_; 389 int n_input_; 390 int n_fw_cell_; 391 int n_bw_cell_; 392 int n_fw_output_; 393 int n_bw_output_; 394 int sequence_length_; 395 396 bool quantize_weights_; 397 }; 398 399 // Declare LSTMOpTest as a parameterized test, where the parameter is a boolean 400 // indicating whether to use quantization or not. 401 class LSTMOpTest : public ::testing::TestWithParam<bool> {}; 402 403 INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, ::testing::Bool()); 404 405 TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { 406 const int n_batch = 1; 407 const int n_input = 2; 408 // n_cell and n_output have the same size when there is no projection. 409 const int n_cell = 4; 410 const int n_output = 4; 411 const int sequence_length = 3; 412 const bool quantize_weights = GetParam(); 413 414 BidirectionalLSTMOpModel lstm( 415 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, 416 /*use_peephole=*/false, /*use_projection_weights=*/false, 417 /*use_projection_bias=*/false, /*merge_outputs=*/false, 418 /*use_aux_input=*/false, /*cell_clip=*/0.0, 419 /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true, 420 { 421 {sequence_length, n_batch, n_input}, // input tensor 422 423 // Forward cell 424 {n_cell, n_input}, // input_to_input_weight tensor 425 {n_cell, n_input}, // input_to_forget_weight tensor 426 {n_cell, n_input}, // input_to_cell_weight tensor 427 {n_cell, n_input}, // input_to_output_weight tensor 428 429 {n_cell, n_output}, // recurrent_to_input_weight tensor 430 {n_cell, n_output}, // recurrent_to_forget_weight tensor 431 {n_cell, n_output}, // recurrent_to_cell_weight tensor 432 {n_cell, n_output}, // recurrent_to_output_weight tensor 433 434 {0}, // cell_to_input_weight tensor 435 {0}, // cell_to_forget_weight tensor 436 {0}, // cell_to_output_weight tensor 437 438 {n_cell}, // input_gate_bias tensor 439 {n_cell}, // forget_gate_bias tensor 440 {n_cell}, // cell_bias tensor 441 {n_cell}, // output_gate_bias tensor 442 443 {0, 0}, // projection_weight tensor 444 {0}, // projection_bias tensor 445 446 // Backward cell 447 {n_cell, n_input}, // input_to_input_weight tensor 448 {n_cell, n_input}, // input_to_forget_weight tensor 449 {n_cell, n_input}, // input_to_cell_weight tensor 450 {n_cell, n_input}, // input_to_output_weight tensor 451 452 {n_cell, n_output}, // recurrent_to_input_weight tensor 453 {n_cell, n_output}, // recurrent_to_forget_weight tensor 454 {n_cell, n_output}, // recurrent_to_cell_weight tensor 455 {n_cell, n_output}, // recurrent_to_output_weight tensor 456 457 {0}, // cell_to_input_weight tensor 458 {0}, // cell_to_forget_weight tensor 459 {0}, // cell_to_output_weight tensor 460 461 {n_cell}, // input_gate_bias tensor 462 {n_cell}, // forget_gate_bias tensor 463 {n_cell}, // cell_bias tensor 464 {n_cell}, // output_gate_bias tensor 465 466 {0, 0}, // projection_weight tensor 467 {0}, // projection_bias tensor 468 469 {n_batch, n_output}, // activation_state tensor 470 {n_batch, n_cell}, // cell_state tensor 471 472 {n_batch, n_output}, // activation_state tensor 473 {n_batch, n_cell}, // cell_state tensor 474 475 // TODO(b/121134029): Update tests so tensor shapes after state tensor 476 // are used. They are currently ignored by test_util. 477 {sequence_length, n_batch, 0}, // aux_input tensor 478 {n_cell, 0}, // aux_fw_input_to_input tensor 479 {n_cell, 0}, // aux_fw_input_to_forget tensor 480 {n_cell, 0}, // aux_fw_input_to_cell tensor 481 {n_cell, 0}, // aux_fw_input_to_output tensor 482 {n_cell, 0}, // aux_bw_input_to_input tensor 483 {n_cell, 0}, // aux_bw_input_to_forget tensor 484 {n_cell, 0}, // aux_bw_input_to_cell tensor 485 {n_cell, 0}, // aux_bw_input_to_output tensor 486 }); 487 488 lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, 489 -0.34550029, 0.04266912, -0.15680569, 490 -0.34856534, 0.43890524}); 491 492 lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, 493 -0.20583314, 0.44344562, 0.22077113, 494 -0.29909778}); 495 496 lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, 497 -0.31343272, -0.40032279, 0.44781327, 498 0.01387155, -0.35593212}); 499 500 lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, 501 0.40525138, 0.44272184, 0.03897077, -0.1556896, 502 0.19487578}); 503 504 lstm.SetInputGateBias({0., 0., 0., 0.}); 505 506 lstm.SetCellBias({0., 0., 0., 0.}); 507 508 lstm.SetForgetGateBias({1., 1., 1., 1.}); 509 510 lstm.SetOutputGateBias({0., 0., 0., 0.}); 511 512 lstm.SetRecurrentToInputWeights( 513 {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, 514 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, 515 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); 516 517 lstm.SetRecurrentToCellWeights( 518 {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, 519 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, 520 -0.46367589, 0.26016325, -0.03894562, -0.16368064}); 521 522 lstm.SetRecurrentToForgetWeights( 523 {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, 524 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, 525 0.28053468, 0.01560611, -0.20127171, -0.01140004}); 526 527 lstm.SetRecurrentToOutputWeights( 528 {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, 529 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, 530 -0.51818722, -0.15390486, 0.0468148, 0.39922136}); 531 532 // Input should have n_input * sequence_length many values. 533 static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; 534 static float lstm_fw_golden_output[] = { 535 -0.02973187, 0.1229473, 0.20885126, -0.15358765, 536 -0.03716109, 0.12507336, 0.41193449, -0.20860538, 537 -0.15053082, 0.09120187, 0.24278517, -0.12222792}; 538 static float lstm_bw_golden_output[] = { 539 -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838, 540 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124}; 541 542 float* batch0_start = lstm_input; 543 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); 544 545 lstm.SetInput(0, batch0_start, batch0_end); 546 547 lstm.Invoke(); 548 549 float* fw_golden_start = lstm_fw_golden_output; 550 float* fw_golden_end = 551 fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length(); 552 std::vector<float> fw_expected; 553 fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end); 554 EXPECT_THAT(lstm.GetFwOutput(), 555 ElementsAreArray( 556 ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5))); 557 558 float* bw_golden_start = lstm_bw_golden_output; 559 float* bw_golden_end = 560 bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length(); 561 std::vector<float> bw_expected; 562 bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end); 563 EXPECT_THAT(lstm.GetBwOutput(), 564 ElementsAreArray( 565 ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5))); 566 } 567 568 // Same as the previous test, yet with a single merged output tensor and n_batch 569 // of 2. 570 TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) { 571 const int n_batch = 2; 572 const int n_input = 2; 573 // n_cell and n_output have the same size when there is no projection. 574 const int n_cell = 4; 575 const int n_output = 4; 576 const int sequence_length = 3; 577 const bool quantize_weights = GetParam(); 578 579 BidirectionalLSTMOpModel lstm( 580 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, 581 /*use_peephole=*/false, /*use_projection_weights=*/false, 582 /*use_projection_bias=*/false, /*merge_outputs=*/true, 583 /*use_aux_input=*/false, /*cell_clip=*/0.0, 584 /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true, 585 { 586 {sequence_length, n_batch, n_input}, // input tensor 587 588 // Forward cell 589 {n_cell, n_input}, // input_to_input_weight tensor 590 {n_cell, n_input}, // input_to_forget_weight tensor 591 {n_cell, n_input}, // input_to_cell_weight tensor 592 {n_cell, n_input}, // input_to_output_weight tensor 593 594 {n_cell, n_output}, // recurrent_to_input_weight tensor 595 {n_cell, n_output}, // recurrent_to_forget_weight tensor 596 {n_cell, n_output}, // recurrent_to_cell_weight tensor 597 {n_cell, n_output}, // recurrent_to_output_weight tensor 598 599 {0}, // cell_to_input_weight tensor 600 {0}, // cell_to_forget_weight tensor 601 {0}, // cell_to_output_weight tensor 602 603 {n_cell}, // input_gate_bias tensor 604 {n_cell}, // forget_gate_bias tensor 605 {n_cell}, // cell_bias tensor 606 {n_cell}, // output_gate_bias tensor 607 608 {0, 0}, // projection_weight tensor 609 {0}, // projection_bias tensor 610 611 // Backward cell 612 {n_cell, n_input}, // input_to_input_weight tensor 613 {n_cell, n_input}, // input_to_forget_weight tensor 614 {n_cell, n_input}, // input_to_cell_weight tensor 615 {n_cell, n_input}, // input_to_output_weight tensor 616 617 {n_cell, n_output}, // recurrent_to_input_weight tensor 618 {n_cell, n_output}, // recurrent_to_forget_weight tensor 619 {n_cell, n_output}, // recurrent_to_cell_weight tensor 620 {n_cell, n_output}, // recurrent_to_output_weight tensor 621 622 {0}, // cell_to_input_weight tensor 623 {0}, // cell_to_forget_weight tensor 624 {0}, // cell_to_output_weight tensor 625 626 {n_cell}, // input_gate_bias tensor 627 {n_cell}, // forget_gate_bias tensor 628 {n_cell}, // cell_bias tensor 629 {n_cell}, // output_gate_bias tensor 630 631 {0, 0}, // projection_weight tensor 632 {0}, // projection_bias tensor 633 634 {n_batch, n_output}, // activation_state tensor 635 {n_batch, n_cell}, // cell_state tensor 636 637 {n_batch, n_output}, // activation_state tensor 638 {n_batch, n_cell}, // cell_state tensor 639 640 // TODO(b/121134029): Update tests so tensor shapes after state tensor 641 // are used. They are currently ignored by test_util. 642 {sequence_length, n_batch, 0}, // aux_input tensor 643 {n_cell, 0}, // aux_fw_input_to_input tensor 644 {n_cell, 0}, // aux_fw_input_to_forget tensor 645 {n_cell, 0}, // aux_fw_input_to_cell tensor 646 {n_cell, 0}, // aux_fw_input_to_output tensor 647 {n_cell, 0}, // aux_bw_input_to_input tensor 648 {n_cell, 0}, // aux_bw_input_to_forget tensor 649 {n_cell, 0}, // aux_bw_input_to_cell tensor 650 {n_cell, 0}, // aux_bw_input_to_output tensor 651 }); 652 653 lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, 654 -0.34550029, 0.04266912, -0.15680569, 655 -0.34856534, 0.43890524}); 656 657 lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, 658 -0.20583314, 0.44344562, 0.22077113, 659 -0.29909778}); 660 661 lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, 662 -0.31343272, -0.40032279, 0.44781327, 663 0.01387155, -0.35593212}); 664 665 lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, 666 0.40525138, 0.44272184, 0.03897077, -0.1556896, 667 0.19487578}); 668 669 lstm.SetInputGateBias({0., 0., 0., 0.}); 670 671 lstm.SetCellBias({0., 0., 0., 0.}); 672 673 lstm.SetForgetGateBias({1., 1., 1., 1.}); 674 675 lstm.SetOutputGateBias({0., 0., 0., 0.}); 676 677 lstm.SetRecurrentToInputWeights( 678 {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, 679 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, 680 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); 681 682 lstm.SetRecurrentToCellWeights( 683 {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, 684 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, 685 -0.46367589, 0.26016325, -0.03894562, -0.16368064}); 686 687 lstm.SetRecurrentToForgetWeights( 688 {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, 689 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, 690 0.28053468, 0.01560611, -0.20127171, -0.01140004}); 691 692 lstm.SetRecurrentToOutputWeights( 693 {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, 694 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, 695 -0.51818722, -0.15390486, 0.0468148, 0.39922136}); 696 697 // Input should have n_input * sequence_length many values. 698 static float lstm_input[] = {2., 3., 2., 3., 3., 4., 3., 4., 1., 1., 1., 1.}; 699 static float lstm_fw_golden_output[] = { 700 -0.02973187, 0.1229473, 0.20885126, -0.15358765, -0.02973187, 701 0.1229473, 0.20885126, -0.15358765, -0.03716109, 0.12507336, 702 0.41193449, -0.20860538, -0.03716109, 0.12507336, 0.41193449, 703 -0.20860538, -0.15053082, 0.09120187, 0.24278517, -0.12222792, 704 -0.15053082, 0.09120187, 0.24278517, -0.12222792}; 705 static float lstm_bw_golden_output[] = { 706 -0.0806187, 0.139077, 0.400476, -0.197842, -0.0806187, 0.139077, 707 0.400476, -0.197842, -0.0332076, 0.123838, 0.309777, -0.17621, 708 -0.0332076, 0.123838, 0.309777, -0.17621, -0.0490733, 0.0739237, 709 0.067706, -0.0208124, -0.0490733, 0.0739237, 0.067706, -0.0208124}; 710 711 float* batch0_start = lstm_input; 712 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.num_batches() * 713 lstm.sequence_length(); 714 715 lstm.SetInput(0, batch0_start, batch0_end); 716 717 lstm.Invoke(); 718 719 std::vector<float> merged_expected; 720 for (int k = 0; k < lstm.sequence_length() * lstm.num_batches(); k++) { 721 merged_expected.insert( 722 merged_expected.end(), 723 lstm_fw_golden_output + k * lstm.num_fw_outputs(), 724 lstm_fw_golden_output + (k + 1) * lstm.num_fw_outputs()); 725 merged_expected.insert( 726 merged_expected.end(), 727 lstm_bw_golden_output + k * lstm.num_bw_outputs(), 728 lstm_bw_golden_output + (k + 1) * lstm.num_bw_outputs()); 729 } 730 EXPECT_THAT(lstm.GetFwOutput(), 731 ElementsAreArray(ArrayFloatNear(merged_expected, 732 quantize_weights ? 1e-2 : 1e-5))); 733 } 734 735 TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) { 736 const int n_batch = 1; 737 const int n_input = 2; 738 // n_cell and n_output have the same size when there is no projection. 739 const int n_cell = 4; 740 const int n_output = 4; 741 const int sequence_length = 3; 742 743 BidirectionalLSTMOpModel lstm( 744 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, 745 /*use_peephole=*/false, /*use_projection_weights=*/false, 746 /*use_projection_bias=*/false, /*merge_outputs=*/false, 747 /*use_aux_input=*/false, /*cell_clip=*/0.0, 748 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true, 749 { 750 {sequence_length, n_batch, n_input}, // input tensor 751 752 // Forward cell 753 {n_cell, n_input}, // input_to_input_weight tensor 754 {n_cell, n_input}, // input_to_forget_weight tensor 755 {n_cell, n_input}, // input_to_cell_weight tensor 756 {n_cell, n_input}, // input_to_output_weight tensor 757 758 {n_cell, n_output}, // recurrent_to_input_weight tensor 759 {n_cell, n_output}, // recurrent_to_forget_weight tensor 760 {n_cell, n_output}, // recurrent_to_cell_weight tensor 761 {n_cell, n_output}, // recurrent_to_output_weight tensor 762 763 {0}, // cell_to_input_weight tensor 764 {0}, // cell_to_forget_weight tensor 765 {0}, // cell_to_output_weight tensor 766 767 {n_cell}, // input_gate_bias tensor 768 {n_cell}, // forget_gate_bias tensor 769 {n_cell}, // cell_bias tensor 770 {n_cell}, // output_gate_bias tensor 771 772 {0, 0}, // projection_weight tensor 773 {0}, // projection_bias tensor 774 775 // Backward cell 776 {n_cell, n_input}, // input_to_input_weight tensor 777 {n_cell, n_input}, // input_to_forget_weight tensor 778 {n_cell, n_input}, // input_to_cell_weight tensor 779 {n_cell, n_input}, // input_to_output_weight tensor 780 781 {n_cell, n_output}, // recurrent_to_input_weight tensor 782 {n_cell, n_output}, // recurrent_to_forget_weight tensor 783 {n_cell, n_output}, // recurrent_to_cell_weight tensor 784 {n_cell, n_output}, // recurrent_to_output_weight tensor 785 786 {0}, // cell_to_input_weight tensor 787 {0}, // cell_to_forget_weight tensor 788 {0}, // cell_to_output_weight tensor 789 790 {n_cell}, // input_gate_bias tensor 791 {n_cell}, // forget_gate_bias tensor 792 {n_cell}, // cell_bias tensor 793 {n_cell}, // output_gate_bias tensor 794 795 {0, 0}, // projection_weight tensor 796 {0}, // projection_bias tensor 797 798 {n_batch, n_output}, // activation_state tensor 799 {n_batch, n_cell}, // cell_state tensor 800 801 {n_batch, n_output}, // activation_state tensor 802 {n_batch, n_cell}, // cell_state tensor 803 804 // TODO(b/121134029): Update tests so tensor shapes after state tensor 805 // are used. They are currently ignored by test_util. 806 {sequence_length, n_batch, 0}, // aux_input tensor 807 {n_cell, 0}, // aux_fw_input_to_input tensor 808 {n_cell, 0}, // aux_fw_input_to_forget tensor 809 {n_cell, 0}, // aux_fw_input_to_cell tensor 810 {n_cell, 0}, // aux_fw_input_to_output tensor 811 {n_cell, 0}, // aux_bw_input_to_input tensor 812 {n_cell, 0}, // aux_bw_input_to_forget tensor 813 {n_cell, 0}, // aux_bw_input_to_cell tensor 814 {n_cell, 0}, // aux_bw_input_to_output tensor 815 }); 816 817 lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, 818 -0.34550029, 0.04266912, -0.15680569, 819 -0.34856534, 0.43890524}); 820 821 lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, 822 -0.20583314, 0.44344562, 0.22077113, 823 -0.29909778}); 824 825 lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, 826 -0.31343272, -0.40032279, 0.44781327, 827 0.01387155, -0.35593212}); 828 829 lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, 830 0.40525138, 0.44272184, 0.03897077, -0.1556896, 831 0.19487578}); 832 833 lstm.SetInputGateBias({0., 0., 0., 0.}); 834 835 lstm.SetCellBias({0., 0., 0., 0.}); 836 837 lstm.SetForgetGateBias({1., 1., 1., 1.}); 838 839 lstm.SetOutputGateBias({0., 0., 0., 0.}); 840 841 lstm.SetRecurrentToInputWeights( 842 {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, 843 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, 844 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); 845 846 lstm.SetRecurrentToCellWeights( 847 {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, 848 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, 849 -0.46367589, 0.26016325, -0.03894562, -0.16368064}); 850 851 lstm.SetRecurrentToForgetWeights( 852 {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, 853 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, 854 0.28053468, 0.01560611, -0.20127171, -0.01140004}); 855 856 lstm.SetRecurrentToOutputWeights( 857 {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, 858 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, 859 -0.51818722, -0.15390486, 0.0468148, 0.39922136}); 860 861 // Input should have n_input * sequence_length many values. 862 // Check reversed inputs. 863 static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.}; 864 static float lstm_fw_golden_output[] = { 865 -0.02973187, 0.1229473, 0.20885126, -0.15358765, 866 -0.03716109, 0.12507336, 0.41193449, -0.20860538, 867 -0.15053082, 0.09120187, 0.24278517, -0.12222792}; 868 static float lstm_bw_golden_output[] = { 869 -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838, 870 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124}; 871 872 float* batch0_start = lstm_input_reversed; 873 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); 874 875 lstm.SetInput(0, batch0_start, batch0_end); 876 877 lstm.Invoke(); 878 879 std::vector<float> fw_expected; 880 for (int s = 0; s < lstm.sequence_length(); s++) { 881 float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs(); 882 float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs(); 883 fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end); 884 } 885 EXPECT_THAT(lstm.GetBwOutput(), 886 ElementsAreArray(ArrayFloatNear(fw_expected))); 887 888 std::vector<float> bw_expected; 889 for (int s = 0; s < lstm.sequence_length(); s++) { 890 float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs(); 891 float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs(); 892 bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end); 893 } 894 EXPECT_THAT(lstm.GetFwOutput(), 895 ElementsAreArray(ArrayFloatNear(bw_expected))); 896 } 897 898 TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { 899 const int n_batch = 1; 900 const int n_input = 2; 901 // n_cell and n_output have the same size when there is no projection. 902 const int n_cell = 4; 903 const int n_output = 4; 904 const int sequence_length = 3; 905 906 BidirectionalLSTMOpModel lstm( 907 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, 908 /*use_peephole=*/true, /*use_projection_weights=*/false, 909 /*use_projection_bias=*/false, /*merge_outputs=*/false, 910 /*use_aux_input=*/false, /*cell_clip=*/0.0, 911 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true, 912 { 913 {sequence_length, n_batch, n_input}, // input tensor 914 915 {0, 0}, // input_to_input_weight tensor 916 {n_cell, n_input}, // input_to_forget_weight tensor 917 {n_cell, n_input}, // input_to_cell_weight tensor 918 {n_cell, n_input}, // input_to_output_weight tensor 919 920 {0, 0}, // recurrent_to_input_weight tensor 921 {n_cell, n_output}, // recurrent_to_forget_weight tensor 922 {n_cell, n_output}, // recurrent_to_cell_weight tensor 923 {n_cell, n_output}, // recurrent_to_output_weight tensor 924 925 {0}, // cell_to_input_weight tensor 926 {n_cell}, // cell_to_forget_weight tensor 927 {n_cell}, // cell_to_output_weight tensor 928 929 {0}, // input_gate_bias tensor 930 {n_cell}, // forget_gate_bias tensor 931 {n_cell}, // cell_bias tensor 932 {n_cell}, // output_gate_bias tensor 933 934 {0, 0}, // projection_weight tensor 935 {0}, // projection_bias tensor 936 937 {0, 0}, // input_to_input_weight tensor 938 {n_cell, n_input}, // input_to_forget_weight tensor 939 {n_cell, n_input}, // input_to_cell_weight tensor 940 {n_cell, n_input}, // input_to_output_weight tensor 941 942 {0, 0}, // recurrent_to_input_weight tensor 943 {n_cell, n_output}, // recurrent_to_forget_weight tensor 944 {n_cell, n_output}, // recurrent_to_cell_weight tensor 945 {n_cell, n_output}, // recurrent_to_output_weight tensor 946 947 {0}, // cell_to_input_weight tensor 948 {n_cell}, // cell_to_forget_weight tensor 949 {n_cell}, // cell_to_output_weight tensor 950 951 {0}, // input_gate_bias tensor 952 {n_cell}, // forget_gate_bias tensor 953 {n_cell}, // cell_bias tensor 954 {n_cell}, // output_gate_bias tensor 955 956 {0, 0}, // projection_weight tensor 957 {0}, // projection_bias tensor 958 959 {n_batch, n_output}, // activation_state tensor 960 {n_batch, n_cell}, // cell_state tensor 961 962 {n_batch, n_output}, // activation_state tensor 963 {n_batch, n_cell}, // cell_state tensor 964 965 // TODO(b/121134029): Update tests so tensor shapes after state tensor 966 // are used. They are currently ignored by test_util. 967 {sequence_length, n_batch, 0}, // aux_input tensor 968 {n_cell, 0}, // aux_fw_input_to_input tensor 969 {n_cell, 0}, // aux_fw_input_to_forget tensor 970 {n_cell, 0}, // aux_fw_input_to_cell tensor 971 {n_cell, 0}, // aux_fw_input_to_output tensor 972 {n_cell, 0}, // aux_bw_input_to_input tensor 973 {n_cell, 0}, // aux_bw_input_to_forget tensor 974 {n_cell, 0}, // aux_bw_input_to_cell tensor 975 {n_cell, 0}, // aux_bw_input_to_output tensor 976 }); 977 978 lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, 979 0.04717243, 0.48944736, -0.38535351, 980 -0.17212132}); 981 982 lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, 983 -0.3633365, -0.22755712, 0.28253698, 0.24407166, 984 0.33826375}); 985 986 lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, 987 -0.09426838, -0.44257352, 0.54939759, 988 0.01533556, 0.42751634}); 989 990 lstm.SetCellBias({0., 0., 0., 0.}); 991 992 lstm.SetForgetGateBias({1., 1., 1., 1.}); 993 994 lstm.SetOutputGateBias({0., 0., 0., 0.}); 995 996 lstm.SetRecurrentToCellWeights( 997 {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, 998 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, 999 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, 1000 0.21193194}); 1001 1002 lstm.SetRecurrentToForgetWeights( 1003 {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, 1004 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, 1005 -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); 1006 1007 lstm.SetRecurrentToOutputWeights( 1008 {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, 1009 -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, 1010 0.50248802, 0.26114327, -0.43736315, 0.33149987}); 1011 1012 lstm.SetCellToForgetWeights( 1013 {0.47485286, -0.51955009, -0.24458408, 0.31544167}); 1014 lstm.SetCellToOutputWeights( 1015 {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); 1016 1017 static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; 1018 static float lstm_fw_golden_output[] = { 1019 -0.36444446, -0.00352185, 0.12886585, -0.05163646, 1020 -0.42312205, -0.01218222, 0.24201041, -0.08124574, 1021 -0.358325, -0.04621704, 0.21641694, -0.06471302}; 1022 static float lstm_bw_golden_output[] = { 1023 -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577, 1024 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578}; 1025 1026 float* batch0_start = lstm_input; 1027 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); 1028 1029 lstm.SetInput(0, batch0_start, batch0_end); 1030 1031 lstm.Invoke(); 1032 1033 float* fw_golden_start = lstm_fw_golden_output; 1034 float* fw_golden_end = 1035 fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length(); 1036 std::vector<float> fw_expected; 1037 fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end); 1038 EXPECT_THAT(lstm.GetFwOutput(), 1039 ElementsAreArray(ArrayFloatNear(fw_expected))); 1040 1041 float* bw_golden_start = lstm_bw_golden_output; 1042 float* bw_golden_end = 1043 bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length(); 1044 std::vector<float> bw_expected; 1045 bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end); 1046 EXPECT_THAT(lstm.GetBwOutput(), 1047 ElementsAreArray(ArrayFloatNear(bw_expected))); 1048 } 1049 1050 TEST(LSTMOpTest, 1051 BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed) { 1052 const int n_batch = 1; 1053 const int n_input = 2; 1054 // n_cell and n_output have the same size when there is no projection. 1055 const int n_cell = 4; 1056 const int n_output = 4; 1057 const int sequence_length = 3; 1058 1059 BidirectionalLSTMOpModel lstm( 1060 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, 1061 /*use_peephole=*/true, /*use_projection_weights=*/false, 1062 /*use_projection_bias=*/false, /*merge_outputs=*/false, 1063 /*use_aux_input=*/false, /*cell_clip=*/0.0, 1064 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true, 1065 { 1066 {sequence_length, n_batch, n_input}, // input tensor 1067 1068 {0, 0}, // input_to_input_weight tensor 1069 {n_cell, n_input}, // input_to_forget_weight tensor 1070 {n_cell, n_input}, // input_to_cell_weight tensor 1071 {n_cell, n_input}, // input_to_output_weight tensor 1072 1073 {0, 0}, // recurrent_to_input_weight tensor 1074 {n_cell, n_output}, // recurrent_to_forget_weight tensor 1075 {n_cell, n_output}, // recurrent_to_cell_weight tensor 1076 {n_cell, n_output}, // recurrent_to_output_weight tensor 1077 1078 {0}, // cell_to_input_weight tensor 1079 {n_cell}, // cell_to_forget_weight tensor 1080 {n_cell}, // cell_to_output_weight tensor 1081 1082 {0}, // input_gate_bias tensor 1083 {n_cell}, // forget_gate_bias tensor 1084 {n_cell}, // cell_bias tensor 1085 {n_cell}, // output_gate_bias tensor 1086 1087 {0, 0}, // projection_weight tensor 1088 {0}, // projection_bias tensor 1089 1090 {0, 0}, // input_to_input_weight tensor 1091 {n_cell, n_input}, // input_to_forget_weight tensor 1092 {n_cell, n_input}, // input_to_cell_weight tensor 1093 {n_cell, n_input}, // input_to_output_weight tensor 1094 1095 {0, 0}, // recurrent_to_input_weight tensor 1096 {n_cell, n_output}, // recurrent_to_forget_weight tensor 1097 {n_cell, n_output}, // recurrent_to_cell_weight tensor 1098 {n_cell, n_output}, // recurrent_to_output_weight tensor 1099 1100 {0}, // cell_to_input_weight tensor 1101 {n_cell}, // cell_to_forget_weight tensor 1102 {n_cell}, // cell_to_output_weight tensor 1103 1104 {0}, // input_gate_bias tensor 1105 {n_cell}, // forget_gate_bias tensor 1106 {n_cell}, // cell_bias tensor 1107 {n_cell}, // output_gate_bias tensor 1108 1109 {0, 0}, // projection_weight tensor 1110 {0}, // projection_bias tensor 1111 1112 {n_batch, n_output}, // activation_state tensor 1113 {n_batch, n_cell}, // cell_state tensor 1114 1115 {n_batch, n_output}, // activation_state tensor 1116 {n_batch, n_cell}, // cell_state tensor 1117 1118 // TODO(b/121134029): Update tests so tensor shapes after state tensor 1119 // are used. They are currently ignored by test_util. 1120 {sequence_length, n_batch, 0}, // aux_input tensor 1121 {n_cell, 0}, // aux_fw_input_to_input tensor 1122 {n_cell, 0}, // aux_fw_input_to_forget tensor 1123 {n_cell, 0}, // aux_fw_input_to_cell tensor 1124 {n_cell, 0}, // aux_fw_input_to_output tensor 1125 {n_cell, 0}, // aux_bw_input_to_input tensor 1126 {n_cell, 0}, // aux_bw_input_to_forget tensor 1127 {n_cell, 0}, // aux_bw_input_to_cell tensor 1128 {n_cell, 0}, // aux_bw_input_to_output tensor 1129 }); 1130 1131 lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, 1132 0.04717243, 0.48944736, -0.38535351, 1133 -0.17212132}); 1134 1135 lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, 1136 -0.3633365, -0.22755712, 0.28253698, 0.24407166, 1137 0.33826375}); 1138 1139 lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, 1140 -0.09426838, -0.44257352, 0.54939759, 1141 0.01533556, 0.42751634}); 1142 1143 lstm.SetCellBias({0., 0., 0., 0.}); 1144 1145 lstm.SetForgetGateBias({1., 1., 1., 1.}); 1146 1147 lstm.SetOutputGateBias({0., 0., 0., 0.}); 1148 1149 lstm.SetRecurrentToCellWeights( 1150 {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, 1151 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, 1152 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, 1153 0.21193194}); 1154 1155 lstm.SetRecurrentToForgetWeights( 1156 {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, 1157 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, 1158 -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); 1159 1160 lstm.SetRecurrentToOutputWeights( 1161 {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, 1162 -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, 1163 0.50248802, 0.26114327, -0.43736315, 0.33149987}); 1164 1165 lstm.SetCellToForgetWeights( 1166 {0.47485286, -0.51955009, -0.24458408, 0.31544167}); 1167 lstm.SetCellToOutputWeights( 1168 {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); 1169 1170 static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.}; 1171 static float lstm_fw_golden_output[] = { 1172 -0.36444446, -0.00352185, 0.12886585, -0.05163646, 1173 -0.42312205, -0.01218222, 0.24201041, -0.08124574, 1174 -0.358325, -0.04621704, 0.21641694, -0.06471302}; 1175 static float lstm_bw_golden_output[] = { 1176 -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577, 1177 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578}; 1178 1179 float* batch0_start = lstm_input_reversed; 1180 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); 1181 1182 lstm.SetInput(0, batch0_start, batch0_end); 1183 1184 lstm.Invoke(); 1185 1186 std::vector<float> fw_expected; 1187 for (int s = 0; s < lstm.sequence_length(); s++) { 1188 float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs(); 1189 float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs(); 1190 fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end); 1191 } 1192 EXPECT_THAT(lstm.GetBwOutput(), 1193 ElementsAreArray(ArrayFloatNear(fw_expected))); 1194 1195 std::vector<float> bw_expected; 1196 for (int s = 0; s < lstm.sequence_length(); s++) { 1197 float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs(); 1198 float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs(); 1199 bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end); 1200 } 1201 EXPECT_THAT(lstm.GetFwOutput(), 1202 ElementsAreArray(ArrayFloatNear(bw_expected))); 1203 } 1204 1205 TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { 1206 const int n_batch = 2; 1207 const int n_input = 5; 1208 const int n_cell = 20; 1209 const int n_output = 16; 1210 const int sequence_length = 4; 1211 1212 BidirectionalLSTMOpModel lstm( 1213 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, 1214 /*use_peephole=*/true, /*use_projection_weights=*/true, 1215 /*use_projection_bias=*/false, /*merge_outputs=*/false, 1216 /*use_aux_input=*/false, /*cell_clip=*/0.0, 1217 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/true, 1218 { 1219 {sequence_length, n_batch, n_input}, // input tensor 1220 1221 {n_cell, n_input}, // input_to_input_weight tensor 1222 {n_cell, n_input}, // input_to_forget_weight tensor 1223 {n_cell, n_input}, // input_to_cell_weight tensor 1224 {n_cell, n_input}, // input_to_output_weight tensor 1225 1226 {n_cell, n_output}, // recurrent_to_input_weight tensor 1227 {n_cell, n_output}, // recurrent_to_forget_weight tensor 1228 {n_cell, n_output}, // recurrent_to_cell_weight tensor 1229 {n_cell, n_output}, // recurrent_to_output_weight tensor 1230 1231 {n_cell}, // cell_to_input_weight tensor 1232 {n_cell}, // cell_to_forget_weight tensor 1233 {n_cell}, // cell_to_output_weight tensor 1234 1235 {n_cell}, // input_gate_bias tensor 1236 {n_cell}, // forget_gate_bias tensor 1237 {n_cell}, // cell_bias tensor 1238 {n_cell}, // output_gate_bias tensor 1239 1240 {n_output, n_cell}, // projection_weight tensor 1241 {0}, // projection_bias tensor 1242 1243 {n_cell, n_input}, // input_to_input_weight tensor 1244 {n_cell, n_input}, // input_to_forget_weight tensor 1245 {n_cell, n_input}, // input_to_cell_weight tensor 1246 {n_cell, n_input}, // input_to_output_weight tensor 1247 1248 {n_cell, n_output}, // recurrent_to_input_weight tensor 1249 {n_cell, n_output}, // recurrent_to_forget_weight tensor 1250 {n_cell, n_output}, // recurrent_to_cell_weight tensor 1251 {n_cell, n_output}, // recurrent_to_output_weight tensor 1252 1253 {n_cell}, // cell_to_input_weight tensor 1254 {n_cell}, // cell_to_forget_weight tensor 1255 {n_cell}, // cell_to_output_weight tensor 1256 1257 {n_cell}, // input_gate_bias tensor 1258 {n_cell}, // forget_gate_bias tensor 1259 {n_cell}, // cell_bias tensor 1260 {n_cell}, // output_gate_bias tensor 1261 1262 {n_output, n_cell}, // projection_weight tensor 1263 {0}, // projection_bias tensor 1264 1265 {n_batch, n_output}, // activation_state tensor 1266 {n_batch, n_cell}, // cell_state tensor 1267 1268 {n_batch, n_output}, // activation_state tensor 1269 {n_batch, n_cell}, // cell_state tensor 1270 1271 // TODO(b/121134029): Update tests so tensor shapes after state tensor 1272 // are used. They are currently ignored by test_util. 1273 {sequence_length, n_batch, 0}, // aux_input tensor 1274 {n_cell, 0}, // aux_fw_input_to_input tensor 1275 {n_cell, 0}, // aux_fw_input_to_forget tensor 1276 {n_cell, 0}, // aux_fw_input_to_cell tensor 1277 {n_cell, 0}, // aux_fw_input_to_output tensor 1278 {n_cell, 0}, // aux_bw_input_to_input tensor 1279 {n_cell, 0}, // aux_bw_input_to_forget tensor 1280 {n_cell, 0}, // aux_bw_input_to_cell tensor 1281 {n_cell, 0}, // aux_bw_input_to_output tensor 1282 }); 1283 1284 lstm.SetInputToInputWeights( 1285 {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, 1286 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, 1287 -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, 1288 -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, 1289 -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, 1290 -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, 1291 -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, 1292 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, 1293 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, 1294 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, 1295 -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, 1296 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, 1297 -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, 1298 -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, 1299 -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, 1300 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, 1301 -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, 1302 -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, 1303 -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, 1304 -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); 1305 1306 lstm.SetInputToForgetWeights( 1307 {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, 1308 -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, 1309 -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, 1310 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, 1311 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, 1312 -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, 1313 -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, 1314 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, 1315 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, 1316 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, 1317 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, 1318 -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, 1319 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, 1320 -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, 1321 -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, 1322 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, 1323 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, 1324 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, 1325 -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, 1326 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); 1327 1328 lstm.SetInputToCellWeights( 1329 {-0.04580283, -0.09549462, -0.032418985, -0.06454633, 1330 -0.043528453, 0.043018587, -0.049152344, -0.12418144, 1331 -0.078985475, -0.07596889, 0.019484362, -0.11434962, 1332 -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, 1333 -0.025034338, -0.0028890965, 0.048929527, 0.06235075, 1334 0.10665918, -0.032036792, -0.08505916, -0.10843358, 1335 -0.13002433, -0.036816437, -0.02130134, -0.016518239, 1336 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, 1337 -0.10652836, -0.1037554, -0.13056071, -0.03266643, 1338 -0.033702414, -0.006473424, -0.04611692, 0.014419339, 1339 -0.025174323, 0.0396852, 0.081777506, 0.06157468, 1340 0.10210095, -0.009658194, 0.046511717, 0.03603906, 1341 0.0069369148, 0.015960095, -0.06507666, 0.09551598, 1342 0.053568836, 0.06408714, 0.12835667, -0.008714329, 1343 -0.20211966, -0.12093674, 0.029450472, 0.2849013, 1344 -0.029227901, 0.1164364, -0.08560263, 0.09941786, 1345 -0.036999565, -0.028842626, -0.0033637602, -0.017012902, 1346 -0.09720865, -0.11193351, -0.029155117, -0.017936034, 1347 -0.009768936, -0.04223324, -0.036159635, 0.06505112, 1348 -0.021742892, -0.023377212, -0.07221364, -0.06430552, 1349 0.05453865, 0.091149814, 0.06387331, 0.007518393, 1350 0.055960953, 0.069779344, 0.046411168, 0.10509911, 1351 0.07463894, 0.0075130584, 0.012850982, 0.04555431, 1352 0.056955688, 0.06555285, 0.050801456, -0.009862683, 1353 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); 1354 1355 lstm.SetInputToOutputWeights( 1356 {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, 1357 -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, 1358 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, 1359 -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, 1360 -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, 1361 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, 1362 -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, 1363 -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, 1364 -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, 1365 -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, 1366 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, 1367 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, 1368 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, 1369 -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, 1370 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, 1371 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, 1372 -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, 1373 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, 1374 -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, 1375 -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); 1376 1377 lstm.SetInputGateBias( 1378 {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, 1379 -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, 1380 -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, 1381 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); 1382 1383 lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, 1384 0.11098921, 0.15378423, 0.09263801, 0.09790885, 1385 0.09508917, 0.061199076, 0.07665568, -0.015443159, 1386 -0.03499149, 0.046190713, 0.08895977, 0.10899629, 1387 0.40694186, 0.06030037, 0.012413437, -0.06108739}); 1388 1389 lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, 1390 -0.1483596, -0.10639995, -0.091433935, 0.058573797, 1391 -0.06809782, -0.07889636, -0.043246906, -0.09829136, 1392 -0.4279842, 0.034901652, 0.18797937, 0.0075234566, 1393 0.016178843, 0.1749513, 0.13975595, 0.92058027}); 1394 1395 lstm.SetOutputGateBias( 1396 {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, 1397 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, 1398 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, 1399 -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); 1400 1401 lstm.SetRecurrentToInputWeights( 1402 {-0.001374326, -0.078856036, 0.10672688, 0.029162422, 1403 -0.11585556, 0.02557986, -0.13446963, -0.035785314, 1404 -0.01244275, 0.025961924, -0.02337298, -0.044228926, 1405 -0.055839065, -0.046598054, -0.010546039, -0.06900766, 1406 0.027239809, 0.022582639, -0.013296484, -0.05459212, 1407 0.08981, -0.045407712, 0.08682226, -0.06867011, 1408 -0.14390695, -0.02916037, 0.000996957, 0.091420636, 1409 0.14283475, -0.07390571, -0.06402044, 0.062524505, 1410 -0.093129106, 0.04860203, -0.08364217, -0.08119002, 1411 0.009352075, 0.22920375, 0.0016303885, 0.11583097, 1412 -0.13732095, 0.012405723, -0.07551853, 0.06343048, 1413 0.12162708, -0.031923793, -0.014335606, 0.01790974, 1414 -0.10650317, -0.0724401, 0.08554849, -0.05727212, 1415 0.06556731, -0.042729504, -0.043227166, 0.011683251, 1416 -0.013082158, -0.029302018, -0.010899579, -0.062036745, 1417 -0.022509435, -0.00964907, -0.01567329, 0.04260106, 1418 -0.07787477, -0.11576462, 0.017356863, 0.048673786, 1419 -0.017577527, -0.05527947, -0.082487635, -0.040137455, 1420 -0.10820036, -0.04666372, 0.022746278, -0.07851417, 1421 0.01068115, 0.032956902, 0.022433773, 0.0026891115, 1422 0.08944216, -0.0685835, 0.010513544, 0.07228705, 1423 0.02032331, -0.059686817, -0.0005566496, -0.086984694, 1424 0.040414046, -0.1380399, 0.094208956, -0.05722982, 1425 0.012092817, -0.04989123, -0.086576, -0.003399834, 1426 -0.04696032, -0.045747425, 0.10091314, 0.048676282, 1427 -0.029037097, 0.031399418, -0.0040285117, 0.047237843, 1428 0.09504992, 0.041799378, -0.049185462, -0.031518843, 1429 -0.10516937, 0.026374253, 0.10058866, -0.0033195973, 1430 -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, 1431 -0.10167381, 0.042500053, -0.01447153, 0.06464186, 1432 -0.017142897, 0.03312627, 0.009205989, 0.024138335, 1433 -0.011337001, 0.035530265, -0.010912711, 0.0706555, 1434 -0.005894094, 0.051841937, -0.1401738, -0.02351249, 1435 0.0365468, 0.07590991, 0.08838724, 0.021681072, 1436 -0.10086113, 0.019608743, -0.06195883, 0.077335775, 1437 0.023646897, -0.095322326, 0.02233014, 0.09756986, 1438 -0.048691444, -0.009579111, 0.07595467, 0.11480546, 1439 -0.09801813, 0.019894179, 0.08502348, 0.004032281, 1440 0.037211012, 0.068537936, -0.048005626, -0.091520436, 1441 -0.028379958, -0.01556313, 0.06554592, -0.045599163, 1442 -0.01672207, -0.020169014, -0.011877351, -0.20212261, 1443 0.010889619, 0.0047078193, 0.038385306, 0.08540671, 1444 -0.017140968, -0.0035865551, 0.016678626, 0.005633034, 1445 0.015963363, 0.00871737, 0.060130805, 0.028611384, 1446 0.10109069, -0.015060172, -0.07894427, 0.06401885, 1447 0.011584063, -0.024466386, 0.0047652307, -0.09041358, 1448 0.030737216, -0.0046374933, 0.14215417, -0.11823516, 1449 0.019899689, 0.006106124, -0.027092824, 0.0786356, 1450 0.05052217, -0.058925, -0.011402121, -0.024987547, 1451 -0.0013661642, -0.06832946, -0.015667673, -0.1083353, 1452 -0.00096863037, -0.06988685, -0.053350925, -0.027275559, 1453 -0.033664223, -0.07978348, -0.025200296, -0.017207067, 1454 -0.058403496, -0.055697463, 0.005798788, 0.12965427, 1455 -0.062582195, 0.0013350133, -0.10482091, 0.0379771, 1456 0.072521195, -0.0029455067, -0.13797039, -0.03628521, 1457 0.013806405, -0.017858358, -0.01008298, -0.07700066, 1458 -0.017081132, 0.019358726, 0.0027079724, 0.004635139, 1459 0.062634714, -0.02338735, -0.039547626, -0.02050681, 1460 0.03385117, -0.083611414, 0.002862572, -0.09421313, 1461 0.058618143, -0.08598433, 0.00972939, 0.023867095, 1462 -0.053934585, -0.023203006, 0.07452513, -0.048767887, 1463 -0.07314807, -0.056307215, -0.10433547, -0.06440842, 1464 0.04328182, 0.04389765, -0.020006588, -0.09076438, 1465 -0.11652589, -0.021705797, 0.03345259, -0.010329105, 1466 -0.025767034, 0.013057034, -0.07316461, -0.10145612, 1467 0.06358255, 0.18531723, 0.07759293, 0.12006465, 1468 0.1305557, 0.058638252, -0.03393652, 0.09622831, 1469 -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, 1470 -0.005644518, 0.06857898, -0.12598175, -0.035084512, 1471 0.03156317, -0.12794146, -0.031963028, 0.04692781, 1472 0.030070418, 0.0071660685, -0.095516115, -0.004643372, 1473 0.040170413, -0.062104587, -0.0037324072, 0.0554317, 1474 0.08184801, -0.019164372, 0.06791302, 0.034257166, 1475 -0.10307039, 0.021943003, 0.046745934, 0.0790918, 1476 -0.0265588, -0.007824208, 0.042546265, -0.00977924, 1477 -0.0002440307, -0.017384544, -0.017990116, 0.12252321, 1478 -0.014512694, -0.08251313, 0.08861942, 0.13589665, 1479 0.026351685, 0.012641483, 0.07466548, 0.044301085, 1480 -0.045414884, -0.051112458, 0.03444247, -0.08502782, 1481 -0.04106223, -0.028126027, 0.028473156, 0.10467447}); 1482 1483 lstm.SetRecurrentToForgetWeights( 1484 {-0.057784554, -0.026057621, -0.068447545, -0.022581743, 1485 0.14811787, 0.10826372, 0.09471067, 0.03987225, 1486 -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, 1487 0.08414449, -0.022036452, -0.00066928595, -0.09203576, 1488 0.032950465, -0.10985798, -0.023809856, 0.0021431844, 1489 -0.02196096, -0.00326074, 0.00058621005, -0.074678116, 1490 -0.06193199, 0.055729095, 0.03736828, 0.020123724, 1491 0.061878487, -0.04729229, 0.034919553, -0.07585433, 1492 -0.04421272, -0.044019096, 0.085488975, 0.04058006, 1493 -0.06890133, -0.030951202, -0.024628663, -0.07672815, 1494 0.034293607, 0.08556707, -0.05293577, -0.033561368, 1495 -0.04899627, 0.0241671, 0.015736353, -0.095442444, 1496 -0.029564252, 0.016493602, -0.035026584, 0.022337519, 1497 -0.026871363, 0.004780428, 0.0077918363, -0.03601621, 1498 0.016435321, -0.03263031, -0.09543275, -0.047392778, 1499 0.013454138, 0.028934088, 0.01685226, -0.086110644, 1500 -0.046250615, -0.01847454, 0.047608484, 0.07339695, 1501 0.034546845, -0.04881143, 0.009128804, -0.08802852, 1502 0.03761666, 0.008096139, -0.014454086, 0.014361001, 1503 -0.023502491, -0.0011840804, -0.07607001, 0.001856849, 1504 -0.06509276, -0.006021153, -0.08570962, -0.1451793, 1505 0.060212336, 0.055259194, 0.06974018, 0.049454916, 1506 -0.027794661, -0.08077226, -0.016179763, 0.1169753, 1507 0.17213494, -0.0056326236, -0.053934924, -0.0124349, 1508 -0.11520337, 0.05409887, 0.088759385, 0.0019655675, 1509 0.0042065294, 0.03881498, 0.019844765, 0.041858196, 1510 -0.05695512, 0.047233116, 0.038937137, -0.06542224, 1511 0.014429736, -0.09719407, 0.13908425, -0.05379757, 1512 0.012321099, 0.082840554, -0.029899208, 0.044217527, 1513 0.059855383, 0.07711018, -0.045319796, 0.0948846, 1514 -0.011724666, -0.0033288454, -0.033542685, -0.04764985, 1515 -0.13873616, 0.040668588, 0.034832682, -0.015319203, 1516 -0.018715994, 0.046002675, 0.0599172, -0.043107376, 1517 0.0294216, -0.002314414, -0.022424703, 0.0030315618, 1518 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, 1519 0.12375372, -0.0006038222, 0.029104086, 0.087442465, 1520 0.052958444, 0.07558703, 0.04817258, 0.044462286, 1521 -0.015213451, -0.08783778, -0.0561384, -0.003008196, 1522 0.047060397, -0.002058388, 0.03429439, -0.018839769, 1523 0.024734668, 0.024614193, -0.042046934, 0.09597743, 1524 -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, 1525 -0.02558259, -0.022822596, -0.023273505, -0.02464396, 1526 -0.10991725, -0.006240552, 0.0074488563, 0.024044557, 1527 0.04383914, -0.046476185, 0.028658995, 0.060410924, 1528 0.050786525, 0.009452605, -0.0073054377, -0.024810238, 1529 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, 1530 0.015898481, 0.021362653, -0.030262267, 0.016587038, 1531 -0.011442813, 0.041154444, -0.007631438, -0.03423484, 1532 -0.010977775, 0.036152758, 0.0066366293, 0.11915515, 1533 0.02318443, -0.041350313, 0.021485701, -0.10906167, 1534 -0.028218046, -0.00954771, 0.020531068, -0.11995105, 1535 -0.03672871, 0.024019798, 0.014255957, -0.05221243, 1536 -0.00661567, -0.04630967, 0.033188973, 0.10107534, 1537 -0.014027541, 0.030796422, -0.10270911, -0.035999842, 1538 0.15443139, 0.07684145, 0.036571592, -0.035900835, 1539 -0.0034699554, 0.06209149, 0.015920248, -0.031122351, 1540 -0.03858649, 0.01849943, 0.13872518, 0.01503974, 1541 0.069941424, -0.06948533, -0.0088794185, 0.061282158, 1542 -0.047401894, 0.03100163, -0.041533746, -0.10430945, 1543 0.044574402, -0.01425562, -0.024290353, 0.034563623, 1544 0.05866852, 0.023947537, -0.09445152, 0.035450947, 1545 0.02247216, -0.0042998926, 0.061146557, -0.10250651, 1546 0.020881841, -0.06747029, 0.10062043, -0.0023941975, 1547 0.03532124, -0.016341697, 0.09685456, -0.016764693, 1548 0.051808182, 0.05875331, -0.04536488, 0.001626336, 1549 -0.028892258, -0.01048663, -0.009793449, -0.017093895, 1550 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, 1551 -0.001845119, -0.03551521, 0.0018358806, 0.05763657, 1552 -0.01769146, 0.040995963, 0.02235177, -0.060430344, 1553 0.11475477, -0.023854522, 0.10071741, 0.0686208, 1554 -0.014250481, 0.034261297, 0.047418304, 0.08562733, 1555 -0.030519066, 0.0060542435, 0.014653856, -0.038836084, 1556 0.04096551, 0.032249358, -0.08355519, -0.026823482, 1557 0.056386515, -0.010401743, -0.028396193, 0.08507674, 1558 0.014410365, 0.020995233, 0.17040324, 0.11511526, 1559 0.02459721, 0.0066619175, 0.025853224, -0.023133837, 1560 -0.081302024, 0.017264642, -0.009585969, 0.09491168, 1561 -0.051313367, 0.054532815, -0.014298593, 0.10657464, 1562 0.007076659, 0.10964551, 0.0409152, 0.008275321, 1563 -0.07283536, 0.07937492, 0.04192024, -0.1075027}); 1564 1565 lstm.SetRecurrentToCellWeights( 1566 {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, 1567 0.055647098, -0.05713207, -0.05626563, 0.005559383, 1568 0.03375411, -0.025757805, -0.088049285, 0.06017052, 1569 -0.06570978, 0.007384076, 0.035123326, -0.07920549, 1570 0.053676967, 0.044480428, -0.07663568, 0.0071805613, 1571 0.08089997, 0.05143358, 0.038261272, 0.03339287, 1572 -0.027673481, 0.044746667, 0.028349208, 0.020090483, 1573 -0.019443132, -0.030755889, -0.0040000007, 0.04465846, 1574 -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, 1575 -0.10893326, 0.076739706, -0.08509834, -0.027997585, 1576 0.037871376, 0.01449768, -0.09002357, -0.06111149, 1577 -0.046195522, 0.0422062, -0.005683705, -0.1253618, 1578 -0.012925729, -0.04890792, 0.06985068, 0.037654128, 1579 0.03398274, -0.004781977, 0.007032333, -0.031787455, 1580 0.010868644, -0.031489216, 0.09525667, 0.013939797, 1581 0.0058680447, 0.0167067, 0.02668468, -0.04797466, 1582 -0.048885044, -0.12722108, 0.035304096, 0.06554885, 1583 0.00972396, -0.039238118, -0.05159735, -0.11329045, 1584 0.1613692, -0.03750952, 0.06529313, -0.071974665, 1585 -0.11769596, 0.015524369, -0.0013754242, -0.12446318, 1586 0.02786344, -0.014179351, 0.005264273, 0.14376344, 1587 0.015983658, 0.03406988, -0.06939408, 0.040699873, 1588 0.02111075, 0.09669095, 0.041345075, -0.08316494, 1589 -0.07684199, -0.045768797, 0.032298047, -0.041805092, 1590 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, 1591 -0.024950314, 0.11574242, 0.04508852, -0.04335324, 1592 0.06760663, -0.027437469, 0.07216407, 0.06977076, 1593 -0.05438599, 0.034033038, -0.028602652, 0.05346137, 1594 0.043184172, -0.037189785, 0.10420091, 0.00882477, 1595 -0.054019816, -0.074273005, -0.030617684, -0.0028467078, 1596 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, 1597 0.04361412, -0.007001822, 0.09631092, -0.06702025, 1598 -0.042049985, -0.035070654, -0.04103342, -0.10273396, 1599 0.0544271, 0.037184782, -0.13150354, -0.0058036847, 1600 -0.008264958, 0.042035464, 0.05891794, 0.029673764, 1601 0.0063542654, 0.044788733, 0.054816857, 0.062257513, 1602 -0.00093483756, 0.048938446, -0.004952862, -0.007730018, 1603 -0.04043371, -0.017094059, 0.07229206, -0.023670016, 1604 -0.052195564, -0.025616996, -0.01520939, 0.045104615, 1605 -0.007376126, 0.003533447, 0.006570588, 0.056037236, 1606 0.12436656, 0.051817212, 0.028532185, -0.08686856, 1607 0.11868599, 0.07663395, -0.07323171, 0.03463402, 1608 -0.050708205, -0.04458982, -0.11590894, 0.021273347, 1609 0.1251325, -0.15313013, -0.12224372, 0.17228661, 1610 0.023029093, 0.086124025, 0.006445803, -0.03496501, 1611 0.028332196, 0.04449512, -0.042436164, -0.026587414, 1612 -0.006041347, -0.09292539, -0.05678812, 0.03897832, 1613 0.09465633, 0.008115513, -0.02171956, 0.08304309, 1614 0.071401566, 0.019622514, 0.032163795, -0.004167056, 1615 0.02295182, 0.030739572, 0.056506045, 0.004612461, 1616 0.06524936, 0.059999723, 0.046395954, -0.0045512207, 1617 -0.1335546, -0.030136576, 0.11584653, -0.014678886, 1618 0.0020118146, -0.09688814, -0.0790206, 0.039770417, 1619 -0.0329582, 0.07922767, 0.029322514, 0.026405897, 1620 0.04207835, -0.07073373, 0.063781224, 0.0859677, 1621 -0.10925287, -0.07011058, 0.048005477, 0.03438226, 1622 -0.09606514, -0.006669445, -0.043381985, 0.04240257, 1623 -0.06955775, -0.06769346, 0.043903265, -0.026784198, 1624 -0.017840602, 0.024307009, -0.040079936, -0.019946516, 1625 0.045318738, -0.12233574, 0.026170589, 0.0074471775, 1626 0.15978073, 0.10185836, 0.10298046, -0.015476589, 1627 -0.039390966, -0.072174534, 0.0739445, -0.1211869, 1628 -0.0347889, -0.07943156, 0.014809798, -0.12412325, 1629 -0.0030663363, 0.039695457, 0.0647603, -0.08291318, 1630 -0.018529687, -0.004423833, 0.0037507233, 0.084633216, 1631 -0.01514876, -0.056505352, -0.012800942, -0.06994386, 1632 0.012962922, -0.031234352, 0.07029052, 0.016418684, 1633 0.03618972, 0.055686004, -0.08663945, -0.017404709, 1634 -0.054761406, 0.029065743, 0.052404847, 0.020238016, 1635 0.0048197987, -0.0214882, 0.07078733, 0.013016777, 1636 0.06262858, 0.009184685, 0.020785125, -0.043904778, 1637 -0.0270329, -0.03299152, -0.060088247, -0.015162964, 1638 -0.001828936, 0.12642565, -0.056757294, 0.013586685, 1639 0.09232601, -0.035886683, 0.06000002, 0.05229691, 1640 -0.052580316, -0.082029596, -0.010794592, 0.012947712, 1641 -0.036429964, -0.085508935, -0.13127148, -0.017744139, 1642 0.031502828, 0.036232427, -0.031581745, 0.023051167, 1643 -0.05325106, -0.03421577, 0.028793324, -0.034633752, 1644 -0.009881397, -0.043551125, -0.018609839, 0.0019097115, 1645 -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); 1646 1647 lstm.SetRecurrentToOutputWeights({ 1648 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, 1649 -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, 1650 -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, 1651 -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, 1652 -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, 1653 -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, 1654 -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, 1655 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, 1656 -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, 1657 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, 1658 -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, 1659 -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, 1660 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, 1661 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, 1662 -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, 1663 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, 1664 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, 1665 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, 1666 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, 1667 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, 1668 -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, 1669 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, 1670 -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, 1671 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, 1672 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, 1673 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, 1674 -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, 1675 -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, 1676 -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, 1677 -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, 1678 -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, 1679 -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, 1680 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, 1681 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, 1682 -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, 1683 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, 1684 -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, 1685 -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, 1686 -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, 1687 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, 1688 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, 1689 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, 1690 -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, 1691 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, 1692 -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, 1693 -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, 1694 -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, 1695 -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, 1696 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, 1697 -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, 1698 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, 1699 -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, 1700 -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, 1701 -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, 1702 -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, 1703 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, 1704 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, 1705 -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, 1706 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, 1707 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, 1708 -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, 1709 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, 1710 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, 1711 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, 1712 }); 1713 1714 lstm.SetCellToInputWeights( 1715 {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, 1716 -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, 1717 -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, 1718 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); 1719 1720 lstm.SetCellToForgetWeights( 1721 {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, 1722 -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, 1723 -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, 1724 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); 1725 1726 lstm.SetCellToOutputWeights( 1727 {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, 1728 -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, 1729 -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, 1730 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); 1731 1732 lstm.SetProjectionWeights( 1733 {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, 1734 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, 1735 -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, 1736 -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, 1737 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, 1738 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, 1739 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, 1740 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, 1741 -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, 1742 -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, 1743 -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, 1744 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, 1745 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, 1746 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, 1747 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, 1748 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, 1749 -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, 1750 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, 1751 -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, 1752 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, 1753 -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, 1754 -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, 1755 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, 1756 -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, 1757 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, 1758 -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, 1759 -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, 1760 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, 1761 -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, 1762 -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, 1763 -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, 1764 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, 1765 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, 1766 -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, 1767 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, 1768 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, 1769 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, 1770 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, 1771 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, 1772 -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, 1773 -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, 1774 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, 1775 -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, 1776 -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, 1777 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, 1778 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, 1779 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, 1780 -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, 1781 -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, 1782 -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, 1783 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, 1784 -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, 1785 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, 1786 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, 1787 -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, 1788 -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, 1789 -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, 1790 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, 1791 -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, 1792 -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, 1793 -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, 1794 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, 1795 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, 1796 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); 1797 1798 static float lstm_input[][20] = { 1799 {// Batch0: 4 (input_sequence_size) * 5 (n_input) 1800 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, 1801 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, 1802 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, 1803 1804 {// Batch1: 4 (input_sequence_size) * 5 (n_input) 1805 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, 1806 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, 1807 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; 1808 1809 static float lstm_fw_golden_output[][64] = { 1810 {// Batch0: 4 (input_sequence_size) * 16 (n_output) 1811 -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, 1812 -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, 1813 -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, 1814 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, 1815 -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, 1816 -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, 1817 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, 1818 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, 1819 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, 1820 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, 1821 -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, 1822 -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, 1823 0.0286833, 0.00824207, 0.0264887, 0.0305169}, 1824 {// Batch1: 4 (input_sequence_size) * 16 (n_output) 1825 -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, 1826 -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, 1827 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, 1828 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, 1829 -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, 1830 -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, 1831 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, 1832 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, 1833 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, 1834 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, 1835 -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, 1836 -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, 1837 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; 1838 1839 static float lstm_combined_golden_output[][64] = { 1840 {-0.022014, 0.073544, -0.002235, 0.040068, -0.037136, -0.052788, 1841 0.075325, -0.029378, 0.024298, -0.07733, -0.030674, -0.060229, 1842 0.040599, 0.011608, 0.042005, 0.045977, -0.039225, 0.076294, 1843 0.000735, 0.032852, -0.069869, -0.053312, 0.073527, -0.028136, 1844 0.021585, -0.102679, -0.004327, -0.043304, 0.072861, 0.027077, 1845 0.034558, 0.068292, -0.036292, 0.069832, -0.003032, 0.053829, 1846 -0.043821, -0.072713, 0.085029, -0.040374, 0.020014, -0.104521, 1847 -0.034504, -0.059759, 0.062569, 0.025652, 0.049306, 0.061189, 1848 -0.025146, 0.079643, -0.005188, 0.033080, -0.048079, -0.048082, 1849 0.069369, -0.028900, 0.024572, -0.077547, -0.022517, -0.054477, 1850 0.038857, 0.013336, 0.043234, 0.044788}, 1851 {-0.039186, 0.070792, -0.005913, 0.02642, -0.068274, -0.05022, 1852 0.061444, -0.031241, 0.014996, -0.094544, -0.004146, -0.03464, 1853 0.058981, 0.026097, 0.039781, 0.058408, -0.031887, 0.069252, 1854 0.00576, 0.054062, -0.042801, -0.059974, 0.085272, -0.034453, 1855 0.026097, -0.0959, -0.031164, -0.058699, 0.06839, 0.020512, 1856 0.044727, 0.063609, -0.039863, 0.084819, -0.003909, 0.028666, 1857 -0.075677, -0.045125, 0.070379, -0.033895, 0.022111, -0.097184, 1858 -0.004921, -0.040851, 0.062316, 0.017435, 0.041437, 0.064568, 1859 -0.039656, 0.060726, -0.003402, 0.036854, -0.056503, -0.058554, 1860 0.068588, -0.034879, 0.01352, -0.09962, -0.01434, -0.039505, 1861 0.065133, 0.024321, 0.038473, 0.062438}}; 1862 1863 for (int i = 0; i < lstm.sequence_length(); i++) { 1864 float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); 1865 float* batch0_end = batch0_start + lstm.num_inputs(); 1866 1867 lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end); 1868 1869 float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); 1870 float* batch1_end = batch1_start + lstm.num_inputs(); 1871 lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end); 1872 } 1873 1874 lstm.Invoke(); 1875 1876 std::vector<float> expected; 1877 for (int i = 0; i < lstm.sequence_length(); i++) { 1878 float* golden_start_batch0 = 1879 lstm_fw_golden_output[0] + i * lstm.num_fw_outputs(); 1880 float* golden_end_batch0 = golden_start_batch0 + lstm.num_fw_outputs(); 1881 float* golden_start_batch1 = 1882 lstm_fw_golden_output[1] + i * lstm.num_fw_outputs(); 1883 float* golden_end_batch1 = golden_start_batch1 + lstm.num_fw_outputs(); 1884 expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); 1885 expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); 1886 } 1887 EXPECT_THAT(lstm.GetFwOutput(), ElementsAreArray(ArrayFloatNear(expected))); 1888 1889 // Check if the sum of forward backward matches the golden. 1890 expected.clear(); 1891 for (int i = 0; i < lstm.sequence_length(); i++) { 1892 float* golden_start_batch0 = 1893 lstm_combined_golden_output[0] + i * lstm.num_fw_outputs(); 1894 float* golden_end_batch0 = golden_start_batch0 + lstm.num_fw_outputs(); 1895 float* golden_start_batch1 = 1896 lstm_combined_golden_output[1] + i * lstm.num_fw_outputs(); 1897 float* golden_end_batch1 = golden_start_batch1 + lstm.num_fw_outputs(); 1898 expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); 1899 expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); 1900 } 1901 1902 std::vector<float> combined; 1903 for (int i = 0; i < lstm.GetFwOutput().size(); ++i) { 1904 combined.push_back(lstm.GetFwOutput()[i] + lstm.GetBwOutput()[i]); 1905 } 1906 EXPECT_THAT(combined, ElementsAreArray(ArrayFloatNear(expected))); 1907 } 1908 1909 // Same as above but with batch_major input/output. 1910 TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClippingBatchMajor) { 1911 const int n_batch = 2; 1912 const int n_input = 5; 1913 const int n_cell = 20; 1914 const int n_output = 16; 1915 const int sequence_length = 4; 1916 1917 BidirectionalLSTMOpModel lstm( 1918 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, 1919 /*use_peephole=*/true, /*use_projection_weights=*/true, 1920 /*use_projection_bias=*/false, /*merge_outputs=*/false, 1921 /*use_aux_input=*/false, /*cell_clip=*/0.0, 1922 /*proj_clip=*/0.0, /*quantize_weights=*/false, /*time_major=*/false, 1923 { 1924 {n_batch, sequence_length, n_input}, // input tensor 1925 1926 {n_cell, n_input}, // input_to_input_weight tensor 1927 {n_cell, n_input}, // input_to_forget_weight tensor 1928 {n_cell, n_input}, // input_to_cell_weight tensor 1929 {n_cell, n_input}, // input_to_output_weight tensor 1930 1931 {n_cell, n_output}, // recurrent_to_input_weight tensor 1932 {n_cell, n_output}, // recurrent_to_forget_weight tensor 1933 {n_cell, n_output}, // recurrent_to_cell_weight tensor 1934 {n_cell, n_output}, // recurrent_to_output_weight tensor 1935 1936 {n_cell}, // cell_to_input_weight tensor 1937 {n_cell}, // cell_to_forget_weight tensor 1938 {n_cell}, // cell_to_output_weight tensor 1939 1940 {n_cell}, // input_gate_bias tensor 1941 {n_cell}, // forget_gate_bias tensor 1942 {n_cell}, // cell_bias tensor 1943 {n_cell}, // output_gate_bias tensor 1944 1945 {n_output, n_cell}, // projection_weight tensor 1946 {0}, // projection_bias tensor 1947 1948 {n_cell, n_input}, // input_to_input_weight tensor 1949 {n_cell, n_input}, // input_to_forget_weight tensor 1950 {n_cell, n_input}, // input_to_cell_weight tensor 1951 {n_cell, n_input}, // input_to_output_weight tensor 1952 1953 {n_cell, n_output}, // recurrent_to_input_weight tensor 1954 {n_cell, n_output}, // recurrent_to_forget_weight tensor 1955 {n_cell, n_output}, // recurrent_to_cell_weight tensor 1956 {n_cell, n_output}, // recurrent_to_output_weight tensor 1957 1958 {n_cell}, // cell_to_input_weight tensor 1959 {n_cell}, // cell_to_forget_weight tensor 1960 {n_cell}, // cell_to_output_weight tensor 1961 1962 {n_cell}, // input_gate_bias tensor 1963 {n_cell}, // forget_gate_bias tensor 1964 {n_cell}, // cell_bias tensor 1965 {n_cell}, // output_gate_bias tensor 1966 1967 {n_output, n_cell}, // projection_weight tensor 1968 {0}, // projection_bias tensor 1969 1970 {n_batch, n_output}, // activation_state tensor 1971 {n_batch, n_cell}, // cell_state tensor 1972 1973 {n_batch, n_output}, // activation_state tensor 1974 {n_batch, n_cell}, // cell_state tensor 1975 1976 {n_batch, sequence_length, 0}, // aux_input tensor 1977 {n_cell, 0}, // aux_fw_input_to_input tensor 1978 {n_cell, 0}, // aux_fw_input_to_forget tensor 1979 {n_cell, 0}, // aux_fw_input_to_cell tensor 1980 {n_cell, 0}, // aux_fw_input_to_output tensor 1981 {n_cell, 0}, // aux_bw_input_to_input tensor 1982 {n_cell, 0}, // aux_bw_input_to_forget tensor 1983 {n_cell, 0}, // aux_bw_input_to_cell tensor 1984 {n_cell, 0}, // aux_bw_input_to_output tensor 1985 }); 1986 1987 lstm.SetInputToInputWeights( 1988 {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, 1989 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, 1990 -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, 1991 -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, 1992 -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, 1993 -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, 1994 -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, 1995 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, 1996 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, 1997 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, 1998 -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, 1999 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, 2000 -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, 2001 -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, 2002 -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, 2003 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, 2004 -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, 2005 -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, 2006 -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, 2007 -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); 2008 2009 lstm.SetInputToForgetWeights( 2010 {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, 2011 -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, 2012 -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, 2013 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, 2014 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, 2015 -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, 2016 -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, 2017 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, 2018 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, 2019 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, 2020 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, 2021 -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, 2022 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, 2023 -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, 2024 -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, 2025 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, 2026 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, 2027 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, 2028 -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, 2029 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); 2030 2031 lstm.SetInputToCellWeights( 2032 {-0.04580283, -0.09549462, -0.032418985, -0.06454633, 2033 -0.043528453, 0.043018587, -0.049152344, -0.12418144, 2034 -0.078985475, -0.07596889, 0.019484362, -0.11434962, 2035 -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, 2036 -0.025034338, -0.0028890965, 0.048929527, 0.06235075, 2037 0.10665918, -0.032036792, -0.08505916, -0.10843358, 2038 -0.13002433, -0.036816437, -0.02130134, -0.016518239, 2039 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, 2040 -0.10652836, -0.1037554, -0.13056071, -0.03266643, 2041 -0.033702414, -0.006473424, -0.04611692, 0.014419339, 2042 -0.025174323, 0.0396852, 0.081777506, 0.06157468, 2043 0.10210095, -0.009658194, 0.046511717, 0.03603906, 2044 0.0069369148, 0.015960095, -0.06507666, 0.09551598, 2045 0.053568836, 0.06408714, 0.12835667, -0.008714329, 2046 -0.20211966, -0.12093674, 0.029450472, 0.2849013, 2047 -0.029227901, 0.1164364, -0.08560263, 0.09941786, 2048 -0.036999565, -0.028842626, -0.0033637602, -0.017012902, 2049 -0.09720865, -0.11193351, -0.029155117, -0.017936034, 2050 -0.009768936, -0.04223324, -0.036159635, 0.06505112, 2051 -0.021742892, -0.023377212, -0.07221364, -0.06430552, 2052 0.05453865, 0.091149814, 0.06387331, 0.007518393, 2053 0.055960953, 0.069779344, 0.046411168, 0.10509911, 2054 0.07463894, 0.0075130584, 0.012850982, 0.04555431, 2055 0.056955688, 0.06555285, 0.050801456, -0.009862683, 2056 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); 2057 2058 lstm.SetInputToOutputWeights( 2059 {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, 2060 -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, 2061 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, 2062 -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, 2063 -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, 2064 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, 2065 -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, 2066 -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, 2067 -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, 2068 -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, 2069 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, 2070 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, 2071 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, 2072 -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, 2073 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, 2074 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, 2075 -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, 2076 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, 2077 -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, 2078 -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); 2079 2080 lstm.SetInputGateBias( 2081 {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, 2082 -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, 2083 -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, 2084 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); 2085 2086 lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, 2087 0.11098921, 0.15378423, 0.09263801, 0.09790885, 2088 0.09508917, 0.061199076, 0.07665568, -0.015443159, 2089 -0.03499149, 0.046190713, 0.08895977, 0.10899629, 2090 0.40694186, 0.06030037, 0.012413437, -0.06108739}); 2091 2092 lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, 2093 -0.1483596, -0.10639995, -0.091433935, 0.058573797, 2094 -0.06809782, -0.07889636, -0.043246906, -0.09829136, 2095 -0.4279842, 0.034901652, 0.18797937, 0.0075234566, 2096 0.016178843, 0.1749513, 0.13975595, 0.92058027}); 2097 2098 lstm.SetOutputGateBias( 2099 {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, 2100 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, 2101 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, 2102 -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); 2103 2104 lstm.SetRecurrentToInputWeights( 2105 {-0.001374326, -0.078856036, 0.10672688, 0.029162422, 2106 -0.11585556, 0.02557986, -0.13446963, -0.035785314, 2107 -0.01244275, 0.025961924, -0.02337298, -0.044228926, 2108 -0.055839065, -0.046598054, -0.010546039, -0.06900766, 2109 0.027239809, 0.022582639, -0.013296484, -0.05459212, 2110 0.08981, -0.045407712, 0.08682226, -0.06867011, 2111 -0.14390695, -0.02916037, 0.000996957, 0.091420636, 2112 0.14283475, -0.07390571, -0.06402044, 0.062524505, 2113 -0.093129106, 0.04860203, -0.08364217, -0.08119002, 2114 0.009352075, 0.22920375, 0.0016303885, 0.11583097, 2115 -0.13732095, 0.012405723, -0.07551853, 0.06343048, 2116 0.12162708, -0.031923793, -0.014335606, 0.01790974, 2117 -0.10650317, -0.0724401, 0.08554849, -0.05727212, 2118 0.06556731, -0.042729504, -0.043227166, 0.011683251, 2119 -0.013082158, -0.029302018, -0.010899579, -0.062036745, 2120 -0.022509435, -0.00964907, -0.01567329, 0.04260106, 2121 -0.07787477, -0.11576462, 0.017356863, 0.048673786, 2122 -0.017577527, -0.05527947, -0.082487635, -0.040137455, 2123 -0.10820036, -0.04666372, 0.022746278, -0.07851417, 2124 0.01068115, 0.032956902, 0.022433773, 0.0026891115, 2125 0.08944216, -0.0685835, 0.010513544, 0.07228705, 2126 0.02032331, -0.059686817, -0.0005566496, -0.086984694, 2127 0.040414046, -0.1380399, 0.094208956, -0.05722982, 2128 0.012092817, -0.04989123, -0.086576, -0.003399834, 2129 -0.04696032, -0.045747425, 0.10091314, 0.048676282, 2130 -0.029037097, 0.031399418, -0.0040285117, 0.047237843, 2131 0.09504992, 0.041799378, -0.049185462, -0.031518843, 2132 -0.10516937, 0.026374253, 0.10058866, -0.0033195973, 2133 -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, 2134 -0.10167381, 0.042500053, -0.01447153, 0.06464186, 2135 -0.017142897, 0.03312627, 0.009205989, 0.024138335, 2136 -0.011337001, 0.035530265, -0.010912711, 0.0706555, 2137 -0.005894094, 0.051841937, -0.1401738, -0.02351249, 2138 0.0365468, 0.07590991, 0.08838724, 0.021681072, 2139 -0.10086113, 0.019608743, -0.06195883, 0.077335775, 2140 0.023646897, -0.095322326, 0.02233014, 0.09756986, 2141 -0.048691444, -0.009579111, 0.07595467, 0.11480546, 2142 -0.09801813, 0.019894179, 0.08502348, 0.004032281, 2143 0.037211012, 0.068537936, -0.048005626, -0.091520436, 2144 -0.028379958, -0.01556313, 0.06554592, -0.045599163, 2145 -0.01672207, -0.020169014, -0.011877351, -0.20212261, 2146 0.010889619, 0.0047078193, 0.038385306, 0.08540671, 2147 -0.017140968, -0.0035865551, 0.016678626, 0.005633034, 2148 0.015963363, 0.00871737, 0.060130805, 0.028611384, 2149 0.10109069, -0.015060172, -0.07894427, 0.06401885, 2150 0.011584063, -0.024466386, 0.0047652307, -0.09041358, 2151 0.030737216, -0.0046374933, 0.14215417, -0.11823516, 2152 0.019899689, 0.006106124, -0.027092824, 0.0786356, 2153 0.05052217, -0.058925, -0.011402121, -0.024987547, 2154 -0.0013661642, -0.06832946, -0.015667673, -0.1083353, 2155 -0.00096863037, -0.06988685, -0.053350925, -0.027275559, 2156 -0.033664223, -0.07978348, -0.025200296, -0.017207067, 2157 -0.058403496, -0.055697463, 0.005798788, 0.12965427, 2158 -0.062582195, 0.0013350133, -0.10482091, 0.0379771, 2159 0.072521195, -0.0029455067, -0.13797039, -0.03628521, 2160 0.013806405, -0.017858358, -0.01008298, -0.07700066, 2161 -0.017081132, 0.019358726, 0.0027079724, 0.004635139, 2162 0.062634714, -0.02338735, -0.039547626, -0.02050681, 2163 0.03385117, -0.083611414, 0.002862572, -0.09421313, 2164 0.058618143, -0.08598433, 0.00972939, 0.023867095, 2165 -0.053934585, -0.023203006, 0.07452513, -0.048767887, 2166 -0.07314807, -0.056307215, -0.10433547, -0.06440842, 2167 0.04328182, 0.04389765, -0.020006588, -0.09076438, 2168 -0.11652589, -0.021705797, 0.03345259, -0.010329105, 2169 -0.025767034, 0.013057034, -0.07316461, -0.10145612, 2170 0.06358255, 0.18531723, 0.07759293, 0.12006465, 2171 0.1305557, 0.058638252, -0.03393652, 0.09622831, 2172 -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, 2173 -0.005644518, 0.06857898, -0.12598175, -0.035084512, 2174 0.03156317, -0.12794146, -0.031963028, 0.04692781, 2175 0.030070418, 0.0071660685, -0.095516115, -0.004643372, 2176 0.040170413, -0.062104587, -0.0037324072, 0.0554317, 2177 0.08184801, -0.019164372, 0.06791302, 0.034257166, 2178 -0.10307039, 0.021943003, 0.046745934, 0.0790918, 2179 -0.0265588, -0.007824208, 0.042546265, -0.00977924, 2180 -0.0002440307, -0.017384544, -0.017990116, 0.12252321, 2181 -0.014512694, -0.08251313, 0.08861942, 0.13589665, 2182 0.026351685, 0.012641483, 0.07466548, 0.044301085, 2183 -0.045414884, -0.051112458, 0.03444247, -0.08502782, 2184 -0.04106223, -0.028126027, 0.028473156, 0.10467447}); 2185 2186 lstm.SetRecurrentToForgetWeights( 2187 {-0.057784554, -0.026057621, -0.068447545, -0.022581743, 2188 0.14811787, 0.10826372, 0.09471067, 0.03987225, 2189 -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, 2190 0.08414449, -0.022036452, -0.00066928595, -0.09203576, 2191 0.032950465, -0.10985798, -0.023809856, 0.0021431844, 2192 -0.02196096, -0.00326074, 0.00058621005, -0.074678116, 2193 -0.06193199, 0.055729095, 0.03736828, 0.020123724, 2194 0.061878487, -0.04729229, 0.034919553, -0.07585433, 2195 -0.04421272, -0.044019096, 0.085488975, 0.04058006, 2196 -0.06890133, -0.030951202, -0.024628663, -0.07672815, 2197 0.034293607, 0.08556707, -0.05293577, -0.033561368, 2198 -0.04899627, 0.0241671, 0.015736353, -0.095442444, 2199 -0.029564252, 0.016493602, -0.035026584, 0.022337519, 2200 -0.026871363, 0.004780428, 0.0077918363, -0.03601621, 2201 0.016435321, -0.03263031, -0.09543275, -0.047392778, 2202 0.013454138, 0.028934088, 0.01685226, -0.086110644, 2203 -0.046250615, -0.01847454, 0.047608484, 0.07339695, 2204 0.034546845, -0.04881143, 0.009128804, -0.08802852, 2205 0.03761666, 0.008096139, -0.014454086, 0.014361001, 2206 -0.023502491, -0.0011840804, -0.07607001, 0.001856849, 2207 -0.06509276, -0.006021153, -0.08570962, -0.1451793, 2208 0.060212336, 0.055259194, 0.06974018, 0.049454916, 2209 -0.027794661, -0.08077226, -0.016179763, 0.1169753, 2210 0.17213494, -0.0056326236, -0.053934924, -0.0124349, 2211 -0.11520337, 0.05409887, 0.088759385, 0.0019655675, 2212 0.0042065294, 0.03881498, 0.019844765, 0.041858196, 2213 -0.05695512, 0.047233116, 0.038937137, -0.06542224, 2214 0.014429736, -0.09719407, 0.13908425, -0.05379757, 2215 0.012321099, 0.082840554, -0.029899208, 0.044217527, 2216 0.059855383, 0.07711018, -0.045319796, 0.0948846, 2217 -0.011724666, -0.0033288454, -0.033542685, -0.04764985, 2218 -0.13873616, 0.040668588, 0.034832682, -0.015319203, 2219 -0.018715994, 0.046002675, 0.0599172, -0.043107376, 2220 0.0294216, -0.002314414, -0.022424703, 0.0030315618, 2221 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, 2222 0.12375372, -0.0006038222, 0.029104086, 0.087442465, 2223 0.052958444, 0.07558703, 0.04817258, 0.044462286, 2224 -0.015213451, -0.08783778, -0.0561384, -0.003008196, 2225 0.047060397, -0.002058388, 0.03429439, -0.018839769, 2226 0.024734668, 0.024614193, -0.042046934, 0.09597743, 2227 -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, 2228 -0.02558259, -0.022822596, -0.023273505, -0.02464396, 2229 -0.10991725, -0.006240552, 0.0074488563, 0.024044557, 2230 0.04383914, -0.046476185, 0.028658995, 0.060410924, 2231 0.050786525, 0.009452605, -0.0073054377, -0.024810238, 2232 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, 2233 0.015898481, 0.021362653, -0.030262267, 0.016587038, 2234 -0.011442813, 0.041154444, -0.007631438, -0.03423484, 2235 -0.010977775, 0.036152758, 0.0066366293, 0.11915515, 2236 0.02318443, -0.041350313, 0.021485701, -0.10906167, 2237 -0.028218046, -0.00954771, 0.020531068, -0.11995105, 2238 -0.03672871, 0.024019798, 0.014255957, -0.05221243, 2239 -0.00661567, -0.04630967, 0.033188973, 0.10107534, 2240 -0.014027541, 0.030796422, -0.10270911, -0.035999842, 2241 0.15443139, 0.07684145, 0.036571592, -0.035900835, 2242 -0.0034699554, 0.06209149, 0.015920248, -0.031122351, 2243 -0.03858649, 0.01849943, 0.13872518, 0.01503974, 2244 0.069941424, -0.06948533, -0.0088794185, 0.061282158, 2245 -0.047401894, 0.03100163, -0.041533746, -0.10430945, 2246 0.044574402, -0.01425562, -0.024290353, 0.034563623, 2247 0.05866852, 0.023947537, -0.09445152, 0.035450947, 2248 0.02247216, -0.0042998926, 0.061146557, -0.10250651, 2249 0.020881841, -0.06747029, 0.10062043, -0.0023941975, 2250 0.03532124, -0.016341697, 0.09685456, -0.016764693, 2251 0.051808182, 0.05875331, -0.04536488, 0.001626336, 2252 -0.028892258, -0.01048663, -0.009793449, -0.017093895, 2253 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, 2254 -0.001845119, -0.03551521, 0.0018358806, 0.05763657, 2255 -0.01769146, 0.040995963, 0.02235177, -0.060430344, 2256 0.11475477, -0.023854522, 0.10071741, 0.0686208, 2257 -0.014250481, 0.034261297, 0.047418304, 0.08562733, 2258 -0.030519066, 0.0060542435, 0.014653856, -0.038836084, 2259 0.04096551, 0.032249358, -0.08355519, -0.026823482, 2260 0.056386515, -0.010401743, -0.028396193, 0.08507674, 2261 0.014410365, 0.020995233, 0.17040324, 0.11511526, 2262 0.02459721, 0.0066619175, 0.025853224, -0.023133837, 2263 -0.081302024, 0.017264642, -0.009585969, 0.09491168, 2264 -0.051313367, 0.054532815, -0.014298593, 0.10657464, 2265 0.007076659, 0.10964551, 0.0409152, 0.008275321, 2266 -0.07283536, 0.07937492, 0.04192024, -0.1075027}); 2267 2268 lstm.SetRecurrentToCellWeights( 2269 {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, 2270 0.055647098, -0.05713207, -0.05626563, 0.005559383, 2271 0.03375411, -0.025757805, -0.088049285, 0.06017052, 2272 -0.06570978, 0.007384076, 0.035123326, -0.07920549, 2273 0.053676967, 0.044480428, -0.07663568, 0.0071805613, 2274 0.08089997, 0.05143358, 0.038261272, 0.03339287, 2275 -0.027673481, 0.044746667, 0.028349208, 0.020090483, 2276 -0.019443132, -0.030755889, -0.0040000007, 0.04465846, 2277 -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, 2278 -0.10893326, 0.076739706, -0.08509834, -0.027997585, 2279 0.037871376, 0.01449768, -0.09002357, -0.06111149, 2280 -0.046195522, 0.0422062, -0.005683705, -0.1253618, 2281 -0.012925729, -0.04890792, 0.06985068, 0.037654128, 2282 0.03398274, -0.004781977, 0.007032333, -0.031787455, 2283 0.010868644, -0.031489216, 0.09525667, 0.013939797, 2284 0.0058680447, 0.0167067, 0.02668468, -0.04797466, 2285 -0.048885044, -0.12722108, 0.035304096, 0.06554885, 2286 0.00972396, -0.039238118, -0.05159735, -0.11329045, 2287 0.1613692, -0.03750952, 0.06529313, -0.071974665, 2288 -0.11769596, 0.015524369, -0.0013754242, -0.12446318, 2289 0.02786344, -0.014179351, 0.005264273, 0.14376344, 2290 0.015983658, 0.03406988, -0.06939408, 0.040699873, 2291 0.02111075, 0.09669095, 0.041345075, -0.08316494, 2292 -0.07684199, -0.045768797, 0.032298047, -0.041805092, 2293 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, 2294 -0.024950314, 0.11574242, 0.04508852, -0.04335324, 2295 0.06760663, -0.027437469, 0.07216407, 0.06977076, 2296 -0.05438599, 0.034033038, -0.028602652, 0.05346137, 2297 0.043184172, -0.037189785, 0.10420091, 0.00882477, 2298 -0.054019816, -0.074273005, -0.030617684, -0.0028467078, 2299 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, 2300 0.04361412, -0.007001822, 0.09631092, -0.06702025, 2301 -0.042049985, -0.035070654, -0.04103342, -0.10273396, 2302 0.0544271, 0.037184782, -0.13150354, -0.0058036847, 2303 -0.008264958, 0.042035464, 0.05891794, 0.029673764, 2304 0.0063542654, 0.044788733, 0.054816857, 0.062257513, 2305 -0.00093483756, 0.048938446, -0.004952862, -0.007730018, 2306 -0.04043371, -0.017094059, 0.07229206, -0.023670016, 2307 -0.052195564, -0.025616996, -0.01520939, 0.045104615, 2308 -0.007376126, 0.003533447, 0.006570588, 0.056037236, 2309 0.12436656, 0.051817212, 0.028532185, -0.08686856, 2310 0.11868599, 0.07663395, -0.07323171, 0.03463402, 2311 -0.050708205, -0.04458982, -0.11590894, 0.021273347, 2312 0.1251325, -0.15313013, -0.12224372, 0.17228661, 2313 0.023029093, 0.086124025, 0.006445803, -0.03496501, 2314 0.028332196, 0.04449512, -0.042436164, -0.026587414, 2315 -0.006041347, -0.09292539, -0.05678812, 0.03897832, 2316 0.09465633, 0.008115513, -0.02171956, 0.08304309, 2317 0.071401566, 0.019622514, 0.032163795, -0.004167056, 2318 0.02295182, 0.030739572, 0.056506045, 0.004612461, 2319 0.06524936, 0.059999723, 0.046395954, -0.0045512207, 2320 -0.1335546, -0.030136576, 0.11584653, -0.014678886, 2321 0.0020118146, -0.09688814, -0.0790206, 0.039770417, 2322 -0.0329582, 0.07922767, 0.029322514, 0.026405897, 2323 0.04207835, -0.07073373, 0.063781224, 0.0859677, 2324 -0.10925287, -0.07011058, 0.048005477, 0.03438226, 2325 -0.09606514, -0.006669445, -0.043381985, 0.04240257, 2326 -0.06955775, -0.06769346, 0.043903265, -0.026784198, 2327 -0.017840602, 0.024307009, -0.040079936, -0.019946516, 2328 0.045318738, -0.12233574, 0.026170589, 0.0074471775, 2329 0.15978073, 0.10185836, 0.10298046, -0.015476589, 2330 -0.039390966, -0.072174534, 0.0739445, -0.1211869, 2331 -0.0347889, -0.07943156, 0.014809798, -0.12412325, 2332 -0.0030663363, 0.039695457, 0.0647603, -0.08291318, 2333 -0.018529687, -0.004423833, 0.0037507233, 0.084633216, 2334 -0.01514876, -0.056505352, -0.012800942, -0.06994386, 2335 0.012962922, -0.031234352, 0.07029052, 0.016418684, 2336 0.03618972, 0.055686004, -0.08663945, -0.017404709, 2337 -0.054761406, 0.029065743, 0.052404847, 0.020238016, 2338 0.0048197987, -0.0214882, 0.07078733, 0.013016777, 2339 0.06262858, 0.009184685, 0.020785125, -0.043904778, 2340 -0.0270329, -0.03299152, -0.060088247, -0.015162964, 2341 -0.001828936, 0.12642565, -0.056757294, 0.013586685, 2342 0.09232601, -0.035886683, 0.06000002, 0.05229691, 2343 -0.052580316, -0.082029596, -0.010794592, 0.012947712, 2344 -0.036429964, -0.085508935, -0.13127148, -0.017744139, 2345 0.031502828, 0.036232427, -0.031581745, 0.023051167, 2346 -0.05325106, -0.03421577, 0.028793324, -0.034633752, 2347 -0.009881397, -0.043551125, -0.018609839, 0.0019097115, 2348 -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); 2349 2350 lstm.SetRecurrentToOutputWeights({ 2351 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, 2352 -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, 2353 -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, 2354 -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, 2355 -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, 2356 -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, 2357 -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, 2358 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, 2359 -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, 2360 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, 2361 -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, 2362 -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, 2363 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, 2364 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, 2365 -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, 2366 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, 2367 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, 2368 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, 2369 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, 2370 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, 2371 -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, 2372 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, 2373 -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, 2374 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, 2375 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, 2376 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, 2377 -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, 2378 -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, 2379 -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, 2380 -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, 2381 -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, 2382 -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, 2383 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, 2384 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, 2385 -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, 2386 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, 2387 -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, 2388 -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, 2389 -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, 2390 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, 2391 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, 2392 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, 2393 -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, 2394 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, 2395 -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, 2396 -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, 2397 -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, 2398 -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, 2399 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, 2400 -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, 2401 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, 2402 -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, 2403 -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, 2404 -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, 2405 -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, 2406 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, 2407 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, 2408 -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, 2409 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, 2410 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, 2411 -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, 2412 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, 2413 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, 2414 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, 2415 }); 2416 2417 lstm.SetCellToInputWeights( 2418 {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, 2419 -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, 2420 -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, 2421 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); 2422 2423 lstm.SetCellToForgetWeights( 2424 {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, 2425 -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, 2426 -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, 2427 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); 2428 2429 lstm.SetCellToOutputWeights( 2430 {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, 2431 -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, 2432 -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, 2433 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); 2434 2435 lstm.SetProjectionWeights( 2436 {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, 2437 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, 2438 -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, 2439 -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, 2440 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, 2441 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, 2442 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, 2443 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, 2444 -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, 2445 -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, 2446 -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, 2447 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, 2448 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, 2449 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, 2450 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, 2451 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, 2452 -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, 2453 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, 2454 -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, 2455 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, 2456 -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, 2457 -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, 2458 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, 2459 -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, 2460 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, 2461 -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, 2462 -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, 2463 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, 2464 -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, 2465 -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, 2466 -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, 2467 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, 2468 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, 2469 -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, 2470 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, 2471 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, 2472 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, 2473 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, 2474 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, 2475 -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, 2476 -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, 2477 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, 2478 -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, 2479 -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, 2480 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, 2481 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, 2482 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, 2483 -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, 2484 -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, 2485 -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, 2486 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, 2487 -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, 2488 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, 2489 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, 2490 -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, 2491 -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, 2492 -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, 2493 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, 2494 -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, 2495 -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, 2496 -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, 2497 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, 2498 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, 2499 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); 2500 2501 static float lstm_input[][20] = { 2502 {// Batch0: 4 (input_sequence_size) * 5 (n_input) 2503 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, 2504 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, 2505 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, 2506 2507 {// Batch1: 4 (input_sequence_size) * 5 (n_input) 2508 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, 2509 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, 2510 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; 2511 2512 static float lstm_fw_golden_output[][64] = { 2513 {// Batch0: 4 (input_sequence_size) * 16 (n_output) 2514 -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, 2515 -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, 2516 -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, 2517 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, 2518 -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, 2519 -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, 2520 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, 2521 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, 2522 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, 2523 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, 2524 -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, 2525 -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, 2526 0.0286833, 0.00824207, 0.0264887, 0.0305169}, 2527 {// Batch1: 4 (input_sequence_size) * 16 (n_output) 2528 -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, 2529 -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, 2530 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, 2531 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, 2532 -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, 2533 -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, 2534 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, 2535 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, 2536 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, 2537 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, 2538 -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, 2539 -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, 2540 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; 2541 2542 static float lstm_combined_golden_output[][64] = { 2543 {-0.022014, 0.073544, -0.002235, 0.040068, -0.037136, -0.052788, 2544 0.075325, -0.029378, 0.024298, -0.07733, -0.030674, -0.060229, 2545 0.040599, 0.011608, 0.042005, 0.045977, -0.039225, 0.076294, 2546 0.000735, 0.032852, -0.069869, -0.053312, 0.073527, -0.028136, 2547 0.021585, -0.102679, -0.004327, -0.043304, 0.072861, 0.027077, 2548 0.034558, 0.068292, -0.036292, 0.069832, -0.003032, 0.053829, 2549 -0.043821, -0.072713, 0.085029, -0.040374, 0.020014, -0.104521, 2550 -0.034504, -0.059759, 0.062569, 0.025652, 0.049306, 0.061189, 2551 -0.025146, 0.079643, -0.005188, 0.033080, -0.048079, -0.048082, 2552 0.069369, -0.028900, 0.024572, -0.077547, -0.022517, -0.054477, 2553 0.038857, 0.013336, 0.043234, 0.044788}, 2554 {-0.039186, 0.070792, -0.005913, 0.02642, -0.068274, -0.05022, 2555 0.061444, -0.031241, 0.014996, -0.094544, -0.004146, -0.03464, 2556 0.058981, 0.026097, 0.039781, 0.058408, -0.031887, 0.069252, 2557 0.00576, 0.054062, -0.042801, -0.059974, 0.085272, -0.034453, 2558 0.026097, -0.0959, -0.031164, -0.058699, 0.06839, 0.020512, 2559 0.044727, 0.063609, -0.039863, 0.084819, -0.003909, 0.028666, 2560 -0.075677, -0.045125, 0.070379, -0.033895, 0.022111, -0.097184, 2561 -0.004921, -0.040851, 0.062316, 0.017435, 0.041437, 0.064568, 2562 -0.039656, 0.060726, -0.003402, 0.036854, -0.056503, -0.058554, 2563 0.068588, -0.034879, 0.01352, -0.09962, -0.01434, -0.039505, 2564 0.065133, 0.024321, 0.038473, 0.062438}}; 2565 2566 const int input_sequence_size = lstm.sequence_length() * lstm.num_inputs(); 2567 EXPECT_EQ(input_sequence_size, 20); 2568 float* batch0_start = lstm_input[0]; 2569 float* batch0_end = batch0_start + input_sequence_size; 2570 lstm.SetInput(0, batch0_start, batch0_end); 2571 2572 float* batch1_start = lstm_input[1]; 2573 float* batch1_end = batch1_start + input_sequence_size; 2574 lstm.SetInput(input_sequence_size, batch1_start, batch1_end); 2575 2576 lstm.Invoke(); 2577 2578 const int output_sequence_size = 2579 lstm.sequence_length() * lstm.num_fw_outputs(); 2580 EXPECT_EQ(output_sequence_size, 64); 2581 std::vector<float> expected; 2582 const float* golden_start_batch0 = lstm_fw_golden_output[0]; 2583 const float* golden_end_batch0 = golden_start_batch0 + output_sequence_size; 2584 expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); 2585 2586 const float* golden_start_batch1 = lstm_fw_golden_output[1]; 2587 const float* golden_end_batch1 = golden_start_batch1 + output_sequence_size; 2588 expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); 2589 EXPECT_THAT(lstm.GetFwOutput(), ElementsAreArray(ArrayFloatNear(expected))); 2590 2591 // Check if the sum of forward backward matches the golden. 2592 expected.clear(); 2593 golden_start_batch0 = lstm_combined_golden_output[0]; 2594 golden_end_batch0 = golden_start_batch0 + output_sequence_size; 2595 expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); 2596 2597 golden_start_batch1 = lstm_combined_golden_output[1]; 2598 golden_end_batch1 = golden_start_batch1 + output_sequence_size; 2599 expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); 2600 2601 std::vector<float> combined; 2602 for (int i = 0; i < lstm.GetFwOutput().size(); ++i) { 2603 combined.push_back(lstm.GetFwOutput()[i] + lstm.GetBwOutput()[i]); 2604 } 2605 EXPECT_THAT(combined, ElementsAreArray(ArrayFloatNear(expected))); 2606 } 2607 2608 // Same as the no cifg no peephole no projection no clipping test, but have an 2609 // aux input (without aux input weights), this is the case when stacking but no 2610 // cross-links. 2611 TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) { 2612 const int n_batch = 1; 2613 const int n_input = 2; 2614 // n_cell and n_output have the same size when there is no projection. 2615 const int n_cell = 4; 2616 const int n_output = 4; 2617 const int sequence_length = 3; 2618 const bool quantize_weights = GetParam(); 2619 2620 BidirectionalLSTMOpModel lstm( 2621 n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, 2622 /*use_peephole=*/false, /*use_projection_weights=*/false, 2623 /*use_projection_bias=*/false, /*merge_outputs=*/false, 2624 /*use_aux_input=*/true, /*cell_clip=*/0.0, 2625 /*proj_clip=*/0.0, quantize_weights, /*time_major=*/true, 2626 { 2627 {sequence_length, n_batch, n_input}, // input tensor 2628 2629 // Forward cell 2630 {n_cell, n_input}, // input_to_input_weight tensor 2631 {n_cell, n_input}, // input_to_forget_weight tensor 2632 {n_cell, n_input}, // input_to_cell_weight tensor 2633 {n_cell, n_input}, // input_to_output_weight tensor 2634 2635 {n_cell, n_output}, // recurrent_to_input_weight tensor 2636 {n_cell, n_output}, // recurrent_to_forget_weight tensor 2637 {n_cell, n_output}, // recurrent_to_cell_weight tensor 2638 {n_cell, n_output}, // recurrent_to_output_weight tensor 2639 2640 {0}, // cell_to_input_weight tensor 2641 {0}, // cell_to_forget_weight tensor 2642 {0}, // cell_to_output_weight tensor 2643 2644 {n_cell}, // input_gate_bias tensor 2645 {n_cell}, // forget_gate_bias tensor 2646 {n_cell}, // cell_bias tensor 2647 {n_cell}, // output_gate_bias tensor 2648 2649 {0, 0}, // projection_weight tensor 2650 {0}, // projection_bias tensor 2651 2652 // Backward cell 2653 {n_cell, n_input}, // input_to_input_weight tensor 2654 {n_cell, n_input}, // input_to_forget_weight tensor 2655 {n_cell, n_input}, // input_to_cell_weight tensor 2656 {n_cell, n_input}, // input_to_output_weight tensor 2657 2658 {n_cell, n_output}, // recurrent_to_input_weight tensor 2659 {n_cell, n_output}, // recurrent_to_forget_weight tensor 2660 {n_cell, n_output}, // recurrent_to_cell_weight tensor 2661 {n_cell, n_output}, // recurrent_to_output_weight tensor 2662 2663 {0}, // cell_to_input_weight tensor 2664 {0}, // cell_to_forget_weight tensor 2665 {0}, // cell_to_output_weight tensor 2666 2667 {n_cell}, // input_gate_bias tensor 2668 {n_cell}, // forget_gate_bias tensor 2669 {n_cell}, // cell_bias tensor 2670 {n_cell}, // output_gate_bias tensor 2671 2672 {0, 0}, // projection_weight tensor 2673 {0}, // projection_bias tensor 2674 2675 {n_batch, n_output}, // activation_state tensor 2676 {n_batch, n_cell}, // cell_state tensor 2677 2678 {n_batch, n_output}, // activation_state tensor 2679 {n_batch, n_cell}, // cell_state tensor 2680 2681 // TODO(b/121134029): Update tests so tensor shapes after state tensor 2682 // are used. They are currently ignored by test_util. 2683 {sequence_length, n_batch, n_input}, // aux_input tensor 2684 {n_cell, 0}, // aux_fw_input_to_input tensor 2685 {n_cell, 0}, // aux_fw_input_to_forget tensor 2686 {n_cell, 0}, // aux_fw_input_to_cell tensor 2687 {n_cell, 0}, // aux_fw_input_to_output tensor 2688 {n_cell, 0}, // aux_bw_input_to_input tensor 2689 {n_cell, 0}, // aux_bw_input_to_forget tensor 2690 {n_cell, 0}, // aux_bw_input_to_cell tensor 2691 {n_cell, 0}, // aux_bw_input_to_output tensor 2692 }); 2693 2694 lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, 2695 -0.34550029, 0.04266912, -0.15680569, 2696 -0.34856534, 0.43890524}); 2697 2698 lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, 2699 -0.20583314, 0.44344562, 0.22077113, 2700 -0.29909778}); 2701 2702 lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, 2703 -0.31343272, -0.40032279, 0.44781327, 2704 0.01387155, -0.35593212}); 2705 2706 lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, 2707 0.40525138, 0.44272184, 0.03897077, -0.1556896, 2708 0.19487578}); 2709 2710 lstm.SetInputGateBias({0., 0., 0., 0.}); 2711 2712 lstm.SetCellBias({0., 0., 0., 0.}); 2713 2714 lstm.SetForgetGateBias({1., 1., 1., 1.}); 2715 2716 lstm.SetOutputGateBias({0., 0., 0., 0.}); 2717 2718 lstm.SetRecurrentToInputWeights( 2719 {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, 2720 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, 2721 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); 2722 2723 lstm.SetRecurrentToCellWeights( 2724 {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, 2725 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, 2726 -0.46367589, 0.26016325, -0.03894562, -0.16368064}); 2727 2728 lstm.SetRecurrentToForgetWeights( 2729 {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, 2730 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, 2731 0.28053468, 0.01560611, -0.20127171, -0.01140004}); 2732 2733 lstm.SetRecurrentToOutputWeights( 2734 {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, 2735 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, 2736 -0.51818722, -0.15390486, 0.0468148, 0.39922136}); 2737 2738 // Input should have n_input * sequence_length many values. 2739 static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; 2740 static float lstm_fw_golden_output[] = { 2741 -0.02973187, 0.1229473, 0.20885126, -0.15358765, 2742 -0.03716109, 0.12507336, 0.41193449, -0.20860538, 2743 -0.15053082, 0.09120187, 0.24278517, -0.12222792}; 2744 static float lstm_bw_golden_output[] = { 2745 -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838, 2746 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124}; 2747 2748 float* batch0_start = lstm_input; 2749 float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); 2750 2751 lstm.SetInput(0, batch0_start, batch0_end); 2752 // Aux input and input are the same, so we should observe the same outputs 2753 // as there's no aux input. 2754 lstm.SetAuxInput(0, batch0_start, batch0_end); 2755 2756 lstm.Invoke(); 2757 2758 float* fw_golden_start = lstm_fw_golden_output; 2759 float* fw_golden_end = 2760 fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length(); 2761 std::vector<float> fw_expected; 2762 fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end); 2763 EXPECT_THAT(lstm.GetFwOutput(), 2764 ElementsAreArray( 2765 ArrayFloatNear(fw_expected, quantize_weights ? 1e-2 : 1e-5))); 2766 2767 float* bw_golden_start = lstm_bw_golden_output; 2768 float* bw_golden_end = 2769 bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length(); 2770 std::vector<float> bw_expected; 2771 bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end); 2772 EXPECT_THAT(lstm.GetBwOutput(), 2773 ElementsAreArray( 2774 ArrayFloatNear(bw_expected, quantize_weights ? 1e-2 : 1e-5))); 2775 } 2776 2777 } // namespace 2778 } // namespace tflite 2779 2780 int main(int argc, char** argv) { 2781 ::tflite::LogToStderr(); 2782 ::testing::InitGoogleTest(&argc, argv); 2783 return RUN_ALL_TESTS(); 2784 } 2785