1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "QuantizedLSTM.h" 18 19 #include "CpuExecutor.h" 20 #include "CpuOperationUtils.h" 21 22 #include "Tracing.h" 23 24 #include "public/gemmlowp.h" 25 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h" 26 27 namespace android { 28 namespace nn { 29 30 namespace { 31 32 template <typename T> 33 inline T* GetBuffer(RunTimeOperandInfo* operand) { 34 return reinterpret_cast<T*>(operand->buffer); 35 } 36 37 template <typename T> 38 inline const T* GetBuffer(const RunTimeOperandInfo* operand) { 39 return reinterpret_cast<const T*>(operand->buffer); 40 } 41 42 using tflite::Dims; 43 44 // The function below is taken from TF Lite implementation in order to decouple 45 // NN API from TF Lite dependency. Original function, with a description of its 46 // parameters and types can be found by this link: 47 // https://github.com/tensorflow/tensorflow/blob/0d697e5fc4c05c699eea0764364104ea500ccc68/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h#L1926 48 // 49 // clang-format off 50 template <int StateIntegerBits> 51 void quantizedLstmStep(const uint8_t* input_data_uint8, const Dims<4>& input_dims, 52 const uint8_t* prev_activ_data_uint8, 53 const Dims<4>& prev_activ_dims, const uint8_t* weights_data_uint8, 54 const Dims<4>& weights_dims, const int32_t* bias_data_int32, 55 const Dims<4>& bias_dims, const int16_t* prevCellState_data_int16, 56 const Dims<4>& prevCellState_dims, int16_t* output_state_data_int16, 57 const Dims<4>& output_state_dims, uint8_t* output_activ_data_uint8, 58 const Dims<4>& output_activ_dims, uint8_t* concat_temp_data_uint8, 59 const Dims<4>& concat_temp_dims, int16_t* activ_temp_data_int16, 60 const Dims<4>& activ_temp_dims, int32_t weights_zero_point, 61 int32_t accum_multiplier, int accum_shift) { 62 // Gather dimensions information, and perform consistency checks. 63 const int outer_size = 64 MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prevCellState_dims, 65 output_state_dims, output_activ_dims); 66 TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); 67 TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); 68 const int input_depth = ArraySize(input_dims, 0); 69 const int prev_activ_depth = ArraySize(prev_activ_dims, 0); 70 const int total_input_depth = prev_activ_depth + input_depth; 71 TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); 72 TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), 73 1); 74 const int intern_activ_depth = 75 MatchingArraySize(weights_dims, 1, bias_dims, 0); 76 TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); 77 const int output_depth = 78 MatchingArraySize(prevCellState_dims, 0, prev_activ_dims, 0, 79 output_state_dims, 0, output_activ_dims, 0); 80 TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); 81 const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0); 82 const int fc_output_depth = 83 MatchingArraySize(weights_dims, 1, activ_temp_dims, 0); 84 const int fc_accum_depth = ArraySize(weights_dims, 0); 85 TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth); 86 87 // Depth-concatenate prev_activ and input data together. 88 uint8_t const* concat_input_arrays_data[2] = {input_data_uint8, 89 prev_activ_data_uint8}; 90 Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims}; 91 tflite::reference_ops::Concatenation<tflite::FusedActivationFunctionType::kNone, uint8_t>( 92 0, concat_input_arrays_data, concat_input_arrays_dims, 2, 93 concat_temp_data_uint8, concat_temp_dims); 94 95 // Implementation of the fully connected node inside the LSTM cell. 96 // The operands are 8-bit integers, the accumulators are internally 32bit 97 // integers, and the output is 16-bit fixed-point with 3 integer bits so 98 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that 99 // is explained in the function comment above. 100 for (int b = 0; b < fc_batches; ++b) { 101 for (int out_c = 0; out_c < fc_output_depth; ++out_c) { 102 // Internal accumulation. 103 // Initialize accumulator with the bias-value. 104 int32_t accum = bias_data_int32[out_c]; 105 // Accumulation loop. 106 for (int d = 0; d < fc_accum_depth; ++d) { 107 int16_t input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128; 108 int16_t weights_val = 109 weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point; 110 accum += input_val * weights_val; 111 } 112 // Down-scale the final int32 accumulator to the scale used by our 113 // (16-bit, using 3 integer bits) fixed-point format. The quantized 114 // multiplier and shift here have been pre-computed offline 115 // (e.g. by toco). 116 accum = 117 tflite::MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift); 118 // Saturate, cast to int16, and store to the temporary activations array. 119 accum = std::max(-32768, std::min(32767, accum)); 120 activ_temp_data_int16[out_c + fc_output_depth * b] = accum; 121 } 122 } 123 124 // Rest of the LSTM cell: tanh and logistic math functions, and some adds 125 // and muls, all done in 16-bit fixed-point. 126 for (int b = 0; b < outer_size; ++b) { 127 for (int c = 0; c < output_depth; ++c) { 128 // Define the fixed-point data types that we will use here. All use 129 // int16 as the underlying integer type i.e. all are 16-bit fixed-point. 130 // They only differ by the number of integral vs. fractional bits, 131 // determining the range of values that they can represent. 132 // 133 // F0 uses 0 integer bits, range [-1, 1]. 134 // This is the return type of math functions such as tanh, logistic, 135 // whose range is in [-1, 1]. 136 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>; 137 // F3 uses 3 integer bits, range [-8, 8]. 138 // This is the range of the previous fully-connected node's output, 139 // which is our input here. 140 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>; 141 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, 142 // 2^StateIntegerBits]. It's used to represent the internal state, whose 143 // number of integer bits is currently dictated by the model. See comment 144 // on the StateIntegerBits template parameter above. 145 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>; 146 // Implementation of input gate, using fixed-point logistic function. 147 F3 input_gate_input = F3::FromRaw( 148 activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]); 149 F0 input_gate_output = gemmlowp::logistic(input_gate_input); 150 // Implementation of input modulation gate, using fixed-point tanh 151 // function. 152 F3 input_modulation_gate_input = F3::FromRaw( 153 activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]); 154 F0 input_modulation_gate_output = 155 gemmlowp::tanh(input_modulation_gate_input); 156 // Implementation of forget gate, using fixed-point logistic function. 157 F3 forget_gate_input = F3::FromRaw( 158 activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]); 159 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); 160 // Implementation of output gate, using fixed-point logistic function. 161 F3 output_gate_input = F3::FromRaw( 162 activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]); 163 F0 output_gate_output = gemmlowp::logistic(output_gate_input); 164 // Implementation of internal multiplication nodes, still in fixed-point. 165 F0 input_times_input_modulation = 166 input_gate_output * input_modulation_gate_output; 167 FS prevCellState = FS::FromRaw(prevCellState_data_int16[b * output_depth + c]); 168 FS prevCellState_times_forget_state = forget_gate_output * prevCellState; 169 // Implementation of internal addition node, saturating. 170 FS new_state = gemmlowp::SaturatingAdd( 171 gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation), 172 prevCellState_times_forget_state); 173 // Implementation of last internal Tanh node, still in fixed-point. 174 // Since a Tanh fixed-point implementation is specialized for a given 175 // number or integer bits, and each specialization can have a substantial 176 // code size, and we already used above a Tanh on an input with 3 integer 177 // bits, and per the table in the above function comment there is no 178 // significant accuracy to be lost by clamping to [-8, +8] for a 179 // 3-integer-bits representation, let us just do that. This helps people 180 // porting this to targets where code footprint must be minimized. 181 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state); 182 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3); 183 // Store the new internal state back to memory, as 16-bit integers. 184 // Note: here we store the original value with StateIntegerBits, not 185 // the rescaled 3-integer-bits value fed to tanh. 186 output_state_data_int16[b * output_depth + c] = new_state.raw(); 187 // Down-scale the output activations to 8-bit integers, saturating, 188 // and store back to memory. 189 int16_t rescaled_output_activ = 190 gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); 191 int16_t clamped_output_activ = 192 std::max<int16_t>(-128, std::min<int16_t>(127, rescaled_output_activ)); 193 output_activ_data_uint8[b * output_depth + c] = 194 128 + clamped_output_activ; 195 } 196 } 197 } 198 // clang-format on 199 200 // The function assigns a 2D matrix to a submatrix of the weights at a given row 201 // and column offsets. 202 void assignWeightsSubmatrix(const RunTimeOperandInfo* submatrix, const int32_t offset_row, 203 const int32_t offset_column, const std::vector<uint32_t>& weightsDims, 204 uint8_t* weights) { 205 const uint8_t* submatrixValues = GetBuffer<uint8_t>(submatrix); 206 const std::vector<uint32_t> submatrixDims = submatrix->shape().dimensions; 207 for (uint32_t i = 0; i < submatrixDims[0] * submatrixDims[1]; ++i) { 208 const uint32_t row = i / submatrixDims[1]; 209 const uint32_t column = i % submatrixDims[1]; 210 weights[(row + offset_row) * weightsDims[1] + column + offset_column] = submatrixValues[i]; 211 } 212 } 213 214 } // namespace 215 216 QuantizedLSTMCell::QuantizedLSTMCell(const Operation& operation, 217 std::vector<RunTimeOperandInfo>& operands) { 218 input_ = GetInput(operation, operands, kInputTensor); 219 220 inputToInputWeights_ = GetInput(operation, operands, kInputToInputWeightsTensor); 221 inputToForgetWeights_ = GetInput(operation, operands, kInputToForgetWeightsTensor); 222 inputToCellWeights_ = GetInput(operation, operands, kInputToCellWeightsTensor); 223 inputToOutputWeights_ = GetInput(operation, operands, kInputToOutputWeightsTensor); 224 225 recurrentToInputWeights_ = GetInput(operation, operands, kRecurrentToInputWeightsTensor); 226 recurrentToForgetWeights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor); 227 recurrentToCellWeights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor); 228 recurrentToOutputWeights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor); 229 230 inputGateBias_ = GetInput(operation, operands, kInputGateBiasTensor); 231 forgetGateBias_ = GetInput(operation, operands, kForgetGateBiasTensor); 232 cellGateBias_ = GetInput(operation, operands, kCellGateBiasTensor); 233 outputGateBias_ = GetInput(operation, operands, kOutputGateBiasTensor); 234 235 prevCellState_ = GetInput(operation, operands, kPrevCellStateTensor); 236 prevOutput_ = GetInput(operation, operands, kPrevOutputTensor); 237 238 cellStateOut_ = GetOutput(operation, operands, kCellStateOutTensor); 239 output_ = GetOutput(operation, operands, kOutputTensor); 240 } 241 242 bool QuantizedLSTMCell::prepare(const Operation& operation, 243 std::vector<RunTimeOperandInfo>& operands, Shape* cellStateOutShape, 244 Shape* outputShape) { 245 auto input = GetInput(operation, operands, kInputTensor); 246 NN_RET_CHECK_EQ(NumDimensions(input), 2); 247 NN_RET_CHECK_EQ(input->scale, 1. / 128.0); 248 NN_RET_CHECK_EQ(input->zeroPoint, 128); 249 const uint32_t numBatches = SizeOfDimension(input, 0); 250 const uint32_t inputSize = SizeOfDimension(input, 1); 251 252 auto prevOutput = GetInput(operation, operands, kPrevOutputTensor); 253 NN_RET_CHECK_EQ(NumDimensions(prevOutput), 2); 254 NN_RET_CHECK_EQ(SizeOfDimension(prevOutput, 0), numBatches); 255 NN_RET_CHECK_EQ(prevOutput->scale, 1. / 128.0); 256 NN_RET_CHECK_EQ(prevOutput->zeroPoint, 128); 257 const uint32_t outputSize = SizeOfDimension(prevOutput, 1); 258 259 auto inputToInputWeights = GetInput(operation, operands, kInputToInputWeightsTensor); 260 const float weightsScale = inputToInputWeights->scale; 261 NN_RET_CHECK(weightsScale != 0); 262 const float weightsZeroPoint = inputToInputWeights->zeroPoint; 263 264 auto checkWeightsShape = [&](const RunTimeOperandInfo* weights, uint32_t columns) -> bool { 265 NN_RET_CHECK_EQ(NumDimensions(weights), 2); 266 NN_RET_CHECK_EQ(SizeOfDimension(weights, 0), outputSize); 267 NN_RET_CHECK_EQ(SizeOfDimension(weights, 1), columns); 268 NN_RET_CHECK_EQ(weights->scale, weightsScale); 269 NN_RET_CHECK_EQ(weights->zeroPoint, weightsZeroPoint); 270 return true; 271 }; 272 273 auto inputToForgetWeights = GetInput(operation, operands, kInputToForgetWeightsTensor); 274 auto inputToCellWeights = GetInput(operation, operands, kInputToCellWeightsTensor); 275 auto inputToOutputWeights = GetInput(operation, operands, kInputToOutputWeightsTensor); 276 NN_RET_CHECK(checkWeightsShape(inputToInputWeights, inputSize)); 277 NN_RET_CHECK(checkWeightsShape(inputToForgetWeights, inputSize)); 278 NN_RET_CHECK(checkWeightsShape(inputToCellWeights, inputSize)); 279 NN_RET_CHECK(checkWeightsShape(inputToOutputWeights, inputSize)); 280 281 auto recurrentToInputWeights = GetInput(operation, operands, kRecurrentToInputWeightsTensor); 282 auto recurrentToForgetWeights = GetInput(operation, operands, kRecurrentToForgetWeightsTensor); 283 auto recurrentToCellWeights = GetInput(operation, operands, kRecurrentToCellWeightsTensor); 284 auto recurrentToOutputWeights = GetInput(operation, operands, kRecurrentToOutputWeightsTensor); 285 NN_RET_CHECK(checkWeightsShape(recurrentToInputWeights, outputSize)); 286 NN_RET_CHECK(checkWeightsShape(recurrentToForgetWeights, outputSize)); 287 NN_RET_CHECK(checkWeightsShape(recurrentToCellWeights, outputSize)); 288 NN_RET_CHECK(checkWeightsShape(recurrentToOutputWeights, outputSize)); 289 290 auto inputGateBias = GetInput(operation, operands, kInputGateBiasTensor); 291 const float biasScale = inputGateBias->scale; 292 NN_RET_CHECK_EQ(biasScale, weightsScale / 128.0); 293 const float biasZeroPoint = inputGateBias->zeroPoint; 294 NN_RET_CHECK_EQ(biasZeroPoint, 0); 295 296 auto checkBiasShape = [&](const RunTimeOperandInfo* bias) -> bool { 297 NN_RET_CHECK_EQ(NumDimensions(bias), 1); 298 NN_RET_CHECK_EQ(SizeOfDimension(bias, 0), outputSize); 299 NN_RET_CHECK_EQ(bias->scale, biasScale); 300 NN_RET_CHECK_EQ(bias->zeroPoint, biasZeroPoint); 301 return true; 302 }; 303 304 auto forgetGateBias = GetInput(operation, operands, kForgetGateBiasTensor); 305 auto cellGateBias = GetInput(operation, operands, kCellGateBiasTensor); 306 auto outputGateBias = GetInput(operation, operands, kOutputGateBiasTensor); 307 NN_RET_CHECK(checkBiasShape(inputGateBias)); 308 NN_RET_CHECK(checkBiasShape(forgetGateBias)); 309 NN_RET_CHECK(checkBiasShape(cellGateBias)); 310 NN_RET_CHECK(checkBiasShape(outputGateBias)); 311 312 auto prevCellState = GetInput(operation, operands, kPrevCellStateTensor); 313 NN_CHECK_EQ(NumDimensions(prevCellState), 2); 314 NN_CHECK_EQ(SizeOfDimension(prevCellState, 0), numBatches); 315 NN_CHECK_EQ(SizeOfDimension(prevCellState, 1), outputSize); 316 NN_CHECK_EQ(prevCellState->zeroPoint, 0); 317 // Cell state range for quantized LSTM is a function of StateIntegerBits and 318 // can be calculated as: 319 // [-2^StateIntegerBits, 2^StateIntegerBits * 32767/32768]. 320 // Therefore, for a fixed StateIntegerBits parameter, cell state scale is 321 // equal to 2^StateIntegerBits * 2^(-15) = 2^(StateIntegerBits - 15) and 322 // therefore: 323 // StateIntegerBits = log2(cell state scale) + 15 324 int stateScaleLog2Rounded; 325 NN_CHECK(tflite::CheckedLog2(prevCellState->scale, &stateScaleLog2Rounded)); 326 const int stateIntegerBits = 15 + stateScaleLog2Rounded; 327 // We only support StateIntegerBits == 4 328 NN_CHECK(stateIntegerBits == 4); 329 330 *cellStateOutShape = prevCellState->shape(); 331 *outputShape = prevOutput->shape(); 332 return true; 333 } 334 335 // The function contatenates 8 input weight matrices into one. Resulting matrix 336 // has a shape [4 * outputSize, outputSize + inputSize]. The matrix is 337 // constructed as follows: 338 // +-----------------------------------+ 339 // | recurrentToInput | inputToInput | 340 // |-------------------+---------------| 341 // | recurrentToCell | inputToCell | 342 // |-------------------+---------------| 343 // | recurrentToForget | inputToForget | 344 // |-------------------+---------------| 345 // | recurrentToOutput | inputToOutput | 346 // +-----------------------------------+ 347 void QuantizedLSTMCell::concatenateWeights(const std::vector<uint32_t>& weightsDims, 348 uint8_t* weights) { 349 const int outputSize = SizeOfDimension(inputToInputWeights_, 0); 350 351 assignWeightsSubmatrix(inputToInputWeights_, 0 * outputSize, outputSize, weightsDims, weights); 352 assignWeightsSubmatrix(inputToCellWeights_, 1 * outputSize, outputSize, weightsDims, weights); 353 assignWeightsSubmatrix(inputToForgetWeights_, 2 * outputSize, outputSize, weightsDims, weights); 354 assignWeightsSubmatrix(inputToOutputWeights_, 3 * outputSize, outputSize, weightsDims, weights); 355 assignWeightsSubmatrix(recurrentToInputWeights_, 0 * outputSize, 0, weightsDims, weights); 356 assignWeightsSubmatrix(recurrentToCellWeights_, 1 * outputSize, 0, weightsDims, weights); 357 assignWeightsSubmatrix(recurrentToForgetWeights_, 2 * outputSize, 0, weightsDims, weights); 358 assignWeightsSubmatrix(recurrentToOutputWeights_, 3 * outputSize, 0, weightsDims, weights); 359 } 360 361 // The function concatenate four bias vectors of shape [outputSize] into one 362 // vector of shape [4 * outputSize]. 363 void QuantizedLSTMCell::concatenateBiases(uint32_t outputSize, int32_t* bias) { 364 memcpy(bias + 0 * outputSize, GetBuffer<int32_t>(inputGateBias_), sizeof(int32_t) * outputSize); 365 memcpy(bias + 1 * outputSize, GetBuffer<int32_t>(cellGateBias_), sizeof(int32_t) * outputSize); 366 memcpy(bias + 2 * outputSize, GetBuffer<int32_t>(forgetGateBias_), 367 sizeof(int32_t) * outputSize); 368 memcpy(bias + 3 * outputSize, GetBuffer<int32_t>(outputGateBias_), 369 sizeof(int32_t) * outputSize); 370 } 371 372 bool QuantizedLSTMCell::eval() { 373 NNTRACE_COMP("QuantizedLSTM::eval"); 374 375 Shape weightsShape; 376 weightsShape.dimensions = {4 * SizeOfDimension(prevOutput_, 1), 377 SizeOfDimension(input_, 1) + SizeOfDimension(prevOutput_, 1)}; 378 std::vector<uint8_t> weights(getNumberOfElements(weightsShape)); 379 concatenateWeights(weightsShape.dimensions, weights.data()); 380 381 Shape biasShape; 382 biasShape.dimensions = {getSizeOfDimension(weightsShape, 0)}; 383 std::vector<int32_t> bias(getNumberOfElements(biasShape)); 384 concatenateBiases(SizeOfDimension(prevOutput_, 1), bias.data()); 385 386 Shape concatTempShape; 387 concatTempShape.dimensions = {SizeOfDimension(input_, 0), getSizeOfDimension(weightsShape, 1)}; 388 389 Shape activationTempShape; 390 activationTempShape.dimensions = {SizeOfDimension(input_, 0), 391 getSizeOfDimension(weightsShape, 0)}; 392 393 std::vector<uint8_t> concatTemp(getNumberOfElements(concatTempShape)); 394 std::vector<int16_t> activationTemp(getNumberOfElements(activationTempShape)); 395 396 // From https://arxiv.org/pdf/1712.05877, for a fully-connected layer, 397 // accumulator multiplier is equal to: 398 // (input scale) * (weights scale) / (fully-connected output scale) 399 // In our case fully-connected output scale is fixed and equal to 400 // 2^(-12) (See LSTMCell definition in TF Lite for more details on that). 401 // But bias scale is set to (input scale) * (weights scale) (also from the 402 // paper), so we can multiply it to an inverse of the fc-output scale to get 403 // the multiplier value: 404 double realAccumMultiplier = 4096 * inputGateBias_->scale; 405 int32_t accumMultiplier; 406 int accumShift; 407 tflite::QuantizeMultiplier(realAccumMultiplier, &accumMultiplier, &accumShift); 408 quantizedLstmStep<4>( 409 // Inputs. 410 GetBuffer<const uint8_t>(input_), convertShapeToDims(input_->shape()), 411 GetBuffer<const uint8_t>(prevOutput_), convertShapeToDims(prevOutput_->shape()), 412 weights.data(), convertShapeToDims(weightsShape), bias.data(), 413 convertShapeToDims(biasShape), GetBuffer<const int16_t>(prevCellState_), 414 convertShapeToDims(prevCellState_->shape()), 415 // Outputs. 416 GetBuffer<int16_t>(cellStateOut_), convertShapeToDims(cellStateOut_->shape()), 417 GetBuffer<uint8_t>(output_), convertShapeToDims(output_->shape()), concatTemp.data(), 418 convertShapeToDims(concatTempShape), activationTemp.data(), 419 convertShapeToDims(activationTempShape), inputToInputWeights_->zeroPoint, 420 accumMultiplier, accumShift); 421 return true; 422 } 423 424 } // namespace nn 425 } // namespace android 426