Home | History | Annotate | Download | only in operations
      1 #ifndef FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H
      2 #define FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H
      3 
      4 #include "HalOperation.h"
      5 #include "OperationsUtils.h"
      6 
      7 #include <vector>
      8 
      9 namespace android {
     10 namespace nn {
     11 
     12 struct RunTimeOperandInfo;
     13 
     14 class QuantizedLSTMCell {
     15    public:
     16     QuantizedLSTMCell(const android::hardware::neuralnetworks::V1_2::Operation& operation,
     17                       std::vector<RunTimeOperandInfo>& operands);
     18 
     19     static bool prepare(const android::hardware::neuralnetworks::V1_2::Operation& operation,
     20                         std::vector<RunTimeOperandInfo>& operands, Shape* cellStateShape,
     21                         Shape* outputShape);
     22     bool eval();
     23 
     24     // Inputs:
     25     static constexpr int kInputTensor = 0;
     26     // Input weight tensors of size: {n_cell, n_input}
     27     static constexpr int kInputToInputWeightsTensor = 1;
     28     static constexpr int kInputToForgetWeightsTensor = 2;
     29     static constexpr int kInputToCellWeightsTensor = 3;
     30     static constexpr int kInputToOutputWeightsTensor = 4;
     31 
     32     // Recurrent weight tensors of size {n_cell, n_output}
     33     static constexpr int kRecurrentToInputWeightsTensor = 5;
     34     static constexpr int kRecurrentToForgetWeightsTensor = 6;
     35     static constexpr int kRecurrentToCellWeightsTensor = 7;
     36     static constexpr int kRecurrentToOutputWeightsTensor = 8;
     37 
     38     // Gates bias tensors of size {n_cell}
     39     static constexpr int kInputGateBiasTensor = 9;
     40     static constexpr int kForgetGateBiasTensor = 10;
     41     static constexpr int kCellGateBiasTensor = 11;
     42     static constexpr int kOutputGateBiasTensor = 12;
     43 
     44     static constexpr int kPrevCellStateTensor = 13;
     45     static constexpr int kPrevOutputTensor = 14;
     46 
     47     // Outputs:
     48     static constexpr int kCellStateOutTensor = 0;
     49     static constexpr int kOutputTensor = 1;
     50 
     51    private:
     52     const RunTimeOperandInfo* input_;
     53 
     54     const RunTimeOperandInfo* inputToInputWeights_;
     55     const RunTimeOperandInfo* inputToForgetWeights_;
     56     const RunTimeOperandInfo* inputToCellWeights_;
     57     const RunTimeOperandInfo* inputToOutputWeights_;
     58 
     59     const RunTimeOperandInfo* recurrentToInputWeights_;
     60     const RunTimeOperandInfo* recurrentToForgetWeights_;
     61     const RunTimeOperandInfo* recurrentToCellWeights_;
     62     const RunTimeOperandInfo* recurrentToOutputWeights_;
     63 
     64     const RunTimeOperandInfo* inputGateBias_;
     65     const RunTimeOperandInfo* forgetGateBias_;
     66     const RunTimeOperandInfo* cellGateBias_;
     67     const RunTimeOperandInfo* outputGateBias_;
     68 
     69     const RunTimeOperandInfo* prevCellState_;
     70     const RunTimeOperandInfo* prevOutput_;
     71 
     72     RunTimeOperandInfo* cellStateOut_;
     73     RunTimeOperandInfo* output_;
     74 
     75     void concatenateWeights(const std::vector<uint32_t>& weightsDims, uint8_t* weights);
     76     void concatenateBiases(uint32_t outputSize, int32_t* bias);
     77 };
     78 
     79 }  // namespace nn
     80 }  // namespace android
     81 
     82 #endif  // FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H
     83