Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "OperationResolver.h"
     20 #include "RNN.h"
     21 
     22 namespace android {
     23 namespace nn {
     24 namespace bidirectional_sequence_rnn {
     25 
     26 constexpr uint32_t kNumInputs = 15;
     27 constexpr uint32_t kInputTensor = 0;
     28 // Forward cell tensors
     29 constexpr uint32_t kFwWeightsTensor = 1;
     30 constexpr uint32_t kFwRecurrentWeightsTensor = 2;
     31 constexpr uint32_t kFwBiasTensor = 3;
     32 constexpr uint32_t kFwHiddenStateTensor = 4;
     33 // Backward cell tensors
     34 constexpr uint32_t kBwWeightsTensor = 5;
     35 constexpr uint32_t kBwRecurrentWeightsTensor = 6;
     36 constexpr uint32_t kBwBiasTensor = 7;
     37 constexpr uint32_t kBwHiddenStateTensor = 8;
     38 // Auxiliary inputs
     39 constexpr uint32_t kAuxInputTensor = 9;       // optional
     40 constexpr uint32_t kFwAuxWeightsTensor = 10;  // optional
     41 constexpr uint32_t kBwAuxWeightsTensor = 11;  // optional
     42 // Cell parameters
     43 constexpr uint32_t kActivationParam = 12;
     44 constexpr uint32_t kTimeMajorParam = 13;
     45 constexpr uint32_t kMergeOutputsParam = 14;
     46 
     47 constexpr uint32_t kFwOutputTensor = 0;
     48 constexpr uint32_t kBwOutputTensor = 1;  // Only if mergeOutputs parameter is false
     49 
     50 namespace {
     51 
     52 template <typename T>
     53 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
     54     const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
     55     const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
     56     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
     57     for (int f = 0; f < firstDimSize; ++f) {
     58         for (int s = 0; s < secondDimSize; ++s) {
     59             for (int i = 0; i < inputSize; ++i) {
     60                 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
     61                 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
     62                 output[outputIndex] = input[inputIndex];
     63             }
     64         }
     65     }
     66 }
     67 
     68 Shape removeFirstDim(const Shape& input) {
     69     Shape output = input;
     70     output.dimensions.resize(input.dimensions.size() - 1);
     71     for (int i = 0; i < input.dimensions.size() - 1; ++i) {
     72         output.dimensions[i] = input.dimensions[i + 1];
     73     }
     74     return output;
     75 }
     76 
     77 template <typename T>
     78 bool executeTyped(IOperationExecutionContext* context) {
     79     const T* input = context->getInputBuffer<T>(kInputTensor);
     80     Shape inputShape = context->getInputShape(kInputTensor);
     81 
     82     const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
     83     Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
     84     const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
     85     Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
     86     const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
     87     const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);
     88 
     89     const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
     90     Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
     91     const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
     92     Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
     93     const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
     94     const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);
     95 
     96     const T* auxInput = nullptr;
     97     const T* fwAuxWeights = nullptr;
     98     const T* bwAuxWeights = nullptr;
     99     const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor);
    100     if (hasAuxInputs) {
    101         auxInput = context->getInputBuffer<T>(kAuxInputTensor);
    102         fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
    103         bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
    104     }
    105     Shape auxInputShape = context->getInputShape(kAuxInputTensor);
    106     Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
    107     Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);
    108 
    109     int32_t activation = context->getInputValue<int32_t>(kActivationParam);
    110     int32_t timeMajor = context->getInputValue<bool>(kTimeMajorParam);
    111     int32_t mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
    112 
    113     T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
    114     Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
    115     T* bwOutput = nullptr;
    116     Shape bwOutputShape;
    117     if (!mergeOutputs) {
    118         bwOutputShape = context->getOutputShape(kBwOutputTensor);
    119         bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
    120     }
    121 
    122     // If the input tensors are not in time major format, we transpose the first
    123     // two dimensions, and set input and output pointers to temporary vectors
    124     // which are transposed back after the RNN is applied.
    125     std::vector<T> inputTransposed;
    126     std::vector<T> auxInputTransposed;
    127     std::vector<T> fwOutputTransposed;
    128     std::vector<T> bwOutputTransposed;
    129     if (!timeMajor) {
    130         // First, resize temporary buffers to accommodate for transposed tensors.
    131         inputTransposed.resize(getNumberOfElements(inputShape));
    132         if (hasAuxInputs) {
    133             auxInputTransposed.resize(getNumberOfElements(auxInputShape));
    134         }
    135         fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
    136         if (!mergeOutputs) {
    137             bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
    138         }
    139 
    140         // Transpose the input tensors.
    141         transposeFirstTwoDims(input, inputShape, inputTransposed.data());
    142         if (hasAuxInputs) {
    143             transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
    144         }
    145 
    146         // Change input and output pointers to the temporary buffers.
    147         input = inputTransposed.data();
    148         if (hasAuxInputs) {
    149             auxInput = auxInputTransposed.data();
    150         }
    151         fwOutput = fwOutputTransposed.data();
    152         if (!mergeOutputs) {
    153             bwOutput = bwOutputTransposed.data();
    154         }
    155 
    156         // Swap the first two dimensions in the Shapes to reflect the
    157         // transposition.
    158         std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
    159         if (hasAuxInputs) {
    160             std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
    161         }
    162         std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
    163         if (!mergeOutputs) {
    164             std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
    165         }
    166     }
    167 
    168     const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
    169     const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
    170     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
    171     uint32_t auxInputSize = 0;
    172     if (hasAuxInputs) {
    173         auxInputSize = getSizeOfDimension(auxInputShape, 2);
    174     }
    175     const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
    176     const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);
    177 
    178     Shape fixedTimeInputShape = removeFirstDim(inputShape);
    179     Shape fixedTimeAuxInputShape = auxInputShape;
    180     if (hasAuxInputs) {
    181         fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
    182     }
    183 
    184     // Create an additional buffer to store a hidden state between steps.
    185     std::vector<T> tempHiddenState(batchSize * fwNumUnits);
    186     // Forward pass
    187     for (int i = 0; i < maxTime; ++i) {
    188         const T* inputBatchPtr = input + i * batchSize * inputSize;
    189         const T* auxInputBatchPtr = nullptr;
    190         if (hasAuxInputs) {
    191             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
    192         }
    193         const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
    194         T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;
    195 
    196         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
    197                         fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
    198                         fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
    199                         fwRecurrentWeightsShape, activation, fwOutputBatchStride,
    200                         /*outputBatchOffset=*/0, fwOutputBatchPtr, tempHiddenState.data());
    201 
    202         fwHiddenState = tempHiddenState.data();
    203     }
    204 
    205     tempHiddenState.resize(batchSize * bwNumUnits);
    206     // Backward pass
    207     for (int i = maxTime - 1; i >= 0; --i) {
    208         const T* inputBatchPtr = input + i * batchSize * inputSize;
    209         const T* auxInputBatchPtr = nullptr;
    210         if (hasAuxInputs) {
    211             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
    212         }
    213         T* bwOutputBatchPtr;
    214         uint32_t bwOutputBatchOffset = 0;
    215         uint32_t bwOutputBatchStride;
    216         if (mergeOutputs) {
    217             bwOutputBatchStride = fwNumUnits + bwNumUnits;
    218             bwOutputBatchOffset = fwNumUnits;
    219             bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
    220         } else {
    221             bwOutputBatchStride = bwNumUnits;
    222             bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
    223         }
    224 
    225         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
    226                         fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
    227                         bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
    228                         bwRecurrentWeightsShape, activation, bwOutputBatchStride,
    229                         bwOutputBatchOffset, bwOutputBatchPtr, tempHiddenState.data());
    230 
    231         bwHiddenState = tempHiddenState.data();
    232     }
    233 
    234     // If the inputs were in batch major format, transpose data in temporary
    235     // buffers and write to the output(s).
    236     if (!timeMajor) {
    237         transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
    238                               context->getOutputBuffer<T>(kFwOutputTensor));
    239         if (!mergeOutputs) {
    240             transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
    241                                   context->getOutputBuffer<T>(kBwOutputTensor));
    242         }
    243     }
    244     return true;
    245 }
    246 
    247 }  // namespace
    248 
    249 bool validate(const IOperationValidationContext* context) {
    250     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    251     // Exact number is dependent on the mergeOutputs parameter and checked
    252     // during preparation.
    253     NN_RET_CHECK(context->getNumOutputs() == 1 || context->getNumOutputs() == 2);
    254     OperandType inputType = context->getInputType(kInputTensor);
    255     if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
    256         LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
    257                    << toString(inputType);
    258         return false;
    259     }
    260     NN_RET_CHECK(validateInputTypes(
    261             context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType,
    262                       inputType, inputType, inputType, inputType, inputType, OperandType::INT32,
    263                       OperandType::BOOL, OperandType::BOOL}));
    264     if (context->getNumOutputs() == 1) {
    265         NN_RET_CHECK(validateOutputTypes(context, {inputType}));
    266     } else {
    267         NN_RET_CHECK(validateOutputTypes(context, {inputType, inputType}));
    268     }
    269     return validateHalVersion(context, HalVersion::V1_2);
    270 }
    271 
    272 bool prepare(IOperationExecutionContext* context) {
    273     int32_t mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
    274     if (mergeOutputs) {
    275         NN_RET_CHECK_EQ(context->getNumOutputs(), 1);
    276     } else {
    277         NN_RET_CHECK_EQ(context->getNumOutputs(), 2);
    278     }
    279 
    280     // Check that none of the required inputs are omitted.
    281     const std::vector<int> requiredInputs = {
    282             kInputTensor,         kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
    283             kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
    284             kBwHiddenStateTensor, kActivationParam, kTimeMajorParam,           kMergeOutputsParam,
    285     };
    286     for (const int requiredInput : requiredInputs) {
    287         NN_RET_CHECK(!context->isOmittedInput(requiredInput))
    288                 << "required input " << requiredInput << " is omitted";
    289     }
    290 
    291     Shape input = context->getInputShape(kInputTensor);
    292     Shape fwWeights = context->getInputShape(kFwWeightsTensor);
    293     Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
    294     Shape fwBias = context->getInputShape(kFwBiasTensor);
    295     Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
    296     Shape bwWeights = context->getInputShape(kBwWeightsTensor);
    297     Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
    298     Shape bwBias = context->getInputShape(kBwBiasTensor);
    299     Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);
    300 
    301     Shape auxInput = context->getInputShape(kAuxInputTensor);
    302     Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
    303     Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
    304 
    305     const bool auxInputsAllOrNone = (context->isOmittedInput(kAuxInputTensor) &&
    306                                      context->isOmittedInput(kFwAuxWeightsTensor) &&
    307                                      context->isOmittedInput(kBwAuxWeightsTensor)) ||
    308                                     (!context->isOmittedInput(kAuxInputTensor) &&
    309                                      !context->isOmittedInput(kFwAuxWeightsTensor) &&
    310                                      !context->isOmittedInput(kBwAuxWeightsTensor));
    311     NN_RET_CHECK(auxInputsAllOrNone);
    312     const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor);
    313 
    314     int32_t timeMajor = context->getInputValue<bool>(kTimeMajorParam);
    315     const uint32_t batchSize =
    316             timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
    317     const uint32_t maxTime =
    318             timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
    319     const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
    320     const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
    321     const uint32_t inputSize = getSizeOfDimension(input, 2);
    322 
    323     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
    324     NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2);
    325     NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2);
    326     NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1);
    327     NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2);
    328     NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2);
    329     NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2);
    330     NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1);
    331     NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2);
    332 
    333     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
    334     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
    335     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
    336     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
    337     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
    338     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));
    339 
    340     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
    341     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
    342     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
    343     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
    344     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
    345     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));
    346 
    347     if (hasAuxInputs) {
    348         NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
    349         NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2);
    350         NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2);
    351 
    352         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
    353         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
    354         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
    355         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
    356         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
    357         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
    358     }
    359 
    360     Shape fwOutput = context->getOutputShape(kFwOutputTensor);
    361     fwOutput.dimensions.resize(3);
    362     fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
    363     fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
    364     fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
    365     NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
    366     if (!mergeOutputs) {
    367         Shape bwOutput = context->getOutputShape(kBwOutputTensor);
    368         bwOutput.dimensions.resize(3);
    369         bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
    370         bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
    371         bwOutput.dimensions[2] = bwNumUnits;
    372         NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
    373     }
    374 
    375     return true;
    376 }
    377 
    378 bool execute(IOperationExecutionContext* context) {
    379     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
    380         executeTyped<_Float16>(context);
    381     } else {
    382         executeTyped<float>(context);
    383     }
    384     return true;
    385 }
    386 
    387 }  // namespace bidirectional_sequence_rnn
    388 
    389 NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN",
    390                       bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare,
    391                       bidirectional_sequence_rnn::execute, .allowOmittedOperand = true);
    392 
    393 }  // namespace nn
    394 }  // namespace android
    395