1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "LSTM.h" 18 19 #include "CpuExecutor.h" 20 #include "CpuOperationUtils.h" 21 #include "HalInterfaces.h" 22 #include "OperationsUtils.h" 23 24 #include "Tracing.h" 25 #include "Utils.h" 26 27 namespace android { 28 namespace nn { 29 30 namespace { 31 32 template <typename T> 33 inline T* GetBuffer(RunTimeOperandInfo* operand) { 34 return reinterpret_cast<T*>(operand->buffer); 35 } 36 37 template <typename T> 38 inline const T* GetBuffer(const RunTimeOperandInfo* operand) { 39 return reinterpret_cast<const T*>(operand->buffer); 40 } 41 42 template <typename T> 43 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) { 44 return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr; 45 } 46 47 } // anonymous namespace 48 49 LSTMCell::LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) { 50 input_ = GetInput(operation, operands, kInputTensor); 51 52 input_to_input_weights_ = 53 GetInput(operation, operands, kInputToInputWeightsTensor); // optional 54 input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor); 55 input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor); 56 input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor); 57 58 recurrent_to_input_weights_ = 59 GetInput(operation, operands, kRecurrentToInputWeightsTensor); // optional 60 recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor); 61 recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor); 62 recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor); 63 64 cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor); // optional 65 cell_to_forget_weights_ = 66 GetInput(operation, operands, kCellToForgetWeightsTensor); // optional 67 cell_to_output_weights_ = 68 GetInput(operation, operands, kCellToOutputWeightsTensor); // optional 69 70 input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor); 71 forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor); 72 cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor); 73 output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor); 74 75 projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional 76 projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional 77 78 output_state_in_ = GetInput(operation, operands, kOutputStateInTensor); 79 cell_state_in_ = GetInput(operation, operands, kCellStateInTensor); 80 81 params_.activation = static_cast<TfLiteFusedActivation>( 82 getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam))); 83 if (input_->type == OperandType::TENSOR_FLOAT32) { 84 params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam)); 85 params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam)); 86 } else { 87 params_.cell_clip = static_cast<float>( 88 getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam))); 89 params_.proj_clip = static_cast<float>( 90 getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam))); 91 } 92 93 // We check the version of LSTM by checking the number of the inputs to the 94 // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27. 95 if (operation.inputs.size() == 27) { 96 input_layer_norm_weights_ = 97 GetInput(operation, operands, kInputLayerNormWeightsTensor); // optional 98 forget_layer_norm_weights_ = 99 GetInput(operation, operands, kForgetLayerNormWeightsTensor); // optional 100 cell_layer_norm_weights_ = 101 GetInput(operation, operands, kCellLayerNormWeightsTensor); // optional 102 output_layer_norm_weights_ = 103 GetInput(operation, operands, kOutputLayerNormWeightsTensor); // optional 104 } else { 105 // For LSTM from HAL v1.0 assign operands with no values 106 static RunTimeOperandInfo no_value; 107 no_value.lifetime = OperandLifeTime::NO_VALUE; 108 109 input_layer_norm_weights_ = &no_value; 110 forget_layer_norm_weights_ = &no_value; 111 cell_layer_norm_weights_ = &no_value; 112 output_layer_norm_weights_ = &no_value; 113 } 114 115 output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor); 116 cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor); 117 output_ = GetOutput(operation, operands, kOutputTensor); 118 119 scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor); 120 } 121 122 // static 123 bool LSTMCell::CheckInputTensorDimensions( 124 const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights, 125 const RunTimeOperandInfo* input_to_forget_weights, 126 const RunTimeOperandInfo* input_to_cell_weights, 127 const RunTimeOperandInfo* input_to_output_weights, 128 const RunTimeOperandInfo* recurrent_to_input_weights, 129 const RunTimeOperandInfo* recurrent_to_forget_weights, 130 const RunTimeOperandInfo* recurrent_to_cell_weights, 131 const RunTimeOperandInfo* recurrent_to_output_weights, 132 const RunTimeOperandInfo* cell_to_input_weights, 133 const RunTimeOperandInfo* cell_to_forget_weights, 134 const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias, 135 const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias, 136 const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights, 137 const RunTimeOperandInfo* projection_bias, 138 const RunTimeOperandInfo* input_layer_norm_weights, 139 const RunTimeOperandInfo* forget_layer_norm_weights, 140 const RunTimeOperandInfo* cell_layer_norm_weights, 141 const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output, 142 uint32_t n_cell, LSTMParams* params) { 143 // Making sure clipping parameters have valid values. 144 // == 0 means no clipping 145 // > 0 means clipping 146 NN_CHECK(params->cell_clip >= 0); 147 NN_CHECK(params->proj_clip >= 0); 148 149 if (!IsNullInput(input_to_input_weights)) { 150 NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2); 151 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell); 152 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input); 153 } 154 155 NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2); 156 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell); 157 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input); 158 159 NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2); 160 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell); 161 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input); 162 163 if (!IsNullInput(recurrent_to_input_weights)) { 164 NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2); 165 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell); 166 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output); 167 } 168 169 NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2); 170 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell); 171 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output); 172 173 NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2); 174 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell); 175 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output); 176 177 // We make sure the input-gate's parameters are either both present (regular 178 // LSTM) or not at all (CIFG-LSTM). 179 const bool cifg_weights_all_or_none = 180 (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) || 181 (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights)); 182 NN_CHECK(cifg_weights_all_or_none); 183 184 if (!IsNullInput(cell_to_input_weights)) { 185 NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1); 186 NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell); 187 } 188 189 if (!IsNullInput(cell_to_forget_weights)) { 190 NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1); 191 NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell); 192 } 193 194 if (!IsNullInput(cell_to_output_weights)) { 195 NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1); 196 NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell); 197 } 198 199 // Making sure the peephole weights are there all or none. 200 params->use_cifg = IsNullInput(input_to_input_weights); 201 const bool peephole_weights_all_or_none = 202 ((!IsNullInput(cell_to_input_weights) || params->use_cifg) && 203 !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) || 204 (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) && 205 IsNullInput(cell_to_output_weights)); 206 NN_CHECK(peephole_weights_all_or_none); 207 208 // Since we have already checked that weights are all there or none, we can 209 // check the existence of only one to the get the condition. 210 params->use_peephole = !IsNullInput(cell_to_output_weights); 211 // Checking output instead of input layer norm weights because input can be 212 // omitted ones can be omited in case CIFG LSTM is used. 213 params->use_layer_norm = !IsNullInput(output_layer_norm_weights); 214 215 params->use_projection_weight = (projection_weights->lifetime != OperandLifeTime::NO_VALUE); 216 params->use_projection_bias = (projection_bias->lifetime != OperandLifeTime::NO_VALUE); 217 218 // Make sure the input gate bias is present only when not a CIFG-LSTM. 219 if (params->use_cifg) { 220 NN_CHECK(IsNullInput(input_gate_bias)); 221 } else { 222 NN_CHECK_EQ(NumDimensions(input_gate_bias), 1); 223 NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell); 224 } 225 226 NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1); 227 NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell); 228 229 NN_CHECK_EQ(NumDimensions(cell_bias), 1); 230 NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell); 231 232 NN_CHECK_EQ(NumDimensions(output_gate_bias), 1); 233 NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell); 234 235 if (!IsNullInput(projection_weights)) { 236 NN_CHECK_EQ(NumDimensions(projection_weights), 2); 237 NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output); 238 NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell); 239 } 240 241 if (!IsNullInput(projection_bias)) { 242 NN_CHECK_EQ(NumDimensions(projection_bias), 1); 243 NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output); 244 } 245 246 // Making sure the projection tensors are consistent: 247 // 1) If projection weight is not present, then projection bias should not be 248 // present. 249 // 2) If projection weight is present, then projection bias is optional. 250 // TODO: make sure this is correct. 251 const bool projecton_tensors_consistent = 252 (!IsNullInput(projection_weights) || IsNullInput(projection_bias)); 253 NN_CHECK(projecton_tensors_consistent == true); 254 255 if (!IsNullInput(input_layer_norm_weights)) { 256 NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1); 257 NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell); 258 } 259 if (!IsNullInput(forget_layer_norm_weights)) { 260 NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1); 261 NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell); 262 } 263 if (!IsNullInput(cell_layer_norm_weights)) { 264 NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1); 265 NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell); 266 } 267 if (!IsNullInput(output_layer_norm_weights)) { 268 NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1); 269 NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell); 270 } 271 272 if (params->use_cifg) { 273 NN_RET_CHECK(IsNullInput(input_layer_norm_weights)) 274 << "input_layer_norm_weights are provided while CIFG is used"; 275 const bool layer_norm_weights_all_or_none_cifg = 276 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) && 277 IsNullInput(output_layer_norm_weights)) || 278 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) && 279 !IsNullInput(output_layer_norm_weights)); 280 NN_RET_CHECK(layer_norm_weights_all_or_none_cifg); 281 } else { 282 const bool layer_norm_weights_all_or_none = 283 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) && 284 IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) || 285 (!IsNullInput(input_layer_norm_weights) && 286 !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) && 287 !IsNullInput(output_layer_norm_weights)); 288 NN_RET_CHECK(layer_norm_weights_all_or_none); 289 } 290 291 return true; 292 } 293 294 bool LSTMCell::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands, 295 Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape, 296 Shape* outputShape) { 297 // Check we have all the inputs and outputs we need. 298 NN_CHECK(NumInputsWithValues(operation, operands) >= 15 && 299 NumInputsWithValues(operation, operands) <= 27); 300 NN_CHECK_EQ(NumOutputs(operation), 4); 301 302 // Inferring batch size, number of outputs and number of cells from the 303 // input tensors. 304 NN_CHECK(NumDimensions(input_) > 1); 305 const uint32_t n_batch = SizeOfDimension(input_, 0); 306 const uint32_t n_input = SizeOfDimension(input_, 1); 307 308 const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0); 309 NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2); 310 NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input); 311 312 NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2); 313 NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell); 314 const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1); 315 316 // Check that input tensor dimensions matches with each other. 317 if (!CheckInputTensorDimensions( 318 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_, 319 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_, 320 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_, 321 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_, 322 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_, 323 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_, 324 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell, 325 ¶ms_)) { 326 return false; 327 } 328 329 // Resize the output and output_state tensors. 330 const Shape& inputShape = input_->shape(); 331 332 outputShape->type = inputShape.type; 333 outputShape->dimensions = {n_batch, n_output}; 334 outputShape->offset = inputShape.offset; 335 outputShape->scale = inputShape.scale; 336 337 outputStateShape->type = inputShape.type; 338 outputStateShape->dimensions = {n_batch, n_output}; 339 outputStateShape->offset = inputShape.offset; 340 outputStateShape->scale = inputShape.scale; 341 342 cellStateShape->type = inputShape.type; 343 cellStateShape->dimensions = {n_batch, n_cell}; 344 cellStateShape->offset = inputShape.offset; 345 cellStateShape->scale = inputShape.scale; 346 347 if (params_.use_cifg) { 348 // Reserving space for Cell, Forget, Output gates 349 scratchShape->dimensions = {n_batch, n_cell * 3}; 350 } else { 351 // Reserving space for Input, Cell, Forget, Output gates 352 scratchShape->dimensions = {n_batch, n_cell * 4}; 353 } 354 scratchShape->type = inputShape.type; 355 scratchShape->offset = inputShape.offset; 356 scratchShape->scale = inputShape.scale; 357 358 return true; 359 } 360 361 // static 362 bool LSTMCell::LSTMEvalFloat32( 363 const LSTMParams& params, const float* input_buffer, const Shape& input_shape, 364 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, 365 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, 366 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer, 367 const float* recurrent_to_forget_weights_buffer, 368 const float* recurrent_to_cell_weights_buffer, 369 const float* recurrent_to_output_weights_buffer, 370 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer, 371 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer, 372 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer, 373 const float* aux_input_to_forget_weights_buffer, 374 const float* aux_input_to_cell_weights_buffer, 375 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer, 376 const float* forget_gate_bias_buffer, const float* cell_bias_buffer, 377 const float* output_gate_bias_buffer, const float* projection_weights_buffer, 378 const float* projection_bias_buffer, const float* output_state_in_buffer, 379 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer, 380 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer, 381 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, 382 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer, 383 bool timeMajor, bool forwardSequence) { 384 NNTRACE_COMP("LSTMCell::LSTMEvalFloat32"); 385 386 const uint32_t inputRank = getNumberOfDimensions(input_shape); 387 NN_CHECK(inputRank == 2 || inputRank == 3); 388 389 const uint32_t maxTime = 390 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1; 391 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0) 392 : getSizeOfDimension(input_shape, 0); 393 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1); 394 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0); 395 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1); 396 397 Shape batchInputShape = input_shape; 398 batchInputShape.dimensions = {batchSize, inputSize}; 399 const uint32_t batchInputSize = batchSize * inputSize; 400 const uint32_t batchOutputSize = batchSize * outputSize; 401 402 std::vector<float> transposedInput; 403 const bool hasAuxInput = (aux_input_buffer != nullptr); 404 std::vector<float> transposedAuxInput; 405 std::vector<float> transposedOutput; 406 Shape transposedInputShape; 407 Shape transposedOutputShape; 408 if (!timeMajor) { 409 transposedInput.resize(maxTime * batchInputSize); 410 transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data()); 411 if (hasAuxInput) { 412 transposedAuxInput.resize(maxTime * batchInputSize); 413 transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape, 414 transposedAuxInput.data()); 415 } 416 transposeFirstTwoDimensions(input_shape, &transposedInputShape); 417 transposedOutput.resize(maxTime * batchOutputSize); 418 transposedOutputShape = transposedInputShape; 419 transposedOutputShape.dimensions[2] = outputSize; 420 } 421 const float* inputData = timeMajor ? input_buffer : transposedInput.data(); 422 const float* auxInputData = 423 hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr; 424 float* outputData = timeMajor ? output_buffer : transposedOutput.data(); 425 426 std::vector<float> outputStateInCurrentTimeStep( 427 output_state_in_buffer, output_state_in_buffer + batchSize * outputSize); 428 std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer, 429 cell_state_in_buffer + batchSize * numCells); 430 const float* inputCurrentTimeStep = 431 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)); 432 const float* auxInputCurrentTimeStep = 433 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1))) 434 : nullptr; 435 float* outputCurrentTimeStep = 436 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1)); 437 const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize; 438 const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize; 439 440 for (int t = 0; t < maxTime; ++t) { 441 LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer, 442 input_to_forget_weights_buffer, input_to_cell_weights_buffer, 443 input_to_output_weights_buffer, input_to_output_weights_shape, 444 recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer, 445 recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer, 446 recurrent_to_output_weights_shape, cell_to_input_weights_buffer, 447 cell_to_forget_weights_buffer, cell_to_output_weights_buffer, 448 auxInputCurrentTimeStep, aux_input_to_input_weights_buffer, 449 aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer, 450 aux_input_to_output_weights_buffer, input_gate_bias_buffer, 451 forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer, 452 projection_weights_buffer, projection_bias_buffer, 453 outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(), 454 input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer, 455 cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer, 456 output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep, 457 scratch_buffer_buffer); 458 inputCurrentTimeStep += batchInputDelta; 459 if (hasAuxInput) { 460 auxInputCurrentTimeStep += batchInputDelta; 461 } 462 outputCurrentTimeStep += batchOutputDelta; 463 outputStateInCurrentTimeStep.assign(output_state_out_buffer, 464 output_state_out_buffer + batchSize * outputSize); 465 cellStateInCurrentTimeStep.assign(cell_state_out_buffer, 466 cell_state_out_buffer + batchSize * numCells); 467 } 468 469 if (!timeMajor) { 470 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape, 471 output_buffer); 472 } 473 474 return true; 475 } 476 477 // static 478 bool LSTMCell::LSTMEvalFloat16( 479 const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape, 480 const _Float16* input_to_input_weights_buffer, 481 const _Float16* input_to_forget_weights_buffer, 482 const _Float16* input_to_cell_weights_buffer, 483 const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape, 484 const _Float16* recurrent_to_input_weights_buffer, 485 const _Float16* recurrent_to_forget_weights_buffer, 486 const _Float16* recurrent_to_cell_weights_buffer, 487 const _Float16* recurrent_to_output_weights_buffer, 488 const Shape& recurrent_to_output_weights_shape, 489 const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer, 490 const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer, 491 const _Float16* aux_input_to_input_weights_buffer, 492 const _Float16* aux_input_to_forget_weights_buffer, 493 const _Float16* aux_input_to_cell_weights_buffer, 494 const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer, 495 const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer, 496 const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer, 497 const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer, 498 const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer, 499 const _Float16* forget_layer_norm_weights_buffer, 500 const _Float16* cell_layer_norm_weights_buffer, 501 const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer, 502 _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer, 503 bool timeMajor, bool forwardSequence) { 504 NNTRACE_COMP("LSTMCell::LSTMEvalFloat16"); 505 506 const uint32_t inputRank = getNumberOfDimensions(input_shape); 507 NN_CHECK(inputRank == 2 || inputRank == 3); 508 509 const uint32_t maxTime = 510 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1; 511 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0) 512 : getSizeOfDimension(input_shape, 0); 513 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1); 514 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0); 515 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1); 516 517 Shape batchInputShape = input_shape; 518 batchInputShape.dimensions = {batchSize, inputSize}; 519 const uint32_t batchInputSize = batchSize * inputSize; 520 const uint32_t batchOutputSize = batchSize * outputSize; 521 522 std::vector<float> input_float32(maxTime * batchInputSize); 523 convertFloat16ToFloat32(input_buffer, &input_float32); 524 std::vector<float> input_to_input_weights_float32(numCells * inputSize); 525 if (input_to_input_weights_buffer != nullptr) { 526 convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32); 527 } 528 std::vector<float> input_to_forget_weights_float32(numCells * inputSize); 529 convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32); 530 std::vector<float> input_to_cell_weights_float32(numCells * inputSize); 531 convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32); 532 std::vector<float> input_to_output_weights_float32(numCells * inputSize); 533 convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32); 534 535 std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize); 536 if (recurrent_to_input_weights_buffer != nullptr) { 537 convertFloat16ToFloat32(recurrent_to_input_weights_buffer, 538 &recurrent_to_input_weights_float32); 539 } 540 std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize); 541 convertFloat16ToFloat32(recurrent_to_forget_weights_buffer, 542 &recurrent_to_forget_weights_float32); 543 std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize); 544 convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32); 545 std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize); 546 convertFloat16ToFloat32(recurrent_to_output_weights_buffer, 547 &recurrent_to_output_weights_float32); 548 549 std::vector<float> cell_to_input_weights_float32(numCells); 550 if (cell_to_input_weights_buffer != nullptr) { 551 convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32); 552 } 553 std::vector<float> cell_to_forget_weights_float32(numCells); 554 if (cell_to_forget_weights_buffer != nullptr) { 555 convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32); 556 } 557 std::vector<float> cell_to_output_weights_float32(numCells); 558 if (cell_to_output_weights_buffer != nullptr) { 559 convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32); 560 } 561 562 std::vector<float> aux_input_float32(maxTime * batchInputSize); 563 if (aux_input_buffer != nullptr) { 564 convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32); 565 } 566 std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize); 567 if (aux_input_to_input_weights_buffer != nullptr) { 568 convertFloat16ToFloat32(aux_input_to_input_weights_buffer, 569 &aux_input_to_input_weights_float32); 570 } 571 std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize); 572 if (aux_input_to_forget_weights_buffer != nullptr) { 573 convertFloat16ToFloat32(aux_input_to_forget_weights_buffer, 574 &aux_input_to_forget_weights_float32); 575 } 576 std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize); 577 if (aux_input_to_cell_weights_buffer != nullptr) { 578 convertFloat16ToFloat32(aux_input_to_cell_weights_buffer, 579 &aux_input_to_cell_weights_float32); 580 } 581 std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize); 582 if (aux_input_to_output_weights_buffer != nullptr) { 583 convertFloat16ToFloat32(aux_input_to_output_weights_buffer, 584 &aux_input_to_output_weights_float32); 585 } 586 587 std::vector<float> input_gate_bias_float32(numCells); 588 if (input_gate_bias_buffer != nullptr) { 589 convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32); 590 } 591 std::vector<float> forget_gate_bias_float32(numCells); 592 convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32); 593 std::vector<float> cell_bias_float32(numCells); 594 convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32); 595 std::vector<float> output_gate_bias_float32(numCells); 596 convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32); 597 598 std::vector<float> projection_weights_float32(numCells * outputSize); 599 if (projection_weights_buffer != nullptr) { 600 convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32); 601 } 602 std::vector<float> projection_bias_float32(outputSize); 603 if (projection_bias_buffer != nullptr) { 604 convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32); 605 } 606 607 std::vector<float> input_layer_norm_weights_float32(numCells); 608 if (input_layer_norm_weights_buffer != nullptr) { 609 convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32); 610 } 611 std::vector<float> forget_layer_norm_weights_float32(numCells); 612 if (forget_layer_norm_weights_buffer != nullptr) { 613 convertFloat16ToFloat32(forget_layer_norm_weights_buffer, 614 &forget_layer_norm_weights_float32); 615 } 616 std::vector<float> cell_layer_norm_weights_float32(numCells); 617 if (cell_layer_norm_weights_buffer != nullptr) { 618 convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32); 619 } 620 std::vector<float> output_layer_norm_weights_float32(numCells); 621 if (output_layer_norm_weights_buffer != nullptr) { 622 convertFloat16ToFloat32(output_layer_norm_weights_buffer, 623 &output_layer_norm_weights_float32); 624 } 625 626 std::vector<float> output_state_out_float32(batchOutputSize); 627 convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32); 628 std::vector<float> cell_state_out_float32(batchSize * numCells); 629 convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32); 630 631 std::vector<float> output_float32(maxTime * batchOutputSize); 632 convertFloat16ToFloat32(output_buffer, &output_float32); 633 std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells 634 : 4 * batchSize * numCells); 635 convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32); 636 637 std::vector<float> transposedInput; 638 const bool hasAuxInput = (aux_input_buffer != nullptr); 639 std::vector<float> transposedAuxInput; 640 std::vector<float> transposedOutput; 641 Shape transposedInputShape; 642 Shape transposedOutputShape; 643 if (!timeMajor) { 644 transposedInput.resize(maxTime * batchInputSize); 645 transposeFirstTwoDimensions<float>(input_float32.data(), input_shape, 646 transposedInput.data()); 647 if (hasAuxInput) { 648 transposedAuxInput.resize(maxTime * batchInputSize); 649 transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape, 650 transposedAuxInput.data()); 651 } 652 transposeFirstTwoDimensions(input_shape, &transposedInputShape); 653 transposedOutput.resize(maxTime * batchOutputSize); 654 transposedOutputShape = transposedInputShape; 655 transposedOutputShape.dimensions[2] = outputSize; 656 } 657 const float* inputData = timeMajor ? input_float32.data() : transposedInput.data(); 658 const float* auxInputData = 659 hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data()) 660 : nullptr; 661 float* outputData = timeMajor ? output_float32.data() : transposedOutput.data(); 662 663 std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize); 664 convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep); 665 std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells); 666 convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep); 667 668 const float* inputCurrentTimeStep = 669 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)); 670 const float* auxInputCurrentTimeStep = 671 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1))) 672 : nullptr; 673 float* outputCurrentTimeStep = 674 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1)); 675 const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize; 676 const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize; 677 678 for (int t = 0; t < maxTime; ++t) { 679 LSTMStep(params, inputCurrentTimeStep, batchInputShape, 680 input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(), 681 input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(), 682 input_to_output_weights_shape, recurrent_to_input_weights_float32.data(), 683 recurrent_to_forget_weights_float32.data(), 684 recurrent_to_cell_weights_float32.data(), 685 recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape, 686 cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(), 687 cell_to_output_weights_float32.data(), auxInputCurrentTimeStep, 688 aux_input_to_input_weights_float32.data(), 689 aux_input_to_forget_weights_float32.data(), 690 aux_input_to_cell_weights_float32.data(), 691 aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(), 692 forget_gate_bias_float32.data(), cell_bias_float32.data(), 693 output_gate_bias_float32.data(), projection_weights_float32.data(), 694 projection_bias_float32.data(), outputStateInCurrentTimeStep.data(), 695 cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(), 696 forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(), 697 output_layer_norm_weights_float32.data(), output_state_out_float32.data(), 698 cell_state_out_float32.data(), outputCurrentTimeStep, 699 scratch_buffer_float32.data()); 700 inputCurrentTimeStep += batchInputDelta; 701 if (hasAuxInput) { 702 auxInputCurrentTimeStep += batchInputDelta; 703 } 704 outputCurrentTimeStep += batchOutputDelta; 705 outputStateInCurrentTimeStep = output_state_out_float32; 706 cellStateInCurrentTimeStep = cell_state_out_float32; 707 } 708 709 if (!timeMajor) { 710 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape, 711 output_float32.data()); 712 } 713 714 convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer); 715 convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer); 716 convertFloat32ToFloat16(output_float32, output_buffer); 717 convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer); 718 return true; 719 } 720 721 // static 722 bool LSTMCell::LSTMStep( 723 const LSTMParams& params, const float* input_buffer, const Shape& input_shape, 724 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, 725 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, 726 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer, 727 const float* recurrent_to_forget_weights_buffer, 728 const float* recurrent_to_cell_weights_buffer, 729 const float* recurrent_to_output_weights_buffer, 730 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer, 731 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer, 732 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer, 733 const float* aux_input_to_forget_weights_buffer, 734 const float* aux_input_to_cell_weights_buffer, 735 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer, 736 const float* forget_gate_bias_buffer, const float* cell_bias_buffer, 737 const float* output_gate_bias_buffer, const float* projection_weights_buffer, 738 const float* projection_bias_buffer, const float* output_state_in_buffer, 739 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer, 740 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer, 741 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, 742 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) { 743 NNTRACE_COMP("LSTMCell::LSTMStep"); 744 745 const uint32_t n_batch = input_shape.dimensions[0]; 746 const uint32_t n_input = input_shape.dimensions[1]; 747 // n_cell and n_output will be the same size when there is no projection. 748 const uint32_t n_cell = input_to_output_weights_shape.dimensions[0]; 749 const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1]; 750 const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input; 751 752 // Index the scratch buffers pointers to the global scratch buffer. 753 float* input_gate_scratch = nullptr; 754 float* cell_scratch = nullptr; 755 float* forget_gate_scratch = nullptr; 756 float* output_gate_scratch = nullptr; 757 if (params.use_cifg) { 758 cell_scratch = scratch_buffer_buffer; 759 forget_gate_scratch = cell_scratch + n_cell * n_batch; 760 output_gate_scratch = cell_scratch + 2 * n_cell * n_batch; 761 } else { 762 input_gate_scratch = scratch_buffer_buffer; 763 cell_scratch = input_gate_scratch + n_cell * n_batch; 764 forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch; 765 output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch; 766 } 767 768 if (!params.use_layer_norm) { 769 // Initialize scratch buffers with bias. 770 if (!params.use_cifg) { 771 tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch, 772 input_gate_scratch); 773 } 774 tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch, 775 forget_gate_scratch); 776 tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch, 777 cell_scratch); 778 tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch, 779 output_gate_scratch); 780 } else { 781 // Initialize scratch buffers with zeroes. 782 if (!params.use_cifg) { 783 tflite::tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch); 784 } 785 tflite::tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch); 786 tflite::tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch); 787 tflite::tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch); 788 } 789 790 // For each batch and cell: compute input_weight * input. 791 if (!params.use_cifg) { 792 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 793 input_to_input_weights_buffer, n_cell, n_input, input_buffer, n_batch, 794 input_gate_scratch, /*result_stride*/ 1); 795 } 796 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 797 input_to_forget_weights_buffer, n_cell, n_input, input_buffer, n_batch, 798 forget_gate_scratch, /*result_stride*/ 1); 799 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_buffer, n_cell, 800 n_input, input_buffer, n_batch, 801 cell_scratch, /*result_stride*/ 1); 802 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 803 input_to_output_weights_buffer, n_cell, n_input, input_buffer, n_batch, 804 output_gate_scratch, /*result_stride*/ 1); 805 806 // If auxiliary input is available then compute aux_input_weight * aux_input 807 if (aux_input_buffer != nullptr) { 808 if (!params.use_cifg) { 809 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 810 aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer, 811 n_batch, input_gate_scratch, 812 /*result_stride=*/1); 813 } 814 815 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 816 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch, 817 forget_gate_scratch, /*result_stride=*/1); 818 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 819 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch, 820 cell_scratch, /*result_stride=*/1); 821 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 822 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch, 823 output_gate_scratch, /*result_stride=*/1); 824 } 825 826 // For each batch and cell: compute recurrent_weight * output_state. 827 if (!params.use_cifg) { 828 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 829 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer, 830 n_batch, input_gate_scratch, 831 /*result_stride*/ 1); 832 } 833 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 834 recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch, 835 forget_gate_scratch, /*result_stride*/ 1); 836 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 837 recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch, 838 cell_scratch, /*result_stride*/ 1); 839 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 840 recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch, 841 output_gate_scratch, /*result_stride*/ 1); 842 843 // For each batch and cell: update input gate. 844 if (!params.use_cifg) { 845 if (params.use_peephole) { 846 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate( 847 cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch, 848 input_gate_scratch); 849 } 850 if (params.use_layer_norm) { 851 tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch, 852 n_cell, n_batch, kLayerNormEpsilon); 853 tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer, 854 n_cell, input_gate_scratch, n_batch, 855 input_gate_scratch); 856 tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch, 857 input_gate_scratch); 858 } 859 tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, 860 input_gate_scratch); 861 } 862 863 // For each batch and cell: update forget gate. 864 if (params.use_peephole) { 865 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer, 866 n_cell, cell_state_in_buffer, 867 n_batch, forget_gate_scratch); 868 } 869 if (params.use_layer_norm) { 870 tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch, 871 n_cell, n_batch, kLayerNormEpsilon); 872 tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer, 873 n_cell, forget_gate_scratch, n_batch, 874 forget_gate_scratch); 875 tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch, 876 forget_gate_scratch); 877 } 878 tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, 879 forget_gate_scratch); 880 881 // For each batch and cell: update the cell. 882 if (params.use_layer_norm) { 883 tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch, 884 kLayerNormEpsilon); 885 tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell, 886 cell_scratch, n_batch, cell_scratch); 887 tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch); 888 } 889 tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer, 890 n_batch * n_cell, cell_state_out_buffer); 891 tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params.activation, 892 cell_scratch); 893 if (params.use_cifg) { 894 tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, 895 forget_gate_scratch); 896 tflite::tensor_utils::VectorVectorCwiseProductAccumulate( 897 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer); 898 } else { 899 tflite::tensor_utils::VectorVectorCwiseProductAccumulate( 900 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer); 901 } 902 if (params.cell_clip > 0.0) { 903 tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell, params.cell_clip, 904 cell_state_out_buffer); 905 } 906 907 // For each batch and cell: update the output gate. 908 if (params.use_peephole) { 909 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer, 910 n_cell, cell_state_out_buffer, 911 n_batch, output_gate_scratch); 912 } 913 if (params.use_layer_norm) { 914 tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch, 915 n_cell, n_batch, kLayerNormEpsilon); 916 tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer, 917 n_cell, output_gate_scratch, n_batch, 918 output_gate_scratch); 919 tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch, 920 output_gate_scratch); 921 } 922 tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, 923 output_gate_scratch); 924 tflite::tensor_utils::ApplyActivationToVector(cell_state_out_buffer, n_batch * n_cell, 925 params.activation, cell_scratch); 926 tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, 927 n_batch * n_cell, output_gate_scratch); 928 929 // For each batch: update the projection and output_state. 930 if (params.use_projection_weight) { 931 if (params.use_projection_bias) { 932 tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch, 933 output_buffer); 934 } else { 935 tflite::tensor_utils::ZeroVector(output_buffer, n_batch * n_output); 936 } 937 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( 938 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch, 939 output_buffer, 940 /*result_stride*/ 1); 941 if (params.proj_clip > 0.0) { 942 tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params.proj_clip, 943 output_buffer); 944 } 945 } else { 946 tflite::tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, output_buffer); 947 } 948 tflite::tensor_utils::CopyVector(output_buffer, n_batch * n_output, output_state_out_buffer); 949 return true; 950 } 951 952 bool LSTMCell::Eval() { 953 switch (input_->type) { 954 case OperandType::TENSOR_FLOAT32: { 955 LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(), 956 GetBuffer<const float>(input_to_input_weights_), 957 GetBuffer<const float>(input_to_forget_weights_), 958 GetBuffer<const float>(input_to_cell_weights_), 959 GetBuffer<const float>(input_to_output_weights_), 960 input_to_output_weights_->shape(), 961 GetBuffer<const float>(recurrent_to_input_weights_), 962 GetBuffer<const float>(recurrent_to_forget_weights_), 963 GetBuffer<const float>(recurrent_to_cell_weights_), 964 GetBuffer<const float>(recurrent_to_output_weights_), 965 recurrent_to_output_weights_->shape(), 966 GetBuffer<const float>(cell_to_input_weights_), 967 GetBuffer<const float>(cell_to_forget_weights_), 968 GetBuffer<const float>(cell_to_output_weights_), 969 /*aux_input_buffer=*/nullptr, 970 /*aux_input_to_input_weights_buffer=*/nullptr, 971 /*aux_input_to_forget_weights_buffer=*/nullptr, 972 /*aux_input_to_cell_weights_buffer=*/nullptr, 973 /*aux_input_to_output_weights_buffer=*/nullptr, 974 GetBuffer<const float>(input_gate_bias_), 975 GetBuffer<const float>(forget_gate_bias_), 976 GetBuffer<const float>(cell_bias_), 977 GetBuffer<const float>(output_gate_bias_), 978 GetBuffer<const float>(projection_weights_), 979 GetBuffer<const float>(projection_bias_), 980 GetBuffer<const float>(output_state_in_), 981 GetBuffer<const float>(cell_state_in_), 982 GetBuffer<const float>(input_layer_norm_weights_), 983 GetBuffer<const float>(forget_layer_norm_weights_), 984 GetBuffer<const float>(cell_layer_norm_weights_), 985 GetBuffer<const float>(output_layer_norm_weights_), 986 GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_), 987 GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_)); 988 } break; 989 case OperandType::TENSOR_FLOAT16: { 990 LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(), 991 GetOptionalBuffer<const _Float16>(input_to_input_weights_), 992 GetBuffer<const _Float16>(input_to_forget_weights_), 993 GetBuffer<const _Float16>(input_to_cell_weights_), 994 GetBuffer<const _Float16>(input_to_output_weights_), 995 input_to_output_weights_->shape(), 996 GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_), 997 GetBuffer<const _Float16>(recurrent_to_forget_weights_), 998 GetBuffer<const _Float16>(recurrent_to_cell_weights_), 999 GetBuffer<const _Float16>(recurrent_to_output_weights_), 1000 recurrent_to_output_weights_->shape(), 1001 GetOptionalBuffer<const _Float16>(cell_to_input_weights_), 1002 GetOptionalBuffer<const _Float16>(cell_to_forget_weights_), 1003 GetOptionalBuffer<const _Float16>(cell_to_output_weights_), 1004 /*aux_input_buffer=*/nullptr, 1005 /*aux_input_to_input_weights_buffer=*/nullptr, 1006 /*aux_input_to_forget_weights_buffer=*/nullptr, 1007 /*aux_input_to_cell_weights_buffer=*/nullptr, 1008 /*aux_input_to_output_weights_buffer=*/nullptr, 1009 GetOptionalBuffer<const _Float16>(input_gate_bias_), 1010 GetBuffer<const _Float16>(forget_gate_bias_), 1011 GetBuffer<const _Float16>(cell_bias_), 1012 GetBuffer<const _Float16>(output_gate_bias_), 1013 GetOptionalBuffer<const _Float16>(projection_weights_), 1014 GetOptionalBuffer<const _Float16>(projection_bias_), 1015 GetBuffer<const _Float16>(output_state_in_), 1016 GetBuffer<const _Float16>(cell_state_in_), 1017 GetOptionalBuffer<const _Float16>(input_layer_norm_weights_), 1018 GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_), 1019 GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_), 1020 GetOptionalBuffer<const _Float16>(output_layer_norm_weights_), 1021 GetBuffer<_Float16>(output_state_out_), 1022 GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_), 1023 GetBuffer<_Float16>(scratch_buffer_)); 1024 } break; 1025 default: { 1026 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type); 1027 return false; 1028 } 1029 } 1030 return true; 1031 } 1032 1033 } // namespace nn 1034 } // namespace android 1035