1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "QuantizedLSTM.h" 18 19 #include "NeuralNetworksWrapper.h" 20 #include "gmock/gmock-matchers.h" 21 #include "gtest/gtest.h" 22 23 #include <iostream> 24 25 namespace android { 26 namespace nn { 27 namespace wrapper { 28 29 namespace { 30 31 struct OperandTypeParams { 32 Type type; 33 std::vector<uint32_t> shape; 34 float scale; 35 int32_t zeroPoint; 36 37 OperandTypeParams(Type type, std::vector<uint32_t> shape, float scale, int32_t zeroPoint) 38 : type(type), shape(shape), scale(scale), zeroPoint(zeroPoint) {} 39 }; 40 41 } // namespace 42 43 using ::testing::Each; 44 using ::testing::ElementsAreArray; 45 using ::testing::FloatNear; 46 using ::testing::Matcher; 47 48 class QuantizedLSTMOpModel { 49 public: 50 QuantizedLSTMOpModel(const std::vector<OperandTypeParams>& inputOperandTypeParams) { 51 std::vector<uint32_t> inputs; 52 53 for (int i = 0; i < NUM_INPUTS; ++i) { 54 const auto& curOTP = inputOperandTypeParams[i]; 55 OperandType curType(curOTP.type, curOTP.shape, curOTP.scale, curOTP.zeroPoint); 56 inputs.push_back(model_.addOperand(&curType)); 57 } 58 59 const uint32_t numBatches = inputOperandTypeParams[0].shape[0]; 60 inputSize_ = inputOperandTypeParams[0].shape[0]; 61 const uint32_t outputSize = 62 inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor].shape[1]; 63 outputSize_ = outputSize; 64 65 std::vector<uint32_t> outputs; 66 OperandType cellStateOutOperandType(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize}, 67 1. / 2048., 0); 68 outputs.push_back(model_.addOperand(&cellStateOutOperandType)); 69 OperandType outputOperandType(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize}, 70 1. / 128., 128); 71 outputs.push_back(model_.addOperand(&outputOperandType)); 72 73 model_.addOperation(ANEURALNETWORKS_QUANTIZED_16BIT_LSTM, inputs, outputs); 74 model_.identifyInputsAndOutputs(inputs, outputs); 75 76 initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kInputTensor], &input_); 77 initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevOutputTensor], 78 &prevOutput_); 79 initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor], 80 &prevCellState_); 81 82 cellStateOut_.resize(numBatches * outputSize, 0); 83 output_.resize(numBatches * outputSize, 0); 84 85 model_.finish(); 86 } 87 88 void invoke() { 89 ASSERT_TRUE(model_.isValid()); 90 91 Compilation compilation(&model_); 92 compilation.finish(); 93 Execution execution(&compilation); 94 95 // Set all the inputs. 96 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputTensor, input_), 97 Result::NO_ERROR); 98 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToInputWeightsTensor, 99 inputToInputWeights_), 100 Result::NO_ERROR); 101 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToForgetWeightsTensor, 102 inputToForgetWeights_), 103 Result::NO_ERROR); 104 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToCellWeightsTensor, 105 inputToCellWeights_), 106 Result::NO_ERROR); 107 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToOutputWeightsTensor, 108 inputToOutputWeights_), 109 Result::NO_ERROR); 110 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToInputWeightsTensor, 111 recurrentToInputWeights_), 112 Result::NO_ERROR); 113 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToForgetWeightsTensor, 114 recurrentToForgetWeights_), 115 Result::NO_ERROR); 116 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToCellWeightsTensor, 117 recurrentToCellWeights_), 118 Result::NO_ERROR); 119 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToOutputWeightsTensor, 120 recurrentToOutputWeights_), 121 Result::NO_ERROR); 122 ASSERT_EQ( 123 setInputTensor(&execution, QuantizedLSTMCell::kInputGateBiasTensor, inputGateBias_), 124 Result::NO_ERROR); 125 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kForgetGateBiasTensor, 126 forgetGateBias_), 127 Result::NO_ERROR); 128 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kCellGateBiasTensor, cellGateBias_), 129 Result::NO_ERROR); 130 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kOutputGateBiasTensor, 131 outputGateBias_), 132 Result::NO_ERROR); 133 ASSERT_EQ( 134 setInputTensor(&execution, QuantizedLSTMCell::kPrevCellStateTensor, prevCellState_), 135 Result::NO_ERROR); 136 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kPrevOutputTensor, prevOutput_), 137 Result::NO_ERROR); 138 // Set all the outputs. 139 ASSERT_EQ( 140 setOutputTensor(&execution, QuantizedLSTMCell::kCellStateOutTensor, &cellStateOut_), 141 Result::NO_ERROR); 142 ASSERT_EQ(setOutputTensor(&execution, QuantizedLSTMCell::kOutputTensor, &output_), 143 Result::NO_ERROR); 144 145 ASSERT_EQ(execution.compute(), Result::NO_ERROR); 146 147 // Put state outputs into inputs for the next step 148 prevOutput_ = output_; 149 prevCellState_ = cellStateOut_; 150 } 151 152 int inputSize() { return inputSize_; } 153 154 int outputSize() { return outputSize_; } 155 156 void setInput(const std::vector<uint8_t>& input) { input_ = input; } 157 158 void setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights, 159 std::vector<uint8_t> inputToForgetWeights, 160 std::vector<uint8_t> inputToCellWeights, 161 std::vector<uint8_t> inputToOutputWeights, 162 std::vector<uint8_t> recurrentToInputWeights, 163 std::vector<uint8_t> recurrentToForgetWeights, 164 std::vector<uint8_t> recurrentToCellWeights, 165 std::vector<uint8_t> recurrentToOutputWeights, 166 std::vector<int32_t> inputGateBias, 167 std::vector<int32_t> forgetGateBias, 168 std::vector<int32_t> cellGateBias, // 169 std::vector<int32_t> outputGateBias) { 170 inputToInputWeights_ = inputToInputWeights; 171 inputToForgetWeights_ = inputToForgetWeights; 172 inputToCellWeights_ = inputToCellWeights; 173 inputToOutputWeights_ = inputToOutputWeights; 174 recurrentToInputWeights_ = recurrentToInputWeights; 175 recurrentToForgetWeights_ = recurrentToForgetWeights; 176 recurrentToCellWeights_ = recurrentToCellWeights; 177 recurrentToOutputWeights_ = recurrentToOutputWeights; 178 inputGateBias_ = inputGateBias; 179 forgetGateBias_ = forgetGateBias; 180 cellGateBias_ = cellGateBias; 181 outputGateBias_ = outputGateBias; 182 } 183 184 template <typename T> 185 void initializeInputData(OperandTypeParams params, std::vector<T>* vec) { 186 int size = 1; 187 for (int d : params.shape) { 188 size *= d; 189 } 190 vec->clear(); 191 vec->resize(size, params.zeroPoint); 192 } 193 194 std::vector<uint8_t> getOutput() { return output_; } 195 196 private: 197 static constexpr int NUM_INPUTS = 15; 198 static constexpr int NUM_OUTPUTS = 2; 199 200 Model model_; 201 // Inputs 202 std::vector<uint8_t> input_; 203 std::vector<uint8_t> inputToInputWeights_; 204 std::vector<uint8_t> inputToForgetWeights_; 205 std::vector<uint8_t> inputToCellWeights_; 206 std::vector<uint8_t> inputToOutputWeights_; 207 std::vector<uint8_t> recurrentToInputWeights_; 208 std::vector<uint8_t> recurrentToForgetWeights_; 209 std::vector<uint8_t> recurrentToCellWeights_; 210 std::vector<uint8_t> recurrentToOutputWeights_; 211 std::vector<int32_t> inputGateBias_; 212 std::vector<int32_t> forgetGateBias_; 213 std::vector<int32_t> cellGateBias_; 214 std::vector<int32_t> outputGateBias_; 215 std::vector<int16_t> prevCellState_; 216 std::vector<uint8_t> prevOutput_; 217 // Outputs 218 std::vector<int16_t> cellStateOut_; 219 std::vector<uint8_t> output_; 220 221 int inputSize_; 222 int outputSize_; 223 224 template <typename T> 225 Result setInputTensor(Execution* execution, int tensor, const std::vector<T>& data) { 226 return execution->setInput(tensor, data.data(), sizeof(T) * data.size()); 227 } 228 template <typename T> 229 Result setOutputTensor(Execution* execution, int tensor, std::vector<T>* data) { 230 return execution->setOutput(tensor, data->data(), sizeof(T) * data->size()); 231 } 232 }; 233 234 class QuantizedLstmTest : public ::testing::Test { 235 protected: 236 void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input, 237 const std::vector<std::vector<uint8_t>>& output, 238 QuantizedLSTMOpModel* lstm) { 239 const int numBatches = input.size(); 240 EXPECT_GT(numBatches, 0); 241 const int inputSize = lstm->inputSize(); 242 EXPECT_GT(inputSize, 0); 243 const int inputSequenceSize = input[0].size() / inputSize; 244 EXPECT_GT(inputSequenceSize, 0); 245 for (int i = 0; i < inputSequenceSize; ++i) { 246 std::vector<uint8_t> inputStep; 247 for (int b = 0; b < numBatches; ++b) { 248 const uint8_t* batchStart = input[b].data() + i * inputSize; 249 const uint8_t* batchEnd = batchStart + inputSize; 250 inputStep.insert(inputStep.end(), batchStart, batchEnd); 251 } 252 lstm->setInput(inputStep); 253 lstm->invoke(); 254 255 const int outputSize = lstm->outputSize(); 256 std::vector<float> expected; 257 for (int b = 0; b < numBatches; ++b) { 258 const uint8_t* goldenBatchStart = output[b].data() + i * outputSize; 259 const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize; 260 expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd); 261 } 262 EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected)); 263 } 264 } 265 }; 266 267 // Inputs and weights in this test are random and the test only checks that the 268 // outputs are equal to outputs obtained from running TF Lite version of 269 // quantized LSTM on the same inputs. 270 TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) { 271 const int numBatches = 2; 272 const int inputSize = 2; 273 const int outputSize = 4; 274 275 float weightsScale = 0.00408021; 276 int weightsZeroPoint = 100; 277 // OperandType biasOperandType(Type::TENSOR_INT32, input_shapes[3], 278 // weightsScale / 128., 0); 279 // inputs.push_back(model_.addOperand(&biasOperandType)); 280 // OperandType prevCellStateOperandType(Type::TENSOR_QUANT16_SYMM, input_shapes[4], 281 // 1. / 2048., 0); 282 // inputs.push_back(model_.addOperand(&prevCellStateOperandType)); 283 284 QuantizedLSTMOpModel lstm({ 285 // input 286 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, inputSize}, 1. / 128., 128), 287 // inputToInputWeights 288 // inputToForgetWeights 289 // inputToCellWeights 290 // inputToOutputWeights 291 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale, 292 weightsZeroPoint), 293 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale, 294 weightsZeroPoint), 295 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale, 296 weightsZeroPoint), 297 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale, 298 weightsZeroPoint), 299 // recurrentToInputWeights 300 // recurrentToForgetWeights 301 // recurrentToCellWeights 302 // recurrentToOutputWeights 303 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale, 304 weightsZeroPoint), 305 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale, 306 weightsZeroPoint), 307 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale, 308 weightsZeroPoint), 309 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale, 310 weightsZeroPoint), 311 // inputGateBias 312 // forgetGateBias 313 // cellGateBias 314 // outputGateBias 315 OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0), 316 OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0), 317 OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0), 318 OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0), 319 // prevCellState 320 OperandTypeParams(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize}, 1. / 2048., 0), 321 // prevOutput 322 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize}, 1. / 128., 128), 323 }); 324 325 lstm.setWeightsAndBiases( 326 // inputToInputWeights 327 {146, 250, 235, 171, 10, 218, 171, 108}, 328 // inputToForgetWeights 329 {24, 50, 132, 179, 158, 110, 3, 169}, 330 // inputToCellWeights 331 {133, 34, 29, 49, 206, 109, 54, 183}, 332 // inputToOutputWeights 333 {195, 187, 11, 99, 109, 10, 218, 48}, 334 // recurrentToInputWeights 335 {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26}, 336 // recurrentToForgetWeights 337 {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253}, 338 // recurrentToCellWeights 339 {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216}, 340 // recurrentToOutputWeights 341 {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98}, 342 // inputGateBias 343 {-7876, 13488, -726, 32839}, 344 // forgetGateBias 345 {9206, -46884, -11693, -38724}, 346 // cellGateBias 347 {39481, 48624, 48976, -21419}, 348 // outputGateBias 349 {-58999, -17050, -41852, -40538}); 350 351 // LSTM input is stored as numBatches x (sequenceLength x inputSize) vector. 352 std::vector<std::vector<uint8_t>> lstmInput; 353 // clang-format off 354 lstmInput = {{154, 166, 355 166, 179, 356 141, 141}, 357 {100, 200, 358 50, 150, 359 111, 222}}; 360 // clang-format on 361 362 // LSTM output is stored as numBatches x (sequenceLength x outputSize) vector. 363 std::vector<std::vector<uint8_t>> lstmGoldenOutput; 364 // clang-format off 365 lstmGoldenOutput = {{136, 150, 140, 115, 366 140, 151, 146, 112, 367 139, 153, 146, 114}, 368 {135, 152, 138, 112, 369 136, 156, 142, 112, 370 141, 154, 146, 108}}; 371 // clang-format on 372 VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm); 373 }; 374 375 } // namespace wrapper 376 } // namespace nn 377 } // namespace android 378