Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 #include "LSTM.h"
     17 
     18 #include <android-base/logging.h>
     19 
     20 #include "NeuralNetworksWrapper.h"
     21 #include "gmock/gmock-matchers.h"
     22 #include "gtest/gtest.h"
     23 
     24 #include <sstream>
     25 #include <string>
     26 #include <vector>
     27 
     28 namespace android {
     29 namespace nn {
     30 namespace wrapper {
     31 
     32 using ::testing::Each;
     33 using ::testing::FloatNear;
     34 using ::testing::Matcher;
     35 
     36 namespace {
     37 
     38 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
     39                                            float max_abs_error = 1.e-6) {
     40     std::vector<Matcher<float>> matchers;
     41     matchers.reserve(values.size());
     42     for (const float& v : values) {
     43         matchers.emplace_back(FloatNear(v, max_abs_error));
     44     }
     45     return matchers;
     46 }
     47 
     48 }  // anonymous namespace
     49 
     50 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
     51     ACTION(Input)                                \
     52     ACTION(InputToInputWeights)                  \
     53     ACTION(InputToCellWeights)                   \
     54     ACTION(InputToForgetWeights)                 \
     55     ACTION(InputToOutputWeights)                 \
     56     ACTION(RecurrentToInputWeights)              \
     57     ACTION(RecurrentToCellWeights)               \
     58     ACTION(RecurrentToForgetWeights)             \
     59     ACTION(RecurrentToOutputWeights)             \
     60     ACTION(CellToInputWeights)                   \
     61     ACTION(CellToForgetWeights)                  \
     62     ACTION(CellToOutputWeights)                  \
     63     ACTION(InputGateBias)                        \
     64     ACTION(CellGateBias)                         \
     65     ACTION(ForgetGateBias)                       \
     66     ACTION(OutputGateBias)                       \
     67     ACTION(ProjectionWeights)                    \
     68     ACTION(ProjectionBias)                       \
     69     ACTION(OutputStateIn)                        \
     70     ACTION(CellStateIn)
     71 
     72 #define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
     73     ACTION(InputLayerNormWeights)          \
     74     ACTION(ForgetLayerNormWeights)         \
     75     ACTION(CellLayerNormWeights)           \
     76     ACTION(OutputLayerNormWeights)
     77 
     78 // For all output and intermediate states
     79 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
     80     ACTION(ScratchBuffer)              \
     81     ACTION(OutputStateOut)             \
     82     ACTION(CellStateOut)               \
     83     ACTION(Output)
     84 
     85 class LayerNormLSTMOpModel {
     86    public:
     87     LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
     88                          bool use_cifg, bool use_peephole, bool use_projection_weights,
     89                          bool use_projection_bias, float cell_clip, float proj_clip,
     90                          const std::vector<std::vector<uint32_t>>& input_shapes0)
     91         : n_input_(n_input),
     92           n_output_(n_output),
     93           use_cifg_(use_cifg),
     94           use_peephole_(use_peephole),
     95           use_projection_weights_(use_projection_weights),
     96           use_projection_bias_(use_projection_bias),
     97           activation_(ActivationFn::kActivationTanh),
     98           cell_clip_(cell_clip),
     99           proj_clip_(proj_clip) {
    100         std::vector<uint32_t> inputs;
    101         std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
    102 
    103         auto it = input_shapes.begin();
    104 
    105         // Input and weights
    106 #define AddInput(X)                                     \
    107     CHECK(it != input_shapes.end());                    \
    108     OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
    109     inputs.push_back(model_.addOperand(&X##OpndTy));
    110 
    111         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
    112 
    113         // Parameters
    114         OperandType ActivationOpndTy(Type::INT32, {});
    115         inputs.push_back(model_.addOperand(&ActivationOpndTy));
    116         OperandType CellClipOpndTy(Type::FLOAT32, {});
    117         inputs.push_back(model_.addOperand(&CellClipOpndTy));
    118         OperandType ProjClipOpndTy(Type::FLOAT32, {});
    119         inputs.push_back(model_.addOperand(&ProjClipOpndTy));
    120 
    121         FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
    122 
    123 #undef AddOperand
    124 
    125         // Output and other intermediate state
    126         std::vector<std::vector<uint32_t>> output_shapes{
    127                 {n_batch, n_cell * (use_cifg ? 3 : 4)},
    128                 {n_batch, n_output},
    129                 {n_batch, n_cell},
    130                 {n_batch, n_output},
    131         };
    132         std::vector<uint32_t> outputs;
    133 
    134         auto it2 = output_shapes.begin();
    135 
    136 #define AddOutput(X)                                     \
    137     CHECK(it2 != output_shapes.end());                   \
    138     OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
    139     outputs.push_back(model_.addOperand(&X##OpndTy));
    140 
    141         FOR_ALL_OUTPUT_TENSORS(AddOutput);
    142 
    143 #undef AddOutput
    144 
    145         model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
    146         model_.identifyInputsAndOutputs(inputs, outputs);
    147 
    148         Input_.insert(Input_.end(), n_batch * n_input, 0.f);
    149         OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
    150         CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
    151 
    152         auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
    153             uint32_t sz = 1;
    154             for (uint32_t d : dims) {
    155                 sz *= d;
    156             }
    157             return sz;
    158         };
    159 
    160         it2 = output_shapes.begin();
    161 
    162 #define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
    163 
    164         FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
    165 
    166 #undef ReserveOutput
    167 
    168         model_.finish();
    169     }
    170 
    171 #define DefineSetter(X) \
    172     void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
    173 
    174     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
    175     FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
    176 
    177 #undef DefineSetter
    178 
    179     void ResetOutputState() {
    180         std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
    181         std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
    182     }
    183 
    184     void ResetCellState() {
    185         std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
    186         std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
    187     }
    188 
    189     void SetInput(int offset, const float* begin, const float* end) {
    190         for (; begin != end; begin++, offset++) {
    191             Input_[offset] = *begin;
    192         }
    193     }
    194 
    195     uint32_t num_inputs() const { return n_input_; }
    196     uint32_t num_outputs() const { return n_output_; }
    197 
    198     const std::vector<float>& GetOutput() const { return Output_; }
    199 
    200     void Invoke() {
    201         ASSERT_TRUE(model_.isValid());
    202 
    203         OutputStateIn_.swap(OutputStateOut_);
    204         CellStateIn_.swap(CellStateOut_);
    205 
    206         Compilation compilation(&model_);
    207         compilation.finish();
    208         Execution execution(&compilation);
    209 #define SetInputOrWeight(X)                                                                       \
    210     ASSERT_EQ(                                                                                    \
    211             execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
    212             Result::NO_ERROR);
    213 
    214         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
    215         FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
    216 
    217 #undef SetInputOrWeight
    218 
    219 #define SetOutput(X)                                                                               \
    220     ASSERT_EQ(                                                                                     \
    221             execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
    222             Result::NO_ERROR);
    223 
    224         FOR_ALL_OUTPUT_TENSORS(SetOutput);
    225 
    226 #undef SetOutput
    227 
    228         if (use_cifg_) {
    229             execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
    230             execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
    231         }
    232 
    233         if (use_peephole_) {
    234             if (use_cifg_) {
    235                 execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
    236             }
    237         } else {
    238             execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
    239             execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
    240             execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
    241         }
    242 
    243         if (use_projection_weights_) {
    244             if (!use_projection_bias_) {
    245                 execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
    246             }
    247         } else {
    248             execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
    249             execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
    250         }
    251 
    252         ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
    253                   Result::NO_ERROR);
    254         ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
    255                   Result::NO_ERROR);
    256         ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
    257                   Result::NO_ERROR);
    258 
    259         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    260     }
    261 
    262    private:
    263     Model model_;
    264     // Execution execution_;
    265     const uint32_t n_input_;
    266     const uint32_t n_output_;
    267 
    268     const bool use_cifg_;
    269     const bool use_peephole_;
    270     const bool use_projection_weights_;
    271     const bool use_projection_bias_;
    272 
    273     const int activation_;
    274     const float cell_clip_;
    275     const float proj_clip_;
    276 
    277 #define DefineTensor(X) std::vector<float> X##_;
    278 
    279     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
    280     FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
    281     FOR_ALL_OUTPUT_TENSORS(DefineTensor);
    282 
    283 #undef DefineTensor
    284 };
    285 
    286 TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
    287     const int n_batch = 2;
    288     const int n_input = 5;
    289     // n_cell and n_output have the same size when there is no projection.
    290     const int n_cell = 4;
    291     const int n_output = 3;
    292 
    293     LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
    294                               /*use_cifg=*/false, /*use_peephole=*/true,
    295                               /*use_projection_weights=*/true,
    296                               /*use_projection_bias=*/false,
    297                               /*cell_clip=*/0.0, /*proj_clip=*/0.0,
    298                               {
    299                                       {n_batch, n_input},  // input tensor
    300 
    301                                       {n_cell, n_input},  // input_to_input_weight tensor
    302                                       {n_cell, n_input},  // input_to_forget_weight tensor
    303                                       {n_cell, n_input},  // input_to_cell_weight tensor
    304                                       {n_cell, n_input},  // input_to_output_weight tensor
    305 
    306                                       {n_cell, n_output},  // recurrent_to_input_weight tensor
    307                                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
    308                                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
    309                                       {n_cell, n_output},  // recurrent_to_output_weight tensor
    310 
    311                                       {n_cell},  // cell_to_input_weight tensor
    312                                       {n_cell},  // cell_to_forget_weight tensor
    313                                       {n_cell},  // cell_to_output_weight tensor
    314 
    315                                       {n_cell},  // input_gate_bias tensor
    316                                       {n_cell},  // forget_gate_bias tensor
    317                                       {n_cell},  // cell_bias tensor
    318                                       {n_cell},  // output_gate_bias tensor
    319 
    320                                       {n_output, n_cell},  // projection_weight tensor
    321                                       {0},                 // projection_bias tensor
    322 
    323                                       {n_batch, n_output},  // output_state_in tensor
    324                                       {n_batch, n_cell},    // cell_state_in tensor
    325 
    326                                       {n_cell},  // input_layer_norm_weights tensor
    327                                       {n_cell},  // forget_layer_norm_weights tensor
    328                                       {n_cell},  // cell_layer_norm_weights tensor
    329                                       {n_cell},  // output_layer_norm_weights tensor
    330                               });
    331 
    332     lstm.SetInputToInputWeights({0.5,  0.6, 0.7,  -0.8, -0.9, 0.1,  0.2,  0.3,  -0.4, 0.5,
    333                                  -0.8, 0.7, -0.6, 0.5,  -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
    334 
    335     lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2, -0.4, 0.3,  -0.8,
    336                                   -0.4, 0.3,  -0.5, -0.4, -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5});
    337 
    338     lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
    339                                 0.6,  -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8,  0.6});
    340 
    341     lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
    342                                   0.6,  -0.2, 0.4,  -0.7, -0.3, -0.5, 0.1, 0.5,  -0.6, -0.4});
    343 
    344     lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
    345 
    346     lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
    347 
    348     lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
    349 
    350     lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
    351 
    352     lstm.SetRecurrentToInputWeights(
    353             {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
    354 
    355     lstm.SetRecurrentToCellWeights(
    356             {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
    357 
    358     lstm.SetRecurrentToForgetWeights(
    359             {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
    360 
    361     lstm.SetRecurrentToOutputWeights(
    362             {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
    363 
    364     lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
    365     lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
    366     lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
    367 
    368     lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
    369 
    370     lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
    371     lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
    372     lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
    373     lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
    374 
    375     const std::vector<std::vector<float>> lstm_input = {
    376             {                           // Batch0: 3 (input_sequence_size) * 5 (n_input)
    377              0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
    378              0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
    379              0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2
    380 
    381             {                           // Batch1: 3 (input_sequence_size) * 5 (n_input)
    382              0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
    383              0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
    384              0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
    385     };
    386 
    387     const std::vector<std::vector<float>> lstm_golden_output = {
    388             {
    389                     // Batch0: 3 (input_sequence_size) * 3 (n_output)
    390                     0.0244077, 0.128027, -0.00170918,  // seq 0
    391                     0.0137642, 0.140751, 0.0395835,    // seq 1
    392                     -0.00459231, 0.155278, 0.0837377,  // seq 2
    393             },
    394             {
    395                     // Batch1: 3 (input_sequence_size) * 3 (n_output)
    396                     -0.00692428, 0.0848741, 0.063445,  // seq 0
    397                     -0.00403912, 0.139963, 0.072681,   // seq 1
    398                     0.00752706, 0.161903, 0.0561371,   // seq 2
    399             }};
    400 
    401     // Resetting cell_state and output_state
    402     lstm.ResetCellState();
    403     lstm.ResetOutputState();
    404 
    405     const int input_sequence_size = lstm_input[0].size() / n_input;
    406     for (int i = 0; i < input_sequence_size; i++) {
    407         for (int b = 0; b < n_batch; ++b) {
    408             const float* batch_start = lstm_input[b].data() + i * n_input;
    409             const float* batch_end = batch_start + n_input;
    410 
    411             lstm.SetInput(b * n_input, batch_start, batch_end);
    412         }
    413 
    414         lstm.Invoke();
    415 
    416         std::vector<float> expected;
    417         for (int b = 0; b < n_batch; ++b) {
    418             const float* golden_start = lstm_golden_output[b].data() + i * n_output;
    419             const float* golden_end = golden_start + n_output;
    420             expected.insert(expected.end(), golden_start, golden_end);
    421         }
    422         EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    423     }
    424 }
    425 
    426 }  // namespace wrapper
    427 }  // namespace nn
    428 }  // namespace android
    429