1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "OperationResolver.h" 20 #include "RNN.h" 21 22 namespace android { 23 namespace nn { 24 namespace bidirectional_sequence_rnn { 25 26 constexpr uint32_t kNumInputs = 15; 27 constexpr uint32_t kInputTensor = 0; 28 // Forward cell tensors 29 constexpr uint32_t kFwWeightsTensor = 1; 30 constexpr uint32_t kFwRecurrentWeightsTensor = 2; 31 constexpr uint32_t kFwBiasTensor = 3; 32 constexpr uint32_t kFwHiddenStateTensor = 4; 33 // Backward cell tensors 34 constexpr uint32_t kBwWeightsTensor = 5; 35 constexpr uint32_t kBwRecurrentWeightsTensor = 6; 36 constexpr uint32_t kBwBiasTensor = 7; 37 constexpr uint32_t kBwHiddenStateTensor = 8; 38 // Auxiliary inputs 39 constexpr uint32_t kAuxInputTensor = 9; // optional 40 constexpr uint32_t kFwAuxWeightsTensor = 10; // optional 41 constexpr uint32_t kBwAuxWeightsTensor = 11; // optional 42 // Cell parameters 43 constexpr uint32_t kActivationParam = 12; 44 constexpr uint32_t kTimeMajorParam = 13; 45 constexpr uint32_t kMergeOutputsParam = 14; 46 47 constexpr uint32_t kFwOutputTensor = 0; 48 constexpr uint32_t kBwOutputTensor = 1; // Only if mergeOutputs parameter is false 49 50 namespace { 51 52 template <typename T> 53 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) { 54 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0); 55 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1); 56 const uint32_t inputSize = getSizeOfDimension(inputShape, 2); 57 for (int f = 0; f < firstDimSize; ++f) { 58 for (int s = 0; s < secondDimSize; ++s) { 59 for (int i = 0; i < inputSize; ++i) { 60 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i; 61 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i; 62 output[outputIndex] = input[inputIndex]; 63 } 64 } 65 } 66 } 67 68 Shape removeFirstDim(const Shape& input) { 69 Shape output = input; 70 output.dimensions.resize(input.dimensions.size() - 1); 71 for (int i = 0; i < input.dimensions.size() - 1; ++i) { 72 output.dimensions[i] = input.dimensions[i + 1]; 73 } 74 return output; 75 } 76 77 template <typename T> 78 bool executeTyped(IOperationExecutionContext* context) { 79 const T* input = context->getInputBuffer<T>(kInputTensor); 80 Shape inputShape = context->getInputShape(kInputTensor); 81 82 const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor); 83 Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor); 84 const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor); 85 Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor); 86 const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor); 87 const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor); 88 89 const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor); 90 Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor); 91 const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor); 92 Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor); 93 const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor); 94 const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor); 95 96 const T* auxInput = nullptr; 97 const T* fwAuxWeights = nullptr; 98 const T* bwAuxWeights = nullptr; 99 const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor); 100 if (hasAuxInputs) { 101 auxInput = context->getInputBuffer<T>(kAuxInputTensor); 102 fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor); 103 bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor); 104 } 105 Shape auxInputShape = context->getInputShape(kAuxInputTensor); 106 Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor); 107 Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor); 108 109 int32_t activation = context->getInputValue<int32_t>(kActivationParam); 110 int32_t timeMajor = context->getInputValue<bool>(kTimeMajorParam); 111 int32_t mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam); 112 113 T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor); 114 Shape fwOutputShape = context->getOutputShape(kFwOutputTensor); 115 T* bwOutput = nullptr; 116 Shape bwOutputShape; 117 if (!mergeOutputs) { 118 bwOutputShape = context->getOutputShape(kBwOutputTensor); 119 bwOutput = context->getOutputBuffer<T>(kBwOutputTensor); 120 } 121 122 // If the input tensors are not in time major format, we transpose the first 123 // two dimensions, and set input and output pointers to temporary vectors 124 // which are transposed back after the RNN is applied. 125 std::vector<T> inputTransposed; 126 std::vector<T> auxInputTransposed; 127 std::vector<T> fwOutputTransposed; 128 std::vector<T> bwOutputTransposed; 129 if (!timeMajor) { 130 // First, resize temporary buffers to accommodate for transposed tensors. 131 inputTransposed.resize(getNumberOfElements(inputShape)); 132 if (hasAuxInputs) { 133 auxInputTransposed.resize(getNumberOfElements(auxInputShape)); 134 } 135 fwOutputTransposed.resize(getNumberOfElements(fwOutputShape)); 136 if (!mergeOutputs) { 137 bwOutputTransposed.resize(getNumberOfElements(bwOutputShape)); 138 } 139 140 // Transpose the input tensors. 141 transposeFirstTwoDims(input, inputShape, inputTransposed.data()); 142 if (hasAuxInputs) { 143 transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data()); 144 } 145 146 // Change input and output pointers to the temporary buffers. 147 input = inputTransposed.data(); 148 if (hasAuxInputs) { 149 auxInput = auxInputTransposed.data(); 150 } 151 fwOutput = fwOutputTransposed.data(); 152 if (!mergeOutputs) { 153 bwOutput = bwOutputTransposed.data(); 154 } 155 156 // Swap the first two dimensions in the Shapes to reflect the 157 // transposition. 158 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]); 159 if (hasAuxInputs) { 160 std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]); 161 } 162 std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]); 163 if (!mergeOutputs) { 164 std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]); 165 } 166 } 167 168 const uint32_t maxTime = getSizeOfDimension(inputShape, 0); 169 const uint32_t batchSize = getSizeOfDimension(inputShape, 1); 170 const uint32_t inputSize = getSizeOfDimension(inputShape, 2); 171 uint32_t auxInputSize = 0; 172 if (hasAuxInputs) { 173 auxInputSize = getSizeOfDimension(auxInputShape, 2); 174 } 175 const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0); 176 const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0); 177 178 Shape fixedTimeInputShape = removeFirstDim(inputShape); 179 Shape fixedTimeAuxInputShape = auxInputShape; 180 if (hasAuxInputs) { 181 fixedTimeAuxInputShape = removeFirstDim(auxInputShape); 182 } 183 184 // Create an additional buffer to store a hidden state between steps. 185 std::vector<T> tempHiddenState(batchSize * fwNumUnits); 186 // Forward pass 187 for (int i = 0; i < maxTime; ++i) { 188 const T* inputBatchPtr = input + i * batchSize * inputSize; 189 const T* auxInputBatchPtr = nullptr; 190 if (hasAuxInputs) { 191 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize; 192 } 193 const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits; 194 T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride; 195 196 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr, 197 fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape, 198 fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights, 199 fwRecurrentWeightsShape, activation, fwOutputBatchStride, 200 /*outputBatchOffset=*/0, fwOutputBatchPtr, tempHiddenState.data()); 201 202 fwHiddenState = tempHiddenState.data(); 203 } 204 205 tempHiddenState.resize(batchSize * bwNumUnits); 206 // Backward pass 207 for (int i = maxTime - 1; i >= 0; --i) { 208 const T* inputBatchPtr = input + i * batchSize * inputSize; 209 const T* auxInputBatchPtr = nullptr; 210 if (hasAuxInputs) { 211 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize; 212 } 213 T* bwOutputBatchPtr; 214 uint32_t bwOutputBatchOffset = 0; 215 uint32_t bwOutputBatchStride; 216 if (mergeOutputs) { 217 bwOutputBatchStride = fwNumUnits + bwNumUnits; 218 bwOutputBatchOffset = fwNumUnits; 219 bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride; 220 } else { 221 bwOutputBatchStride = bwNumUnits; 222 bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride; 223 } 224 225 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr, 226 fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape, 227 bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights, 228 bwRecurrentWeightsShape, activation, bwOutputBatchStride, 229 bwOutputBatchOffset, bwOutputBatchPtr, tempHiddenState.data()); 230 231 bwHiddenState = tempHiddenState.data(); 232 } 233 234 // If the inputs were in batch major format, transpose data in temporary 235 // buffers and write to the output(s). 236 if (!timeMajor) { 237 transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape, 238 context->getOutputBuffer<T>(kFwOutputTensor)); 239 if (!mergeOutputs) { 240 transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape, 241 context->getOutputBuffer<T>(kBwOutputTensor)); 242 } 243 } 244 return true; 245 } 246 247 } // namespace 248 249 bool validate(const IOperationValidationContext* context) { 250 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 251 // Exact number is dependent on the mergeOutputs parameter and checked 252 // during preparation. 253 NN_RET_CHECK(context->getNumOutputs() == 1 || context->getNumOutputs() == 2); 254 OperandType inputType = context->getInputType(kInputTensor); 255 if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) { 256 LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: " 257 << toString(inputType); 258 return false; 259 } 260 NN_RET_CHECK(validateInputTypes( 261 context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType, 262 inputType, inputType, inputType, inputType, inputType, OperandType::INT32, 263 OperandType::BOOL, OperandType::BOOL})); 264 if (context->getNumOutputs() == 1) { 265 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 266 } else { 267 NN_RET_CHECK(validateOutputTypes(context, {inputType, inputType})); 268 } 269 return validateHalVersion(context, HalVersion::V1_2); 270 } 271 272 bool prepare(IOperationExecutionContext* context) { 273 int32_t mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam); 274 if (mergeOutputs) { 275 NN_RET_CHECK_EQ(context->getNumOutputs(), 1); 276 } else { 277 NN_RET_CHECK_EQ(context->getNumOutputs(), 2); 278 } 279 280 // Check that none of the required inputs are omitted. 281 const std::vector<int> requiredInputs = { 282 kInputTensor, kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor, 283 kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor, 284 kBwHiddenStateTensor, kActivationParam, kTimeMajorParam, kMergeOutputsParam, 285 }; 286 for (const int requiredInput : requiredInputs) { 287 NN_RET_CHECK(!context->isOmittedInput(requiredInput)) 288 << "required input " << requiredInput << " is omitted"; 289 } 290 291 Shape input = context->getInputShape(kInputTensor); 292 Shape fwWeights = context->getInputShape(kFwWeightsTensor); 293 Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor); 294 Shape fwBias = context->getInputShape(kFwBiasTensor); 295 Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor); 296 Shape bwWeights = context->getInputShape(kBwWeightsTensor); 297 Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor); 298 Shape bwBias = context->getInputShape(kBwBiasTensor); 299 Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor); 300 301 Shape auxInput = context->getInputShape(kAuxInputTensor); 302 Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor); 303 Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor); 304 305 const bool auxInputsAllOrNone = (context->isOmittedInput(kAuxInputTensor) && 306 context->isOmittedInput(kFwAuxWeightsTensor) && 307 context->isOmittedInput(kBwAuxWeightsTensor)) || 308 (!context->isOmittedInput(kAuxInputTensor) && 309 !context->isOmittedInput(kFwAuxWeightsTensor) && 310 !context->isOmittedInput(kBwAuxWeightsTensor)); 311 NN_RET_CHECK(auxInputsAllOrNone); 312 const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor); 313 314 int32_t timeMajor = context->getInputValue<bool>(kTimeMajorParam); 315 const uint32_t batchSize = 316 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0); 317 const uint32_t maxTime = 318 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1); 319 const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0); 320 const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0); 321 const uint32_t inputSize = getSizeOfDimension(input, 2); 322 323 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3); 324 NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2); 325 NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2); 326 NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1); 327 NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2); 328 NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2); 329 NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2); 330 NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1); 331 NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2); 332 333 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1)); 334 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0)); 335 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0)); 336 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1)); 337 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0)); 338 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1)); 339 340 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1)); 341 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0)); 342 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0)); 343 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1)); 344 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0)); 345 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1)); 346 347 if (hasAuxInputs) { 348 NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3); 349 NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2); 350 NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2); 351 352 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0)); 353 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1)); 354 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits); 355 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2)); 356 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits); 357 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2)); 358 } 359 360 Shape fwOutput = context->getOutputShape(kFwOutputTensor); 361 fwOutput.dimensions.resize(3); 362 fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize; 363 fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime; 364 fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits; 365 NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput)); 366 if (!mergeOutputs) { 367 Shape bwOutput = context->getOutputShape(kBwOutputTensor); 368 bwOutput.dimensions.resize(3); 369 bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize; 370 bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime; 371 bwOutput.dimensions[2] = bwNumUnits; 372 NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput)); 373 } 374 375 return true; 376 } 377 378 bool execute(IOperationExecutionContext* context) { 379 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) { 380 executeTyped<_Float16>(context); 381 } else { 382 executeTyped<float>(context); 383 } 384 return true; 385 } 386 387 } // namespace bidirectional_sequence_rnn 388 389 NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN", 390 bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare, 391 bidirectional_sequence_rnn::execute, .allowOmittedOperand = true); 392 393 } // namespace nn 394 } // namespace android 395