Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // Unit test for TFLite Sequential LSTM op.
     16 
     17 #include <iomanip>
     18 #include <memory>
     19 #include <vector>
     20 
     21 #include <gmock/gmock.h>
     22 #include <gtest/gtest.h>
     23 #include "tensorflow/contrib/lite/interpreter.h"
     24 #include "tensorflow/contrib/lite/kernels/register.h"
     25 #include "tensorflow/contrib/lite/kernels/test_util.h"
     26 #include "tensorflow/contrib/lite/model.h"
     27 
     28 namespace tflite {
     29 namespace {
     30 
     31 using ::testing::ElementsAreArray;
     32 
     33 class UnidirectionalLSTMOpModel : public SingleOpModel {
     34  public:
     35   UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
     36                             int sequence_length, bool use_cifg,
     37                             bool use_peephole, bool use_projection_weights,
     38                             bool use_projection_bias, float cell_clip,
     39                             float proj_clip,
     40                             const std::vector<std::vector<int>>& input_shapes)
     41       : n_batch_(n_batch),
     42         n_input_(n_input),
     43         n_cell_(n_cell),
     44         n_output_(n_output),
     45         sequence_length_(sequence_length) {
     46     input_ = AddInput(TensorType_FLOAT32);
     47 
     48     if (use_cifg) {
     49       input_to_input_weights_ = AddNullInput();
     50     } else {
     51       input_to_input_weights_ = AddInput(TensorType_FLOAT32);
     52     }
     53 
     54     input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
     55     input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
     56     input_to_output_weights_ = AddInput(TensorType_FLOAT32);
     57 
     58     if (use_cifg) {
     59       recurrent_to_input_weights_ = AddNullInput();
     60     } else {
     61       recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
     62     }
     63 
     64     recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
     65     recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
     66     recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
     67 
     68     if (use_peephole) {
     69       if (use_cifg) {
     70         cell_to_input_weights_ = AddNullInput();
     71       } else {
     72         cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
     73       }
     74       cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
     75       cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
     76     } else {
     77       cell_to_input_weights_ = AddNullInput();
     78       cell_to_forget_weights_ = AddNullInput();
     79       cell_to_output_weights_ = AddNullInput();
     80     }
     81 
     82     if (use_cifg) {
     83       input_gate_bias_ = AddNullInput();
     84     } else {
     85       input_gate_bias_ = AddInput(TensorType_FLOAT32);
     86     }
     87     forget_gate_bias_ = AddInput(TensorType_FLOAT32);
     88     cell_bias_ = AddInput(TensorType_FLOAT32);
     89     output_gate_bias_ = AddInput(TensorType_FLOAT32);
     90 
     91     if (use_projection_weights) {
     92       projection_weights_ = AddInput(TensorType_FLOAT32);
     93       if (use_projection_bias) {
     94         projection_bias_ = AddInput(TensorType_FLOAT32);
     95       } else {
     96         projection_bias_ = AddNullInput();
     97       }
     98     } else {
     99       projection_weights_ = AddNullInput();
    100       projection_bias_ = AddNullInput();
    101     }
    102 
    103     scratch_buffer_ = AddOutput(TensorType_FLOAT32);
    104     // TODO(ghodrat): Modify these states when we have a permanent solution for
    105     // persistent buffer.
    106     output_state_ = AddOutput(TensorType_FLOAT32);
    107     cell_state_ = AddOutput(TensorType_FLOAT32);
    108     output_ = AddOutput(TensorType_FLOAT32);
    109 
    110     SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
    111                  BuiltinOptions_LSTMOptions,
    112                  CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
    113                                    cell_clip, proj_clip)
    114                      .Union());
    115     BuildInterpreter(input_shapes);
    116   }
    117 
    118   void SetInputToInputWeights(std::initializer_list<float> f) {
    119     PopulateTensor(input_to_input_weights_, f);
    120   }
    121 
    122   void SetInputToForgetWeights(std::initializer_list<float> f) {
    123     PopulateTensor(input_to_forget_weights_, f);
    124   }
    125 
    126   void SetInputToCellWeights(std::initializer_list<float> f) {
    127     PopulateTensor(input_to_cell_weights_, f);
    128   }
    129 
    130   void SetInputToOutputWeights(std::initializer_list<float> f) {
    131     PopulateTensor(input_to_output_weights_, f);
    132   }
    133 
    134   void SetRecurrentToInputWeights(std::initializer_list<float> f) {
    135     PopulateTensor(recurrent_to_input_weights_, f);
    136   }
    137 
    138   void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
    139     PopulateTensor(recurrent_to_forget_weights_, f);
    140   }
    141 
    142   void SetRecurrentToCellWeights(std::initializer_list<float> f) {
    143     PopulateTensor(recurrent_to_cell_weights_, f);
    144   }
    145 
    146   void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
    147     PopulateTensor(recurrent_to_output_weights_, f);
    148   }
    149 
    150   void SetCellToInputWeights(std::initializer_list<float> f) {
    151     PopulateTensor(cell_to_input_weights_, f);
    152   }
    153 
    154   void SetCellToForgetWeights(std::initializer_list<float> f) {
    155     PopulateTensor(cell_to_forget_weights_, f);
    156   }
    157 
    158   void SetCellToOutputWeights(std::initializer_list<float> f) {
    159     PopulateTensor(cell_to_output_weights_, f);
    160   }
    161 
    162   void SetInputGateBias(std::initializer_list<float> f) {
    163     PopulateTensor(input_gate_bias_, f);
    164   }
    165 
    166   void SetForgetGateBias(std::initializer_list<float> f) {
    167     PopulateTensor(forget_gate_bias_, f);
    168   }
    169 
    170   void SetCellBias(std::initializer_list<float> f) {
    171     PopulateTensor(cell_bias_, f);
    172   }
    173 
    174   void SetOutputGateBias(std::initializer_list<float> f) {
    175     PopulateTensor(output_gate_bias_, f);
    176   }
    177 
    178   void SetProjectionWeights(std::initializer_list<float> f) {
    179     PopulateTensor(projection_weights_, f);
    180   }
    181 
    182   void SetProjectionBias(std::initializer_list<float> f) {
    183     PopulateTensor(projection_bias_, f);
    184   }
    185 
    186   void ResetOutputState() {
    187     const int zero_buffer_size = n_cell_ * n_batch_;
    188     std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
    189     memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
    190     PopulateTensor(output_state_, 0, zero_buffer.get(),
    191                    zero_buffer.get() + zero_buffer_size);
    192   }
    193 
    194   void ResetCellState() {
    195     const int zero_buffer_size = n_cell_ * n_batch_;
    196     std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
    197     memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
    198     PopulateTensor(cell_state_, 0, zero_buffer.get(),
    199                    zero_buffer.get() + zero_buffer_size);
    200   }
    201 
    202   void SetInput(int offset, float* begin, float* end) {
    203     PopulateTensor(input_, offset, begin, end);
    204   }
    205 
    206   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
    207 
    208   int num_inputs() { return n_input_; }
    209   int num_outputs() { return n_output_; }
    210   int num_cells() { return n_cell_; }
    211   int num_batches() { return n_batch_; }
    212   int sequence_length() { return sequence_length_; }
    213 
    214  private:
    215   int input_;
    216   int input_to_input_weights_;
    217   int input_to_forget_weights_;
    218   int input_to_cell_weights_;
    219   int input_to_output_weights_;
    220 
    221   int recurrent_to_input_weights_;
    222   int recurrent_to_forget_weights_;
    223   int recurrent_to_cell_weights_;
    224   int recurrent_to_output_weights_;
    225 
    226   int cell_to_input_weights_;
    227   int cell_to_forget_weights_;
    228   int cell_to_output_weights_;
    229 
    230   int input_gate_bias_;
    231   int forget_gate_bias_;
    232   int cell_bias_;
    233   int output_gate_bias_;
    234 
    235   int projection_weights_;
    236   int projection_bias_;
    237 
    238   int output_;
    239   int output_state_;
    240   int cell_state_;
    241   int scratch_buffer_;
    242 
    243   int n_batch_;
    244   int n_input_;
    245   int n_cell_;
    246   int n_output_;
    247   int sequence_length_;
    248 };
    249 
    250 TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
    251   const int n_batch = 1;
    252   const int n_input = 2;
    253   // n_cell and n_output have the same size when there is no projection.
    254   const int n_cell = 4;
    255   const int n_output = 4;
    256   const int sequence_length = 3;
    257 
    258   UnidirectionalLSTMOpModel lstm(
    259       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
    260       /*use_peephole=*/false, /*use_projection_weights=*/false,
    261       /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
    262       {
    263           {sequence_length, n_batch, n_input},  // input tensor
    264 
    265           {n_cell, n_input},  // input_to_input_weight tensor
    266           {n_cell, n_input},  // input_to_forget_weight tensor
    267           {n_cell, n_input},  // input_to_cell_weight tensor
    268           {n_cell, n_input},  // input_to_output_weight tensor
    269 
    270           {n_cell, n_output},  // recurrent_to_input_weight tensor
    271           {n_cell, n_output},  // recurrent_to_forget_weight tensor
    272           {n_cell, n_output},  // recurrent_to_cell_weight tensor
    273           {n_cell, n_output},  // recurrent_to_output_weight tensor
    274 
    275           {0},  // cell_to_input_weight tensor
    276           {0},  // cell_to_forget_weight tensor
    277           {0},  // cell_to_output_weight tensor
    278 
    279           {n_cell},  // input_gate_bias tensor
    280           {n_cell},  // forget_gate_bias tensor
    281           {n_cell},  // cell_bias tensor
    282           {n_cell},  // output_gate_bias tensor
    283 
    284           {0, 0},  // projection_weight tensor
    285           {0},     // projection_bias tensor
    286       });
    287 
    288   lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
    289                                -0.34550029, 0.04266912, -0.15680569,
    290                                -0.34856534, 0.43890524});
    291 
    292   lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
    293                               -0.20583314, 0.44344562, 0.22077113,
    294                               -0.29909778});
    295 
    296   lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
    297                                 -0.31343272, -0.40032279, 0.44781327,
    298                                 0.01387155, -0.35593212});
    299 
    300   lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
    301                                 0.40525138, 0.44272184, 0.03897077, -0.1556896,
    302                                 0.19487578});
    303 
    304   lstm.SetInputGateBias({0., 0., 0., 0.});
    305 
    306   lstm.SetCellBias({0., 0., 0., 0.});
    307 
    308   lstm.SetForgetGateBias({1., 1., 1., 1.});
    309 
    310   lstm.SetOutputGateBias({0., 0., 0., 0.});
    311 
    312   lstm.SetRecurrentToInputWeights(
    313       {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
    314        -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
    315        -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
    316 
    317   lstm.SetRecurrentToCellWeights(
    318       {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
    319        -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
    320        -0.46367589, 0.26016325, -0.03894562, -0.16368064});
    321 
    322   lstm.SetRecurrentToForgetWeights(
    323       {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
    324        -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
    325        0.28053468, 0.01560611, -0.20127171, -0.01140004});
    326 
    327   lstm.SetRecurrentToOutputWeights(
    328       {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
    329        0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
    330        -0.51818722, -0.15390486, 0.0468148, 0.39922136});
    331 
    332   // Input should have n_input * sequence_length many values.
    333   static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
    334   static float lstm_golden_output[] = {-0.02973187, 0.1229473,   0.20885126,
    335                                        -0.15358765, -0.03716109, 0.12507336,
    336                                        0.41193449,  -0.20860538, -0.15053082,
    337                                        0.09120187,  0.24278517,  -0.12222792};
    338 
    339   // Resetting cell_state and output_state
    340   lstm.ResetCellState();
    341   lstm.ResetOutputState();
    342 
    343   float* batch0_start = lstm_input;
    344   float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
    345 
    346   lstm.SetInput(0, batch0_start, batch0_end);
    347 
    348   lstm.Invoke();
    349 
    350   float* golden_start = lstm_golden_output;
    351   float* golden_end =
    352       golden_start + lstm.num_outputs() * lstm.sequence_length();
    353   std::vector<float> expected;
    354   expected.insert(expected.end(), golden_start, golden_end);
    355   EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    356 }
    357 
    358 TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
    359   const int n_batch = 1;
    360   const int n_input = 2;
    361   // n_cell and n_output have the same size when there is no projection.
    362   const int n_cell = 4;
    363   const int n_output = 4;
    364   const int sequence_length = 3;
    365 
    366   UnidirectionalLSTMOpModel lstm(
    367       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
    368       /*use_peephole=*/true, /*use_projection_weights=*/false,
    369       /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
    370       {
    371           {sequence_length, n_batch, n_input},  // input tensor
    372 
    373           {0, 0},             // input_to_input_weight tensor
    374           {n_cell, n_input},  // input_to_forget_weight tensor
    375           {n_cell, n_input},  // input_to_cell_weight tensor
    376           {n_cell, n_input},  // input_to_output_weight tensor
    377 
    378           {0, 0},              // recurrent_to_input_weight tensor
    379           {n_cell, n_output},  // recurrent_to_forget_weight tensor
    380           {n_cell, n_output},  // recurrent_to_cell_weight tensor
    381           {n_cell, n_output},  // recurrent_to_output_weight tensor
    382 
    383           {0},       // cell_to_input_weight tensor
    384           {n_cell},  // cell_to_forget_weight tensor
    385           {n_cell},  // cell_to_output_weight tensor
    386 
    387           {0},       // input_gate_bias tensor
    388           {n_cell},  // forget_gate_bias tensor
    389           {n_cell},  // cell_bias tensor
    390           {n_cell},  // output_gate_bias tensor
    391 
    392           {0, 0},  // projection_weight tensor
    393           {0},     // projection_bias tensor
    394       });
    395 
    396   lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
    397                               0.04717243, 0.48944736, -0.38535351,
    398                               -0.17212132});
    399 
    400   lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
    401                                 -0.3633365, -0.22755712, 0.28253698, 0.24407166,
    402                                 0.33826375});
    403 
    404   lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
    405                                 -0.09426838, -0.44257352, 0.54939759,
    406                                 0.01533556, 0.42751634});
    407 
    408   lstm.SetCellBias({0., 0., 0., 0.});
    409 
    410   lstm.SetForgetGateBias({1., 1., 1., 1.});
    411 
    412   lstm.SetOutputGateBias({0., 0., 0., 0.});
    413 
    414   lstm.SetRecurrentToCellWeights(
    415       {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
    416        0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
    417        0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
    418        0.21193194});
    419 
    420   lstm.SetRecurrentToForgetWeights(
    421       {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
    422        0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
    423        -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
    424 
    425   lstm.SetRecurrentToOutputWeights(
    426       {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
    427        -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
    428        0.50248802, 0.26114327, -0.43736315, 0.33149987});
    429 
    430   lstm.SetCellToForgetWeights(
    431       {0.47485286, -0.51955009, -0.24458408, 0.31544167});
    432   lstm.SetCellToOutputWeights(
    433       {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
    434 
    435   static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
    436   static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
    437                                        -0.05163646, -0.42312205, -0.01218222,
    438                                        0.24201041,  -0.08124574, -0.358325,
    439                                        -0.04621704, 0.21641694,  -0.06471302};
    440 
    441   // Resetting cell_state and output_state
    442   lstm.ResetCellState();
    443   lstm.ResetOutputState();
    444 
    445   float* batch0_start = lstm_input;
    446   float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
    447 
    448   lstm.SetInput(0, batch0_start, batch0_end);
    449 
    450   lstm.Invoke();
    451 
    452   float* golden_start = lstm_golden_output;
    453   float* golden_end =
    454       golden_start + lstm.num_outputs() * lstm.sequence_length();
    455   std::vector<float> expected;
    456   expected.insert(expected.end(), golden_start, golden_end);
    457   EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    458 }
    459 
    460 TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
    461   const int n_batch = 2;
    462   const int n_input = 5;
    463   const int n_cell = 20;
    464   const int n_output = 16;
    465   const int sequence_length = 4;
    466 
    467   UnidirectionalLSTMOpModel lstm(
    468       n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
    469       /*use_peephole=*/true, /*use_projection_weights=*/true,
    470       /*use_projection_bias=*/false,
    471       /*cell_clip=*/0.0, /*proj_clip=*/0.0,
    472       {
    473           {sequence_length, n_batch, n_input},  // input tensor
    474 
    475           {n_cell, n_input},  // input_to_input_weight tensor
    476           {n_cell, n_input},  // input_to_forget_weight tensor
    477           {n_cell, n_input},  // input_to_cell_weight tensor
    478           {n_cell, n_input},  // input_to_output_weight tensor
    479 
    480           {n_cell, n_output},  // recurrent_to_input_weight tensor
    481           {n_cell, n_output},  // recurrent_to_forget_weight tensor
    482           {n_cell, n_output},  // recurrent_to_cell_weight tensor
    483           {n_cell, n_output},  // recurrent_to_output_weight tensor
    484 
    485           {n_cell},  // cell_to_input_weight tensor
    486           {n_cell},  // cell_to_forget_weight tensor
    487           {n_cell},  // cell_to_output_weight tensor
    488 
    489           {n_cell},  // input_gate_bias tensor
    490           {n_cell},  // forget_gate_bias tensor
    491           {n_cell},  // cell_bias tensor
    492           {n_cell},  // output_gate_bias tensor
    493 
    494           {n_output, n_cell},  // projection_weight tensor
    495           {0},                 // projection_bias tensor
    496       });
    497 
    498   lstm.SetInputToInputWeights(
    499       {0.021393683,  0.06124551,    0.046905167,  -0.014657677,  -0.03149463,
    500        0.09171803,   0.14647801,    0.10797193,   -0.0057968358, 0.0019193048,
    501        -0.2726754,   0.10154029,    -0.018539885, 0.080349885,   -0.10262385,
    502        -0.022599787, -0.09121155,   -0.008675967, -0.045206103,  -0.0821282,
    503        -0.008045952, 0.015478081,   0.055217247,  0.038719587,   0.044153627,
    504        -0.06453243,  0.05031825,    -0.046935108, -0.008164439,  0.014574226,
    505        -0.1671009,   -0.15519552,   -0.16819797,  -0.13971269,   -0.11953059,
    506        0.25005487,   -0.22790983,   0.009855087,  -0.028140958,  -0.11200698,
    507        0.11295408,   -0.0035217577, 0.054485075,  0.05184695,    0.064711206,
    508        0.10989193,   0.11674786,    0.03490607,   0.07727357,    0.11390585,
    509        -0.1863375,   -0.1034451,    -0.13945189,  -0.049401227,  -0.18767063,
    510        0.042483903,  0.14233552,    0.13832581,   0.18350165,    0.14545603,
    511        -0.028545704, 0.024939531,   0.050929718,  0.0076203286,  -0.0029723682,
    512        -0.042484224, -0.11827596,   -0.09171104,  -0.10808628,   -0.16327988,
    513        -0.2273378,   -0.0993647,    -0.017155107, 0.0023917493,  0.049272764,
    514        0.0038534778, 0.054764505,   0.089753784,  0.06947234,    0.08014476,
    515        -0.04544234,  -0.0497073,    -0.07135631,  -0.048929106,  -0.004042012,
    516        -0.009284026, 0.018042054,   0.0036860977, -0.07427302,   -0.11434604,
    517        -0.018995456, 0.031487543,   0.012834908,  0.019977754,   0.044256654,
    518        -0.39292613,  -0.18519334,   -0.11651281,  -0.06809892,   0.011373677});
    519 
    520   lstm.SetInputToForgetWeights(
    521       {-0.0018401089, -0.004852237,  0.03698424,   0.014181704,   0.028273236,
    522        -0.016726194,  -0.05249759,   -0.10204261,  0.00861066,    -0.040979505,
    523        -0.009899187,  0.01923892,    -0.028177269, -0.08535103,   -0.14585495,
    524        0.10662567,    -0.01909731,   -0.017883534, -0.0047269356, -0.045103323,
    525        0.0030784295,  0.076784775,   0.07463696,   0.094531395,   0.0814421,
    526        -0.12257899,   -0.033945758,  -0.031303465, 0.045630626,   0.06843887,
    527        -0.13492945,   -0.012480007,  -0.0811829,   -0.07224499,   -0.09628791,
    528        0.045100946,   0.0012300825,  0.013964662,  0.099372394,   0.02543059,
    529        0.06958324,    0.034257296,   0.0482646,    0.06267997,    0.052625068,
    530        0.12784666,    0.07077897,    0.025725935,  0.04165009,    0.07241905,
    531        0.018668644,   -0.037377294,  -0.06277783,  -0.08833636,   -0.040120605,
    532        -0.011405586,  -0.007808335,  -0.010301386, -0.005102167,  0.027717464,
    533        0.05483423,    0.11449111,    0.11289652,   0.10939839,    0.13396506,
    534        -0.08402166,   -0.01901462,   -0.044678304, -0.07720565,   0.014350063,
    535        -0.11757958,   -0.0652038,    -0.08185733,  -0.076754324,  -0.092614375,
    536        0.10405491,    0.052960336,   0.035755895,  0.035839386,   -0.012540553,
    537        0.036881298,   0.02913376,    0.03420159,   0.05448447,    -0.054523353,
    538        0.02582715,    0.02327355,    -0.011857179, -0.0011980024, -0.034641717,
    539        -0.026125094,  -0.17582615,   -0.15923657,  -0.27486774,   -0.0006143371,
    540        0.0001771948,  -8.470171e-05, 0.02651807,   0.045790765,   0.06956496});
    541 
    542   lstm.SetInputToCellWeights(
    543       {-0.04580283,   -0.09549462,   -0.032418985,  -0.06454633,
    544        -0.043528453,  0.043018587,   -0.049152344,  -0.12418144,
    545        -0.078985475,  -0.07596889,   0.019484362,   -0.11434962,
    546        -0.0074034138, -0.06314844,   -0.092981495,  0.0062155537,
    547        -0.025034338,  -0.0028890965, 0.048929527,   0.06235075,
    548        0.10665918,    -0.032036792,  -0.08505916,   -0.10843358,
    549        -0.13002433,   -0.036816437,  -0.02130134,   -0.016518239,
    550        0.0047691227,  -0.0025825808, 0.066017866,   0.029991534,
    551        -0.10652836,   -0.1037554,    -0.13056071,   -0.03266643,
    552        -0.033702414,  -0.006473424,  -0.04611692,   0.014419339,
    553        -0.025174323,  0.0396852,     0.081777506,   0.06157468,
    554        0.10210095,    -0.009658194,  0.046511717,   0.03603906,
    555        0.0069369148,  0.015960095,   -0.06507666,   0.09551598,
    556        0.053568836,   0.06408714,    0.12835667,    -0.008714329,
    557        -0.20211966,   -0.12093674,   0.029450472,   0.2849013,
    558        -0.029227901,  0.1164364,     -0.08560263,   0.09941786,
    559        -0.036999565,  -0.028842626,  -0.0033637602, -0.017012902,
    560        -0.09720865,   -0.11193351,   -0.029155117,  -0.017936034,
    561        -0.009768936,  -0.04223324,   -0.036159635,  0.06505112,
    562        -0.021742892,  -0.023377212,  -0.07221364,   -0.06430552,
    563        0.05453865,    0.091149814,   0.06387331,    0.007518393,
    564        0.055960953,   0.069779344,   0.046411168,   0.10509911,
    565        0.07463894,    0.0075130584,  0.012850982,   0.04555431,
    566        0.056955688,   0.06555285,    0.050801456,   -0.009862683,
    567        0.00826772,    -0.026555609,  -0.0073611983, -0.0014897042});
    568 
    569   lstm.SetInputToOutputWeights(
    570       {-0.0998932,   -0.07201956,  -0.052803773,  -0.15629593,  -0.15001918,
    571        -0.07650751,  0.02359855,   -0.075155355,  -0.08037709,  -0.15093534,
    572        0.029517552,  -0.04751393,  0.010350531,   -0.02664851,  -0.016839722,
    573        -0.023121163, 0.0077019283, 0.012851257,   -0.05040649,  -0.0129761,
    574        -0.021737747, -0.038305793, -0.06870586,   -0.01481247,  -0.001285394,
    575        0.10124236,   0.083122835,  0.053313006,   -0.062235646, -0.075637154,
    576        -0.027833903, 0.029774971,  0.1130802,     0.09218906,   0.09506135,
    577        -0.086665764, -0.037162706, -0.038880914,  -0.035832845, -0.014481564,
    578        -0.09825003,  -0.12048569,  -0.097665586,  -0.05287633,  -0.0964047,
    579        -0.11366429,  0.035777505,  0.13568819,    0.052451383,  0.050649304,
    580        0.05798951,   -0.021852335, -0.099848844,  0.014740475,  -0.078897946,
    581        0.04974699,   0.014160473,  0.06973932,    0.04964942,   0.033364646,
    582        0.08190124,   0.025535367,  0.050893165,   0.048514254,  0.06945813,
    583        -0.078907564, -0.06707616,  -0.11844508,   -0.09986688,  -0.07509403,
    584        0.06263226,   0.14925587,   0.20188436,    0.12098451,   0.14639415,
    585        0.0015017595, -0.014267382, -0.03417257,   0.012711468,  0.0028300495,
    586        -0.024758482, -0.05098548,  -0.0821182,    0.014225672,  0.021544158,
    587        0.08949725,   0.07505268,   -0.0020780868, 0.04908258,   0.06476295,
    588        -0.022907063, 0.027562456,  0.040185735,   0.019567577,  -0.015598739,
    589        -0.049097303, -0.017121866, -0.083368234,  -0.02332002,  -0.0840956});
    590 
    591   lstm.SetInputGateBias(
    592       {0.02234832,  0.14757581,   0.18176508,  0.10380666,  0.053110216,
    593        -0.06928846, -0.13942584,  -0.11816189, 0.19483899,  0.03652339,
    594        -0.10250295, 0.036714908,  -0.18426876, 0.036065217, 0.21810818,
    595        0.02383196,  -0.043370757, 0.08690144,  -0.04444982, 0.00030581196});
    596 
    597   lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
    598                           0.11098921,  0.15378423,   0.09263801,  0.09790885,
    599                           0.09508917,  0.061199076,  0.07665568,  -0.015443159,
    600                           -0.03499149, 0.046190713,  0.08895977,  0.10899629,
    601                           0.40694186,  0.06030037,   0.012413437, -0.06108739});
    602 
    603   lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132,   0.033463873,
    604                     -0.1483596,   -0.10639995,  -0.091433935, 0.058573797,
    605                     -0.06809782,  -0.07889636,  -0.043246906, -0.09829136,
    606                     -0.4279842,   0.034901652,  0.18797937,   0.0075234566,
    607                     0.016178843,  0.1749513,    0.13975595,   0.92058027});
    608 
    609   lstm.SetOutputGateBias(
    610       {0.046159424,  -0.0012809046, 0.03563469,   0.12648113, 0.027195795,
    611        0.35373217,   -0.018957434,  0.008907322,  -0.0762701, 0.12018895,
    612        0.04216877,   0.0022856654,  0.040952638,  0.3147856,  0.08225149,
    613        -0.057416286, -0.14995944,   -0.008040261, 0.13208859, 0.029760877});
    614 
    615   lstm.SetRecurrentToInputWeights(
    616       {-0.001374326,   -0.078856036,   0.10672688,    0.029162422,
    617        -0.11585556,    0.02557986,     -0.13446963,   -0.035785314,
    618        -0.01244275,    0.025961924,    -0.02337298,   -0.044228926,
    619        -0.055839065,   -0.046598054,   -0.010546039,  -0.06900766,
    620        0.027239809,    0.022582639,    -0.013296484,  -0.05459212,
    621        0.08981,        -0.045407712,   0.08682226,    -0.06867011,
    622        -0.14390695,    -0.02916037,    0.000996957,   0.091420636,
    623        0.14283475,     -0.07390571,    -0.06402044,   0.062524505,
    624        -0.093129106,   0.04860203,     -0.08364217,   -0.08119002,
    625        0.009352075,    0.22920375,     0.0016303885,  0.11583097,
    626        -0.13732095,    0.012405723,    -0.07551853,   0.06343048,
    627        0.12162708,     -0.031923793,   -0.014335606,  0.01790974,
    628        -0.10650317,    -0.0724401,     0.08554849,    -0.05727212,
    629        0.06556731,     -0.042729504,   -0.043227166,  0.011683251,
    630        -0.013082158,   -0.029302018,   -0.010899579,  -0.062036745,
    631        -0.022509435,   -0.00964907,    -0.01567329,   0.04260106,
    632        -0.07787477,    -0.11576462,    0.017356863,   0.048673786,
    633        -0.017577527,   -0.05527947,    -0.082487635,  -0.040137455,
    634        -0.10820036,    -0.04666372,    0.022746278,   -0.07851417,
    635        0.01068115,     0.032956902,    0.022433773,   0.0026891115,
    636        0.08944216,     -0.0685835,     0.010513544,   0.07228705,
    637        0.02032331,     -0.059686817,   -0.0005566496, -0.086984694,
    638        0.040414046,    -0.1380399,     0.094208956,   -0.05722982,
    639        0.012092817,    -0.04989123,    -0.086576,     -0.003399834,
    640        -0.04696032,    -0.045747425,   0.10091314,    0.048676282,
    641        -0.029037097,   0.031399418,    -0.0040285117, 0.047237843,
    642        0.09504992,     0.041799378,    -0.049185462,  -0.031518843,
    643        -0.10516937,    0.026374253,    0.10058866,    -0.0033195973,
    644        -0.041975245,   0.0073591834,   0.0033782164,  -0.004325073,
    645        -0.10167381,    0.042500053,    -0.01447153,   0.06464186,
    646        -0.017142897,   0.03312627,     0.009205989,   0.024138335,
    647        -0.011337001,   0.035530265,    -0.010912711,  0.0706555,
    648        -0.005894094,   0.051841937,    -0.1401738,    -0.02351249,
    649        0.0365468,      0.07590991,     0.08838724,    0.021681072,
    650        -0.10086113,    0.019608743,    -0.06195883,   0.077335775,
    651        0.023646897,    -0.095322326,   0.02233014,    0.09756986,
    652        -0.048691444,   -0.009579111,   0.07595467,    0.11480546,
    653        -0.09801813,    0.019894179,    0.08502348,    0.004032281,
    654        0.037211012,    0.068537936,    -0.048005626,  -0.091520436,
    655        -0.028379958,   -0.01556313,    0.06554592,    -0.045599163,
    656        -0.01672207,    -0.020169014,   -0.011877351,  -0.20212261,
    657        0.010889619,    0.0047078193,   0.038385306,   0.08540671,
    658        -0.017140968,   -0.0035865551,  0.016678626,   0.005633034,
    659        0.015963363,    0.00871737,     0.060130805,   0.028611384,
    660        0.10109069,     -0.015060172,   -0.07894427,   0.06401885,
    661        0.011584063,    -0.024466386,   0.0047652307,  -0.09041358,
    662        0.030737216,    -0.0046374933,  0.14215417,    -0.11823516,
    663        0.019899689,    0.006106124,    -0.027092824,  0.0786356,
    664        0.05052217,     -0.058925,      -0.011402121,  -0.024987547,
    665        -0.0013661642,  -0.06832946,    -0.015667673,  -0.1083353,
    666        -0.00096863037, -0.06988685,    -0.053350925,  -0.027275559,
    667        -0.033664223,   -0.07978348,    -0.025200296,  -0.017207067,
    668        -0.058403496,   -0.055697463,   0.005798788,   0.12965427,
    669        -0.062582195,   0.0013350133,   -0.10482091,   0.0379771,
    670        0.072521195,    -0.0029455067,  -0.13797039,   -0.03628521,
    671        0.013806405,    -0.017858358,   -0.01008298,   -0.07700066,
    672        -0.017081132,   0.019358726,    0.0027079724,  0.004635139,
    673        0.062634714,    -0.02338735,    -0.039547626,  -0.02050681,
    674        0.03385117,     -0.083611414,   0.002862572,   -0.09421313,
    675        0.058618143,    -0.08598433,    0.00972939,    0.023867095,
    676        -0.053934585,   -0.023203006,   0.07452513,    -0.048767887,
    677        -0.07314807,    -0.056307215,   -0.10433547,   -0.06440842,
    678        0.04328182,     0.04389765,     -0.020006588,  -0.09076438,
    679        -0.11652589,    -0.021705797,   0.03345259,    -0.010329105,
    680        -0.025767034,   0.013057034,    -0.07316461,   -0.10145612,
    681        0.06358255,     0.18531723,     0.07759293,    0.12006465,
    682        0.1305557,      0.058638252,    -0.03393652,   0.09622831,
    683        -0.16253184,    -2.4580743e-06, 0.079869635,   -0.070196845,
    684        -0.005644518,   0.06857898,     -0.12598175,   -0.035084512,
    685        0.03156317,     -0.12794146,    -0.031963028,  0.04692781,
    686        0.030070418,    0.0071660685,   -0.095516115,  -0.004643372,
    687        0.040170413,    -0.062104587,   -0.0037324072, 0.0554317,
    688        0.08184801,     -0.019164372,   0.06791302,    0.034257166,
    689        -0.10307039,    0.021943003,    0.046745934,   0.0790918,
    690        -0.0265588,     -0.007824208,   0.042546265,   -0.00977924,
    691        -0.0002440307,  -0.017384544,   -0.017990116,  0.12252321,
    692        -0.014512694,   -0.08251313,    0.08861942,    0.13589665,
    693        0.026351685,    0.012641483,    0.07466548,    0.044301085,
    694        -0.045414884,   -0.051112458,   0.03444247,    -0.08502782,
    695        -0.04106223,    -0.028126027,   0.028473156,   0.10467447});
    696 
    697   lstm.SetRecurrentToForgetWeights(
    698       {-0.057784554,  -0.026057621,  -0.068447545,   -0.022581743,
    699        0.14811787,    0.10826372,    0.09471067,     0.03987225,
    700        -0.0039523416, 0.00030638507, 0.053185795,    0.10572994,
    701        0.08414449,    -0.022036452,  -0.00066928595, -0.09203576,
    702        0.032950465,   -0.10985798,   -0.023809856,   0.0021431844,
    703        -0.02196096,   -0.00326074,   0.00058621005,  -0.074678116,
    704        -0.06193199,   0.055729095,   0.03736828,     0.020123724,
    705        0.061878487,   -0.04729229,   0.034919553,    -0.07585433,
    706        -0.04421272,   -0.044019096,  0.085488975,    0.04058006,
    707        -0.06890133,   -0.030951202,  -0.024628663,   -0.07672815,
    708        0.034293607,   0.08556707,    -0.05293577,    -0.033561368,
    709        -0.04899627,   0.0241671,     0.015736353,    -0.095442444,
    710        -0.029564252,  0.016493602,   -0.035026584,   0.022337519,
    711        -0.026871363,  0.004780428,   0.0077918363,   -0.03601621,
    712        0.016435321,   -0.03263031,   -0.09543275,    -0.047392778,
    713        0.013454138,   0.028934088,   0.01685226,     -0.086110644,
    714        -0.046250615,  -0.01847454,   0.047608484,    0.07339695,
    715        0.034546845,   -0.04881143,   0.009128804,    -0.08802852,
    716        0.03761666,    0.008096139,   -0.014454086,   0.014361001,
    717        -0.023502491,  -0.0011840804, -0.07607001,    0.001856849,
    718        -0.06509276,   -0.006021153,  -0.08570962,    -0.1451793,
    719        0.060212336,   0.055259194,   0.06974018,     0.049454916,
    720        -0.027794661,  -0.08077226,   -0.016179763,   0.1169753,
    721        0.17213494,    -0.0056326236, -0.053934924,   -0.0124349,
    722        -0.11520337,   0.05409887,    0.088759385,    0.0019655675,
    723        0.0042065294,  0.03881498,    0.019844765,    0.041858196,
    724        -0.05695512,   0.047233116,   0.038937137,    -0.06542224,
    725        0.014429736,   -0.09719407,   0.13908425,     -0.05379757,
    726        0.012321099,   0.082840554,   -0.029899208,   0.044217527,
    727        0.059855383,   0.07711018,    -0.045319796,   0.0948846,
    728        -0.011724666,  -0.0033288454, -0.033542685,   -0.04764985,
    729        -0.13873616,   0.040668588,   0.034832682,    -0.015319203,
    730        -0.018715994,  0.046002675,   0.0599172,      -0.043107376,
    731        0.0294216,     -0.002314414,  -0.022424703,   0.0030315618,
    732        0.0014641669,  0.0029166266,  -0.11878115,    0.013738511,
    733        0.12375372,    -0.0006038222, 0.029104086,    0.087442465,
    734        0.052958444,   0.07558703,    0.04817258,     0.044462286,
    735        -0.015213451,  -0.08783778,   -0.0561384,     -0.003008196,
    736        0.047060397,   -0.002058388,  0.03429439,     -0.018839769,
    737        0.024734668,   0.024614193,   -0.042046934,   0.09597743,
    738        -0.0043254104, 0.04320769,    0.0064070094,   -0.0019131786,
    739        -0.02558259,   -0.022822596,  -0.023273505,   -0.02464396,
    740        -0.10991725,   -0.006240552,  0.0074488563,   0.024044557,
    741        0.04383914,    -0.046476185,  0.028658995,    0.060410924,
    742        0.050786525,   0.009452605,   -0.0073054377,  -0.024810238,
    743        0.0052906186,  0.0066939713,  -0.0020913032,  0.014515517,
    744        0.015898481,   0.021362653,   -0.030262267,   0.016587038,
    745        -0.011442813,  0.041154444,   -0.007631438,   -0.03423484,
    746        -0.010977775,  0.036152758,   0.0066366293,   0.11915515,
    747        0.02318443,    -0.041350313,  0.021485701,    -0.10906167,
    748        -0.028218046,  -0.00954771,   0.020531068,    -0.11995105,
    749        -0.03672871,   0.024019798,   0.014255957,    -0.05221243,
    750        -0.00661567,   -0.04630967,   0.033188973,    0.10107534,
    751        -0.014027541,  0.030796422,   -0.10270911,    -0.035999842,
    752        0.15443139,    0.07684145,    0.036571592,    -0.035900835,
    753        -0.0034699554, 0.06209149,    0.015920248,    -0.031122351,
    754        -0.03858649,   0.01849943,    0.13872518,     0.01503974,
    755        0.069941424,   -0.06948533,   -0.0088794185,  0.061282158,
    756        -0.047401894,  0.03100163,    -0.041533746,   -0.10430945,
    757        0.044574402,   -0.01425562,   -0.024290353,   0.034563623,
    758        0.05866852,    0.023947537,   -0.09445152,    0.035450947,
    759        0.02247216,    -0.0042998926, 0.061146557,    -0.10250651,
    760        0.020881841,   -0.06747029,   0.10062043,     -0.0023941975,
    761        0.03532124,    -0.016341697,  0.09685456,     -0.016764693,
    762        0.051808182,   0.05875331,    -0.04536488,    0.001626336,
    763        -0.028892258,  -0.01048663,   -0.009793449,   -0.017093895,
    764        0.010987891,   0.02357273,    -0.00010856845, 0.0099760275,
    765        -0.001845119,  -0.03551521,   0.0018358806,   0.05763657,
    766        -0.01769146,   0.040995963,   0.02235177,     -0.060430344,
    767        0.11475477,    -0.023854522,  0.10071741,     0.0686208,
    768        -0.014250481,  0.034261297,   0.047418304,    0.08562733,
    769        -0.030519066,  0.0060542435,  0.014653856,    -0.038836084,
    770        0.04096551,    0.032249358,   -0.08355519,    -0.026823482,
    771        0.056386515,   -0.010401743,  -0.028396193,   0.08507674,
    772        0.014410365,   0.020995233,   0.17040324,     0.11511526,
    773        0.02459721,    0.0066619175,  0.025853224,    -0.023133837,
    774        -0.081302024,  0.017264642,   -0.009585969,   0.09491168,
    775        -0.051313367,  0.054532815,   -0.014298593,   0.10657464,
    776        0.007076659,   0.10964551,    0.0409152,      0.008275321,
    777        -0.07283536,   0.07937492,    0.04192024,     -0.1075027});
    778 
    779   lstm.SetRecurrentToCellWeights(
    780       {-0.037322544,   0.018592842,   0.0056175636,  -0.06253426,
    781        0.055647098,    -0.05713207,   -0.05626563,   0.005559383,
    782        0.03375411,     -0.025757805,  -0.088049285,  0.06017052,
    783        -0.06570978,    0.007384076,   0.035123326,   -0.07920549,
    784        0.053676967,    0.044480428,   -0.07663568,   0.0071805613,
    785        0.08089997,     0.05143358,    0.038261272,   0.03339287,
    786        -0.027673481,   0.044746667,   0.028349208,   0.020090483,
    787        -0.019443132,   -0.030755889,  -0.0040000007, 0.04465846,
    788        -0.021585021,   0.0031670958,  0.0053199246,  -0.056117613,
    789        -0.10893326,    0.076739706,   -0.08509834,   -0.027997585,
    790        0.037871376,    0.01449768,    -0.09002357,   -0.06111149,
    791        -0.046195522,   0.0422062,     -0.005683705,  -0.1253618,
    792        -0.012925729,   -0.04890792,   0.06985068,    0.037654128,
    793        0.03398274,     -0.004781977,  0.007032333,   -0.031787455,
    794        0.010868644,    -0.031489216,  0.09525667,    0.013939797,
    795        0.0058680447,   0.0167067,     0.02668468,    -0.04797466,
    796        -0.048885044,   -0.12722108,   0.035304096,   0.06554885,
    797        0.00972396,     -0.039238118,  -0.05159735,   -0.11329045,
    798        0.1613692,      -0.03750952,   0.06529313,    -0.071974665,
    799        -0.11769596,    0.015524369,   -0.0013754242, -0.12446318,
    800        0.02786344,     -0.014179351,  0.005264273,   0.14376344,
    801        0.015983658,    0.03406988,    -0.06939408,   0.040699873,
    802        0.02111075,     0.09669095,    0.041345075,   -0.08316494,
    803        -0.07684199,    -0.045768797,  0.032298047,   -0.041805092,
    804        0.0119405,      0.0061010392,  0.12652606,    0.0064572375,
    805        -0.024950314,   0.11574242,    0.04508852,    -0.04335324,
    806        0.06760663,     -0.027437469,  0.07216407,    0.06977076,
    807        -0.05438599,    0.034033038,   -0.028602652,  0.05346137,
    808        0.043184172,    -0.037189785,  0.10420091,    0.00882477,
    809        -0.054019816,   -0.074273005,  -0.030617684,  -0.0028467078,
    810        0.024302477,    -0.0038869337, 0.005332455,   0.0013399826,
    811        0.04361412,     -0.007001822,  0.09631092,    -0.06702025,
    812        -0.042049985,   -0.035070654,  -0.04103342,   -0.10273396,
    813        0.0544271,      0.037184782,   -0.13150354,   -0.0058036847,
    814        -0.008264958,   0.042035464,   0.05891794,    0.029673764,
    815        0.0063542654,   0.044788733,   0.054816857,   0.062257513,
    816        -0.00093483756, 0.048938446,   -0.004952862,  -0.007730018,
    817        -0.04043371,    -0.017094059,  0.07229206,    -0.023670016,
    818        -0.052195564,   -0.025616996,  -0.01520939,   0.045104615,
    819        -0.007376126,   0.003533447,   0.006570588,   0.056037236,
    820        0.12436656,     0.051817212,   0.028532185,   -0.08686856,
    821        0.11868599,     0.07663395,    -0.07323171,   0.03463402,
    822        -0.050708205,   -0.04458982,   -0.11590894,   0.021273347,
    823        0.1251325,      -0.15313013,   -0.12224372,   0.17228661,
    824        0.023029093,    0.086124025,   0.006445803,   -0.03496501,
    825        0.028332196,    0.04449512,    -0.042436164,  -0.026587414,
    826        -0.006041347,   -0.09292539,   -0.05678812,   0.03897832,
    827        0.09465633,     0.008115513,   -0.02171956,   0.08304309,
    828        0.071401566,    0.019622514,   0.032163795,   -0.004167056,
    829        0.02295182,     0.030739572,   0.056506045,   0.004612461,
    830        0.06524936,     0.059999723,   0.046395954,   -0.0045512207,
    831        -0.1335546,     -0.030136576,  0.11584653,    -0.014678886,
    832        0.0020118146,   -0.09688814,   -0.0790206,    0.039770417,
    833        -0.0329582,     0.07922767,    0.029322514,   0.026405897,
    834        0.04207835,     -0.07073373,   0.063781224,   0.0859677,
    835        -0.10925287,    -0.07011058,   0.048005477,   0.03438226,
    836        -0.09606514,    -0.006669445,  -0.043381985,  0.04240257,
    837        -0.06955775,    -0.06769346,   0.043903265,   -0.026784198,
    838        -0.017840602,   0.024307009,   -0.040079936,  -0.019946516,
    839        0.045318738,    -0.12233574,   0.026170589,   0.0074471775,
    840        0.15978073,     0.10185836,    0.10298046,    -0.015476589,
    841        -0.039390966,   -0.072174534,  0.0739445,     -0.1211869,
    842        -0.0347889,     -0.07943156,   0.014809798,   -0.12412325,
    843        -0.0030663363,  0.039695457,   0.0647603,     -0.08291318,
    844        -0.018529687,   -0.004423833,  0.0037507233,  0.084633216,
    845        -0.01514876,    -0.056505352,  -0.012800942,  -0.06994386,
    846        0.012962922,    -0.031234352,  0.07029052,    0.016418684,
    847        0.03618972,     0.055686004,   -0.08663945,   -0.017404709,
    848        -0.054761406,   0.029065743,   0.052404847,   0.020238016,
    849        0.0048197987,   -0.0214882,    0.07078733,    0.013016777,
    850        0.06262858,     0.009184685,   0.020785125,   -0.043904778,
    851        -0.0270329,     -0.03299152,   -0.060088247,  -0.015162964,
    852        -0.001828936,   0.12642565,    -0.056757294,  0.013586685,
    853        0.09232601,     -0.035886683,  0.06000002,    0.05229691,
    854        -0.052580316,   -0.082029596,  -0.010794592,  0.012947712,
    855        -0.036429964,   -0.085508935,  -0.13127148,   -0.017744139,
    856        0.031502828,    0.036232427,   -0.031581745,  0.023051167,
    857        -0.05325106,    -0.03421577,   0.028793324,   -0.034633752,
    858        -0.009881397,   -0.043551125,  -0.018609839,  0.0019097115,
    859        -0.008799762,   0.056595087,   0.0022273948,  0.055752404});
    860 
    861   lstm.SetRecurrentToOutputWeights({
    862       0.025825322,   -0.05813119,  0.09495884,   -0.045984812,   -0.01255415,
    863       -0.0026479573, -0.08196161,  -0.054914974, -0.0046604523,  -0.029587349,
    864       -0.044576716,  -0.07480124,  -0.082868785, 0.023254942,    0.027502948,
    865       -0.0039728214, -0.08683098,  -0.08116779,  -0.014675607,   -0.037924774,
    866       -0.023314456,  -0.007401714, -0.09255757,  0.029460307,    -0.08829125,
    867       -0.005139627,  -0.08989442,  -0.0555066,   0.13596267,     -0.025062224,
    868       -0.048351806,  -0.03850004,  0.07266485,   -0.022414139,   0.05940088,
    869       0.075114764,   0.09597592,   -0.010211725, -0.0049794707,  -0.011523867,
    870       -0.025980417,  0.072999895,  0.11091378,   -0.081685916,   0.014416728,
    871       0.043229222,   0.034178585,  -0.07530371,  0.035837382,    -0.085607,
    872       -0.007721233,  -0.03287832,  -0.043848954, -0.06404588,    -0.06632928,
    873       -0.073643476,  0.008214239,  -0.045984086, 0.039764922,    0.03474462,
    874       0.060612556,   -0.080590084, 0.049127717,  0.04151091,     -0.030063879,
    875       0.008801774,   -0.023021035, -0.019558564, 0.05158114,     -0.010947698,
    876       -0.011825728,  0.0075720972, 0.0699727,    -0.0039981045,  0.069350146,
    877       0.08799282,    0.016156472,  0.035502106,  0.11695009,     0.006217345,
    878       0.13392477,    -0.037875112, 0.025745004,  0.08940699,     -0.00924166,
    879       0.0046702605,  -0.036598757, -0.08811812,  0.10522024,     -0.032441203,
    880       0.008176899,   -0.04454919,  0.07058152,   0.0067963637,   0.039206743,
    881       0.03259838,    0.03725492,   -0.09515802,  0.013326398,    -0.052055415,
    882       -0.025676316,  0.03198509,   -0.015951829, -0.058556724,   0.036879618,
    883       0.043357447,   0.028362012,  -0.05908629,  0.0059240665,   -0.04995891,
    884       -0.019187413,  0.0276265,    -0.01628143,  0.0025863599,   0.08800015,
    885       0.035250366,   -0.022165963, -0.07328642,  -0.009415526,   -0.07455109,
    886       0.11690406,    0.0363299,    0.07411125,   0.042103454,    -0.009660886,
    887       0.019076364,   0.018299393,  -0.046004917, 0.08891175,     0.0431396,
    888       -0.026327137,  -0.051502608, 0.08979574,   -0.051670972,   0.04940282,
    889       -0.07491107,   -0.021240504, 0.022596184,  -0.034280192,   0.060163025,
    890       -0.058211457,  -0.051837247, -0.01349775,  -0.04639988,    -0.035936575,
    891       -0.011681591,  0.064818054,  0.0073146066, -0.021745546,   -0.043124277,
    892       -0.06471268,   -0.07053354,  -0.029321948, -0.05330136,    0.016933719,
    893       -0.053782392,  0.13747959,   -0.1361751,   -0.11569455,    0.0033329215,
    894       0.05693899,    -0.053219706, 0.063698,     0.07977434,     -0.07924483,
    895       0.06936997,    0.0034815092, -0.007305279, -0.037325785,   -0.07251102,
    896       -0.033633437,  -0.08677009,  0.091591336,  -0.14165086,    0.021752775,
    897       0.019683983,   0.0011612234, -0.058154266, 0.049996935,    0.0288841,
    898       -0.0024567875, -0.14345716,  0.010955264,  -0.10234828,    0.1183656,
    899       -0.0010731248, -0.023590032, -0.072285876, -0.0724771,     -0.026382286,
    900       -0.0014920527, 0.042667855,  0.0018776858, 0.02986552,     0.009814309,
    901       0.0733756,     0.12289186,   0.018043943,  -0.0458958,     0.049412545,
    902       0.033632483,   0.05495232,   0.036686596,  -0.013781798,   -0.010036754,
    903       0.02576849,    -0.08307328,  0.010112348,  0.042521734,    -0.05869831,
    904       -0.071689695,  0.03876447,   -0.13275425,  -0.0352966,     -0.023077697,
    905       0.10285965,    0.084736146,  0.15568255,   -0.00040734606, 0.027835453,
    906       -0.10292561,   -0.032401145, 0.10053256,   -0.026142767,   -0.08271222,
    907       -0.0030240538, -0.016368777, 0.1070414,    0.042672627,    0.013456989,
    908       -0.0437609,    -0.022309763, 0.11576483,   0.04108048,     0.061026827,
    909       -0.0190714,    -0.0869359,   0.037901703,  0.0610107,      0.07202949,
    910       0.01675338,    0.086139716,  -0.08795751,  -0.014898893,   -0.023771819,
    911       -0.01965048,   0.007955471,  -0.043740474, 0.03346837,     -0.10549954,
    912       0.090567775,   0.042013682,  -0.03176985,  0.12569028,     -0.02421228,
    913       -0.029526481,  0.023851605,  0.031539805,  0.05292009,     -0.02344001,
    914       -0.07811758,   -0.08834428,  0.10094801,   0.16594367,     -0.06861939,
    915       -0.021256343,  -0.041093912, -0.06669611,  0.035498552,    0.021757556,
    916       -0.09302526,   -0.015403468, -0.06614931,  -0.051798206,   -0.013874718,
    917       0.03630673,    0.010412845,  -0.08077351,  0.046185967,    0.0035662893,
    918       0.03541868,    -0.094149634, -0.034814864, 0.003128424,    -0.020674974,
    919       -0.03944324,   -0.008110165, -0.11113267,  0.08484226,     0.043586485,
    920       0.040582247,   0.0968012,    -0.065249965, -0.028036479,   0.0050708856,
    921       0.0017462453,  0.0326779,    0.041296225,  0.09164146,     -0.047743853,
    922       -0.015952192,  -0.034451712, 0.084197424,  -0.05347844,    -0.11768019,
    923       0.085926116,   -0.08251791,  -0.045081906, 0.0948852,      0.068401024,
    924       0.024856757,   0.06978981,   -0.057309967, -0.012775832,   -0.0032452994,
    925       0.01977615,    -0.041040014, -0.024264973, 0.063464895,    0.05431621,
    926   });
    927 
    928   lstm.SetCellToInputWeights(
    929       {0.040369894, 0.030746894,  0.24704495,  0.018586371,  -0.037586458,
    930        -0.15312155, -0.11812848,  -0.11465643, 0.20259799,   0.11418174,
    931        -0.10116027, -0.011334949, 0.12411352,  -0.076769054, -0.052169047,
    932        0.21198851,  -0.38871562,  -0.09061183, -0.09683246,  -0.21929175});
    933 
    934   lstm.SetCellToForgetWeights(
    935       {-0.01998659,  -0.15568835,  -0.24248174,   -0.012770197, 0.041331276,
    936        -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
    937        -0.047248036, 0.021479502,  0.033189066,   0.11952997,   -0.020432774,
    938        0.64658105,   -0.06650122,  -0.03467612,   0.095340036,  0.23647355});
    939 
    940   lstm.SetCellToOutputWeights(
    941       {0.08286371,  -0.08261836, -0.51210177, 0.002913762, 0.17764764,
    942        -0.5495371,  -0.08460716, -0.24552552, 0.030037103, 0.04123544,
    943        -0.11940523, 0.007358328, 0.1890978,   0.4833202,   -0.34441817,
    944        0.36312827,  -0.26375428, 0.1457655,   -0.19724406, 0.15548733});
    945 
    946   lstm.SetProjectionWeights(
    947       {-0.009802181,  0.09401916,    0.0717386,     -0.13895074,  0.09641832,
    948        0.060420845,   0.08539281,    0.054285463,   0.061395317,  0.034448683,
    949        -0.042991187,  0.019801661,   -0.16840284,   -0.015726732, -0.23041931,
    950        -0.024478018,  -0.10959692,   -0.013875541,  0.18600968,   -0.061274476,
    951        0.0138165,     -0.08160894,   -0.07661644,   0.032372914,  0.16169067,
    952        0.22465782,    -0.03993472,   -0.004017731,  0.08633481,   -0.28869787,
    953        0.08682067,    0.17240396,    0.014975425,   0.056431185,  0.031037588,
    954        0.16702051,    0.0077946745,  0.15140012,    0.29405436,   0.120285,
    955        -0.188994,     -0.027265169,  0.043389652,   -0.022061434, 0.014777949,
    956        -0.20203483,   0.094781205,   0.19100232,    0.13987629,   -0.036132768,
    957        -0.06426278,   -0.05108664,   0.13221376,    0.009441198,  -0.16715929,
    958        0.15859416,    -0.040437475,  0.050779544,   -0.022187516, 0.012166504,
    959        0.027685808,   -0.07675938,   -0.0055694645, -0.09444123,  0.0046453946,
    960        0.050794356,   0.10770313,    -0.20790008,   -0.07149004,  -0.11425117,
    961        0.008225835,   -0.035802525,  0.14374903,    0.15262283,   0.048710253,
    962        0.1847461,     -0.007487823,  0.11000021,    -0.09542012,  0.22619456,
    963        -0.029149994,  0.08527916,    0.009043713,   0.0042746216, 0.016261552,
    964        0.022461696,   0.12689082,    -0.043589946,  -0.12035478,  -0.08361797,
    965        -0.050666027,  -0.1248618,    -0.1275799,    -0.071875185, 0.07377272,
    966        0.09944291,    -0.18897448,   -0.1593054,    -0.06526116,  -0.040107165,
    967        -0.004618631,  -0.067624845,  -0.007576253,  0.10727444,   0.041546922,
    968        -0.20424393,   0.06907816,    0.050412357,   0.00724631,   0.039827548,
    969        0.12449835,    0.10747581,    0.13708383,    0.09134148,   -0.12617786,
    970        -0.06428341,   0.09956831,    0.1208086,     -0.14676677,  -0.0727722,
    971        0.1126304,     0.010139365,   0.015571211,   -0.038128063, 0.022913318,
    972        -0.042050496,  0.16842307,    -0.060597885,  0.10531834,   -0.06411776,
    973        -0.07451711,   -0.03410368,   -0.13393489,   0.06534304,   0.003620307,
    974        0.04490757,    0.05970546,    0.05197996,    0.02839995,   0.10434969,
    975        -0.013699693,  -0.028353551,  -0.07260381,   0.047201227,  -0.024575593,
    976        -0.036445823,  0.07155557,    0.009672501,   -0.02328883,  0.009533515,
    977        -0.03606021,   -0.07421458,   -0.028082801,  -0.2678904,   -0.13221288,
    978        0.18419984,    -0.13012612,   -0.014588381,  -0.035059117, -0.04824723,
    979        0.07830115,    -0.056184657,  0.03277091,    0.025466874,  0.14494097,
    980        -0.12522776,   -0.098633975,  -0.10766018,   -0.08317623,  0.08594209,
    981        0.07749552,    0.039474737,   0.1776665,     -0.07409566,  -0.0477268,
    982        0.29323658,    0.10801441,    0.1154011,     0.013952499,  0.10739139,
    983        0.10708251,    -0.051456142,  0.0074137426,  -0.10430189,  0.10034707,
    984        0.045594677,   0.0635285,     -0.0715442,    -0.089667566, -0.10811871,
    985        0.00026344223, 0.08298446,    -0.009525053,  0.006585689,  -0.24567553,
    986        -0.09450807,   0.09648481,    0.026996298,   -0.06419476,  -0.04752702,
    987        -0.11063944,   -0.23441927,   -0.17608605,   -0.052156363, 0.067035615,
    988        0.19271925,    -0.0032889997, -0.043264326,  0.09663576,   -0.057112187,
    989        -0.10100678,   0.0628376,     0.04447668,    0.017961001,  -0.10094388,
    990        -0.10190601,   0.18335468,    0.10494553,    -0.052095775, -0.0026118709,
    991        0.10539724,    -0.04383912,   -0.042349473,  0.08438151,   -0.1947263,
    992        0.02251204,    0.11216432,    -0.10307853,   0.17351969,   -0.039091777,
    993        0.08066188,    -0.00561982,   0.12633002,    0.11335965,   -0.0088127935,
    994        -0.019777594,  0.06864014,    -0.059751723,  0.016233567,  -0.06894641,
    995        -0.28651384,   -0.004228674,  0.019708522,   -0.16305895,  -0.07468996,
    996        -0.0855457,    0.099339016,   -0.07580735,   -0.13775392,  0.08434318,
    997        0.08330512,    -0.12131499,   0.031935584,   0.09180414,   -0.08876437,
    998        -0.08049874,   0.008753825,   0.03498998,    0.030215185,  0.03907079,
    999        0.089751154,   0.029194152,   -0.03337423,   -0.019092513, 0.04331237,
   1000        0.04299654,    -0.036394123,  -0.12915532,   0.09793732,   0.07512415,
   1001        -0.11319543,   -0.032502122,  0.15661901,    0.07671967,   -0.005491124,
   1002        -0.19379048,   -0.218606,     0.21448623,    0.017840758,  0.1416943,
   1003        -0.07051762,   0.19488361,    0.02664691,    -0.18104725,  -0.09334311,
   1004        0.15026465,    -0.15493552,   -0.057762887,  -0.11604192,  -0.262013,
   1005        -0.01391798,   0.012185008,   0.11156489,    -0.07483202,  0.06693364,
   1006        -0.26151478,   0.046425626,   0.036540434,   -0.16435726,  0.17338543,
   1007        -0.21401681,   -0.11385144,   -0.08283257,   -0.069031075, 0.030635102,
   1008        0.010969227,   0.11109743,    0.010919218,   0.027526086,  0.13519906,
   1009        0.01891392,    -0.046839405,  -0.040167913,  0.017953383,  -0.09700955,
   1010        0.0061885654,  -0.07000971,   0.026893595,   -0.038844477, 0.14543656});
   1011 
   1012   static float lstm_input[][20] = {
   1013       {// Batch0: 4 (input_sequence_size) * 5 (n_input)
   1014        0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
   1015        0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
   1016        0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
   1017 
   1018       {// Batch1: 4 (input_sequence_size) * 5 (n_input)
   1019        0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
   1020        0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
   1021        0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
   1022 
   1023   static float lstm_golden_output[][64] = {
   1024       {// Batch0: 4 (input_sequence_size) * 16 (n_output)
   1025        -0.00396806, 0.029352,     -0.00279226, 0.0159977,   -0.00835576,
   1026        -0.0211779,  0.0283512,    -0.0114597,  0.00907307,  -0.0244004,
   1027        -0.0152191,  -0.0259063,   0.00914318,  0.00415118,  0.017147,
   1028        0.0134203,   -0.0166936,   0.0381209,   0.000889694, 0.0143363,
   1029        -0.0328911,  -0.0234288,   0.0333051,   -0.012229,   0.0110322,
   1030        -0.0457725,  -0.000832209, -0.0202817,  0.0327257,   0.0121308,
   1031        0.0155969,   0.0312091,    -0.0213783,  0.0350169,   0.000324794,
   1032        0.0276012,   -0.0263374,   -0.0371449,  0.0446149,   -0.0205474,
   1033        0.0103729,   -0.0576349,   -0.0150052,  -0.0292043,  0.0376827,
   1034        0.0136115,   0.0243435,    0.0354492,   -0.0189322,  0.0464512,
   1035        -0.00251373, 0.0225745,    -0.0308346,  -0.0317124,  0.0460407,
   1036        -0.0189395,  0.0149363,    -0.0530162,  -0.0150767,  -0.0340193,
   1037        0.0286833,   0.00824207,   0.0264887,   0.0305169},
   1038       {// Batch1: 4 (input_sequence_size) * 16 (n_output)
   1039        -0.013869,    0.0287268,   -0.00334693, 0.00733398,  -0.0287926,
   1040        -0.0186926,   0.0193662,   -0.0115437,  0.00422612,  -0.0345232,
   1041        0.00223253,   -0.00957321, 0.0210624,   0.013331,    0.0150954,
   1042        0.02168,      -0.0141913,  0.0322082,   0.00227024,  0.0260507,
   1043        -0.0188721,   -0.0296489,  0.0399134,   -0.0160509,  0.0116039,
   1044        -0.0447318,   -0.0150515,  -0.0277406,  0.0316596,   0.0118233,
   1045        0.0214762,    0.0293641,   -0.0204549,  0.0450315,   -0.00117378,
   1046        0.0167673,    -0.0375007,  -0.0238314,  0.038784,    -0.0174034,
   1047        0.0131743,    -0.0506589,  -0.0048447,  -0.0240239,  0.0325789,
   1048        0.00790065,   0.0220157,   0.0333314,   -0.0264787,  0.0387855,
   1049        -0.000764675, 0.0217599,   -0.037537,   -0.0335206,  0.0431679,
   1050        -0.0211424,   0.010203,    -0.062785,   -0.00832363, -0.025181,
   1051        0.0412031,    0.0118723,   0.0239643,   0.0394009}};
   1052 
   1053   // Resetting cell_state and output_state
   1054   lstm.ResetCellState();
   1055   lstm.ResetOutputState();
   1056 
   1057   for (int i = 0; i < lstm.sequence_length(); i++) {
   1058     float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
   1059     float* batch0_end = batch0_start + lstm.num_inputs();
   1060 
   1061     lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end);
   1062 
   1063     float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
   1064     float* batch1_end = batch1_start + lstm.num_inputs();
   1065     lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end);
   1066   }
   1067 
   1068   lstm.Invoke();
   1069 
   1070   std::vector<float> expected;
   1071   for (int i = 0; i < lstm.sequence_length(); i++) {
   1072     float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
   1073     float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
   1074     float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
   1075     float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
   1076     expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
   1077     expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
   1078   }
   1079   EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
   1080 }
   1081 
   1082 }  // namespace
   1083 }  // namespace tflite
   1084 
   1085 int main(int argc, char** argv) {
   1086   ::tflite::LogToStderr();
   1087   ::testing::InitGoogleTest(&argc, argv);
   1088   return RUN_ALL_TESTS();
   1089 }
   1090