Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "QuantizedLSTM.h"
     18 
     19 #include "NeuralNetworksWrapper.h"
     20 #include "gmock/gmock-matchers.h"
     21 #include "gtest/gtest.h"
     22 
     23 #include <iostream>
     24 
     25 namespace android {
     26 namespace nn {
     27 namespace wrapper {
     28 
     29 namespace {
     30 
     31 struct OperandTypeParams {
     32     Type type;
     33     std::vector<uint32_t> shape;
     34     float scale;
     35     int32_t zeroPoint;
     36 
     37     OperandTypeParams(Type type, std::vector<uint32_t> shape, float scale, int32_t zeroPoint)
     38         : type(type), shape(shape), scale(scale), zeroPoint(zeroPoint) {}
     39 };
     40 
     41 }  // namespace
     42 
     43 using ::testing::Each;
     44 using ::testing::ElementsAreArray;
     45 using ::testing::FloatNear;
     46 using ::testing::Matcher;
     47 
     48 class QuantizedLSTMOpModel {
     49    public:
     50     QuantizedLSTMOpModel(const std::vector<OperandTypeParams>& inputOperandTypeParams) {
     51         std::vector<uint32_t> inputs;
     52 
     53         for (int i = 0; i < NUM_INPUTS; ++i) {
     54             const auto& curOTP = inputOperandTypeParams[i];
     55             OperandType curType(curOTP.type, curOTP.shape, curOTP.scale, curOTP.zeroPoint);
     56             inputs.push_back(model_.addOperand(&curType));
     57         }
     58 
     59         const uint32_t numBatches = inputOperandTypeParams[0].shape[0];
     60         inputSize_ = inputOperandTypeParams[0].shape[0];
     61         const uint32_t outputSize =
     62                 inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor].shape[1];
     63         outputSize_ = outputSize;
     64 
     65         std::vector<uint32_t> outputs;
     66         OperandType cellStateOutOperandType(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize},
     67                                             1. / 2048., 0);
     68         outputs.push_back(model_.addOperand(&cellStateOutOperandType));
     69         OperandType outputOperandType(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize},
     70                                       1. / 128., 128);
     71         outputs.push_back(model_.addOperand(&outputOperandType));
     72 
     73         model_.addOperation(ANEURALNETWORKS_QUANTIZED_16BIT_LSTM, inputs, outputs);
     74         model_.identifyInputsAndOutputs(inputs, outputs);
     75 
     76         initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kInputTensor], &input_);
     77         initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevOutputTensor],
     78                             &prevOutput_);
     79         initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor],
     80                             &prevCellState_);
     81 
     82         cellStateOut_.resize(numBatches * outputSize, 0);
     83         output_.resize(numBatches * outputSize, 0);
     84 
     85         model_.finish();
     86     }
     87 
     88     void invoke() {
     89         ASSERT_TRUE(model_.isValid());
     90 
     91         Compilation compilation(&model_);
     92         compilation.finish();
     93         Execution execution(&compilation);
     94 
     95         // Set all the inputs.
     96         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputTensor, input_),
     97                   Result::NO_ERROR);
     98         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToInputWeightsTensor,
     99                                  inputToInputWeights_),
    100                   Result::NO_ERROR);
    101         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToForgetWeightsTensor,
    102                                  inputToForgetWeights_),
    103                   Result::NO_ERROR);
    104         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToCellWeightsTensor,
    105                                  inputToCellWeights_),
    106                   Result::NO_ERROR);
    107         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToOutputWeightsTensor,
    108                                  inputToOutputWeights_),
    109                   Result::NO_ERROR);
    110         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToInputWeightsTensor,
    111                                  recurrentToInputWeights_),
    112                   Result::NO_ERROR);
    113         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToForgetWeightsTensor,
    114                                  recurrentToForgetWeights_),
    115                   Result::NO_ERROR);
    116         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToCellWeightsTensor,
    117                                  recurrentToCellWeights_),
    118                   Result::NO_ERROR);
    119         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToOutputWeightsTensor,
    120                                  recurrentToOutputWeights_),
    121                   Result::NO_ERROR);
    122         ASSERT_EQ(
    123                 setInputTensor(&execution, QuantizedLSTMCell::kInputGateBiasTensor, inputGateBias_),
    124                 Result::NO_ERROR);
    125         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kForgetGateBiasTensor,
    126                                  forgetGateBias_),
    127                   Result::NO_ERROR);
    128         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kCellGateBiasTensor, cellGateBias_),
    129                   Result::NO_ERROR);
    130         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kOutputGateBiasTensor,
    131                                  outputGateBias_),
    132                   Result::NO_ERROR);
    133         ASSERT_EQ(
    134                 setInputTensor(&execution, QuantizedLSTMCell::kPrevCellStateTensor, prevCellState_),
    135                 Result::NO_ERROR);
    136         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kPrevOutputTensor, prevOutput_),
    137                   Result::NO_ERROR);
    138         // Set all the outputs.
    139         ASSERT_EQ(
    140                 setOutputTensor(&execution, QuantizedLSTMCell::kCellStateOutTensor, &cellStateOut_),
    141                 Result::NO_ERROR);
    142         ASSERT_EQ(setOutputTensor(&execution, QuantizedLSTMCell::kOutputTensor, &output_),
    143                   Result::NO_ERROR);
    144 
    145         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    146 
    147         // Put state outputs into inputs for the next step
    148         prevOutput_ = output_;
    149         prevCellState_ = cellStateOut_;
    150     }
    151 
    152     int inputSize() { return inputSize_; }
    153 
    154     int outputSize() { return outputSize_; }
    155 
    156     void setInput(const std::vector<uint8_t>& input) { input_ = input; }
    157 
    158     void setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights,
    159                              std::vector<uint8_t> inputToForgetWeights,
    160                              std::vector<uint8_t> inputToCellWeights,
    161                              std::vector<uint8_t> inputToOutputWeights,
    162                              std::vector<uint8_t> recurrentToInputWeights,
    163                              std::vector<uint8_t> recurrentToForgetWeights,
    164                              std::vector<uint8_t> recurrentToCellWeights,
    165                              std::vector<uint8_t> recurrentToOutputWeights,
    166                              std::vector<int32_t> inputGateBias,
    167                              std::vector<int32_t> forgetGateBias,
    168                              std::vector<int32_t> cellGateBias,  //
    169                              std::vector<int32_t> outputGateBias) {
    170         inputToInputWeights_ = inputToInputWeights;
    171         inputToForgetWeights_ = inputToForgetWeights;
    172         inputToCellWeights_ = inputToCellWeights;
    173         inputToOutputWeights_ = inputToOutputWeights;
    174         recurrentToInputWeights_ = recurrentToInputWeights;
    175         recurrentToForgetWeights_ = recurrentToForgetWeights;
    176         recurrentToCellWeights_ = recurrentToCellWeights;
    177         recurrentToOutputWeights_ = recurrentToOutputWeights;
    178         inputGateBias_ = inputGateBias;
    179         forgetGateBias_ = forgetGateBias;
    180         cellGateBias_ = cellGateBias;
    181         outputGateBias_ = outputGateBias;
    182     }
    183 
    184     template <typename T>
    185     void initializeInputData(OperandTypeParams params, std::vector<T>* vec) {
    186         int size = 1;
    187         for (int d : params.shape) {
    188             size *= d;
    189         }
    190         vec->clear();
    191         vec->resize(size, params.zeroPoint);
    192     }
    193 
    194     std::vector<uint8_t> getOutput() { return output_; }
    195 
    196    private:
    197     static constexpr int NUM_INPUTS = 15;
    198     static constexpr int NUM_OUTPUTS = 2;
    199 
    200     Model model_;
    201     // Inputs
    202     std::vector<uint8_t> input_;
    203     std::vector<uint8_t> inputToInputWeights_;
    204     std::vector<uint8_t> inputToForgetWeights_;
    205     std::vector<uint8_t> inputToCellWeights_;
    206     std::vector<uint8_t> inputToOutputWeights_;
    207     std::vector<uint8_t> recurrentToInputWeights_;
    208     std::vector<uint8_t> recurrentToForgetWeights_;
    209     std::vector<uint8_t> recurrentToCellWeights_;
    210     std::vector<uint8_t> recurrentToOutputWeights_;
    211     std::vector<int32_t> inputGateBias_;
    212     std::vector<int32_t> forgetGateBias_;
    213     std::vector<int32_t> cellGateBias_;
    214     std::vector<int32_t> outputGateBias_;
    215     std::vector<int16_t> prevCellState_;
    216     std::vector<uint8_t> prevOutput_;
    217     // Outputs
    218     std::vector<int16_t> cellStateOut_;
    219     std::vector<uint8_t> output_;
    220 
    221     int inputSize_;
    222     int outputSize_;
    223 
    224     template <typename T>
    225     Result setInputTensor(Execution* execution, int tensor, const std::vector<T>& data) {
    226         return execution->setInput(tensor, data.data(), sizeof(T) * data.size());
    227     }
    228     template <typename T>
    229     Result setOutputTensor(Execution* execution, int tensor, std::vector<T>* data) {
    230         return execution->setOutput(tensor, data->data(), sizeof(T) * data->size());
    231     }
    232 };
    233 
    234 class QuantizedLstmTest : public ::testing::Test {
    235    protected:
    236     void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input,
    237                        const std::vector<std::vector<uint8_t>>& output,
    238                        QuantizedLSTMOpModel* lstm) {
    239         const int numBatches = input.size();
    240         EXPECT_GT(numBatches, 0);
    241         const int inputSize = lstm->inputSize();
    242         EXPECT_GT(inputSize, 0);
    243         const int inputSequenceSize = input[0].size() / inputSize;
    244         EXPECT_GT(inputSequenceSize, 0);
    245         for (int i = 0; i < inputSequenceSize; ++i) {
    246             std::vector<uint8_t> inputStep;
    247             for (int b = 0; b < numBatches; ++b) {
    248                 const uint8_t* batchStart = input[b].data() + i * inputSize;
    249                 const uint8_t* batchEnd = batchStart + inputSize;
    250                 inputStep.insert(inputStep.end(), batchStart, batchEnd);
    251             }
    252             lstm->setInput(inputStep);
    253             lstm->invoke();
    254 
    255             const int outputSize = lstm->outputSize();
    256             std::vector<float> expected;
    257             for (int b = 0; b < numBatches; ++b) {
    258                 const uint8_t* goldenBatchStart = output[b].data() + i * outputSize;
    259                 const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize;
    260                 expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd);
    261             }
    262             EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected));
    263         }
    264     }
    265 };
    266 
    267 // Inputs and weights in this test are random and the test only checks that the
    268 // outputs are equal to outputs obtained from running TF Lite version of
    269 // quantized LSTM on the same inputs.
    270 TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) {
    271     const int numBatches = 2;
    272     const int inputSize = 2;
    273     const int outputSize = 4;
    274 
    275     float weightsScale = 0.00408021;
    276     int weightsZeroPoint = 100;
    277     // OperandType biasOperandType(Type::TENSOR_INT32, input_shapes[3],
    278     // weightsScale / 128., 0);
    279     // inputs.push_back(model_.addOperand(&biasOperandType));
    280     // OperandType prevCellStateOperandType(Type::TENSOR_QUANT16_SYMM, input_shapes[4],
    281     // 1. / 2048., 0);
    282     // inputs.push_back(model_.addOperand(&prevCellStateOperandType));
    283 
    284     QuantizedLSTMOpModel lstm({
    285             // input
    286             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, inputSize}, 1. / 128., 128),
    287             // inputToInputWeights
    288             // inputToForgetWeights
    289             // inputToCellWeights
    290             // inputToOutputWeights
    291             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
    292                               weightsZeroPoint),
    293             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
    294                               weightsZeroPoint),
    295             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
    296                               weightsZeroPoint),
    297             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
    298                               weightsZeroPoint),
    299             // recurrentToInputWeights
    300             // recurrentToForgetWeights
    301             // recurrentToCellWeights
    302             // recurrentToOutputWeights
    303             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
    304                               weightsZeroPoint),
    305             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
    306                               weightsZeroPoint),
    307             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
    308                               weightsZeroPoint),
    309             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
    310                               weightsZeroPoint),
    311             // inputGateBias
    312             // forgetGateBias
    313             // cellGateBias
    314             // outputGateBias
    315             OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
    316             OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
    317             OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
    318             OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
    319             // prevCellState
    320             OperandTypeParams(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize}, 1. / 2048., 0),
    321             // prevOutput
    322             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize}, 1. / 128., 128),
    323     });
    324 
    325     lstm.setWeightsAndBiases(
    326             // inputToInputWeights
    327             {146, 250, 235, 171, 10, 218, 171, 108},
    328             // inputToForgetWeights
    329             {24, 50, 132, 179, 158, 110, 3, 169},
    330             // inputToCellWeights
    331             {133, 34, 29, 49, 206, 109, 54, 183},
    332             // inputToOutputWeights
    333             {195, 187, 11, 99, 109, 10, 218, 48},
    334             // recurrentToInputWeights
    335             {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26},
    336             // recurrentToForgetWeights
    337             {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253},
    338             // recurrentToCellWeights
    339             {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216},
    340             // recurrentToOutputWeights
    341             {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98},
    342             // inputGateBias
    343             {-7876, 13488, -726, 32839},
    344             // forgetGateBias
    345             {9206, -46884, -11693, -38724},
    346             // cellGateBias
    347             {39481, 48624, 48976, -21419},
    348             // outputGateBias
    349             {-58999, -17050, -41852, -40538});
    350 
    351     // LSTM input is stored as numBatches x (sequenceLength x inputSize) vector.
    352     std::vector<std::vector<uint8_t>> lstmInput;
    353     // clang-format off
    354     lstmInput = {{154, 166,
    355                   166, 179,
    356                   141, 141},
    357                  {100, 200,
    358                   50,  150,
    359                   111, 222}};
    360     // clang-format on
    361 
    362     // LSTM output is stored as numBatches x (sequenceLength x outputSize) vector.
    363     std::vector<std::vector<uint8_t>> lstmGoldenOutput;
    364     // clang-format off
    365     lstmGoldenOutput = {{136, 150, 140, 115,
    366                          140, 151, 146, 112,
    367                          139, 153, 146, 114},
    368                         {135, 152, 138, 112,
    369                          136, 156, 142, 112,
    370                          141, 154, 146, 108}};
    371     // clang-format on
    372     VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm);
    373 };
    374 
    375 }  // namespace wrapper
    376 }  // namespace nn
    377 }  // namespace android
    378