Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "OperationResolver.h"
     20 #include "RNN.h"
     21 
     22 namespace android {
     23 namespace nn {
     24 namespace unidirectional_sequence_rnn {
     25 
     26 constexpr uint32_t kNumInputs = 7;
     27 constexpr uint32_t kInputTensor = 0;
     28 constexpr uint32_t kWeightsTensor = 1;
     29 constexpr uint32_t kRecurrentWeightsTensor = 2;
     30 constexpr uint32_t kBiasTensor = 3;
     31 constexpr uint32_t kHiddenStateTensor = 4;
     32 constexpr uint32_t kActivationParam = 5;
     33 constexpr uint32_t kTimeMajorParam = 6;
     34 
     35 constexpr uint32_t kNumOutputs = 1;
     36 constexpr uint32_t kOutputTensor = 0;
     37 
     38 namespace {
     39 
     40 template <typename T>
     41 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
     42     const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
     43     const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
     44     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
     45     for (int f = 0; f < firstDimSize; ++f) {
     46         for (int s = 0; s < secondDimSize; ++s) {
     47             for (int i = 0; i < inputSize; ++i) {
     48                 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
     49                 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
     50                 output[outputIndex] = input[inputIndex];
     51             }
     52         }
     53     }
     54 }
     55 
     56 template <typename T>
     57 bool executeTyped(IOperationExecutionContext* context) {
     58     const T* input = context->getInputBuffer<T>(kInputTensor);
     59     Shape inputShape = context->getInputShape(kInputTensor);
     60     const T* weights = context->getInputBuffer<T>(kWeightsTensor);
     61     Shape weightsShape = context->getInputShape(kWeightsTensor);
     62     const T* recurrentWeights = context->getInputBuffer<T>(kRecurrentWeightsTensor);
     63     Shape recurrentWeightsShape = context->getInputShape(kRecurrentWeightsTensor);
     64     const T* bias = context->getInputBuffer<T>(kBiasTensor);
     65     const T* hiddenState = context->getInputBuffer<T>(kHiddenStateTensor);
     66     int32_t activation = context->getInputValue<int32_t>(kActivationParam);
     67 
     68     T* output = context->getOutputBuffer<T>(kOutputTensor);
     69     Shape outputShape = context->getOutputShape(kOutputTensor);
     70 
     71     int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
     72     // If the input tensors are not in time major format, we transpose the first
     73     // two dimensions, and set input and output pointers to temporary vectors
     74     // which are transposed back after the RNN is applied.
     75     std::vector<T> inputTransposed;
     76     std::vector<T> outputTransposed;
     77     if (!timeMajor) {
     78         // Convert input and output to time major format.
     79         inputTransposed.resize(getNumberOfElements(inputShape));
     80         outputTransposed.resize(getNumberOfElements(outputShape));
     81         transposeFirstTwoDims(input, inputShape, inputTransposed.data());
     82         input = inputTransposed.data();
     83         output = outputTransposed.data();
     84         std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
     85         std::swap(outputShape.dimensions[0], outputShape.dimensions[1]);
     86     }
     87 
     88     const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
     89     const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
     90     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
     91     const uint32_t numUnits = getSizeOfDimension(weightsShape, 0);
     92 
     93     // A shape at a fixed step (removed time dimension).
     94     Shape fixedTimeInputShape = inputShape;
     95     fixedTimeInputShape.dimensions.resize(2);
     96     fixedTimeInputShape.dimensions[0] = inputShape.dimensions[1];
     97     fixedTimeInputShape.dimensions[1] = inputShape.dimensions[2];
     98 
     99     for (int i = 0; i < maxTime; ++i) {
    100         RNN::RNNStep<T>(input, fixedTimeInputShape, hiddenState, bias, weights, weightsShape,
    101                         recurrentWeights, recurrentWeightsShape, activation, output);
    102         input += batchSize * inputSize;
    103         hiddenState = output;
    104         output += batchSize * numUnits;
    105     }
    106 
    107     if (!timeMajor) {
    108         transposeFirstTwoDims(outputTransposed.data(), outputShape,
    109                               context->getOutputBuffer<T>(kOutputTensor));
    110     }
    111     return true;
    112 }
    113 
    114 }  // namespace
    115 
    116 bool validate(const IOperationValidationContext* context) {
    117     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    118     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    119     OperandType inputType = context->getInputType(kInputTensor);
    120     if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
    121         LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
    122                    << toString(inputType);
    123         return false;
    124     }
    125     NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, inputType, inputType, inputType,
    126                                               OperandType::INT32, OperandType::INT32}));
    127     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
    128     return validateHalVersion(context, HalVersion::V1_2);
    129 }
    130 
    131 bool prepare(IOperationExecutionContext* context) {
    132     Shape input = context->getInputShape(kInputTensor);
    133     Shape weights = context->getInputShape(kWeightsTensor);
    134     Shape recurrentWeights = context->getInputShape(kRecurrentWeightsTensor);
    135     Shape bias = context->getInputShape(kBiasTensor);
    136     Shape hiddenState = context->getInputShape(kHiddenStateTensor);
    137 
    138     int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
    139     NN_RET_CHECK(timeMajor == 0 || timeMajor == 1);
    140     const uint32_t batchSize =
    141             timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
    142     const uint32_t maxTime =
    143             timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
    144     const uint32_t numUnits = getSizeOfDimension(weights, 0);
    145     const uint32_t inputSize = getSizeOfDimension(input, 2);
    146 
    147     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
    148     NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
    149     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentWeights), 2);
    150     NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1);
    151     NN_RET_CHECK_EQ(getNumberOfDimensions(hiddenState), 2);
    152 
    153     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(weights, 1));
    154     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(bias, 0));
    155     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 0));
    156     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 1));
    157     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(hiddenState, 0));
    158     NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(hiddenState, 1));
    159 
    160     Shape output = context->getOutputShape(kOutputTensor);
    161     output.dimensions[0] = timeMajor ? maxTime : batchSize;
    162     output.dimensions[1] = timeMajor ? batchSize : maxTime;
    163     output.dimensions[2] = numUnits;
    164 
    165     return context->setOutputShape(kOutputTensor, output);
    166 }
    167 
    168 bool execute(IOperationExecutionContext* context) {
    169     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
    170         executeTyped<_Float16>(context);
    171     } else {
    172         executeTyped<float>(context);
    173     }
    174     return true;
    175 }
    176 
    177 }  // namespace unidirectional_sequence_rnn
    178 
    179 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_RNN, "UNIDIRECTIONAL_SEQUENCE_RNN",
    180                       unidirectional_sequence_rnn::validate, unidirectional_sequence_rnn::prepare,
    181                       unidirectional_sequence_rnn::execute);
    182 
    183 }  // namespace nn
    184 }  // namespace android
    185