1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Unit test for TFLite LSTM op. 16 17 #include <iomanip> 18 #include <memory> 19 #include <vector> 20 21 #include <gmock/gmock.h> 22 #include <gtest/gtest.h> 23 #include "tensorflow/lite/interpreter.h" 24 #include "tensorflow/lite/kernels/register.h" 25 #include "tensorflow/lite/kernels/test_util.h" 26 #include "tensorflow/lite/model.h" 27 28 namespace tflite { 29 namespace { 30 31 class LSTMOpModel : public SingleOpModel { 32 public: 33 LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, 34 bool use_peephole, bool use_projection_weights, 35 bool use_projection_bias, float cell_clip, float proj_clip, 36 const std::vector<std::vector<int>>& input_shapes) 37 : n_batch_(n_batch), 38 n_input_(n_input), 39 n_cell_(n_cell), 40 n_output_(n_output) { 41 input_ = AddInput(TensorType_FLOAT32); 42 43 if (use_cifg) { 44 input_to_input_weights_ = AddNullInput(); 45 } else { 46 input_to_input_weights_ = AddInput(TensorType_FLOAT32); 47 } 48 49 input_to_forget_weights_ = AddInput(TensorType_FLOAT32); 50 input_to_cell_weights_ = AddInput(TensorType_FLOAT32); 51 input_to_output_weights_ = AddInput(TensorType_FLOAT32); 52 53 if (use_cifg) { 54 recurrent_to_input_weights_ = AddNullInput(); 55 } else { 56 recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); 57 } 58 59 recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); 60 recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); 61 recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); 62 63 if (use_peephole) { 64 if (use_cifg) { 65 cell_to_input_weights_ = AddNullInput(); 66 } else { 67 cell_to_input_weights_ = AddInput(TensorType_FLOAT32); 68 } 69 cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); 70 cell_to_output_weights_ = AddInput(TensorType_FLOAT32); 71 } else { 72 cell_to_input_weights_ = AddNullInput(); 73 cell_to_forget_weights_ = AddNullInput(); 74 cell_to_output_weights_ = AddNullInput(); 75 } 76 77 if (use_cifg) { 78 input_gate_bias_ = AddNullInput(); 79 } else { 80 input_gate_bias_ = AddInput(TensorType_FLOAT32); 81 } 82 forget_gate_bias_ = AddInput(TensorType_FLOAT32); 83 cell_bias_ = AddInput(TensorType_FLOAT32); 84 output_gate_bias_ = AddInput(TensorType_FLOAT32); 85 86 if (use_projection_weights) { 87 projection_weights_ = AddInput(TensorType_FLOAT32); 88 if (use_projection_bias) { 89 projection_bias_ = AddInput(TensorType_FLOAT32); 90 } else { 91 projection_bias_ = AddNullInput(); 92 } 93 } else { 94 projection_weights_ = AddNullInput(); 95 projection_bias_ = AddNullInput(); 96 } 97 98 // Adding the 2 input state tensors. 99 input_activation_state_ = 100 AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); 101 input_cell_state_ = 102 AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); 103 104 output_ = AddOutput(TensorType_FLOAT32); 105 106 SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, 107 CreateLSTMOptions(builder_, ActivationFunctionType_TANH, 108 cell_clip, proj_clip) 109 .Union()); 110 BuildInterpreter(input_shapes); 111 } 112 113 void SetInputToInputWeights(std::initializer_list<float> f) { 114 PopulateTensor(input_to_input_weights_, f); 115 } 116 117 void SetInputToForgetWeights(std::initializer_list<float> f) { 118 PopulateTensor(input_to_forget_weights_, f); 119 } 120 121 void SetInputToCellWeights(std::initializer_list<float> f) { 122 PopulateTensor(input_to_cell_weights_, f); 123 } 124 125 void SetInputToOutputWeights(std::initializer_list<float> f) { 126 PopulateTensor(input_to_output_weights_, f); 127 } 128 129 void SetRecurrentToInputWeights(std::initializer_list<float> f) { 130 PopulateTensor(recurrent_to_input_weights_, f); 131 } 132 133 void SetRecurrentToForgetWeights(std::initializer_list<float> f) { 134 PopulateTensor(recurrent_to_forget_weights_, f); 135 } 136 137 void SetRecurrentToCellWeights(std::initializer_list<float> f) { 138 PopulateTensor(recurrent_to_cell_weights_, f); 139 } 140 141 void SetRecurrentToOutputWeights(std::initializer_list<float> f) { 142 PopulateTensor(recurrent_to_output_weights_, f); 143 } 144 145 void SetCellToInputWeights(std::initializer_list<float> f) { 146 PopulateTensor(cell_to_input_weights_, f); 147 } 148 149 void SetCellToForgetWeights(std::initializer_list<float> f) { 150 PopulateTensor(cell_to_forget_weights_, f); 151 } 152 153 void SetCellToOutputWeights(std::initializer_list<float> f) { 154 PopulateTensor(cell_to_output_weights_, f); 155 } 156 157 void SetInputGateBias(std::initializer_list<float> f) { 158 PopulateTensor(input_gate_bias_, f); 159 } 160 161 void SetForgetGateBias(std::initializer_list<float> f) { 162 PopulateTensor(forget_gate_bias_, f); 163 } 164 165 void SetCellBias(std::initializer_list<float> f) { 166 PopulateTensor(cell_bias_, f); 167 } 168 169 void SetOutputGateBias(std::initializer_list<float> f) { 170 PopulateTensor(output_gate_bias_, f); 171 } 172 173 void SetProjectionWeights(std::initializer_list<float> f) { 174 PopulateTensor(projection_weights_, f); 175 } 176 177 void SetProjectionBias(std::initializer_list<float> f) { 178 PopulateTensor(projection_bias_, f); 179 } 180 181 void SetInput(int offset, float* begin, float* end) { 182 PopulateTensor(input_, offset, begin, end); 183 } 184 185 std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 186 void Verify() { 187 auto model = tflite::UnPackModel(builder_.GetBufferPointer()); 188 EXPECT_NE(model, nullptr); 189 } 190 191 int num_inputs() { return n_input_; } 192 int num_outputs() { return n_output_; } 193 int num_cells() { return n_cell_; } 194 int num_batches() { return n_batch_; } 195 196 private: 197 int input_; 198 int input_to_input_weights_; 199 int input_to_forget_weights_; 200 int input_to_cell_weights_; 201 int input_to_output_weights_; 202 203 int recurrent_to_input_weights_; 204 int recurrent_to_forget_weights_; 205 int recurrent_to_cell_weights_; 206 int recurrent_to_output_weights_; 207 208 int cell_to_input_weights_; 209 int cell_to_forget_weights_; 210 int cell_to_output_weights_; 211 212 int input_gate_bias_; 213 int forget_gate_bias_; 214 int cell_bias_; 215 int output_gate_bias_; 216 217 int projection_weights_; 218 int projection_bias_; 219 int input_activation_state_; 220 int input_cell_state_; 221 222 int output_; 223 224 int n_batch_; 225 int n_input_; 226 int n_cell_; 227 int n_output_; 228 }; 229 230 TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { 231 const int n_batch = 1; 232 const int n_input = 2; 233 // n_cell and n_output have the same size when there is no projection. 234 const int n_cell = 4; 235 const int n_output = 4; 236 237 LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, 238 /*use_cifg=*/true, /*use_peephole=*/true, 239 /*use_projection_weights=*/false, 240 /*use_projection_bias=*/false, 241 /*cell_clip=*/0.0, /*proj_clip=*/0.0, 242 { 243 {n_batch, n_input}, // input tensor 244 245 {0, 0}, // input_to_input_weight tensor 246 {n_cell, n_input}, // input_to_forget_weight tensor 247 {n_cell, n_input}, // input_to_cell_weight tensor 248 {n_cell, n_input}, // input_to_output_weight tensor 249 250 {0, 0}, // recurrent_to_input_weight tensor 251 {n_cell, n_output}, // recurrent_to_forget_weight tensor 252 {n_cell, n_output}, // recurrent_to_cell_weight tensor 253 {n_cell, n_output}, // recurrent_to_output_weight tensor 254 255 {0}, // cell_to_input_weight tensor 256 {n_cell}, // cell_to_forget_weight tensor 257 {n_cell}, // cell_to_output_weight tensor 258 259 {0}, // input_gate_bias tensor 260 {n_cell}, // forget_gate_bias tensor 261 {n_cell}, // cell_bias tensor 262 {n_cell}, // output_gate_bias tensor 263 264 {0, 0}, // projection_weight tensor 265 {0}, // projection_bias tensor 266 }); 267 268 lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, 269 0.04717243, 0.48944736, -0.38535351, 270 -0.17212132}); 271 272 lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, 273 -0.3633365, -0.22755712, 0.28253698, 0.24407166, 274 0.33826375}); 275 276 lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, 277 -0.09426838, -0.44257352, 0.54939759, 278 0.01533556, 0.42751634}); 279 280 lstm.SetCellBias({0., 0., 0., 0.}); 281 282 lstm.SetForgetGateBias({1., 1., 1., 1.}); 283 284 lstm.SetOutputGateBias({0., 0., 0., 0.}); 285 286 lstm.SetRecurrentToCellWeights( 287 {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, 288 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, 289 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, 290 0.21193194}); 291 292 lstm.SetRecurrentToForgetWeights( 293 {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, 294 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, 295 -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); 296 297 lstm.SetRecurrentToOutputWeights( 298 {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, 299 -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, 300 0.50248802, 0.26114327, -0.43736315, 0.33149987}); 301 302 lstm.SetCellToForgetWeights( 303 {0.47485286, -0.51955009, -0.24458408, 0.31544167}); 304 lstm.SetCellToOutputWeights( 305 {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); 306 307 // Verify the model by unpacking it. 308 lstm.Verify(); 309 } 310 311 } // namespace 312 } // namespace tflite 313 314 int main(int argc, char** argv) { 315 ::tflite::LogToStderr(); 316 ::testing::InitGoogleTest(&argc, argv); 317 return RUN_ALL_TESTS(); 318 } 319