1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "HalInterfaces.h" 20 #include "IndexedShapeWrapper.h" 21 #include "LSTM.h" 22 #include "OperationResolver.h" 23 #include "OperationsUtils.h" 24 25 namespace android { 26 namespace nn { 27 namespace unidirectional_sequence_lstm { 28 29 // Inputs 30 constexpr uint32_t kNumInputs = 28; 31 32 // Input tensor of size {max_time, n_batch, n_input} 33 constexpr uint32_t kInputTensor = 0; 34 35 // Input weight tensors of size: {n_cell, n_input} 36 constexpr uint32_t kInputToInputWeightsTensor = 1; // Optional 37 constexpr uint32_t kInputToForgetWeightsTensor = 2; 38 constexpr uint32_t kInputToCellWeightsTensor = 3; 39 constexpr uint32_t kInputToOutputWeightsTensor = 4; 40 41 // Recurrent weight tensors of size {n_cell, n_output} 42 constexpr uint32_t kRecurrentToInputWeightsTensor = 5; // Optional 43 constexpr uint32_t kRecurrentToForgetWeightsTensor = 6; 44 constexpr uint32_t kRecurrentToCellWeightsTensor = 7; 45 constexpr uint32_t kRecurrentToOutputWeightsTensor = 8; 46 47 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 48 constexpr uint32_t kCellToInputWeightsTensor = 9; // Optional 49 constexpr uint32_t kCellToForgetWeightsTensor = 10; // Optional 50 constexpr uint32_t kCellToOutputWeightsTensor = 11; // Optional 51 52 // Gates bias tensors of size {n_cell} 53 constexpr uint32_t kInputGateBiasTensor = 12; // Optional 54 constexpr uint32_t kForgetGateBiasTensor = 13; 55 constexpr uint32_t kCellGateBiasTensor = 14; 56 constexpr uint32_t kOutputGateBiasTensor = 15; 57 58 // Projection weight tensor of size {n_output, n_cell} 59 constexpr uint32_t kProjectionWeightsTensor = 16; // Optional 60 // Projection bias tensor of size {n_output} 61 constexpr uint32_t kProjectionBiasTensor = 17; // Optional 62 63 // Input from the output of the previous step, tensor of size {batch_size, n_output} 64 constexpr uint32_t kOutputStateInTensor = 18; 65 // Input from the cell state of the previous step, tensor of size {batch_size, n_cell} 66 constexpr uint32_t kCellStateInTensor = 19; 67 68 constexpr uint32_t kActivationParam = 20; 69 constexpr uint32_t kCellClipParam = 21; 70 constexpr uint32_t kProjClipParam = 22; 71 constexpr uint32_t kTimeMajorParam = 23; 72 73 // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix. 74 constexpr uint32_t kInputLayerNormWeightsTensor = 24; // Optional 75 constexpr uint32_t kForgetLayerNormWeightsTensor = 25; // Optional 76 constexpr uint32_t kCellLayerNormWeightsTensor = 26; // Optional 77 constexpr uint32_t kOutputLayerNormWeightsTensor = 27; // Optional 78 79 // Output tensors. 80 constexpr uint32_t kNumOutputs = 1; 81 82 constexpr uint32_t kOutputTensor = 0; 83 84 namespace { 85 86 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) { 87 return context->getInputBuffer(tensor) != nullptr; 88 } 89 90 inline bool isTimeMajor(IOperationExecutionContext* context) { 91 return context->getInputValue<bool>(kTimeMajorParam); 92 } 93 94 template <typename T> 95 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) { 96 LSTMParams params; 97 params.activation = 98 static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam)); 99 params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam)); 100 params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam)); 101 params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor); 102 params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor); 103 params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor); 104 params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor); 105 params.use_projection_bias = hasTensor(context, kProjectionBiasTensor); 106 return params; 107 } 108 109 } // namespace 110 111 bool validate(const IOperationValidationContext* context) { 112 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 113 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 114 const OperandType inputType = context->getInputType(kInputTensor); 115 std::vector<OperandType> inExpectedTypes; 116 std::vector<OperandType> outExpectedTypes; 117 if (inputType == OperandType::TENSOR_FLOAT32) { 118 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 119 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 120 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 121 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 122 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 123 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 124 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 125 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 126 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 127 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 128 OperandType::INT32, OperandType::FLOAT32, 129 OperandType::FLOAT32, OperandType::BOOL, 130 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 131 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32}; 132 outExpectedTypes = {OperandType::TENSOR_FLOAT32}; 133 } else if (inputType == OperandType::TENSOR_FLOAT16) { 134 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 135 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 136 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 137 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 138 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 139 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 140 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 141 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 142 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 143 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 144 OperandType::INT32, OperandType::FLOAT16, 145 OperandType::FLOAT16, OperandType::BOOL, 146 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 147 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16}; 148 outExpectedTypes = {OperandType::TENSOR_FLOAT16}; 149 } else { 150 NN_RET_CHECK_FAIL() 151 << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_LSTM op: " 152 << toString(inputType); 153 } 154 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); 155 NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); 156 return validateHalVersion(context, HalVersion::V1_2); 157 } 158 159 bool prepare(IOperationExecutionContext* context) { 160 // Check that none of the required inputs are omitted 161 const std::vector<int> requiredInputs = { 162 kInputTensor, 163 kInputToForgetWeightsTensor, 164 kInputToCellWeightsTensor, 165 kInputToOutputWeightsTensor, 166 kRecurrentToForgetWeightsTensor, 167 kRecurrentToCellWeightsTensor, 168 kRecurrentToOutputWeightsTensor, 169 kForgetGateBiasTensor, 170 kCellGateBiasTensor, 171 kOutputGateBiasTensor, 172 kOutputStateInTensor, 173 kCellStateInTensor, 174 kActivationParam, 175 kCellClipParam, 176 kProjClipParam, 177 kTimeMajorParam, 178 }; 179 for (const int requiredInput : requiredInputs) { 180 NN_RET_CHECK(!context->isOmittedInput(requiredInput)) 181 << "required input " << requiredInput << " is omitted"; 182 } 183 184 const Shape inputShape = context->getInputShape(kInputTensor); 185 const uint32_t inputRank = getNumberOfDimensions(inputShape); 186 NN_RET_CHECK_EQ(inputRank, 3) << "Invalid input tensor rank: " << inputRank; 187 188 const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1); 189 const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0); 190 const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1); 191 192 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor); 193 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2); 194 NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize); 195 const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0); 196 197 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor); 198 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2); 199 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells); 200 const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1); 201 202 if (hasTensor(context, kInputToInputWeightsTensor)) { 203 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor); 204 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2); 205 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells); 206 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize); 207 } 208 209 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor); 210 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2); 211 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells); 212 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize); 213 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor); 214 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2); 215 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells); 216 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize); 217 218 if (hasTensor(context, kRecurrentToInputWeightsTensor)) { 219 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor); 220 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2); 221 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells); 222 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize); 223 } 224 225 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor); 226 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2); 227 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells); 228 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize); 229 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor); 230 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2); 231 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells); 232 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize); 233 234 // We make sure the input-gate's parameters are either both present (regular 235 // LSTM) or not at all (CIFG-LSTM). 236 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) && 237 hasTensor(context, kRecurrentToInputWeightsTensor)) || 238 (!hasTensor(context, kInputToInputWeightsTensor) && 239 !hasTensor(context, kRecurrentToInputWeightsTensor)); 240 NN_RET_CHECK(cifgWeightsAllOrNone); 241 242 if (hasTensor(context, kCellToInputWeightsTensor)) { 243 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); 244 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1); 245 NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells); 246 } 247 248 if (hasTensor(context, kCellToForgetWeightsTensor)) { 249 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); 250 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1); 251 NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells); 252 } 253 254 if (hasTensor(context, kCellToOutputWeightsTensor)) { 255 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); 256 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1); 257 NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells); 258 } 259 260 // Making sure the peephole weights are there all or none. 261 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor); 262 const bool peepholeWeightsAllOrNone = 263 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) && 264 hasTensor(context, kCellToForgetWeightsTensor) && 265 hasTensor(context, kCellToOutputWeightsTensor)) || 266 (!hasTensor(context, kCellToInputWeightsTensor) && 267 !hasTensor(context, kCellToForgetWeightsTensor) && 268 !hasTensor(context, kCellToOutputWeightsTensor)); 269 NN_RET_CHECK(peepholeWeightsAllOrNone); 270 271 if (!cifgUsed) { 272 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor)); 273 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor); 274 NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1); 275 NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells); 276 } else { 277 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor)) 278 << "Input gate bias tensor is present when CIFG is used"; 279 } 280 281 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor); 282 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1); 283 NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells); 284 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor); 285 NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1); 286 NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells); 287 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor); 288 NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1); 289 NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells); 290 291 if (hasTensor(context, kProjectionWeightsTensor)) { 292 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); 293 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2); 294 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize); 295 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells); 296 } 297 298 if (hasTensor(context, kProjectionBiasTensor)) { 299 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor); 300 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1); 301 NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize); 302 } 303 304 const Shape outputStateShape = context->getInputShape(kOutputStateInTensor); 305 NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2); 306 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize); 307 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize); 308 const Shape cellStateShape = context->getInputShape(kCellStateInTensor); 309 NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2); 310 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize); 311 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells); 312 313 if (hasTensor(context, kInputLayerNormWeightsTensor)) { 314 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor); 315 NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1); 316 NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells); 317 } 318 319 if (hasTensor(context, kForgetLayerNormWeightsTensor)) { 320 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor); 321 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1); 322 NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells); 323 } 324 325 if (hasTensor(context, kCellLayerNormWeightsTensor)) { 326 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor); 327 NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1); 328 NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells); 329 } 330 331 if (hasTensor(context, kOutputLayerNormWeightsTensor)) { 332 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor); 333 NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1); 334 NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells); 335 } 336 337 if (cifgUsed) { 338 NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor)) 339 << "Input layer norm weights tensor is present when CIFG is used"; 340 const bool layerNormWeightsAllOrNoneCifg = 341 (hasTensor(context, kForgetLayerNormWeightsTensor) && 342 hasTensor(context, kCellLayerNormWeightsTensor) && 343 hasTensor(context, kOutputLayerNormWeightsTensor)) || 344 (!hasTensor(context, kForgetLayerNormWeightsTensor) && 345 !hasTensor(context, kCellLayerNormWeightsTensor) && 346 !hasTensor(context, kOutputLayerNormWeightsTensor)); 347 NN_RET_CHECK(layerNormWeightsAllOrNoneCifg); 348 } else { 349 const bool layerNormWeightsAllOrNone = 350 (hasTensor(context, kInputLayerNormWeightsTensor) && 351 hasTensor(context, kForgetLayerNormWeightsTensor) && 352 hasTensor(context, kCellLayerNormWeightsTensor) && 353 hasTensor(context, kOutputLayerNormWeightsTensor)) || 354 (!hasTensor(context, kInputLayerNormWeightsTensor) && 355 !hasTensor(context, kForgetLayerNormWeightsTensor) && 356 !hasTensor(context, kCellLayerNormWeightsTensor) && 357 !hasTensor(context, kOutputLayerNormWeightsTensor)); 358 NN_RET_CHECK(layerNormWeightsAllOrNone); 359 } 360 361 Shape outputShape = context->getInputShape(kInputTensor); 362 outputShape.dimensions[2] = outputSize; 363 364 return context->setOutputShape(kOutputTensor, outputShape); 365 } 366 367 bool execute(IOperationExecutionContext* context) { 368 const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor)); 369 const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor)); 370 const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor); 371 const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize; 372 373 const OperandType inputType = context->getInputType(kInputTensor); 374 switch (inputType) { 375 case OperandType::TENSOR_FLOAT32: { 376 std::vector<float> outputStateOut(outputStateSize); 377 std::vector<float> cellStateOut(cellStateSize); 378 std::vector<float> scratchBuffer(scratchSize); 379 LSTMCell::LSTMEvalFloat32( 380 getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor), 381 context->getInputShape(kInputTensor), 382 context->getInputBuffer<float>(kInputToInputWeightsTensor), 383 context->getInputBuffer<float>(kInputToForgetWeightsTensor), 384 context->getInputBuffer<float>(kInputToCellWeightsTensor), 385 context->getInputBuffer<float>(kInputToOutputWeightsTensor), 386 context->getInputShape(kInputToOutputWeightsTensor), 387 context->getInputBuffer<float>(kRecurrentToInputWeightsTensor), 388 context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor), 389 context->getInputBuffer<float>(kRecurrentToCellWeightsTensor), 390 context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor), 391 context->getInputShape(kRecurrentToOutputWeightsTensor), 392 context->getInputBuffer<float>(kCellToInputWeightsTensor), 393 context->getInputBuffer<float>(kCellToForgetWeightsTensor), 394 context->getInputBuffer<float>(kCellToOutputWeightsTensor), 395 /*aux_input_buffer=*/nullptr, 396 /*aux_input_to_input_weights_buffer=*/nullptr, 397 /*aux_input_to_forget_weights_buffer=*/nullptr, 398 /*aux_input_to_cell_weights_buffer=*/nullptr, 399 /*aux_input_to_output_weights_buffer=*/nullptr, 400 context->getInputBuffer<float>(kInputGateBiasTensor), 401 context->getInputBuffer<float>(kForgetGateBiasTensor), 402 context->getInputBuffer<float>(kCellGateBiasTensor), 403 context->getInputBuffer<float>(kOutputGateBiasTensor), 404 context->getInputBuffer<float>(kProjectionWeightsTensor), 405 context->getInputBuffer<float>(kProjectionBiasTensor), 406 context->getInputBuffer<float>(kOutputStateInTensor), 407 context->getInputBuffer<float>(kCellStateInTensor), 408 context->getInputBuffer<float>(kInputLayerNormWeightsTensor), 409 context->getInputBuffer<float>(kForgetLayerNormWeightsTensor), 410 context->getInputBuffer<float>(kCellLayerNormWeightsTensor), 411 context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), 412 outputStateOut.data(), cellStateOut.data(), 413 context->getOutputBuffer<float>(kOutputTensor), scratchBuffer.data(), 414 isTimeMajor(context)); 415 } break; 416 case OperandType::TENSOR_FLOAT16: { 417 std::vector<_Float16> outputStateOut(outputStateSize); 418 std::vector<_Float16> cellStateOut(cellStateSize); 419 std::vector<_Float16> scratchBuffer(scratchSize); 420 LSTMCell::LSTMEvalFloat16( 421 getLSTMParams<_Float16>(context), 422 context->getInputBuffer<_Float16>(kInputTensor), 423 context->getInputShape(kInputTensor), 424 context->getInputBuffer<_Float16>(kInputToInputWeightsTensor), 425 context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor), 426 context->getInputBuffer<_Float16>(kInputToCellWeightsTensor), 427 context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor), 428 context->getInputShape(kInputToOutputWeightsTensor), 429 context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor), 430 context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor), 431 context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor), 432 context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor), 433 context->getInputShape(kRecurrentToOutputWeightsTensor), 434 context->getInputBuffer<_Float16>(kCellToInputWeightsTensor), 435 context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor), 436 context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor), 437 /*aux_input_buffer=*/nullptr, 438 /*aux_input_to_input_weights_buffer=*/nullptr, 439 /*aux_input_to_forget_weights_buffer=*/nullptr, 440 /*aux_input_to_cell_weights_buffer=*/nullptr, 441 /*aux_input_to_output_weights_buffer=*/nullptr, 442 context->getInputBuffer<_Float16>(kInputGateBiasTensor), 443 context->getInputBuffer<_Float16>(kForgetGateBiasTensor), 444 context->getInputBuffer<_Float16>(kCellGateBiasTensor), 445 context->getInputBuffer<_Float16>(kOutputGateBiasTensor), 446 context->getInputBuffer<_Float16>(kProjectionWeightsTensor), 447 context->getInputBuffer<_Float16>(kProjectionBiasTensor), 448 context->getInputBuffer<_Float16>(kOutputStateInTensor), 449 context->getInputBuffer<_Float16>(kCellStateInTensor), 450 context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor), 451 context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor), 452 context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor), 453 context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor), 454 outputStateOut.data(), cellStateOut.data(), 455 context->getOutputBuffer<_Float16>(kOutputTensor), scratchBuffer.data(), 456 isTimeMajor(context)); 457 } break; 458 default: { 459 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType); 460 return false; 461 } 462 } 463 return true; 464 } 465 466 } // namespace unidirectional_sequence_lstm 467 468 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_LSTM, "UNIDIRECTIONAL_SEQUENCE_LSTM", 469 unidirectional_sequence_lstm::validate, unidirectional_sequence_lstm::prepare, 470 unidirectional_sequence_lstm::execute, .allowOmittedOperand = true); 471 472 } // namespace nn 473 } // namespace android 474