1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "OperationResolver.h" 20 #include "RNN.h" 21 22 namespace android { 23 namespace nn { 24 namespace unidirectional_sequence_rnn { 25 26 constexpr uint32_t kNumInputs = 7; 27 constexpr uint32_t kInputTensor = 0; 28 constexpr uint32_t kWeightsTensor = 1; 29 constexpr uint32_t kRecurrentWeightsTensor = 2; 30 constexpr uint32_t kBiasTensor = 3; 31 constexpr uint32_t kHiddenStateTensor = 4; 32 constexpr uint32_t kActivationParam = 5; 33 constexpr uint32_t kTimeMajorParam = 6; 34 35 constexpr uint32_t kNumOutputs = 1; 36 constexpr uint32_t kOutputTensor = 0; 37 38 namespace { 39 40 template <typename T> 41 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) { 42 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0); 43 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1); 44 const uint32_t inputSize = getSizeOfDimension(inputShape, 2); 45 for (int f = 0; f < firstDimSize; ++f) { 46 for (int s = 0; s < secondDimSize; ++s) { 47 for (int i = 0; i < inputSize; ++i) { 48 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i; 49 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i; 50 output[outputIndex] = input[inputIndex]; 51 } 52 } 53 } 54 } 55 56 template <typename T> 57 bool executeTyped(IOperationExecutionContext* context) { 58 const T* input = context->getInputBuffer<T>(kInputTensor); 59 Shape inputShape = context->getInputShape(kInputTensor); 60 const T* weights = context->getInputBuffer<T>(kWeightsTensor); 61 Shape weightsShape = context->getInputShape(kWeightsTensor); 62 const T* recurrentWeights = context->getInputBuffer<T>(kRecurrentWeightsTensor); 63 Shape recurrentWeightsShape = context->getInputShape(kRecurrentWeightsTensor); 64 const T* bias = context->getInputBuffer<T>(kBiasTensor); 65 const T* hiddenState = context->getInputBuffer<T>(kHiddenStateTensor); 66 int32_t activation = context->getInputValue<int32_t>(kActivationParam); 67 68 T* output = context->getOutputBuffer<T>(kOutputTensor); 69 Shape outputShape = context->getOutputShape(kOutputTensor); 70 71 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam); 72 // If the input tensors are not in time major format, we transpose the first 73 // two dimensions, and set input and output pointers to temporary vectors 74 // which are transposed back after the RNN is applied. 75 std::vector<T> inputTransposed; 76 std::vector<T> outputTransposed; 77 if (!timeMajor) { 78 // Convert input and output to time major format. 79 inputTransposed.resize(getNumberOfElements(inputShape)); 80 outputTransposed.resize(getNumberOfElements(outputShape)); 81 transposeFirstTwoDims(input, inputShape, inputTransposed.data()); 82 input = inputTransposed.data(); 83 output = outputTransposed.data(); 84 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]); 85 std::swap(outputShape.dimensions[0], outputShape.dimensions[1]); 86 } 87 88 const uint32_t maxTime = getSizeOfDimension(inputShape, 0); 89 const uint32_t batchSize = getSizeOfDimension(inputShape, 1); 90 const uint32_t inputSize = getSizeOfDimension(inputShape, 2); 91 const uint32_t numUnits = getSizeOfDimension(weightsShape, 0); 92 93 // A shape at a fixed step (removed time dimension). 94 Shape fixedTimeInputShape = inputShape; 95 fixedTimeInputShape.dimensions.resize(2); 96 fixedTimeInputShape.dimensions[0] = inputShape.dimensions[1]; 97 fixedTimeInputShape.dimensions[1] = inputShape.dimensions[2]; 98 99 for (int i = 0; i < maxTime; ++i) { 100 RNN::RNNStep<T>(input, fixedTimeInputShape, hiddenState, bias, weights, weightsShape, 101 recurrentWeights, recurrentWeightsShape, activation, output); 102 input += batchSize * inputSize; 103 hiddenState = output; 104 output += batchSize * numUnits; 105 } 106 107 if (!timeMajor) { 108 transposeFirstTwoDims(outputTransposed.data(), outputShape, 109 context->getOutputBuffer<T>(kOutputTensor)); 110 } 111 return true; 112 } 113 114 } // namespace 115 116 bool validate(const IOperationValidationContext* context) { 117 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 118 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 119 OperandType inputType = context->getInputType(kInputTensor); 120 if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) { 121 LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: " 122 << toString(inputType); 123 return false; 124 } 125 NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, inputType, inputType, inputType, 126 OperandType::INT32, OperandType::INT32})); 127 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 128 return validateHalVersion(context, HalVersion::V1_2); 129 } 130 131 bool prepare(IOperationExecutionContext* context) { 132 Shape input = context->getInputShape(kInputTensor); 133 Shape weights = context->getInputShape(kWeightsTensor); 134 Shape recurrentWeights = context->getInputShape(kRecurrentWeightsTensor); 135 Shape bias = context->getInputShape(kBiasTensor); 136 Shape hiddenState = context->getInputShape(kHiddenStateTensor); 137 138 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam); 139 NN_RET_CHECK(timeMajor == 0 || timeMajor == 1); 140 const uint32_t batchSize = 141 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0); 142 const uint32_t maxTime = 143 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1); 144 const uint32_t numUnits = getSizeOfDimension(weights, 0); 145 const uint32_t inputSize = getSizeOfDimension(input, 2); 146 147 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3); 148 NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2); 149 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentWeights), 2); 150 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1); 151 NN_RET_CHECK_EQ(getNumberOfDimensions(hiddenState), 2); 152 153 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(weights, 1)); 154 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(bias, 0)); 155 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 0)); 156 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 1)); 157 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(hiddenState, 0)); 158 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(hiddenState, 1)); 159 160 Shape output = context->getOutputShape(kOutputTensor); 161 output.dimensions[0] = timeMajor ? maxTime : batchSize; 162 output.dimensions[1] = timeMajor ? batchSize : maxTime; 163 output.dimensions[2] = numUnits; 164 165 return context->setOutputShape(kOutputTensor, output); 166 } 167 168 bool execute(IOperationExecutionContext* context) { 169 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) { 170 executeTyped<_Float16>(context); 171 } else { 172 executeTyped<float>(context); 173 } 174 return true; 175 } 176 177 } // namespace unidirectional_sequence_rnn 178 179 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_RNN, "UNIDIRECTIONAL_SEQUENCE_RNN", 180 unidirectional_sequence_rnn::validate, unidirectional_sequence_rnn::prepare, 181 unidirectional_sequence_rnn::execute); 182 183 } // namespace nn 184 } // namespace android 185