Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef FRAMEWORKS_ML_NN_LSTMCELL_H
     18 #define FRAMEWORKS_ML_NN_LSTMCELL_H
     19 
     20 #include "ActivationFunctor.h"
     21 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
     22 
     23 #include <algorithm>
     24 #include <cmath>
     25 
     26 namespace android {
     27 namespace hardware {
     28 namespace neuralnetworks {
     29 namespace V1_1 {
     30 struct Operation;
     31 }
     32 }  // namespace neuralnetworks
     33 }  // namespace hardware
     34 }  // namespace android
     35 
     36 namespace android {
     37 namespace nn {
     38 
     39 struct LSTMParams {
     40   TfLiteFusedActivation activation_;
     41   float cell_clip_;
     42   float proj_clip_;
     43 };
     44 
     45 struct RunTimeOperandInfo;
     46 struct Shape;
     47 
     48 class LSTMCell {
     49  public:
     50   LSTMCell(const android::hardware::neuralnetworks::V1_1::Operation &operation,
     51            std::vector<RunTimeOperandInfo> &operands);
     52 
     53   static bool Prepare(const android::hardware::neuralnetworks::V1_1::Operation &operation,
     54                       std::vector<RunTimeOperandInfo> &operands,
     55                       Shape *scratchShape,
     56                       Shape *outputStateShape,
     57                       Shape *cellStateShape,
     58                       Shape *outputShape);
     59   bool Eval();
     60 
     61   // Input Tensors of size {n_batch, n_input}
     62   static constexpr int kInputTensor = 0;
     63 
     64   // Input weight tensors of size: {n_cell, n_input}
     65   static constexpr int kInputToInputWeightsTensor = 1;  // Optional
     66   static constexpr int kInputToForgetWeightsTensor = 2;
     67   static constexpr int kInputToCellWeightsTensor = 3;
     68   static constexpr int kInputToOutputWeightsTensor = 4;
     69 
     70   // Recurrent weight tensors of size {n_cell, n_output}
     71   static constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
     72   static constexpr int kRecurrentToForgetWeightsTensor = 6;
     73   static constexpr int kRecurrentToCellWeightsTensor = 7;
     74   static constexpr int kRecurrentToOutputWeightsTensor = 8;
     75 
     76   // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
     77   static constexpr int kCellToInputWeightsTensor = 9;    // Optional
     78   static constexpr int kCellToForgetWeightsTensor = 10;  // Optional
     79   static constexpr int kCellToOutputWeightsTensor = 11;  // Optional
     80 
     81   // Gates bias tensors of size {n_cell}
     82   static constexpr int kInputGateBiasTensor = 12;  // Optional
     83   static constexpr int kForgetGateBiasTensor = 13;
     84   static constexpr int kCellGateBiasTensor = 14;
     85   static constexpr int kOutputGateBiasTensor = 15;
     86 
     87   // Projection weight tensor of size {n_output, n_cell}
     88   static constexpr int kProjectionWeightsTensor = 16;  // Optional
     89   // Projection bias tensor of size {n_output}
     90   static constexpr int kProjectionBiasTensor = 17;  // Optional
     91 
     92   static constexpr int kOutputStateInTensor = 18;
     93   static constexpr int kCellStateInTensor = 19;
     94 
     95   static constexpr int kActivationParam = 20;
     96   static constexpr int kCellClipParam = 21;
     97   static constexpr int kProjClipParam = 22;
     98 
     99   // Output tensors.
    100   static constexpr int kScratchBufferTensor = 0;
    101   static constexpr int kOutputStateOutTensor = 1;
    102   static constexpr int kCellStateOutTensor = 2;
    103   static constexpr int kOutputTensor = 3;
    104 
    105  private:
    106   static bool CheckInputTensorDimensions(
    107       const android::hardware::neuralnetworks::V1_1::Operation &operation,
    108       std::vector<RunTimeOperandInfo> &operands, uint32_t n_input,
    109       uint32_t n_output, uint32_t n_cell);
    110   LSTMParams params_;
    111 
    112   const RunTimeOperandInfo *input_;
    113 
    114   const RunTimeOperandInfo *input_to_input_weights_;
    115   const RunTimeOperandInfo *input_to_forget_weights_;
    116   const RunTimeOperandInfo *input_to_cell_weights_;
    117   const RunTimeOperandInfo *input_to_output_weights_;
    118 
    119   const RunTimeOperandInfo *recurrent_to_input_weights_;
    120   const RunTimeOperandInfo *recurrent_to_forget_weights_;
    121   const RunTimeOperandInfo *recurrent_to_cell_weights_;
    122   const RunTimeOperandInfo *recurrent_to_output_weights_;
    123 
    124   const RunTimeOperandInfo *cell_to_input_weights_;
    125   const RunTimeOperandInfo *cell_to_forget_weights_;
    126   const RunTimeOperandInfo *cell_to_output_weights_;
    127 
    128   const RunTimeOperandInfo *input_gate_bias_;
    129   const RunTimeOperandInfo *forget_gate_bias_;
    130   const RunTimeOperandInfo *cell_bias_;
    131   const RunTimeOperandInfo *output_gate_bias_;
    132 
    133   const RunTimeOperandInfo *projection_weights_;
    134   const RunTimeOperandInfo *projection_bias_;
    135 
    136   const RunTimeOperandInfo *output_state_in_;
    137   const RunTimeOperandInfo *cell_state_in_;
    138 
    139   RunTimeOperandInfo *output_state_out_;
    140   RunTimeOperandInfo *cell_state_out_;
    141   RunTimeOperandInfo *output_;
    142 
    143   RunTimeOperandInfo *scratch_buffer_;
    144 };
    145 
    146 }  // namespace nn
    147 }  // namespace android
    148 
    149 #endif  // FRAMEWORKS_ML_NN_LSTMCELL_H
    150