1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "LSTM.h" 17 18 #include <android-base/logging.h> 19 20 #include "NeuralNetworksWrapper.h" 21 #include "gmock/gmock-matchers.h" 22 #include "gtest/gtest.h" 23 24 #include <sstream> 25 #include <string> 26 #include <vector> 27 28 namespace android { 29 namespace nn { 30 namespace wrapper { 31 32 using ::testing::Each; 33 using ::testing::FloatNear; 34 using ::testing::Matcher; 35 36 namespace { 37 38 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, 39 float max_abs_error = 1.e-6) { 40 std::vector<Matcher<float>> matchers; 41 matchers.reserve(values.size()); 42 for (const float& v : values) { 43 matchers.emplace_back(FloatNear(v, max_abs_error)); 44 } 45 return matchers; 46 } 47 48 } // anonymous namespace 49 50 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ 51 ACTION(Input) \ 52 ACTION(InputToInputWeights) \ 53 ACTION(InputToCellWeights) \ 54 ACTION(InputToForgetWeights) \ 55 ACTION(InputToOutputWeights) \ 56 ACTION(RecurrentToInputWeights) \ 57 ACTION(RecurrentToCellWeights) \ 58 ACTION(RecurrentToForgetWeights) \ 59 ACTION(RecurrentToOutputWeights) \ 60 ACTION(CellToInputWeights) \ 61 ACTION(CellToForgetWeights) \ 62 ACTION(CellToOutputWeights) \ 63 ACTION(InputGateBias) \ 64 ACTION(CellGateBias) \ 65 ACTION(ForgetGateBias) \ 66 ACTION(OutputGateBias) \ 67 ACTION(ProjectionWeights) \ 68 ACTION(ProjectionBias) \ 69 ACTION(OutputStateIn) \ 70 ACTION(CellStateIn) 71 72 #define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \ 73 ACTION(InputLayerNormWeights) \ 74 ACTION(ForgetLayerNormWeights) \ 75 ACTION(CellLayerNormWeights) \ 76 ACTION(OutputLayerNormWeights) 77 78 // For all output and intermediate states 79 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \ 80 ACTION(ScratchBuffer) \ 81 ACTION(OutputStateOut) \ 82 ACTION(CellStateOut) \ 83 ACTION(Output) 84 85 class LayerNormLSTMOpModel { 86 public: 87 LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output, 88 bool use_cifg, bool use_peephole, bool use_projection_weights, 89 bool use_projection_bias, float cell_clip, float proj_clip, 90 const std::vector<std::vector<uint32_t>>& input_shapes0) 91 : n_input_(n_input), 92 n_output_(n_output), 93 use_cifg_(use_cifg), 94 use_peephole_(use_peephole), 95 use_projection_weights_(use_projection_weights), 96 use_projection_bias_(use_projection_bias), 97 activation_(ActivationFn::kActivationTanh), 98 cell_clip_(cell_clip), 99 proj_clip_(proj_clip) { 100 std::vector<uint32_t> inputs; 101 std::vector<std::vector<uint32_t>> input_shapes(input_shapes0); 102 103 auto it = input_shapes.begin(); 104 105 // Input and weights 106 #define AddInput(X) \ 107 CHECK(it != input_shapes.end()); \ 108 OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \ 109 inputs.push_back(model_.addOperand(&X##OpndTy)); 110 111 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput); 112 113 // Parameters 114 OperandType ActivationOpndTy(Type::INT32, {}); 115 inputs.push_back(model_.addOperand(&ActivationOpndTy)); 116 OperandType CellClipOpndTy(Type::FLOAT32, {}); 117 inputs.push_back(model_.addOperand(&CellClipOpndTy)); 118 OperandType ProjClipOpndTy(Type::FLOAT32, {}); 119 inputs.push_back(model_.addOperand(&ProjClipOpndTy)); 120 121 FOR_ALL_LAYER_NORM_WEIGHTS(AddInput); 122 123 #undef AddOperand 124 125 // Output and other intermediate state 126 std::vector<std::vector<uint32_t>> output_shapes{ 127 {n_batch, n_cell * (use_cifg ? 3 : 4)}, 128 {n_batch, n_output}, 129 {n_batch, n_cell}, 130 {n_batch, n_output}, 131 }; 132 std::vector<uint32_t> outputs; 133 134 auto it2 = output_shapes.begin(); 135 136 #define AddOutput(X) \ 137 CHECK(it2 != output_shapes.end()); \ 138 OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \ 139 outputs.push_back(model_.addOperand(&X##OpndTy)); 140 141 FOR_ALL_OUTPUT_TENSORS(AddOutput); 142 143 #undef AddOutput 144 145 model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs); 146 model_.identifyInputsAndOutputs(inputs, outputs); 147 148 Input_.insert(Input_.end(), n_batch * n_input, 0.f); 149 OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f); 150 CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f); 151 152 auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t { 153 uint32_t sz = 1; 154 for (uint32_t d : dims) { 155 sz *= d; 156 } 157 return sz; 158 }; 159 160 it2 = output_shapes.begin(); 161 162 #define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f); 163 164 FOR_ALL_OUTPUT_TENSORS(ReserveOutput); 165 166 #undef ReserveOutput 167 168 model_.finish(); 169 } 170 171 #define DefineSetter(X) \ 172 void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); } 173 174 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); 175 FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter); 176 177 #undef DefineSetter 178 179 void ResetOutputState() { 180 std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f); 181 std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f); 182 } 183 184 void ResetCellState() { 185 std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f); 186 std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f); 187 } 188 189 void SetInput(int offset, const float* begin, const float* end) { 190 for (; begin != end; begin++, offset++) { 191 Input_[offset] = *begin; 192 } 193 } 194 195 uint32_t num_inputs() const { return n_input_; } 196 uint32_t num_outputs() const { return n_output_; } 197 198 const std::vector<float>& GetOutput() const { return Output_; } 199 200 void Invoke() { 201 ASSERT_TRUE(model_.isValid()); 202 203 OutputStateIn_.swap(OutputStateOut_); 204 CellStateIn_.swap(CellStateOut_); 205 206 Compilation compilation(&model_); 207 compilation.finish(); 208 Execution execution(&compilation); 209 #define SetInputOrWeight(X) \ 210 ASSERT_EQ( \ 211 execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \ 212 Result::NO_ERROR); 213 214 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); 215 FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight); 216 217 #undef SetInputOrWeight 218 219 #define SetOutput(X) \ 220 ASSERT_EQ( \ 221 execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \ 222 Result::NO_ERROR); 223 224 FOR_ALL_OUTPUT_TENSORS(SetOutput); 225 226 #undef SetOutput 227 228 if (use_cifg_) { 229 execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0); 230 execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0); 231 } 232 233 if (use_peephole_) { 234 if (use_cifg_) { 235 execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0); 236 } 237 } else { 238 execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0); 239 execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0); 240 execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0); 241 } 242 243 if (use_projection_weights_) { 244 if (!use_projection_bias_) { 245 execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0); 246 } 247 } else { 248 execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0); 249 execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0); 250 } 251 252 ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)), 253 Result::NO_ERROR); 254 ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)), 255 Result::NO_ERROR); 256 ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)), 257 Result::NO_ERROR); 258 259 ASSERT_EQ(execution.compute(), Result::NO_ERROR); 260 } 261 262 private: 263 Model model_; 264 // Execution execution_; 265 const uint32_t n_input_; 266 const uint32_t n_output_; 267 268 const bool use_cifg_; 269 const bool use_peephole_; 270 const bool use_projection_weights_; 271 const bool use_projection_bias_; 272 273 const int activation_; 274 const float cell_clip_; 275 const float proj_clip_; 276 277 #define DefineTensor(X) std::vector<float> X##_; 278 279 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); 280 FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor); 281 FOR_ALL_OUTPUT_TENSORS(DefineTensor); 282 283 #undef DefineTensor 284 }; 285 286 TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) { 287 const int n_batch = 2; 288 const int n_input = 5; 289 // n_cell and n_output have the same size when there is no projection. 290 const int n_cell = 4; 291 const int n_output = 3; 292 293 LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output, 294 /*use_cifg=*/false, /*use_peephole=*/true, 295 /*use_projection_weights=*/true, 296 /*use_projection_bias=*/false, 297 /*cell_clip=*/0.0, /*proj_clip=*/0.0, 298 { 299 {n_batch, n_input}, // input tensor 300 301 {n_cell, n_input}, // input_to_input_weight tensor 302 {n_cell, n_input}, // input_to_forget_weight tensor 303 {n_cell, n_input}, // input_to_cell_weight tensor 304 {n_cell, n_input}, // input_to_output_weight tensor 305 306 {n_cell, n_output}, // recurrent_to_input_weight tensor 307 {n_cell, n_output}, // recurrent_to_forget_weight tensor 308 {n_cell, n_output}, // recurrent_to_cell_weight tensor 309 {n_cell, n_output}, // recurrent_to_output_weight tensor 310 311 {n_cell}, // cell_to_input_weight tensor 312 {n_cell}, // cell_to_forget_weight tensor 313 {n_cell}, // cell_to_output_weight tensor 314 315 {n_cell}, // input_gate_bias tensor 316 {n_cell}, // forget_gate_bias tensor 317 {n_cell}, // cell_bias tensor 318 {n_cell}, // output_gate_bias tensor 319 320 {n_output, n_cell}, // projection_weight tensor 321 {0}, // projection_bias tensor 322 323 {n_batch, n_output}, // output_state_in tensor 324 {n_batch, n_cell}, // cell_state_in tensor 325 326 {n_cell}, // input_layer_norm_weights tensor 327 {n_cell}, // forget_layer_norm_weights tensor 328 {n_cell}, // cell_layer_norm_weights tensor 329 {n_cell}, // output_layer_norm_weights tensor 330 }); 331 332 lstm.SetInputToInputWeights({0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5, 333 -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}); 334 335 lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8, 336 -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}); 337 338 lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6, 339 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}); 340 341 lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2, 342 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}); 343 344 lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38}); 345 346 lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1}); 347 348 lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08}); 349 350 lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1}); 351 352 lstm.SetRecurrentToInputWeights( 353 {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}); 354 355 lstm.SetRecurrentToCellWeights( 356 {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}); 357 358 lstm.SetRecurrentToForgetWeights( 359 {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}); 360 361 lstm.SetRecurrentToOutputWeights( 362 {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}); 363 364 lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15}); 365 lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03}); 366 lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05}); 367 368 lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}); 369 370 lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5}); 371 lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3}); 372 lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8}); 373 lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5}); 374 375 const std::vector<std::vector<float>> lstm_input = { 376 { // Batch0: 3 (input_sequence_size) * 5 (n_input) 377 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0 378 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1 379 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2 380 381 { // Batch1: 3 (input_sequence_size) * 5 (n_input) 382 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0 383 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1 384 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2 385 }; 386 387 const std::vector<std::vector<float>> lstm_golden_output = { 388 { 389 // Batch0: 3 (input_sequence_size) * 3 (n_output) 390 0.0244077, 0.128027, -0.00170918, // seq 0 391 0.0137642, 0.140751, 0.0395835, // seq 1 392 -0.00459231, 0.155278, 0.0837377, // seq 2 393 }, 394 { 395 // Batch1: 3 (input_sequence_size) * 3 (n_output) 396 -0.00692428, 0.0848741, 0.063445, // seq 0 397 -0.00403912, 0.139963, 0.072681, // seq 1 398 0.00752706, 0.161903, 0.0561371, // seq 2 399 }}; 400 401 // Resetting cell_state and output_state 402 lstm.ResetCellState(); 403 lstm.ResetOutputState(); 404 405 const int input_sequence_size = lstm_input[0].size() / n_input; 406 for (int i = 0; i < input_sequence_size; i++) { 407 for (int b = 0; b < n_batch; ++b) { 408 const float* batch_start = lstm_input[b].data() + i * n_input; 409 const float* batch_end = batch_start + n_input; 410 411 lstm.SetInput(b * n_input, batch_start, batch_end); 412 } 413 414 lstm.Invoke(); 415 416 std::vector<float> expected; 417 for (int b = 0; b < n_batch; ++b) { 418 const float* golden_start = lstm_golden_output[b].data() + i * n_output; 419 const float* golden_end = golden_start + n_output; 420 expected.insert(expected.end(), golden_start, golden_end); 421 } 422 EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); 423 } 424 } 425 426 } // namespace wrapper 427 } // namespace nn 428 } // namespace android 429