1 #ifndef FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H 2 #define FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H 3 4 #include "HalOperation.h" 5 #include "OperationsUtils.h" 6 7 #include <vector> 8 9 namespace android { 10 namespace nn { 11 12 struct RunTimeOperandInfo; 13 14 class QuantizedLSTMCell { 15 public: 16 QuantizedLSTMCell(const android::hardware::neuralnetworks::V1_2::Operation& operation, 17 std::vector<RunTimeOperandInfo>& operands); 18 19 static bool prepare(const android::hardware::neuralnetworks::V1_2::Operation& operation, 20 std::vector<RunTimeOperandInfo>& operands, Shape* cellStateShape, 21 Shape* outputShape); 22 bool eval(); 23 24 // Inputs: 25 static constexpr int kInputTensor = 0; 26 // Input weight tensors of size: {n_cell, n_input} 27 static constexpr int kInputToInputWeightsTensor = 1; 28 static constexpr int kInputToForgetWeightsTensor = 2; 29 static constexpr int kInputToCellWeightsTensor = 3; 30 static constexpr int kInputToOutputWeightsTensor = 4; 31 32 // Recurrent weight tensors of size {n_cell, n_output} 33 static constexpr int kRecurrentToInputWeightsTensor = 5; 34 static constexpr int kRecurrentToForgetWeightsTensor = 6; 35 static constexpr int kRecurrentToCellWeightsTensor = 7; 36 static constexpr int kRecurrentToOutputWeightsTensor = 8; 37 38 // Gates bias tensors of size {n_cell} 39 static constexpr int kInputGateBiasTensor = 9; 40 static constexpr int kForgetGateBiasTensor = 10; 41 static constexpr int kCellGateBiasTensor = 11; 42 static constexpr int kOutputGateBiasTensor = 12; 43 44 static constexpr int kPrevCellStateTensor = 13; 45 static constexpr int kPrevOutputTensor = 14; 46 47 // Outputs: 48 static constexpr int kCellStateOutTensor = 0; 49 static constexpr int kOutputTensor = 1; 50 51 private: 52 const RunTimeOperandInfo* input_; 53 54 const RunTimeOperandInfo* inputToInputWeights_; 55 const RunTimeOperandInfo* inputToForgetWeights_; 56 const RunTimeOperandInfo* inputToCellWeights_; 57 const RunTimeOperandInfo* inputToOutputWeights_; 58 59 const RunTimeOperandInfo* recurrentToInputWeights_; 60 const RunTimeOperandInfo* recurrentToForgetWeights_; 61 const RunTimeOperandInfo* recurrentToCellWeights_; 62 const RunTimeOperandInfo* recurrentToOutputWeights_; 63 64 const RunTimeOperandInfo* inputGateBias_; 65 const RunTimeOperandInfo* forgetGateBias_; 66 const RunTimeOperandInfo* cellGateBias_; 67 const RunTimeOperandInfo* outputGateBias_; 68 69 const RunTimeOperandInfo* prevCellState_; 70 const RunTimeOperandInfo* prevOutput_; 71 72 RunTimeOperandInfo* cellStateOut_; 73 RunTimeOperandInfo* output_; 74 75 void concatenateWeights(const std::vector<uint32_t>& weightsDims, uint8_t* weights); 76 void concatenateBiases(uint32_t outputSize, int32_t* bias); 77 }; 78 79 } // namespace nn 80 } // namespace android 81 82 #endif // FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H 83