Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef FRAMEWORKS_ML_NN_BIDIRECTIONAL_SEQUENCE_LSTM_H
     18 #define FRAMEWORKS_ML_NN_BIDIRECTIONAL_SEQUENCE_LSTM_H
     19 
     20 #include "ActivationFunctor.h"
     21 #include "HalOperation.h"
     22 #include "LSTM.h"
     23 #include "OperationsUtils.h"
     24 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
     25 
     26 #include <algorithm>
     27 #include <cmath>
     28 
     29 namespace android {
     30 namespace nn {
     31 
     32 struct RunTimeOperandInfo;
     33 
     34 class BidirectionalSequenceLSTM {
     35    public:
     36     BidirectionalSequenceLSTM(const Operation& operation,
     37                               std::vector<RunTimeOperandInfo>& operands);
     38 
     39     bool Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
     40                  Shape* fwOutputShape, Shape* bwOutputShape);
     41     bool Eval();
     42 
     43     // Input Tensors of size {max_time, n_batch, n_input}
     44     static constexpr int kInputTensor = 0;
     45 
     46     // Forward LSTM cell tensors.
     47     // Input weight tensors of size: {n_cell, n_input}
     48     static constexpr int kFwInputToInputWeightsTensor = 1;  // Optional
     49     static constexpr int kFwInputToForgetWeightsTensor = 2;
     50     static constexpr int kFwInputToCellWeightsTensor = 3;
     51     static constexpr int kFwInputToOutputWeightsTensor = 4;
     52 
     53     // Recurrent weight tensors of size {n_cell, n_output}
     54     static constexpr int kFwRecurrentToInputWeightsTensor = 5;  // Optional
     55     static constexpr int kFwRecurrentToForgetWeightsTensor = 6;
     56     static constexpr int kFwRecurrentToCellWeightsTensor = 7;
     57     static constexpr int kFwRecurrentToOutputWeightsTensor = 8;
     58 
     59     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
     60     static constexpr int kFwCellToInputWeightsTensor = 9;    // Optional
     61     static constexpr int kFwCellToForgetWeightsTensor = 10;  // Optional
     62     static constexpr int kFwCellToOutputWeightsTensor = 11;  // Optional
     63 
     64     // Gates bias tensors of size {n_cell}
     65     static constexpr int kFwInputGateBiasTensor = 12;  // Optional
     66     static constexpr int kFwForgetGateBiasTensor = 13;
     67     static constexpr int kFwCellGateBiasTensor = 14;
     68     static constexpr int kFwOutputGateBiasTensor = 15;
     69 
     70     // Projection weight tensor of size {n_output, n_cell}
     71     static constexpr int kFwProjectionWeightsTensor = 16;  // Optional
     72     // Projection bias tensor of size {n_output}
     73     static constexpr int kFwProjectionBiasTensor = 17;  // Optional
     74 
     75     // Backward LSTM cell tensors.
     76     // Input weight tensors of size: {n_cell, n_input}
     77     static constexpr int kBwInputToInputWeightsTensor = 18;  // Optional
     78     static constexpr int kBwInputToForgetWeightsTensor = 19;
     79     static constexpr int kBwInputToCellWeightsTensor = 20;
     80     static constexpr int kBwInputToOutputWeightsTensor = 21;
     81 
     82     // Recurrent weight tensors of size {n_cell, n_output}
     83     static constexpr int kBwRecurrentToInputWeightsTensor = 22;  // Optional
     84     static constexpr int kBwRecurrentToForgetWeightsTensor = 23;
     85     static constexpr int kBwRecurrentToCellWeightsTensor = 24;
     86     static constexpr int kBwRecurrentToOutputWeightsTensor = 25;
     87 
     88     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
     89     static constexpr int kBwCellToInputWeightsTensor = 26;   // Optional
     90     static constexpr int kBwCellToForgetWeightsTensor = 27;  // Optional
     91     static constexpr int kBwCellToOutputWeightsTensor = 28;  // Optional
     92 
     93     // Gates bias tensors of size {n_cell}
     94     static constexpr int kBwInputGateBiasTensor = 29;  // Optional
     95     static constexpr int kBwForgetGateBiasTensor = 30;
     96     static constexpr int kBwCellGateBiasTensor = 31;
     97     static constexpr int kBwOutputGateBiasTensor = 32;
     98 
     99     // Projection weight tensor of size {n_output, n_cell}
    100     static constexpr int kBwProjectionWeightsTensor = 33;  // Optional
    101     // Projection bias tensor of size {n_output}
    102     static constexpr int kBwProjectionBiasTensor = 34;  // Optional
    103 
    104     // Stateful input tensors that are variables and will be modified by the Op.
    105     // Activation state tensors of size {n_batch, n_output}
    106     static constexpr int kFwInputActivationStateTensor = 35;
    107     // Cell state tensors of size {n_batch, n_cell}
    108     static constexpr int kFwInputCellStateTensor = 36;
    109     // Activation state tensors of size {n_batch, n_output}
    110     static constexpr int kBwInputActivationStateTensor = 37;
    111     // Cell state tensors of size {n_batch, n_cell}
    112     static constexpr int kBwInputCellStateTensor = 38;
    113 
    114     // Used as auxiliary input and weights when stacking for
    115     // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
    116     // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
    117     // (without cross links).
    118     static constexpr int kAuxInputTensor = 39;  // Optional
    119     // Forward weights.
    120     static constexpr int kFwAuxInputToInputWeightsTensor = 40;   // Optional
    121     static constexpr int kFwAuxInputToForgetWeightsTensor = 41;  // Optional
    122     static constexpr int kFwAuxInputToCellWeightsTensor = 42;    // Optional
    123     static constexpr int kFwAuxInputToOutputWeightsTensor = 43;  // Optional
    124     // Backward weights.
    125     static constexpr int kBwAuxInputToInputWeightsTensor = 44;   // Optional
    126     static constexpr int kBwAuxInputToForgetWeightsTensor = 45;  // Optional
    127     static constexpr int kBwAuxInputToCellWeightsTensor = 46;    // Optional
    128     static constexpr int kBwAuxInputToOutputWeightsTensor = 47;  // Optional
    129 
    130     static constexpr int kActivationParam = 48;
    131     static constexpr int kCellClipParam = 49;
    132     static constexpr int kProjClipParam = 50;
    133     static constexpr int kMergeOutputsParam = 51;
    134     static constexpr int kTimeMajorParam = 52;
    135 
    136     // Forward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
    137     static constexpr int kFwInputLayerNormWeightsTensor = 53;   // Optional
    138     static constexpr int kFwForgetLayerNormWeightsTensor = 54;  // Optional
    139     static constexpr int kFwCellLayerNormWeightsTensor = 55;    // Optional
    140     static constexpr int kFwOutputLayerNormWeightsTensor = 56;  // Optional
    141     // Backward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
    142     static constexpr int kBwInputLayerNormWeightsTensor = 57;   // Optional
    143     static constexpr int kBwForgetLayerNormWeightsTensor = 58;  // Optional
    144     static constexpr int kBwCellLayerNormWeightsTensor = 59;    // Optional
    145     static constexpr int kBwOutputLayerNormWeightsTensor = 60;  // Optional
    146 
    147     // Output tensors.
    148     static constexpr int kFwOutputTensor = 0;
    149     static constexpr int kBwOutputTensor = 1;  // Ignored if merge_outputs is set.
    150 
    151    private:
    152     LSTMParams params_;
    153     Shape fw_scratch_shape_;
    154     Shape bw_scratch_shape_;
    155 
    156     const RunTimeOperandInfo* input_;
    157 
    158     const RunTimeOperandInfo* aux_input_;
    159     const RunTimeOperandInfo* fw_aux_input_to_input_weights_;
    160     const RunTimeOperandInfo* fw_aux_input_to_forget_weights_;
    161     const RunTimeOperandInfo* fw_aux_input_to_cell_weights_;
    162     const RunTimeOperandInfo* fw_aux_input_to_output_weights_;
    163     const RunTimeOperandInfo* bw_aux_input_to_input_weights_;
    164     const RunTimeOperandInfo* bw_aux_input_to_forget_weights_;
    165     const RunTimeOperandInfo* bw_aux_input_to_cell_weights_;
    166     const RunTimeOperandInfo* bw_aux_input_to_output_weights_;
    167 
    168     const RunTimeOperandInfo* fw_input_to_input_weights_;
    169     const RunTimeOperandInfo* fw_input_to_forget_weights_;
    170     const RunTimeOperandInfo* fw_input_to_cell_weights_;
    171     const RunTimeOperandInfo* fw_input_to_output_weights_;
    172 
    173     const RunTimeOperandInfo* fw_recurrent_to_input_weights_;
    174     const RunTimeOperandInfo* fw_recurrent_to_forget_weights_;
    175     const RunTimeOperandInfo* fw_recurrent_to_cell_weights_;
    176     const RunTimeOperandInfo* fw_recurrent_to_output_weights_;
    177 
    178     const RunTimeOperandInfo* fw_cell_to_input_weights_;
    179     const RunTimeOperandInfo* fw_cell_to_forget_weights_;
    180     const RunTimeOperandInfo* fw_cell_to_output_weights_;
    181 
    182     const RunTimeOperandInfo* fw_input_gate_bias_;
    183     const RunTimeOperandInfo* fw_forget_gate_bias_;
    184     const RunTimeOperandInfo* fw_cell_bias_;
    185     const RunTimeOperandInfo* fw_output_gate_bias_;
    186 
    187     const RunTimeOperandInfo* fw_projection_weights_;
    188     const RunTimeOperandInfo* fw_projection_bias_;
    189 
    190     const RunTimeOperandInfo* fw_input_layer_norm_weights_;
    191     const RunTimeOperandInfo* fw_forget_layer_norm_weights_;
    192     const RunTimeOperandInfo* fw_cell_layer_norm_weights_;
    193     const RunTimeOperandInfo* fw_output_layer_norm_weights_;
    194 
    195     RunTimeOperandInfo* fw_activation_state_;
    196     RunTimeOperandInfo* fw_cell_state_;
    197     RunTimeOperandInfo* fw_output_;
    198 
    199     const RunTimeOperandInfo* bw_input_to_input_weights_;
    200     const RunTimeOperandInfo* bw_input_to_forget_weights_;
    201     const RunTimeOperandInfo* bw_input_to_cell_weights_;
    202     const RunTimeOperandInfo* bw_input_to_output_weights_;
    203 
    204     const RunTimeOperandInfo* bw_recurrent_to_input_weights_;
    205     const RunTimeOperandInfo* bw_recurrent_to_forget_weights_;
    206     const RunTimeOperandInfo* bw_recurrent_to_cell_weights_;
    207     const RunTimeOperandInfo* bw_recurrent_to_output_weights_;
    208 
    209     const RunTimeOperandInfo* bw_cell_to_input_weights_;
    210     const RunTimeOperandInfo* bw_cell_to_forget_weights_;
    211     const RunTimeOperandInfo* bw_cell_to_output_weights_;
    212 
    213     const RunTimeOperandInfo* bw_input_gate_bias_;
    214     const RunTimeOperandInfo* bw_forget_gate_bias_;
    215     const RunTimeOperandInfo* bw_cell_bias_;
    216     const RunTimeOperandInfo* bw_output_gate_bias_;
    217 
    218     const RunTimeOperandInfo* bw_projection_weights_;
    219     const RunTimeOperandInfo* bw_projection_bias_;
    220 
    221     const RunTimeOperandInfo* bw_input_layer_norm_weights_;
    222     const RunTimeOperandInfo* bw_forget_layer_norm_weights_;
    223     const RunTimeOperandInfo* bw_cell_layer_norm_weights_;
    224     const RunTimeOperandInfo* bw_output_layer_norm_weights_;
    225 
    226     RunTimeOperandInfo* bw_activation_state_;
    227     RunTimeOperandInfo* bw_cell_state_;
    228     RunTimeOperandInfo* bw_output_;
    229 };
    230 
    231 }  // namespace nn
    232 }  // namespace android
    233 
    234 #endif  // FRAMEWORKS_ML_NN_BIDIRECTIONAL_SEQUENCE_LSTM_H
    235