Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "QuantizedLSTM.h"
     18 
     19 #include "CpuExecutor.h"
     20 #include "CpuOperationUtils.h"
     21 
     22 #include "Tracing.h"
     23 
     24 #include "public/gemmlowp.h"
     25 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
     26 
     27 namespace android {
     28 namespace nn {
     29 
     30 namespace {
     31 
     32 template <typename T>
     33 inline T* GetBuffer(RunTimeOperandInfo* operand) {
     34     return reinterpret_cast<T*>(operand->buffer);
     35 }
     36 
     37 template <typename T>
     38 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
     39     return reinterpret_cast<const T*>(operand->buffer);
     40 }
     41 
     42 using tflite::Dims;
     43 
     44 // The function below is taken from TF Lite implementation in order to decouple
     45 // NN API from TF Lite dependency. Original function, with a description of its
     46 // parameters and types can be found by this link:
     47 // https://github.com/tensorflow/tensorflow/blob/0d697e5fc4c05c699eea0764364104ea500ccc68/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h#L1926
     48 //
     49 // clang-format off
     50 template <int StateIntegerBits>
     51 void quantizedLstmStep(const uint8_t* input_data_uint8, const Dims<4>& input_dims,
     52                        const uint8_t* prev_activ_data_uint8,
     53                        const Dims<4>& prev_activ_dims, const uint8_t* weights_data_uint8,
     54                        const Dims<4>& weights_dims, const int32_t* bias_data_int32,
     55                        const Dims<4>& bias_dims, const int16_t* prevCellState_data_int16,
     56                        const Dims<4>& prevCellState_dims, int16_t* output_state_data_int16,
     57                        const Dims<4>& output_state_dims, uint8_t* output_activ_data_uint8,
     58                        const Dims<4>& output_activ_dims, uint8_t* concat_temp_data_uint8,
     59                        const Dims<4>& concat_temp_dims, int16_t* activ_temp_data_int16,
     60                        const Dims<4>& activ_temp_dims, int32_t weights_zero_point,
     61                        int32_t accum_multiplier, int accum_shift) {
     62   // Gather dimensions information, and perform consistency checks.
     63   const int outer_size =
     64       MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prevCellState_dims,
     65                               output_state_dims, output_activ_dims);
     66   TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
     67   TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
     68   const int input_depth = ArraySize(input_dims, 0);
     69   const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
     70   const int total_input_depth = prev_activ_depth + input_depth;
     71   TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
     72   TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
     73                   1);
     74   const int intern_activ_depth =
     75       MatchingArraySize(weights_dims, 1, bias_dims, 0);
     76   TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
     77   const int output_depth =
     78       MatchingArraySize(prevCellState_dims, 0, prev_activ_dims, 0,
     79                         output_state_dims, 0, output_activ_dims, 0);
     80   TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
     81   const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
     82   const int fc_output_depth =
     83       MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
     84   const int fc_accum_depth = ArraySize(weights_dims, 0);
     85   TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
     86 
     87   // Depth-concatenate prev_activ and input data together.
     88   uint8_t const* concat_input_arrays_data[2] = {input_data_uint8,
     89                                                 prev_activ_data_uint8};
     90   Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
     91   tflite::reference_ops::Concatenation<tflite::FusedActivationFunctionType::kNone, uint8_t>(
     92       0, concat_input_arrays_data, concat_input_arrays_dims, 2,
     93       concat_temp_data_uint8, concat_temp_dims);
     94 
     95   // Implementation of the fully connected node inside the LSTM cell.
     96   // The operands are 8-bit integers, the accumulators are internally 32bit
     97   // integers, and the output is 16-bit fixed-point with 3 integer bits so
     98   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
     99   // is explained in the function comment above.
    100   for (int b = 0; b < fc_batches; ++b) {
    101     for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
    102       // Internal accumulation.
    103       // Initialize accumulator with the bias-value.
    104       int32_t accum = bias_data_int32[out_c];
    105       // Accumulation loop.
    106       for (int d = 0; d < fc_accum_depth; ++d) {
    107         int16_t input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
    108         int16_t weights_val =
    109             weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
    110         accum += input_val * weights_val;
    111       }
    112       // Down-scale the final int32 accumulator to the scale used by our
    113       // (16-bit, using 3 integer bits) fixed-point format. The quantized
    114       // multiplier and shift here have been pre-computed offline
    115       // (e.g. by toco).
    116       accum =
    117           tflite::MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
    118       // Saturate, cast to int16, and store to the temporary activations array.
    119       accum = std::max(-32768, std::min(32767, accum));
    120       activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
    121     }
    122   }
    123 
    124   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
    125   // and muls, all done in 16-bit fixed-point.
    126   for (int b = 0; b < outer_size; ++b) {
    127     for (int c = 0; c < output_depth; ++c) {
    128       // Define the fixed-point data types that we will use here. All use
    129       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
    130       // They only differ by the number of integral vs. fractional bits,
    131       // determining the range of values that they can represent.
    132       //
    133       // F0 uses 0 integer bits, range [-1, 1].
    134       // This is the return type of math functions such as tanh, logistic,
    135       // whose range is in [-1, 1].
    136       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
    137       // F3 uses 3 integer bits, range [-8, 8].
    138       // This is the range of the previous fully-connected node's output,
    139       // which is our input here.
    140       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
    141       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
    142       // 2^StateIntegerBits]. It's used to represent the internal state, whose
    143       // number of integer bits is currently dictated by the model. See comment
    144       // on the StateIntegerBits template parameter above.
    145       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
    146       // Implementation of input gate, using fixed-point logistic function.
    147       F3 input_gate_input = F3::FromRaw(
    148           activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
    149       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
    150       // Implementation of input modulation gate, using fixed-point tanh
    151       // function.
    152       F3 input_modulation_gate_input = F3::FromRaw(
    153           activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
    154       F0 input_modulation_gate_output =
    155           gemmlowp::tanh(input_modulation_gate_input);
    156       // Implementation of forget gate, using fixed-point logistic function.
    157       F3 forget_gate_input = F3::FromRaw(
    158           activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
    159       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
    160       // Implementation of output gate, using fixed-point logistic function.
    161       F3 output_gate_input = F3::FromRaw(
    162           activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
    163       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
    164       // Implementation of internal multiplication nodes, still in fixed-point.
    165       F0 input_times_input_modulation =
    166           input_gate_output * input_modulation_gate_output;
    167       FS prevCellState = FS::FromRaw(prevCellState_data_int16[b * output_depth + c]);
    168       FS prevCellState_times_forget_state = forget_gate_output * prevCellState;
    169       // Implementation of internal addition node, saturating.
    170       FS new_state = gemmlowp::SaturatingAdd(
    171           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
    172           prevCellState_times_forget_state);
    173       // Implementation of last internal Tanh node, still in fixed-point.
    174       // Since a Tanh fixed-point implementation is specialized for a given
    175       // number or integer bits, and each specialization can have a substantial
    176       // code size, and we already used above a Tanh on an input with 3 integer
    177       // bits, and per the table in the above function comment there is no
    178       // significant accuracy to be lost by clamping to [-8, +8] for a
    179       // 3-integer-bits representation, let us just do that. This helps people
    180       // porting this to targets where code footprint must be minimized.
    181       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
    182       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
    183       // Store the new internal state back to memory, as 16-bit integers.
    184       // Note: here we store the original value with StateIntegerBits, not
    185       // the rescaled 3-integer-bits value fed to tanh.
    186       output_state_data_int16[b * output_depth + c] = new_state.raw();
    187       // Down-scale the output activations to 8-bit integers, saturating,
    188       // and store back to memory.
    189       int16_t rescaled_output_activ =
    190           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
    191       int16_t clamped_output_activ =
    192           std::max<int16_t>(-128, std::min<int16_t>(127, rescaled_output_activ));
    193       output_activ_data_uint8[b * output_depth + c] =
    194           128 + clamped_output_activ;
    195     }
    196   }
    197 }
    198 // clang-format on
    199 
    200 // The function assigns a 2D matrix to a submatrix of the weights at a given row
    201 // and column offsets.
    202 void assignWeightsSubmatrix(const RunTimeOperandInfo* submatrix, const int32_t offset_row,
    203                             const int32_t offset_column, const std::vector<uint32_t>& weightsDims,
    204                             uint8_t* weights) {
    205     const uint8_t* submatrixValues = GetBuffer<uint8_t>(submatrix);
    206     const std::vector<uint32_t> submatrixDims = submatrix->shape().dimensions;
    207     for (uint32_t i = 0; i < submatrixDims[0] * submatrixDims[1]; ++i) {
    208         const uint32_t row = i / submatrixDims[1];
    209         const uint32_t column = i % submatrixDims[1];
    210         weights[(row + offset_row) * weightsDims[1] + column + offset_column] = submatrixValues[i];
    211     }
    212 }
    213 
    214 }  // namespace
    215 
    216 QuantizedLSTMCell::QuantizedLSTMCell(const Operation& operation,
    217                                      std::vector<RunTimeOperandInfo>& operands) {
    218     input_ = GetInput(operation, operands, kInputTensor);
    219 
    220     inputToInputWeights_ = GetInput(operation, operands, kInputToInputWeightsTensor);
    221     inputToForgetWeights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
    222     inputToCellWeights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
    223     inputToOutputWeights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
    224 
    225     recurrentToInputWeights_ = GetInput(operation, operands, kRecurrentToInputWeightsTensor);
    226     recurrentToForgetWeights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
    227     recurrentToCellWeights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
    228     recurrentToOutputWeights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
    229 
    230     inputGateBias_ = GetInput(operation, operands, kInputGateBiasTensor);
    231     forgetGateBias_ = GetInput(operation, operands, kForgetGateBiasTensor);
    232     cellGateBias_ = GetInput(operation, operands, kCellGateBiasTensor);
    233     outputGateBias_ = GetInput(operation, operands, kOutputGateBiasTensor);
    234 
    235     prevCellState_ = GetInput(operation, operands, kPrevCellStateTensor);
    236     prevOutput_ = GetInput(operation, operands, kPrevOutputTensor);
    237 
    238     cellStateOut_ = GetOutput(operation, operands, kCellStateOutTensor);
    239     output_ = GetOutput(operation, operands, kOutputTensor);
    240 }
    241 
    242 bool QuantizedLSTMCell::prepare(const Operation& operation,
    243                                 std::vector<RunTimeOperandInfo>& operands, Shape* cellStateOutShape,
    244                                 Shape* outputShape) {
    245     auto input = GetInput(operation, operands, kInputTensor);
    246     NN_RET_CHECK_EQ(NumDimensions(input), 2);
    247     NN_RET_CHECK_EQ(input->scale, 1. / 128.0);
    248     NN_RET_CHECK_EQ(input->zeroPoint, 128);
    249     const uint32_t numBatches = SizeOfDimension(input, 0);
    250     const uint32_t inputSize = SizeOfDimension(input, 1);
    251 
    252     auto prevOutput = GetInput(operation, operands, kPrevOutputTensor);
    253     NN_RET_CHECK_EQ(NumDimensions(prevOutput), 2);
    254     NN_RET_CHECK_EQ(SizeOfDimension(prevOutput, 0), numBatches);
    255     NN_RET_CHECK_EQ(prevOutput->scale, 1. / 128.0);
    256     NN_RET_CHECK_EQ(prevOutput->zeroPoint, 128);
    257     const uint32_t outputSize = SizeOfDimension(prevOutput, 1);
    258 
    259     auto inputToInputWeights = GetInput(operation, operands, kInputToInputWeightsTensor);
    260     const float weightsScale = inputToInputWeights->scale;
    261     NN_RET_CHECK(weightsScale != 0);
    262     const float weightsZeroPoint = inputToInputWeights->zeroPoint;
    263 
    264     auto checkWeightsShape = [&](const RunTimeOperandInfo* weights, uint32_t columns) -> bool {
    265         NN_RET_CHECK_EQ(NumDimensions(weights), 2);
    266         NN_RET_CHECK_EQ(SizeOfDimension(weights, 0), outputSize);
    267         NN_RET_CHECK_EQ(SizeOfDimension(weights, 1), columns);
    268         NN_RET_CHECK_EQ(weights->scale, weightsScale);
    269         NN_RET_CHECK_EQ(weights->zeroPoint, weightsZeroPoint);
    270         return true;
    271     };
    272 
    273     auto inputToForgetWeights = GetInput(operation, operands, kInputToForgetWeightsTensor);
    274     auto inputToCellWeights = GetInput(operation, operands, kInputToCellWeightsTensor);
    275     auto inputToOutputWeights = GetInput(operation, operands, kInputToOutputWeightsTensor);
    276     NN_RET_CHECK(checkWeightsShape(inputToInputWeights, inputSize));
    277     NN_RET_CHECK(checkWeightsShape(inputToForgetWeights, inputSize));
    278     NN_RET_CHECK(checkWeightsShape(inputToCellWeights, inputSize));
    279     NN_RET_CHECK(checkWeightsShape(inputToOutputWeights, inputSize));
    280 
    281     auto recurrentToInputWeights = GetInput(operation, operands, kRecurrentToInputWeightsTensor);
    282     auto recurrentToForgetWeights = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
    283     auto recurrentToCellWeights = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
    284     auto recurrentToOutputWeights = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
    285     NN_RET_CHECK(checkWeightsShape(recurrentToInputWeights, outputSize));
    286     NN_RET_CHECK(checkWeightsShape(recurrentToForgetWeights, outputSize));
    287     NN_RET_CHECK(checkWeightsShape(recurrentToCellWeights, outputSize));
    288     NN_RET_CHECK(checkWeightsShape(recurrentToOutputWeights, outputSize));
    289 
    290     auto inputGateBias = GetInput(operation, operands, kInputGateBiasTensor);
    291     const float biasScale = inputGateBias->scale;
    292     NN_RET_CHECK_EQ(biasScale, weightsScale / 128.0);
    293     const float biasZeroPoint = inputGateBias->zeroPoint;
    294     NN_RET_CHECK_EQ(biasZeroPoint, 0);
    295 
    296     auto checkBiasShape = [&](const RunTimeOperandInfo* bias) -> bool {
    297         NN_RET_CHECK_EQ(NumDimensions(bias), 1);
    298         NN_RET_CHECK_EQ(SizeOfDimension(bias, 0), outputSize);
    299         NN_RET_CHECK_EQ(bias->scale, biasScale);
    300         NN_RET_CHECK_EQ(bias->zeroPoint, biasZeroPoint);
    301         return true;
    302     };
    303 
    304     auto forgetGateBias = GetInput(operation, operands, kForgetGateBiasTensor);
    305     auto cellGateBias = GetInput(operation, operands, kCellGateBiasTensor);
    306     auto outputGateBias = GetInput(operation, operands, kOutputGateBiasTensor);
    307     NN_RET_CHECK(checkBiasShape(inputGateBias));
    308     NN_RET_CHECK(checkBiasShape(forgetGateBias));
    309     NN_RET_CHECK(checkBiasShape(cellGateBias));
    310     NN_RET_CHECK(checkBiasShape(outputGateBias));
    311 
    312     auto prevCellState = GetInput(operation, operands, kPrevCellStateTensor);
    313     NN_CHECK_EQ(NumDimensions(prevCellState), 2);
    314     NN_CHECK_EQ(SizeOfDimension(prevCellState, 0), numBatches);
    315     NN_CHECK_EQ(SizeOfDimension(prevCellState, 1), outputSize);
    316     NN_CHECK_EQ(prevCellState->zeroPoint, 0);
    317     // Cell state range for quantized LSTM is a function of StateIntegerBits and
    318     // can be calculated as:
    319     // [-2^StateIntegerBits, 2^StateIntegerBits * 32767/32768].
    320     // Therefore, for a fixed StateIntegerBits parameter, cell state scale is
    321     // equal to 2^StateIntegerBits * 2^(-15) = 2^(StateIntegerBits - 15) and
    322     // therefore:
    323     // StateIntegerBits = log2(cell state scale) + 15
    324     int stateScaleLog2Rounded;
    325     NN_CHECK(tflite::CheckedLog2(prevCellState->scale, &stateScaleLog2Rounded));
    326     const int stateIntegerBits = 15 + stateScaleLog2Rounded;
    327     // We only support StateIntegerBits == 4
    328     NN_CHECK(stateIntegerBits == 4);
    329 
    330     *cellStateOutShape = prevCellState->shape();
    331     *outputShape = prevOutput->shape();
    332     return true;
    333 }
    334 
    335 // The function contatenates 8 input weight matrices into one. Resulting matrix
    336 // has a shape [4 * outputSize, outputSize + inputSize]. The matrix is
    337 // constructed as follows:
    338 // +-----------------------------------+
    339 // | recurrentToInput  | inputToInput  |
    340 // |-------------------+---------------|
    341 // | recurrentToCell   | inputToCell   |
    342 // |-------------------+---------------|
    343 // | recurrentToForget | inputToForget |
    344 // |-------------------+---------------|
    345 // | recurrentToOutput | inputToOutput |
    346 // +-----------------------------------+
    347 void QuantizedLSTMCell::concatenateWeights(const std::vector<uint32_t>& weightsDims,
    348                                            uint8_t* weights) {
    349     const int outputSize = SizeOfDimension(inputToInputWeights_, 0);
    350 
    351     assignWeightsSubmatrix(inputToInputWeights_, 0 * outputSize, outputSize, weightsDims, weights);
    352     assignWeightsSubmatrix(inputToCellWeights_, 1 * outputSize, outputSize, weightsDims, weights);
    353     assignWeightsSubmatrix(inputToForgetWeights_, 2 * outputSize, outputSize, weightsDims, weights);
    354     assignWeightsSubmatrix(inputToOutputWeights_, 3 * outputSize, outputSize, weightsDims, weights);
    355     assignWeightsSubmatrix(recurrentToInputWeights_, 0 * outputSize, 0, weightsDims, weights);
    356     assignWeightsSubmatrix(recurrentToCellWeights_, 1 * outputSize, 0, weightsDims, weights);
    357     assignWeightsSubmatrix(recurrentToForgetWeights_, 2 * outputSize, 0, weightsDims, weights);
    358     assignWeightsSubmatrix(recurrentToOutputWeights_, 3 * outputSize, 0, weightsDims, weights);
    359 }
    360 
    361 // The function concatenate four bias vectors of shape [outputSize] into one
    362 // vector of shape [4 * outputSize].
    363 void QuantizedLSTMCell::concatenateBiases(uint32_t outputSize, int32_t* bias) {
    364     memcpy(bias + 0 * outputSize, GetBuffer<int32_t>(inputGateBias_), sizeof(int32_t) * outputSize);
    365     memcpy(bias + 1 * outputSize, GetBuffer<int32_t>(cellGateBias_), sizeof(int32_t) * outputSize);
    366     memcpy(bias + 2 * outputSize, GetBuffer<int32_t>(forgetGateBias_),
    367            sizeof(int32_t) * outputSize);
    368     memcpy(bias + 3 * outputSize, GetBuffer<int32_t>(outputGateBias_),
    369            sizeof(int32_t) * outputSize);
    370 }
    371 
    372 bool QuantizedLSTMCell::eval() {
    373     NNTRACE_COMP("QuantizedLSTM::eval");
    374 
    375     Shape weightsShape;
    376     weightsShape.dimensions = {4 * SizeOfDimension(prevOutput_, 1),
    377                                SizeOfDimension(input_, 1) + SizeOfDimension(prevOutput_, 1)};
    378     std::vector<uint8_t> weights(getNumberOfElements(weightsShape));
    379     concatenateWeights(weightsShape.dimensions, weights.data());
    380 
    381     Shape biasShape;
    382     biasShape.dimensions = {getSizeOfDimension(weightsShape, 0)};
    383     std::vector<int32_t> bias(getNumberOfElements(biasShape));
    384     concatenateBiases(SizeOfDimension(prevOutput_, 1), bias.data());
    385 
    386     Shape concatTempShape;
    387     concatTempShape.dimensions = {SizeOfDimension(input_, 0), getSizeOfDimension(weightsShape, 1)};
    388 
    389     Shape activationTempShape;
    390     activationTempShape.dimensions = {SizeOfDimension(input_, 0),
    391                                       getSizeOfDimension(weightsShape, 0)};
    392 
    393     std::vector<uint8_t> concatTemp(getNumberOfElements(concatTempShape));
    394     std::vector<int16_t> activationTemp(getNumberOfElements(activationTempShape));
    395 
    396     // From https://arxiv.org/pdf/1712.05877, for a fully-connected layer,
    397     // accumulator multiplier is equal to:
    398     // (input scale) * (weights scale) / (fully-connected output scale)
    399     // In our case fully-connected output scale is fixed and equal to
    400     // 2^(-12) (See LSTMCell definition in TF Lite for more details on that).
    401     // But bias scale is set to (input scale) * (weights scale) (also from the
    402     // paper), so we can multiply it to an inverse of the fc-output scale to get
    403     // the multiplier value:
    404     double realAccumMultiplier = 4096 * inputGateBias_->scale;
    405     int32_t accumMultiplier;
    406     int accumShift;
    407     tflite::QuantizeMultiplier(realAccumMultiplier, &accumMultiplier, &accumShift);
    408     quantizedLstmStep<4>(
    409             // Inputs.
    410             GetBuffer<const uint8_t>(input_), convertShapeToDims(input_->shape()),
    411             GetBuffer<const uint8_t>(prevOutput_), convertShapeToDims(prevOutput_->shape()),
    412             weights.data(), convertShapeToDims(weightsShape), bias.data(),
    413             convertShapeToDims(biasShape), GetBuffer<const int16_t>(prevCellState_),
    414             convertShapeToDims(prevCellState_->shape()),
    415             // Outputs.
    416             GetBuffer<int16_t>(cellStateOut_), convertShapeToDims(cellStateOut_->shape()),
    417             GetBuffer<uint8_t>(output_), convertShapeToDims(output_->shape()), concatTemp.data(),
    418             convertShapeToDims(concatTempShape), activationTemp.data(),
    419             convertShapeToDims(activationTempShape), inputToInputWeights_->zeroPoint,
    420             accumMultiplier, accumShift);
    421     return true;
    422 }
    423 
    424 }  // namespace nn
    425 }  // namespace android
    426