Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "RNN.h"
     18 
     19 #include "CpuExecutor.h"
     20 #include "CpuOperationUtils.h"
     21 #include "HalInterfaces.h"
     22 
     23 #include "Tracing.h"
     24 
     25 namespace android {
     26 namespace nn {
     27 
     28 RNN::RNN(const Operation& operation,
     29          std::vector<RunTimeOperandInfo>& operands) {
     30   NNTRACE_TRANS("RNN::RNN");
     31   input_ = GetInput(operation, operands, kInputTensor);
     32   weights_ = GetInput(operation, operands, kWeightsTensor);
     33   recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
     34   hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
     35   bias_ = GetInput(operation, operands, kBiasTensor);
     36 
     37   activation_ = static_cast<ActivationFn>(
     38       getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
     39 
     40   hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
     41   output_ = GetOutput(operation, operands, kOutputTensor);
     42 }
     43 
     44 bool RNN::Prepare(const Operation &operation,
     45                   std::vector<RunTimeOperandInfo> &operands,
     46                   Shape *hiddenStateShape,
     47                   Shape *outputShape) {
     48   NNTRACE_TRANS("RNN::Prepare");
     49   // Check we have all the inputs and outputs we need.
     50   const int num_inputs = NumInputsWithValues(operation, operands);
     51   NN_CHECK(num_inputs == 5 || num_inputs == 6);
     52   NN_CHECK_EQ(NumOutputs(operation), 2);
     53 
     54   const RunTimeOperandInfo *input =
     55       GetInput(operation, operands, kInputTensor);
     56   const RunTimeOperandInfo *input_weights =
     57       GetInput(operation, operands, kWeightsTensor);
     58   const RunTimeOperandInfo *recurrent_weights =
     59       GetInput(operation, operands, kRecurrentWeightsTensor);
     60   const RunTimeOperandInfo *bias =
     61       GetInput(operation, operands, kBiasTensor);
     62 
     63   // Check all the parameters of tensor match within themselves and match the
     64   // input configuration.
     65   const uint32_t batch_size = SizeOfDimension(input, 0);
     66   const uint32_t num_units = SizeOfDimension(input_weights, 0);
     67   NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
     68   NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
     69   NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
     70   NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
     71 
     72   const Shape &inputShape = input->shape();
     73 
     74   // Resize state.
     75   hiddenStateShape->type = inputShape.type;
     76   hiddenStateShape->dimensions = { batch_size, num_units };
     77 
     78   // Resize output.
     79   outputShape->type = inputShape.type;
     80   outputShape->dimensions = { batch_size, num_units };
     81 
     82   return true;
     83 }
     84 
     85 bool RNN::Eval() {
     86     switch (input_->type) {
     87         case OperandType::TENSOR_FLOAT16: {
     88             RNNStep<_Float16>(reinterpret_cast<_Float16*>(input_->buffer), input_->shape(),
     89                               reinterpret_cast<_Float16*>(hidden_state_in_->buffer),
     90                               reinterpret_cast<_Float16*>(bias_->buffer),
     91                               reinterpret_cast<_Float16*>(weights_->buffer), weights_->shape(),
     92                               reinterpret_cast<_Float16*>(recurrent_weights_->buffer),
     93                               recurrent_weights_->shape(), activation_,
     94                               reinterpret_cast<_Float16*>(output_->buffer));
     95             memcpy(hidden_state_out_->buffer, output_->buffer,
     96                    sizeof(_Float16) * getNumberOfElements(output_->shape()));
     97             break;
     98         }
     99         case OperandType::TENSOR_FLOAT32: {
    100             RNNStep<float>(reinterpret_cast<float*>(input_->buffer), input_->shape(),
    101                            reinterpret_cast<float*>(hidden_state_in_->buffer),
    102                            reinterpret_cast<float*>(bias_->buffer),
    103                            reinterpret_cast<float*>(weights_->buffer), weights_->shape(),
    104                            reinterpret_cast<float*>(recurrent_weights_->buffer),
    105                            recurrent_weights_->shape(), activation_,
    106                            reinterpret_cast<float*>(output_->buffer));
    107             memcpy(hidden_state_out_->buffer, output_->buffer,
    108                    sizeof(float) * getNumberOfElements(output_->shape()));
    109             break;
    110         }
    111         default: {
    112             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
    113             return false;
    114         }
    115     }
    116     return true;
    117 }
    118 
    119 template <typename T>
    120 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData,
    121                   const T* biasData, const T* weightsData, const Shape& weightsShape,
    122                   const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
    123                   const int32_t activation, T* outputData) {
    124     NNTRACE_COMP("RNN::Eval");
    125 
    126     Shape dummyShape;
    127     uint32_t numUnits = weightsShape.dimensions[0];
    128     return RNNStep<T>(inputData, inputShape, /*auxInputData=*/nullptr, /*auxInputShape=*/dummyShape,
    129                       hiddenStateInputData, biasData, weightsData, weightsShape,
    130                       /*auxWeightsData=*/nullptr, /*auxWeightsShape=*/dummyShape,
    131                       recurrentWeightsData, recurrentWeightsShape, activation,
    132                       /*outputBatchStride=*/numUnits, /*outputBatchOffset=*/0, outputData);
    133 }
    134 
    135 // A more general version of the RNNStep function.
    136 // Auxiliary input is treated as if it was concatenated to a regular input and
    137 // the result was multiplied by the weights matrix which was also concatenated
    138 // with auxiliary weights.
    139 template <typename T>
    140 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData,
    141                   const Shape& auxInputShape, const T* hiddenStateInputData, const T* biasData,
    142                   const T* weightsData, const Shape& weightsShape, const T* auxWeightsData,
    143                   const Shape& auxWeightsShape, const T* recurrentWeightsData,
    144                   const Shape& recurrentWeightsShape, const int32_t activation,
    145                   const uint32_t outputBatchStride, const uint32_t outputBatchOffset, T* outputData,
    146                   T* hiddenStateOutput) {
    147     NNTRACE_COMP("RNN::Eval");
    148 
    149     const uint32_t batch_size = inputShape.dimensions[0];
    150     const uint32_t num_units = weightsShape.dimensions[0];
    151     const uint32_t input_size = inputShape.dimensions[1];
    152     const uint32_t input_weights_stride = weightsShape.dimensions[1];
    153     const uint32_t recurrent_weights_stride = recurrentWeightsShape.dimensions[1];
    154 
    155     uint32_t aux_input_size = 0;
    156     uint32_t aux_input_weights_stride = 0;
    157     bool hasAuxInput = (auxInputData != nullptr);
    158     if (hasAuxInput) {
    159         aux_input_size = auxInputShape.dimensions[1];
    160         aux_input_weights_stride = auxWeightsShape.dimensions[1];
    161     }
    162 
    163     // For each batch
    164     for (uint32_t b = 0; b < batch_size; b++) {
    165         // Initialize the pointer to input, output and bias.
    166         const T* input_ptr_batch = inputData + b * input_size;
    167         const T* hidden_state_in_ptr_batch = hiddenStateInputData + b * num_units;
    168         const T* aux_input_ptr_batch = nullptr;
    169         if (hasAuxInput) {
    170             aux_input_ptr_batch = auxInputData + b * aux_input_size;
    171         }
    172         T* output_ptr_batch = outputData + b * outputBatchStride + outputBatchOffset;
    173 
    174         // Initialize input_weights and recurrent_weights.
    175         const T* input_weights_ptr = weightsData;
    176         const T* recurrent_weights_ptr = recurrentWeightsData;
    177         const T* aux_input_weights_ptr = nullptr;
    178         if (hasAuxInput) {
    179             aux_input_weights_ptr = auxWeightsData;
    180         }
    181 
    182         // Output = bias
    183         for (uint32_t o = 0; o < num_units; o++) {
    184             output_ptr_batch[o] = biasData[o];
    185         }
    186 
    187         // Output += input * input_weights
    188         for (uint32_t o = 0; o < num_units; o++) {
    189             for (uint32_t i = 0; i < input_size; i++) {
    190                 output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
    191             }
    192             input_weights_ptr += input_weights_stride;
    193         }
    194 
    195         if (hasAuxInput) {
    196             // Output += aux_input * aux_input_weights
    197             for (uint32_t o = 0; o < num_units; o++) {
    198                 for (uint32_t i = 0; i < input_size; i++) {
    199                     output_ptr_batch[o] += aux_input_ptr_batch[i] * aux_input_weights_ptr[i];
    200                 }
    201                 aux_input_weights_ptr += aux_input_weights_stride;
    202             }
    203         }
    204 
    205         // Output += recurrent_weights * hidden_state
    206         for (uint32_t o = 0; o < num_units; o++) {
    207             for (uint32_t h = 0; h < num_units; h++) {
    208                 output_ptr_batch[o] += hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
    209             }
    210             recurrent_weights_ptr += recurrent_weights_stride;
    211         }
    212 
    213         // Output = activation(Output)
    214         for (uint32_t o = 0; o < num_units; o++) {
    215             output_ptr_batch[o] =
    216                     (ActivationFunctor(static_cast<ActivationFn>(activation)))(output_ptr_batch[o]);
    217             if (hiddenStateOutput != nullptr) {
    218                 *hiddenStateOutput = output_ptr_batch[o];
    219                 ++hiddenStateOutput;
    220             }
    221         }
    222     }
    223 
    224     return true;
    225 }
    226 
    227 }  // namespace nn
    228 }  // namespace android
    229