Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // Unit test for TFLite LSTM op.
     16 
     17 #include <iomanip>
     18 #include <memory>
     19 #include <vector>
     20 
     21 #include <gmock/gmock.h>
     22 #include <gtest/gtest.h>
     23 #include "tensorflow/lite/interpreter.h"
     24 #include "tensorflow/lite/kernels/register.h"
     25 #include "tensorflow/lite/kernels/test_util.h"
     26 #include "tensorflow/lite/model.h"
     27 
     28 namespace tflite {
     29 namespace {
     30 
     31 class LSTMOpModel : public SingleOpModel {
     32  public:
     33   LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
     34               bool use_peephole, bool use_projection_weights,
     35               bool use_projection_bias, float cell_clip, float proj_clip,
     36               const std::vector<std::vector<int>>& input_shapes)
     37       : n_batch_(n_batch),
     38         n_input_(n_input),
     39         n_cell_(n_cell),
     40         n_output_(n_output) {
     41     input_ = AddInput(TensorType_FLOAT32);
     42 
     43     if (use_cifg) {
     44       input_to_input_weights_ = AddNullInput();
     45     } else {
     46       input_to_input_weights_ = AddInput(TensorType_FLOAT32);
     47     }
     48 
     49     input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
     50     input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
     51     input_to_output_weights_ = AddInput(TensorType_FLOAT32);
     52 
     53     if (use_cifg) {
     54       recurrent_to_input_weights_ = AddNullInput();
     55     } else {
     56       recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
     57     }
     58 
     59     recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
     60     recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
     61     recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
     62 
     63     if (use_peephole) {
     64       if (use_cifg) {
     65         cell_to_input_weights_ = AddNullInput();
     66       } else {
     67         cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
     68       }
     69       cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
     70       cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
     71     } else {
     72       cell_to_input_weights_ = AddNullInput();
     73       cell_to_forget_weights_ = AddNullInput();
     74       cell_to_output_weights_ = AddNullInput();
     75     }
     76 
     77     if (use_cifg) {
     78       input_gate_bias_ = AddNullInput();
     79     } else {
     80       input_gate_bias_ = AddInput(TensorType_FLOAT32);
     81     }
     82     forget_gate_bias_ = AddInput(TensorType_FLOAT32);
     83     cell_bias_ = AddInput(TensorType_FLOAT32);
     84     output_gate_bias_ = AddInput(TensorType_FLOAT32);
     85 
     86     if (use_projection_weights) {
     87       projection_weights_ = AddInput(TensorType_FLOAT32);
     88       if (use_projection_bias) {
     89         projection_bias_ = AddInput(TensorType_FLOAT32);
     90       } else {
     91         projection_bias_ = AddNullInput();
     92       }
     93     } else {
     94       projection_weights_ = AddNullInput();
     95       projection_bias_ = AddNullInput();
     96     }
     97 
     98     // Adding the 2 input state tensors.
     99     input_activation_state_ =
    100         AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
    101     input_cell_state_ =
    102         AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
    103 
    104     output_ = AddOutput(TensorType_FLOAT32);
    105 
    106     SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
    107                  CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
    108                                    cell_clip, proj_clip)
    109                      .Union());
    110     BuildInterpreter(input_shapes);
    111   }
    112 
    113   void SetInputToInputWeights(std::initializer_list<float> f) {
    114     PopulateTensor(input_to_input_weights_, f);
    115   }
    116 
    117   void SetInputToForgetWeights(std::initializer_list<float> f) {
    118     PopulateTensor(input_to_forget_weights_, f);
    119   }
    120 
    121   void SetInputToCellWeights(std::initializer_list<float> f) {
    122     PopulateTensor(input_to_cell_weights_, f);
    123   }
    124 
    125   void SetInputToOutputWeights(std::initializer_list<float> f) {
    126     PopulateTensor(input_to_output_weights_, f);
    127   }
    128 
    129   void SetRecurrentToInputWeights(std::initializer_list<float> f) {
    130     PopulateTensor(recurrent_to_input_weights_, f);
    131   }
    132 
    133   void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
    134     PopulateTensor(recurrent_to_forget_weights_, f);
    135   }
    136 
    137   void SetRecurrentToCellWeights(std::initializer_list<float> f) {
    138     PopulateTensor(recurrent_to_cell_weights_, f);
    139   }
    140 
    141   void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
    142     PopulateTensor(recurrent_to_output_weights_, f);
    143   }
    144 
    145   void SetCellToInputWeights(std::initializer_list<float> f) {
    146     PopulateTensor(cell_to_input_weights_, f);
    147   }
    148 
    149   void SetCellToForgetWeights(std::initializer_list<float> f) {
    150     PopulateTensor(cell_to_forget_weights_, f);
    151   }
    152 
    153   void SetCellToOutputWeights(std::initializer_list<float> f) {
    154     PopulateTensor(cell_to_output_weights_, f);
    155   }
    156 
    157   void SetInputGateBias(std::initializer_list<float> f) {
    158     PopulateTensor(input_gate_bias_, f);
    159   }
    160 
    161   void SetForgetGateBias(std::initializer_list<float> f) {
    162     PopulateTensor(forget_gate_bias_, f);
    163   }
    164 
    165   void SetCellBias(std::initializer_list<float> f) {
    166     PopulateTensor(cell_bias_, f);
    167   }
    168 
    169   void SetOutputGateBias(std::initializer_list<float> f) {
    170     PopulateTensor(output_gate_bias_, f);
    171   }
    172 
    173   void SetProjectionWeights(std::initializer_list<float> f) {
    174     PopulateTensor(projection_weights_, f);
    175   }
    176 
    177   void SetProjectionBias(std::initializer_list<float> f) {
    178     PopulateTensor(projection_bias_, f);
    179   }
    180 
    181   void SetInput(int offset, float* begin, float* end) {
    182     PopulateTensor(input_, offset, begin, end);
    183   }
    184 
    185   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
    186   void Verify() {
    187     auto model = tflite::UnPackModel(builder_.GetBufferPointer());
    188     EXPECT_NE(model, nullptr);
    189   }
    190 
    191   int num_inputs() { return n_input_; }
    192   int num_outputs() { return n_output_; }
    193   int num_cells() { return n_cell_; }
    194   int num_batches() { return n_batch_; }
    195 
    196  private:
    197   int input_;
    198   int input_to_input_weights_;
    199   int input_to_forget_weights_;
    200   int input_to_cell_weights_;
    201   int input_to_output_weights_;
    202 
    203   int recurrent_to_input_weights_;
    204   int recurrent_to_forget_weights_;
    205   int recurrent_to_cell_weights_;
    206   int recurrent_to_output_weights_;
    207 
    208   int cell_to_input_weights_;
    209   int cell_to_forget_weights_;
    210   int cell_to_output_weights_;
    211 
    212   int input_gate_bias_;
    213   int forget_gate_bias_;
    214   int cell_bias_;
    215   int output_gate_bias_;
    216 
    217   int projection_weights_;
    218   int projection_bias_;
    219   int input_activation_state_;
    220   int input_cell_state_;
    221 
    222   int output_;
    223 
    224   int n_batch_;
    225   int n_input_;
    226   int n_cell_;
    227   int n_output_;
    228 };
    229 
    230 TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
    231   const int n_batch = 1;
    232   const int n_input = 2;
    233   // n_cell and n_output have the same size when there is no projection.
    234   const int n_cell = 4;
    235   const int n_output = 4;
    236 
    237   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
    238                    /*use_cifg=*/true, /*use_peephole=*/true,
    239                    /*use_projection_weights=*/false,
    240                    /*use_projection_bias=*/false,
    241                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
    242                    {
    243                        {n_batch, n_input},  // input tensor
    244 
    245                        {0, 0},             // input_to_input_weight tensor
    246                        {n_cell, n_input},  // input_to_forget_weight tensor
    247                        {n_cell, n_input},  // input_to_cell_weight tensor
    248                        {n_cell, n_input},  // input_to_output_weight tensor
    249 
    250                        {0, 0},              // recurrent_to_input_weight tensor
    251                        {n_cell, n_output},  // recurrent_to_forget_weight tensor
    252                        {n_cell, n_output},  // recurrent_to_cell_weight tensor
    253                        {n_cell, n_output},  // recurrent_to_output_weight tensor
    254 
    255                        {0},       // cell_to_input_weight tensor
    256                        {n_cell},  // cell_to_forget_weight tensor
    257                        {n_cell},  // cell_to_output_weight tensor
    258 
    259                        {0},       // input_gate_bias tensor
    260                        {n_cell},  // forget_gate_bias tensor
    261                        {n_cell},  // cell_bias tensor
    262                        {n_cell},  // output_gate_bias tensor
    263 
    264                        {0, 0},  // projection_weight tensor
    265                        {0},     // projection_bias tensor
    266                    });
    267 
    268   lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
    269                               0.04717243, 0.48944736, -0.38535351,
    270                               -0.17212132});
    271 
    272   lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
    273                                 -0.3633365, -0.22755712, 0.28253698, 0.24407166,
    274                                 0.33826375});
    275 
    276   lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
    277                                 -0.09426838, -0.44257352, 0.54939759,
    278                                 0.01533556, 0.42751634});
    279 
    280   lstm.SetCellBias({0., 0., 0., 0.});
    281 
    282   lstm.SetForgetGateBias({1., 1., 1., 1.});
    283 
    284   lstm.SetOutputGateBias({0., 0., 0., 0.});
    285 
    286   lstm.SetRecurrentToCellWeights(
    287       {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
    288        0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
    289        0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
    290        0.21193194});
    291 
    292   lstm.SetRecurrentToForgetWeights(
    293       {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
    294        0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
    295        -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
    296 
    297   lstm.SetRecurrentToOutputWeights(
    298       {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
    299        -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
    300        0.50248802, 0.26114327, -0.43736315, 0.33149987});
    301 
    302   lstm.SetCellToForgetWeights(
    303       {0.47485286, -0.51955009, -0.24458408, 0.31544167});
    304   lstm.SetCellToOutputWeights(
    305       {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
    306 
    307   // Verify the model by unpacking it.
    308   lstm.Verify();
    309 }
    310 
    311 }  // namespace
    312 }  // namespace tflite
    313 
    314 int main(int argc, char** argv) {
    315   ::tflite::LogToStderr();
    316   ::testing::InitGoogleTest(&argc, argv);
    317   return RUN_ALL_TESTS();
    318 }
    319