1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "BidirectionalSequenceLSTM.h" 18 19 #include "CpuExecutor.h" 20 #include "CpuOperationUtils.h" 21 #include "HalInterfaces.h" 22 #include "OperationsUtils.h" 23 24 #include "Tracing.h" 25 26 namespace android { 27 namespace nn { 28 29 namespace { 30 31 template <typename T> 32 inline T* GetBuffer(RunTimeOperandInfo* operand) { 33 return reinterpret_cast<T*>(operand->buffer); 34 } 35 36 template <typename T> 37 inline const T* GetBuffer(const RunTimeOperandInfo* operand) { 38 return reinterpret_cast<const T*>(operand->buffer); 39 } 40 41 template <typename T> 42 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) { 43 return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr; 44 } 45 46 } // anonymous namespace 47 48 BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation, 49 std::vector<RunTimeOperandInfo>& operands) { 50 input_ = GetInput(operation, operands, kInputTensor); 51 52 fw_input_to_input_weights_ = 53 GetInput(operation, operands, kFwInputToInputWeightsTensor); // optional 54 fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor); 55 fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor); 56 fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor); 57 58 fw_recurrent_to_input_weights_ = 59 GetInput(operation, operands, kFwRecurrentToInputWeightsTensor); // optional 60 fw_recurrent_to_forget_weights_ = 61 GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor); 62 fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor); 63 fw_recurrent_to_output_weights_ = 64 GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor); 65 66 fw_cell_to_input_weights_ = 67 GetInput(operation, operands, kFwCellToInputWeightsTensor); // optional 68 fw_cell_to_forget_weights_ = 69 GetInput(operation, operands, kFwCellToForgetWeightsTensor); // optional 70 fw_cell_to_output_weights_ = 71 GetInput(operation, operands, kFwCellToOutputWeightsTensor); // optional 72 73 fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor); 74 fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor); 75 fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor); 76 fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor); 77 78 fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor); // optional 79 fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor); // optional 80 81 fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor); 82 fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor); 83 84 bw_input_to_input_weights_ = 85 GetInput(operation, operands, kBwInputToInputWeightsTensor); // optional 86 bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor); 87 bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor); 88 bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor); 89 90 bw_recurrent_to_input_weights_ = 91 GetInput(operation, operands, kBwRecurrentToInputWeightsTensor); // optional 92 bw_recurrent_to_forget_weights_ = 93 GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor); 94 bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor); 95 bw_recurrent_to_output_weights_ = 96 GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor); 97 98 bw_cell_to_input_weights_ = 99 GetInput(operation, operands, kBwCellToInputWeightsTensor); // optional 100 bw_cell_to_forget_weights_ = 101 GetInput(operation, operands, kBwCellToForgetWeightsTensor); // optional 102 bw_cell_to_output_weights_ = 103 GetInput(operation, operands, kBwCellToOutputWeightsTensor); // optional 104 105 bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor); 106 bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor); 107 bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor); 108 bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor); 109 110 bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor); // optional 111 bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor); // optional 112 113 bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor); 114 bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor); 115 116 aux_input_ = GetInput(operation, operands, kAuxInputTensor); 117 fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor); 118 fw_aux_input_to_forget_weights_ = 119 GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor); 120 fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor); 121 fw_aux_input_to_output_weights_ = 122 GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor); 123 bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor); 124 bw_aux_input_to_forget_weights_ = 125 GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor); 126 bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor); 127 bw_aux_input_to_output_weights_ = 128 GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor); 129 130 fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor); 131 fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor); 132 fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor); 133 fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor); 134 bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor); 135 bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor); 136 bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor); 137 bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor); 138 139 params_.activation = static_cast<TfLiteFusedActivation>( 140 getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam))); 141 if (input_->type == OperandType::TENSOR_FLOAT32) { 142 params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam)); 143 params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam)); 144 } else { 145 params_.cell_clip = static_cast<float>( 146 getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam))); 147 params_.proj_clip = static_cast<float>( 148 getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam))); 149 } 150 params_.merge_outputs = getScalarData<bool>(*GetInput(operation, operands, kMergeOutputsParam)); 151 params_.time_major = getScalarData<bool>(*GetInput(operation, operands, kTimeMajorParam)); 152 params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_); 153 154 fw_output_ = GetOutput(operation, operands, kFwOutputTensor); 155 if (!params_.merge_outputs) { 156 bw_output_ = GetOutput(operation, operands, kBwOutputTensor); 157 } 158 } 159 160 bool BidirectionalSequenceLSTM::Prepare(const Operation& operation, 161 std::vector<RunTimeOperandInfo>& operands, 162 Shape* fwOutputShape, Shape* bwOutputShape) { 163 // Inferring batch size, number of outputs and number of cells from the 164 // input tensors. 165 NN_CHECK(NumDimensions(input_) == 3); 166 const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1); 167 const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0); 168 const uint32_t n_input = SizeOfDimension(input_, 2); 169 170 const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0); 171 NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2); 172 NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_input); 173 174 NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2); 175 NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell); 176 const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1); 177 178 // Check that input tensor dimensions matches with each other. 179 if (!LSTMCell::CheckInputTensorDimensions( 180 input_, fw_input_to_input_weights_, fw_input_to_forget_weights_, 181 fw_input_to_cell_weights_, fw_input_to_output_weights_, 182 fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_, 183 fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_, 184 fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_, 185 fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_, 186 fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_, 187 fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_, 188 fw_output_layer_norm_weights_, n_input, n_fw_output, n_fw_cell, ¶ms_)) { 189 return false; 190 } 191 192 const bool aux_inputs_all_or_none = 193 (!IsNullInput(aux_input_) && !IsNullInput(fw_aux_input_to_cell_weights_) && 194 !IsNullInput(fw_aux_input_to_forget_weights_) && 195 !IsNullInput(fw_aux_input_to_output_weights_) && 196 !IsNullInput(bw_aux_input_to_cell_weights_) && 197 !IsNullInput(bw_aux_input_to_forget_weights_) && 198 !IsNullInput(bw_aux_input_to_output_weights_)) || 199 (IsNullInput(fw_aux_input_to_cell_weights_) && 200 IsNullInput(fw_aux_input_to_forget_weights_) && 201 IsNullInput(fw_aux_input_to_output_weights_) && 202 IsNullInput(bw_aux_input_to_cell_weights_) && 203 IsNullInput(bw_aux_input_to_forget_weights_) && 204 IsNullInput(bw_aux_input_to_output_weights_)); 205 NN_CHECK(aux_inputs_all_or_none); 206 if (!IsNullInput(aux_input_)) { 207 // Check that aux_input has the same dimensions (except last) as the input. 208 NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]); 209 NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]); 210 } 211 212 const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0); 213 NN_CHECK_EQ(NumDimensions(bw_input_to_output_weights_), 2); 214 NN_CHECK_EQ(SizeOfDimension(bw_input_to_output_weights_, 1), n_input); 215 216 NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2); 217 NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell); 218 const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1); 219 220 const Shape& inputShape = input_->shape(); 221 fwOutputShape->type = inputShape.type; 222 fwOutputShape->offset = inputShape.offset; 223 fwOutputShape->scale = inputShape.scale; 224 fwOutputShape->dimensions.resize(3); 225 fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch; 226 fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time; 227 fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output; 228 229 // Check that input tensor dimensions matches with each other. 230 if (!LSTMCell::CheckInputTensorDimensions( 231 input_, bw_input_to_input_weights_, bw_input_to_forget_weights_, 232 bw_input_to_cell_weights_, bw_input_to_output_weights_, 233 bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_, 234 bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_, 235 bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_, 236 bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_, 237 bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_, 238 bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_, 239 bw_output_layer_norm_weights_, n_input, n_bw_output, n_bw_cell, ¶ms_)) { 240 return false; 241 } 242 243 if (!params_.merge_outputs) { 244 bwOutputShape->type = inputShape.type; 245 bwOutputShape->offset = inputShape.offset; 246 bwOutputShape->scale = inputShape.scale; 247 bwOutputShape->dimensions.resize(3); 248 bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch; 249 bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time; 250 bwOutputShape->dimensions[2] = n_bw_output; 251 } 252 253 if (params_.use_cifg) { 254 fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3}; 255 bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3}; 256 } else { 257 fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4}; 258 bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4}; 259 } 260 fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type; 261 fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset; 262 fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale; 263 264 return true; 265 } 266 267 bool BidirectionalSequenceLSTM::Eval() { 268 const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1); 269 const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1); 270 std::vector<uint32_t> fw_output_dims = input_->shape().dimensions; 271 fw_output_dims[2] = n_fw_output; 272 std::vector<uint32_t> bw_output_dims = fw_output_dims; 273 bw_output_dims[2] = n_bw_output; 274 const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2]; 275 const uint32_t n_output_elements = 276 fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]); 277 278 switch (input_->type) { 279 case OperandType::TENSOR_FLOAT32: { 280 std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_)); 281 const bool kForwardSequence = true; 282 LSTMCell::LSTMEvalFloat32( 283 params_, GetBuffer<const float>(input_), input_->shape(), 284 GetBuffer<const float>(fw_input_to_input_weights_), 285 GetBuffer<const float>(fw_input_to_forget_weights_), 286 GetBuffer<const float>(fw_input_to_cell_weights_), 287 GetBuffer<const float>(fw_input_to_output_weights_), 288 fw_input_to_output_weights_->shape(), 289 GetBuffer<const float>(fw_recurrent_to_input_weights_), 290 GetBuffer<const float>(fw_recurrent_to_forget_weights_), 291 GetBuffer<const float>(fw_recurrent_to_cell_weights_), 292 GetBuffer<const float>(fw_recurrent_to_output_weights_), 293 fw_recurrent_to_output_weights_->shape(), 294 GetBuffer<const float>(fw_cell_to_input_weights_), 295 GetBuffer<const float>(fw_cell_to_forget_weights_), 296 GetBuffer<const float>(fw_cell_to_output_weights_), 297 GetOptionalBuffer<const float>(aux_input_), 298 GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_), 299 GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_), 300 GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_), 301 GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_), 302 GetBuffer<const float>(fw_input_gate_bias_), 303 GetBuffer<const float>(fw_forget_gate_bias_), 304 GetBuffer<const float>(fw_cell_bias_), 305 GetBuffer<const float>(fw_output_gate_bias_), 306 GetBuffer<const float>(fw_projection_weights_), 307 GetBuffer<const float>(fw_projection_bias_), 308 GetBuffer<const float>(fw_activation_state_), 309 GetBuffer<const float>(fw_cell_state_), 310 GetOptionalBuffer<const float>(fw_input_layer_norm_weights_), 311 GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_), 312 GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_), 313 GetOptionalBuffer<const float>(fw_output_layer_norm_weights_), 314 GetBuffer<float>(fw_activation_state_), GetBuffer<float>(fw_cell_state_), 315 GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major, 316 kForwardSequence); 317 318 std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_)); 319 const bool kBackwardSequence = false; 320 LSTMCell::LSTMEvalFloat32( 321 params_, GetBuffer<const float>(input_), input_->shape(), 322 GetBuffer<const float>(bw_input_to_input_weights_), 323 GetBuffer<const float>(bw_input_to_forget_weights_), 324 GetBuffer<const float>(bw_input_to_cell_weights_), 325 GetBuffer<const float>(bw_input_to_output_weights_), 326 bw_input_to_output_weights_->shape(), 327 GetBuffer<const float>(bw_recurrent_to_input_weights_), 328 GetBuffer<const float>(bw_recurrent_to_forget_weights_), 329 GetBuffer<const float>(bw_recurrent_to_cell_weights_), 330 GetBuffer<const float>(bw_recurrent_to_output_weights_), 331 bw_recurrent_to_output_weights_->shape(), 332 GetBuffer<const float>(bw_cell_to_input_weights_), 333 GetBuffer<const float>(bw_cell_to_forget_weights_), 334 GetBuffer<const float>(bw_cell_to_output_weights_), 335 GetOptionalBuffer<const float>(aux_input_), 336 GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_), 337 GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_), 338 GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_), 339 GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_), 340 GetBuffer<const float>(bw_input_gate_bias_), 341 GetBuffer<const float>(bw_forget_gate_bias_), 342 GetBuffer<const float>(bw_cell_bias_), 343 GetBuffer<const float>(bw_output_gate_bias_), 344 GetBuffer<const float>(bw_projection_weights_), 345 GetBuffer<const float>(bw_projection_bias_), 346 GetBuffer<const float>(bw_activation_state_), 347 GetBuffer<const float>(bw_cell_state_), 348 GetOptionalBuffer<const float>(bw_input_layer_norm_weights_), 349 GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_), 350 GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_), 351 GetOptionalBuffer<const float>(bw_output_layer_norm_weights_), 352 GetBuffer<float>(bw_activation_state_), GetBuffer<float>(bw_cell_state_), 353 params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements 354 : GetBuffer<float>(bw_output_), 355 bw_scratch_buffer.data(), params_.time_major, kBackwardSequence); 356 if (params_.merge_outputs) { 357 std::vector<float> temp(n_output_elements); 358 mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims, 359 GetBuffer<float>(fw_output_) + n_fw_output_elements, 360 bw_output_dims, temp.data()); 361 std::copy(temp.data(), temp.data() + n_output_elements, 362 GetBuffer<float>(fw_output_)); 363 } 364 } break; 365 case OperandType::TENSOR_FLOAT16: { 366 std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_)); 367 const bool kForwardSequence = true; 368 LSTMCell::LSTMEvalFloat16( 369 params_, GetBuffer<const _Float16>(input_), input_->shape(), 370 GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_), 371 GetBuffer<const _Float16>(fw_input_to_forget_weights_), 372 GetBuffer<const _Float16>(fw_input_to_cell_weights_), 373 GetBuffer<const _Float16>(fw_input_to_output_weights_), 374 fw_input_to_output_weights_->shape(), 375 GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_), 376 GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_), 377 GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_), 378 GetBuffer<const _Float16>(fw_recurrent_to_output_weights_), 379 fw_recurrent_to_output_weights_->shape(), 380 GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_), 381 GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_), 382 GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_), 383 GetOptionalBuffer<const _Float16>(aux_input_), 384 GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_), 385 GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_), 386 GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_), 387 GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_), 388 GetOptionalBuffer<const _Float16>(fw_input_gate_bias_), 389 GetBuffer<const _Float16>(fw_forget_gate_bias_), 390 GetBuffer<const _Float16>(fw_cell_bias_), 391 GetBuffer<const _Float16>(fw_output_gate_bias_), 392 GetOptionalBuffer<const _Float16>(fw_projection_weights_), 393 GetOptionalBuffer<const _Float16>(fw_projection_bias_), 394 GetBuffer<const _Float16>(fw_activation_state_), 395 GetBuffer<const _Float16>(fw_cell_state_), 396 GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_), 397 GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_), 398 GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_), 399 GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_), 400 GetBuffer<_Float16>(fw_activation_state_), GetBuffer<_Float16>(fw_cell_state_), 401 GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major, 402 kForwardSequence); 403 404 std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_)); 405 const bool kBackwardSequence = false; 406 LSTMCell::LSTMEvalFloat16( 407 params_, GetBuffer<const _Float16>(input_), input_->shape(), 408 GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_), 409 GetBuffer<const _Float16>(bw_input_to_forget_weights_), 410 GetBuffer<const _Float16>(bw_input_to_cell_weights_), 411 GetBuffer<const _Float16>(bw_input_to_output_weights_), 412 bw_input_to_output_weights_->shape(), 413 GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_), 414 GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_), 415 GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_), 416 GetBuffer<const _Float16>(bw_recurrent_to_output_weights_), 417 bw_recurrent_to_output_weights_->shape(), 418 GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_), 419 GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_), 420 GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_), 421 GetOptionalBuffer<const _Float16>(aux_input_), 422 GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_), 423 GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_), 424 GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_), 425 GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_), 426 GetOptionalBuffer<const _Float16>(bw_input_gate_bias_), 427 GetBuffer<const _Float16>(bw_forget_gate_bias_), 428 GetBuffer<const _Float16>(bw_cell_bias_), 429 GetBuffer<const _Float16>(bw_output_gate_bias_), 430 GetOptionalBuffer<const _Float16>(bw_projection_weights_), 431 GetOptionalBuffer<const _Float16>(bw_projection_bias_), 432 GetBuffer<const _Float16>(bw_activation_state_), 433 GetBuffer<const _Float16>(bw_cell_state_), 434 GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_), 435 GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_), 436 GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_), 437 GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_), 438 GetBuffer<_Float16>(bw_activation_state_), GetBuffer<_Float16>(bw_cell_state_), 439 params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements 440 : GetBuffer<_Float16>(bw_output_), 441 bw_scratch_buffer.data(), params_.time_major, kBackwardSequence); 442 if (params_.merge_outputs) { 443 std::vector<_Float16> temp(n_output_elements); 444 mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims, 445 GetBuffer<_Float16>(fw_output_) + n_fw_output_elements, 446 bw_output_dims, temp.data()); 447 std::copy(temp.data(), temp.data() + n_output_elements, 448 GetBuffer<_Float16>(fw_output_)); 449 } 450 } break; 451 default: { 452 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type); 453 return false; 454 } 455 } 456 return true; 457 } 458 459 } // namespace nn 460 } // namespace android 461