Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "BidirectionalSequenceLSTM.h"
     18 
     19 #include "CpuExecutor.h"
     20 #include "CpuOperationUtils.h"
     21 #include "HalInterfaces.h"
     22 #include "OperationsUtils.h"
     23 
     24 #include "Tracing.h"
     25 
     26 namespace android {
     27 namespace nn {
     28 
     29 namespace {
     30 
     31 template <typename T>
     32 inline T* GetBuffer(RunTimeOperandInfo* operand) {
     33     return reinterpret_cast<T*>(operand->buffer);
     34 }
     35 
     36 template <typename T>
     37 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
     38     return reinterpret_cast<const T*>(operand->buffer);
     39 }
     40 
     41 template <typename T>
     42 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
     43     return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
     44 }
     45 
     46 }  // anonymous namespace
     47 
     48 BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
     49                                                      std::vector<RunTimeOperandInfo>& operands) {
     50     input_ = GetInput(operation, operands, kInputTensor);
     51 
     52     fw_input_to_input_weights_ =
     53             GetInput(operation, operands, kFwInputToInputWeightsTensor);  // optional
     54     fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor);
     55     fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor);
     56     fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor);
     57 
     58     fw_recurrent_to_input_weights_ =
     59             GetInput(operation, operands, kFwRecurrentToInputWeightsTensor);  // optional
     60     fw_recurrent_to_forget_weights_ =
     61             GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor);
     62     fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor);
     63     fw_recurrent_to_output_weights_ =
     64             GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor);
     65 
     66     fw_cell_to_input_weights_ =
     67             GetInput(operation, operands, kFwCellToInputWeightsTensor);  // optional
     68     fw_cell_to_forget_weights_ =
     69             GetInput(operation, operands, kFwCellToForgetWeightsTensor);  // optional
     70     fw_cell_to_output_weights_ =
     71             GetInput(operation, operands, kFwCellToOutputWeightsTensor);  // optional
     72 
     73     fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor);
     74     fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor);
     75     fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor);
     76     fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor);
     77 
     78     fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor);  // optional
     79     fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor);        // optional
     80 
     81     fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor);
     82     fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor);
     83 
     84     bw_input_to_input_weights_ =
     85             GetInput(operation, operands, kBwInputToInputWeightsTensor);  // optional
     86     bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor);
     87     bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor);
     88     bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor);
     89 
     90     bw_recurrent_to_input_weights_ =
     91             GetInput(operation, operands, kBwRecurrentToInputWeightsTensor);  // optional
     92     bw_recurrent_to_forget_weights_ =
     93             GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor);
     94     bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor);
     95     bw_recurrent_to_output_weights_ =
     96             GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor);
     97 
     98     bw_cell_to_input_weights_ =
     99             GetInput(operation, operands, kBwCellToInputWeightsTensor);  // optional
    100     bw_cell_to_forget_weights_ =
    101             GetInput(operation, operands, kBwCellToForgetWeightsTensor);  // optional
    102     bw_cell_to_output_weights_ =
    103             GetInput(operation, operands, kBwCellToOutputWeightsTensor);  // optional
    104 
    105     bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor);
    106     bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor);
    107     bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor);
    108     bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor);
    109 
    110     bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor);  // optional
    111     bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor);        // optional
    112 
    113     bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor);
    114     bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor);
    115 
    116     aux_input_ = GetInput(operation, operands, kAuxInputTensor);
    117     fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor);
    118     fw_aux_input_to_forget_weights_ =
    119             GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor);
    120     fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor);
    121     fw_aux_input_to_output_weights_ =
    122             GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor);
    123     bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor);
    124     bw_aux_input_to_forget_weights_ =
    125             GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor);
    126     bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor);
    127     bw_aux_input_to_output_weights_ =
    128             GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor);
    129 
    130     fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor);
    131     fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor);
    132     fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor);
    133     fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor);
    134     bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor);
    135     bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor);
    136     bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
    137     bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
    138 
    139     params_.activation = static_cast<TfLiteFusedActivation>(
    140             getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
    141     if (input_->type == OperandType::TENSOR_FLOAT32) {
    142         params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
    143         params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
    144     } else {
    145         params_.cell_clip = static_cast<float>(
    146                 getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
    147         params_.proj_clip = static_cast<float>(
    148                 getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
    149     }
    150     params_.merge_outputs = getScalarData<bool>(*GetInput(operation, operands, kMergeOutputsParam));
    151     params_.time_major = getScalarData<bool>(*GetInput(operation, operands, kTimeMajorParam));
    152     params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
    153 
    154     fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
    155     if (!params_.merge_outputs) {
    156         bw_output_ = GetOutput(operation, operands, kBwOutputTensor);
    157     }
    158 }
    159 
    160 bool BidirectionalSequenceLSTM::Prepare(const Operation& operation,
    161                                         std::vector<RunTimeOperandInfo>& operands,
    162                                         Shape* fwOutputShape, Shape* bwOutputShape) {
    163     // Inferring batch size, number of outputs and number of cells from the
    164     // input tensors.
    165     NN_CHECK(NumDimensions(input_) == 3);
    166     const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1);
    167     const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0);
    168     const uint32_t n_input = SizeOfDimension(input_, 2);
    169 
    170     const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0);
    171     NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2);
    172     NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_input);
    173 
    174     NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2);
    175     NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell);
    176     const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
    177 
    178     // Check that input tensor dimensions matches with each other.
    179     if (!LSTMCell::CheckInputTensorDimensions(
    180                 input_, fw_input_to_input_weights_, fw_input_to_forget_weights_,
    181                 fw_input_to_cell_weights_, fw_input_to_output_weights_,
    182                 fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_,
    183                 fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_,
    184                 fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_,
    185                 fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_,
    186                 fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_,
    187                 fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_,
    188                 fw_output_layer_norm_weights_, n_input, n_fw_output, n_fw_cell, &params_)) {
    189         return false;
    190     }
    191 
    192     const bool aux_inputs_all_or_none =
    193             (!IsNullInput(aux_input_) && !IsNullInput(fw_aux_input_to_cell_weights_) &&
    194              !IsNullInput(fw_aux_input_to_forget_weights_) &&
    195              !IsNullInput(fw_aux_input_to_output_weights_) &&
    196              !IsNullInput(bw_aux_input_to_cell_weights_) &&
    197              !IsNullInput(bw_aux_input_to_forget_weights_) &&
    198              !IsNullInput(bw_aux_input_to_output_weights_)) ||
    199             (IsNullInput(fw_aux_input_to_cell_weights_) &&
    200              IsNullInput(fw_aux_input_to_forget_weights_) &&
    201              IsNullInput(fw_aux_input_to_output_weights_) &&
    202              IsNullInput(bw_aux_input_to_cell_weights_) &&
    203              IsNullInput(bw_aux_input_to_forget_weights_) &&
    204              IsNullInput(bw_aux_input_to_output_weights_));
    205     NN_CHECK(aux_inputs_all_or_none);
    206     if (!IsNullInput(aux_input_)) {
    207         // Check that aux_input has the same dimensions (except last) as the input.
    208         NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]);
    209         NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]);
    210     }
    211 
    212     const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0);
    213     NN_CHECK_EQ(NumDimensions(bw_input_to_output_weights_), 2);
    214     NN_CHECK_EQ(SizeOfDimension(bw_input_to_output_weights_, 1), n_input);
    215 
    216     NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2);
    217     NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell);
    218     const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
    219 
    220     const Shape& inputShape = input_->shape();
    221     fwOutputShape->type = inputShape.type;
    222     fwOutputShape->offset = inputShape.offset;
    223     fwOutputShape->scale = inputShape.scale;
    224     fwOutputShape->dimensions.resize(3);
    225     fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
    226     fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
    227     fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output;
    228 
    229     // Check that input tensor dimensions matches with each other.
    230     if (!LSTMCell::CheckInputTensorDimensions(
    231                 input_, bw_input_to_input_weights_, bw_input_to_forget_weights_,
    232                 bw_input_to_cell_weights_, bw_input_to_output_weights_,
    233                 bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_,
    234                 bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_,
    235                 bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_,
    236                 bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_,
    237                 bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_,
    238                 bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_,
    239                 bw_output_layer_norm_weights_, n_input, n_bw_output, n_bw_cell, &params_)) {
    240         return false;
    241     }
    242 
    243     if (!params_.merge_outputs) {
    244         bwOutputShape->type = inputShape.type;
    245         bwOutputShape->offset = inputShape.offset;
    246         bwOutputShape->scale = inputShape.scale;
    247         bwOutputShape->dimensions.resize(3);
    248         bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
    249         bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
    250         bwOutputShape->dimensions[2] = n_bw_output;
    251     }
    252 
    253     if (params_.use_cifg) {
    254         fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3};
    255         bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3};
    256     } else {
    257         fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4};
    258         bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4};
    259     }
    260     fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type;
    261     fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset;
    262     fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale;
    263 
    264     return true;
    265 }
    266 
    267 bool BidirectionalSequenceLSTM::Eval() {
    268     const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
    269     const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
    270     std::vector<uint32_t> fw_output_dims = input_->shape().dimensions;
    271     fw_output_dims[2] = n_fw_output;
    272     std::vector<uint32_t> bw_output_dims = fw_output_dims;
    273     bw_output_dims[2] = n_bw_output;
    274     const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2];
    275     const uint32_t n_output_elements =
    276             fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]);
    277 
    278     switch (input_->type) {
    279         case OperandType::TENSOR_FLOAT32: {
    280             std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
    281             const bool kForwardSequence = true;
    282             LSTMCell::LSTMEvalFloat32(
    283                     params_, GetBuffer<const float>(input_), input_->shape(),
    284                     GetBuffer<const float>(fw_input_to_input_weights_),
    285                     GetBuffer<const float>(fw_input_to_forget_weights_),
    286                     GetBuffer<const float>(fw_input_to_cell_weights_),
    287                     GetBuffer<const float>(fw_input_to_output_weights_),
    288                     fw_input_to_output_weights_->shape(),
    289                     GetBuffer<const float>(fw_recurrent_to_input_weights_),
    290                     GetBuffer<const float>(fw_recurrent_to_forget_weights_),
    291                     GetBuffer<const float>(fw_recurrent_to_cell_weights_),
    292                     GetBuffer<const float>(fw_recurrent_to_output_weights_),
    293                     fw_recurrent_to_output_weights_->shape(),
    294                     GetBuffer<const float>(fw_cell_to_input_weights_),
    295                     GetBuffer<const float>(fw_cell_to_forget_weights_),
    296                     GetBuffer<const float>(fw_cell_to_output_weights_),
    297                     GetOptionalBuffer<const float>(aux_input_),
    298                     GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_),
    299                     GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_),
    300                     GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_),
    301                     GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_),
    302                     GetBuffer<const float>(fw_input_gate_bias_),
    303                     GetBuffer<const float>(fw_forget_gate_bias_),
    304                     GetBuffer<const float>(fw_cell_bias_),
    305                     GetBuffer<const float>(fw_output_gate_bias_),
    306                     GetBuffer<const float>(fw_projection_weights_),
    307                     GetBuffer<const float>(fw_projection_bias_),
    308                     GetBuffer<const float>(fw_activation_state_),
    309                     GetBuffer<const float>(fw_cell_state_),
    310                     GetOptionalBuffer<const float>(fw_input_layer_norm_weights_),
    311                     GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_),
    312                     GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_),
    313                     GetOptionalBuffer<const float>(fw_output_layer_norm_weights_),
    314                     GetBuffer<float>(fw_activation_state_), GetBuffer<float>(fw_cell_state_),
    315                     GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
    316                     kForwardSequence);
    317 
    318             std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
    319             const bool kBackwardSequence = false;
    320             LSTMCell::LSTMEvalFloat32(
    321                     params_, GetBuffer<const float>(input_), input_->shape(),
    322                     GetBuffer<const float>(bw_input_to_input_weights_),
    323                     GetBuffer<const float>(bw_input_to_forget_weights_),
    324                     GetBuffer<const float>(bw_input_to_cell_weights_),
    325                     GetBuffer<const float>(bw_input_to_output_weights_),
    326                     bw_input_to_output_weights_->shape(),
    327                     GetBuffer<const float>(bw_recurrent_to_input_weights_),
    328                     GetBuffer<const float>(bw_recurrent_to_forget_weights_),
    329                     GetBuffer<const float>(bw_recurrent_to_cell_weights_),
    330                     GetBuffer<const float>(bw_recurrent_to_output_weights_),
    331                     bw_recurrent_to_output_weights_->shape(),
    332                     GetBuffer<const float>(bw_cell_to_input_weights_),
    333                     GetBuffer<const float>(bw_cell_to_forget_weights_),
    334                     GetBuffer<const float>(bw_cell_to_output_weights_),
    335                     GetOptionalBuffer<const float>(aux_input_),
    336                     GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_),
    337                     GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_),
    338                     GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_),
    339                     GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_),
    340                     GetBuffer<const float>(bw_input_gate_bias_),
    341                     GetBuffer<const float>(bw_forget_gate_bias_),
    342                     GetBuffer<const float>(bw_cell_bias_),
    343                     GetBuffer<const float>(bw_output_gate_bias_),
    344                     GetBuffer<const float>(bw_projection_weights_),
    345                     GetBuffer<const float>(bw_projection_bias_),
    346                     GetBuffer<const float>(bw_activation_state_),
    347                     GetBuffer<const float>(bw_cell_state_),
    348                     GetOptionalBuffer<const float>(bw_input_layer_norm_weights_),
    349                     GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_),
    350                     GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_),
    351                     GetOptionalBuffer<const float>(bw_output_layer_norm_weights_),
    352                     GetBuffer<float>(bw_activation_state_), GetBuffer<float>(bw_cell_state_),
    353                     params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements
    354                                           : GetBuffer<float>(bw_output_),
    355                     bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
    356             if (params_.merge_outputs) {
    357                 std::vector<float> temp(n_output_elements);
    358                 mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims,
    359                                     GetBuffer<float>(fw_output_) + n_fw_output_elements,
    360                                     bw_output_dims, temp.data());
    361                 std::copy(temp.data(), temp.data() + n_output_elements,
    362                           GetBuffer<float>(fw_output_));
    363             }
    364         } break;
    365         case OperandType::TENSOR_FLOAT16: {
    366             std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
    367             const bool kForwardSequence = true;
    368             LSTMCell::LSTMEvalFloat16(
    369                     params_, GetBuffer<const _Float16>(input_), input_->shape(),
    370                     GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_),
    371                     GetBuffer<const _Float16>(fw_input_to_forget_weights_),
    372                     GetBuffer<const _Float16>(fw_input_to_cell_weights_),
    373                     GetBuffer<const _Float16>(fw_input_to_output_weights_),
    374                     fw_input_to_output_weights_->shape(),
    375                     GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_),
    376                     GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_),
    377                     GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_),
    378                     GetBuffer<const _Float16>(fw_recurrent_to_output_weights_),
    379                     fw_recurrent_to_output_weights_->shape(),
    380                     GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_),
    381                     GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_),
    382                     GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_),
    383                     GetOptionalBuffer<const _Float16>(aux_input_),
    384                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_),
    385                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_),
    386                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_),
    387                     GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_),
    388                     GetOptionalBuffer<const _Float16>(fw_input_gate_bias_),
    389                     GetBuffer<const _Float16>(fw_forget_gate_bias_),
    390                     GetBuffer<const _Float16>(fw_cell_bias_),
    391                     GetBuffer<const _Float16>(fw_output_gate_bias_),
    392                     GetOptionalBuffer<const _Float16>(fw_projection_weights_),
    393                     GetOptionalBuffer<const _Float16>(fw_projection_bias_),
    394                     GetBuffer<const _Float16>(fw_activation_state_),
    395                     GetBuffer<const _Float16>(fw_cell_state_),
    396                     GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_),
    397                     GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_),
    398                     GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_),
    399                     GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_),
    400                     GetBuffer<_Float16>(fw_activation_state_), GetBuffer<_Float16>(fw_cell_state_),
    401                     GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
    402                     kForwardSequence);
    403 
    404             std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
    405             const bool kBackwardSequence = false;
    406             LSTMCell::LSTMEvalFloat16(
    407                     params_, GetBuffer<const _Float16>(input_), input_->shape(),
    408                     GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_),
    409                     GetBuffer<const _Float16>(bw_input_to_forget_weights_),
    410                     GetBuffer<const _Float16>(bw_input_to_cell_weights_),
    411                     GetBuffer<const _Float16>(bw_input_to_output_weights_),
    412                     bw_input_to_output_weights_->shape(),
    413                     GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_),
    414                     GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_),
    415                     GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_),
    416                     GetBuffer<const _Float16>(bw_recurrent_to_output_weights_),
    417                     bw_recurrent_to_output_weights_->shape(),
    418                     GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_),
    419                     GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_),
    420                     GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_),
    421                     GetOptionalBuffer<const _Float16>(aux_input_),
    422                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_),
    423                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_),
    424                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_),
    425                     GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_),
    426                     GetOptionalBuffer<const _Float16>(bw_input_gate_bias_),
    427                     GetBuffer<const _Float16>(bw_forget_gate_bias_),
    428                     GetBuffer<const _Float16>(bw_cell_bias_),
    429                     GetBuffer<const _Float16>(bw_output_gate_bias_),
    430                     GetOptionalBuffer<const _Float16>(bw_projection_weights_),
    431                     GetOptionalBuffer<const _Float16>(bw_projection_bias_),
    432                     GetBuffer<const _Float16>(bw_activation_state_),
    433                     GetBuffer<const _Float16>(bw_cell_state_),
    434                     GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_),
    435                     GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_),
    436                     GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_),
    437                     GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_),
    438                     GetBuffer<_Float16>(bw_activation_state_), GetBuffer<_Float16>(bw_cell_state_),
    439                     params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements
    440                                           : GetBuffer<_Float16>(bw_output_),
    441                     bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
    442             if (params_.merge_outputs) {
    443                 std::vector<_Float16> temp(n_output_elements);
    444                 mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims,
    445                                     GetBuffer<_Float16>(fw_output_) + n_fw_output_elements,
    446                                     bw_output_dims, temp.data());
    447                 std::copy(temp.data(), temp.data() + n_output_elements,
    448                           GetBuffer<_Float16>(fw_output_));
    449             }
    450         } break;
    451         default: {
    452             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
    453             return false;
    454         }
    455     }
    456     return true;
    457 }
    458 
    459 }  // namespace nn
    460 }  // namespace android
    461