1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "RNN.h" 18 19 #include "CpuExecutor.h" 20 #include "CpuOperationUtils.h" 21 #include "HalInterfaces.h" 22 23 #include "Tracing.h" 24 25 namespace android { 26 namespace nn { 27 28 RNN::RNN(const Operation& operation, 29 std::vector<RunTimeOperandInfo>& operands) { 30 NNTRACE_TRANS("RNN::RNN"); 31 input_ = GetInput(operation, operands, kInputTensor); 32 weights_ = GetInput(operation, operands, kWeightsTensor); 33 recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor); 34 hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor); 35 bias_ = GetInput(operation, operands, kBiasTensor); 36 37 activation_ = static_cast<ActivationFn>( 38 getScalarData<int32_t>(operands[operation.inputs[kActivationParam]])); 39 40 hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor); 41 output_ = GetOutput(operation, operands, kOutputTensor); 42 } 43 44 bool RNN::Prepare(const Operation &operation, 45 std::vector<RunTimeOperandInfo> &operands, 46 Shape *hiddenStateShape, 47 Shape *outputShape) { 48 NNTRACE_TRANS("RNN::Prepare"); 49 // Check we have all the inputs and outputs we need. 50 const int num_inputs = NumInputsWithValues(operation, operands); 51 NN_CHECK(num_inputs == 5 || num_inputs == 6); 52 NN_CHECK_EQ(NumOutputs(operation), 2); 53 54 const RunTimeOperandInfo *input = 55 GetInput(operation, operands, kInputTensor); 56 const RunTimeOperandInfo *input_weights = 57 GetInput(operation, operands, kWeightsTensor); 58 const RunTimeOperandInfo *recurrent_weights = 59 GetInput(operation, operands, kRecurrentWeightsTensor); 60 const RunTimeOperandInfo *bias = 61 GetInput(operation, operands, kBiasTensor); 62 63 // Check all the parameters of tensor match within themselves and match the 64 // input configuration. 65 const uint32_t batch_size = SizeOfDimension(input, 0); 66 const uint32_t num_units = SizeOfDimension(input_weights, 0); 67 NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1)); 68 NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0)); 69 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0)); 70 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0)); 71 72 const Shape &inputShape = input->shape(); 73 74 // Resize state. 75 hiddenStateShape->type = inputShape.type; 76 hiddenStateShape->dimensions = { batch_size, num_units }; 77 78 // Resize output. 79 outputShape->type = inputShape.type; 80 outputShape->dimensions = { batch_size, num_units }; 81 82 return true; 83 } 84 85 bool RNN::Eval() { 86 switch (input_->type) { 87 case OperandType::TENSOR_FLOAT16: { 88 RNNStep<_Float16>(reinterpret_cast<_Float16*>(input_->buffer), input_->shape(), 89 reinterpret_cast<_Float16*>(hidden_state_in_->buffer), 90 reinterpret_cast<_Float16*>(bias_->buffer), 91 reinterpret_cast<_Float16*>(weights_->buffer), weights_->shape(), 92 reinterpret_cast<_Float16*>(recurrent_weights_->buffer), 93 recurrent_weights_->shape(), activation_, 94 reinterpret_cast<_Float16*>(output_->buffer)); 95 memcpy(hidden_state_out_->buffer, output_->buffer, 96 sizeof(_Float16) * getNumberOfElements(output_->shape())); 97 break; 98 } 99 case OperandType::TENSOR_FLOAT32: { 100 RNNStep<float>(reinterpret_cast<float*>(input_->buffer), input_->shape(), 101 reinterpret_cast<float*>(hidden_state_in_->buffer), 102 reinterpret_cast<float*>(bias_->buffer), 103 reinterpret_cast<float*>(weights_->buffer), weights_->shape(), 104 reinterpret_cast<float*>(recurrent_weights_->buffer), 105 recurrent_weights_->shape(), activation_, 106 reinterpret_cast<float*>(output_->buffer)); 107 memcpy(hidden_state_out_->buffer, output_->buffer, 108 sizeof(float) * getNumberOfElements(output_->shape())); 109 break; 110 } 111 default: { 112 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type); 113 return false; 114 } 115 } 116 return true; 117 } 118 119 template <typename T> 120 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData, 121 const T* biasData, const T* weightsData, const Shape& weightsShape, 122 const T* recurrentWeightsData, const Shape& recurrentWeightsShape, 123 const int32_t activation, T* outputData) { 124 NNTRACE_COMP("RNN::Eval"); 125 126 Shape dummyShape; 127 uint32_t numUnits = weightsShape.dimensions[0]; 128 return RNNStep<T>(inputData, inputShape, /*auxInputData=*/nullptr, /*auxInputShape=*/dummyShape, 129 hiddenStateInputData, biasData, weightsData, weightsShape, 130 /*auxWeightsData=*/nullptr, /*auxWeightsShape=*/dummyShape, 131 recurrentWeightsData, recurrentWeightsShape, activation, 132 /*outputBatchStride=*/numUnits, /*outputBatchOffset=*/0, outputData); 133 } 134 135 // A more general version of the RNNStep function. 136 // Auxiliary input is treated as if it was concatenated to a regular input and 137 // the result was multiplied by the weights matrix which was also concatenated 138 // with auxiliary weights. 139 template <typename T> 140 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData, 141 const Shape& auxInputShape, const T* hiddenStateInputData, const T* biasData, 142 const T* weightsData, const Shape& weightsShape, const T* auxWeightsData, 143 const Shape& auxWeightsShape, const T* recurrentWeightsData, 144 const Shape& recurrentWeightsShape, const int32_t activation, 145 const uint32_t outputBatchStride, const uint32_t outputBatchOffset, T* outputData, 146 T* hiddenStateOutput) { 147 NNTRACE_COMP("RNN::Eval"); 148 149 const uint32_t batch_size = inputShape.dimensions[0]; 150 const uint32_t num_units = weightsShape.dimensions[0]; 151 const uint32_t input_size = inputShape.dimensions[1]; 152 const uint32_t input_weights_stride = weightsShape.dimensions[1]; 153 const uint32_t recurrent_weights_stride = recurrentWeightsShape.dimensions[1]; 154 155 uint32_t aux_input_size = 0; 156 uint32_t aux_input_weights_stride = 0; 157 bool hasAuxInput = (auxInputData != nullptr); 158 if (hasAuxInput) { 159 aux_input_size = auxInputShape.dimensions[1]; 160 aux_input_weights_stride = auxWeightsShape.dimensions[1]; 161 } 162 163 // For each batch 164 for (uint32_t b = 0; b < batch_size; b++) { 165 // Initialize the pointer to input, output and bias. 166 const T* input_ptr_batch = inputData + b * input_size; 167 const T* hidden_state_in_ptr_batch = hiddenStateInputData + b * num_units; 168 const T* aux_input_ptr_batch = nullptr; 169 if (hasAuxInput) { 170 aux_input_ptr_batch = auxInputData + b * aux_input_size; 171 } 172 T* output_ptr_batch = outputData + b * outputBatchStride + outputBatchOffset; 173 174 // Initialize input_weights and recurrent_weights. 175 const T* input_weights_ptr = weightsData; 176 const T* recurrent_weights_ptr = recurrentWeightsData; 177 const T* aux_input_weights_ptr = nullptr; 178 if (hasAuxInput) { 179 aux_input_weights_ptr = auxWeightsData; 180 } 181 182 // Output = bias 183 for (uint32_t o = 0; o < num_units; o++) { 184 output_ptr_batch[o] = biasData[o]; 185 } 186 187 // Output += input * input_weights 188 for (uint32_t o = 0; o < num_units; o++) { 189 for (uint32_t i = 0; i < input_size; i++) { 190 output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; 191 } 192 input_weights_ptr += input_weights_stride; 193 } 194 195 if (hasAuxInput) { 196 // Output += aux_input * aux_input_weights 197 for (uint32_t o = 0; o < num_units; o++) { 198 for (uint32_t i = 0; i < input_size; i++) { 199 output_ptr_batch[o] += aux_input_ptr_batch[i] * aux_input_weights_ptr[i]; 200 } 201 aux_input_weights_ptr += aux_input_weights_stride; 202 } 203 } 204 205 // Output += recurrent_weights * hidden_state 206 for (uint32_t o = 0; o < num_units; o++) { 207 for (uint32_t h = 0; h < num_units; h++) { 208 output_ptr_batch[o] += hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h]; 209 } 210 recurrent_weights_ptr += recurrent_weights_stride; 211 } 212 213 // Output = activation(Output) 214 for (uint32_t o = 0; o < num_units; o++) { 215 output_ptr_batch[o] = 216 (ActivationFunctor(static_cast<ActivationFn>(activation)))(output_ptr_batch[o]); 217 if (hiddenStateOutput != nullptr) { 218 *hiddenStateOutput = output_ptr_batch[o]; 219 ++hiddenStateOutput; 220 } 221 } 222 } 223 224 return true; 225 } 226 227 } // namespace nn 228 } // namespace android 229