Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "CpuOperationUtils.h"
     18 #include "OperationResolver.h"
     19 #include "OperationsUtils.h"
     20 
     21 #include <cfloat>
     22 #include <cmath>
     23 
     24 #include "Tracing.h"
     25 #include "tensorflow/lite/kernels/internal/common.h"
     26 
     27 namespace android {
     28 namespace nn {
     29 namespace roi_align {
     30 
     31 constexpr char kOperationName[] = "ROI_ALIGN";
     32 
     33 constexpr uint32_t kNumInputs = 10;
     34 constexpr uint32_t kInputTensor = 0;
     35 constexpr uint32_t kRoiTensor = 1;
     36 constexpr uint32_t kBatchSplitTensor = 2;
     37 constexpr uint32_t kOutputHeightScalar = 3;
     38 constexpr uint32_t kOutputWidthScalar = 4;
     39 constexpr uint32_t kHeightStrideSalar = 5;
     40 constexpr uint32_t kWidthStrideScalar = 6;
     41 constexpr uint32_t kHeightSamplingRatioScalar = 7;
     42 constexpr uint32_t kWidthSamplingRatioScalar = 8;
     43 constexpr uint32_t kLayoutScalar = 9;
     44 
     45 constexpr uint32_t kNumOutputs = 1;
     46 constexpr uint32_t kOutputTensor = 0;
     47 
     48 namespace {
     49 
     50 template <typename T_Input, typename T_Roi>
     51 inline bool roiAlignNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
     52                          const Shape& roiShape, const int32_t* batchSplitData,
     53                          const Shape& batchSplitShape, float heightStride, float widthStride,
     54                          int32_t heightSamplingRatio, int32_t widthSamplingRatio,
     55                          T_Input* outputData, const Shape& outputShape) {
     56     NNTRACE_TRANS("RoiAlign");
     57 
     58     const uint32_t kRoiDim = 4;
     59     const T_Roi heightScale = 1.0f / heightStride;
     60     const T_Roi widthScale = 1.0f / widthStride;
     61 
     62     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
     63     uint32_t inHeight = getSizeOfDimension(inputShape, 1);
     64     uint32_t inWidth = getSizeOfDimension(inputShape, 2);
     65     uint32_t inDepth = getSizeOfDimension(inputShape, 3);
     66     uint32_t outHeight = getSizeOfDimension(outputShape, 1);
     67     uint32_t outWidth = getSizeOfDimension(outputShape, 2);
     68     uint32_t numRois = getSizeOfDimension(roiShape, 0);
     69     uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
     70 
     71     T_Input* outPtr = outputData;
     72     const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
     73     uint32_t roiIndex = 0;
     74     for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
     75         uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
     76         // Check for malformed data
     77         // 1. invalid batch id
     78         // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
     79         // 3. Invalid region: x2 < x1 || y2 < y1
     80         NN_RET_CHECK_GE(batchId, 0);
     81         NN_RET_CHECK_LT(batchId, numBatches);
     82         NN_RET_CHECK(roiInfo[0] >= 0);
     83         NN_RET_CHECK(roiInfo[1] >= 0);
     84         NN_RET_CHECK(roiInfo[2] >= 0);
     85         NN_RET_CHECK(roiInfo[3] >= 0);
     86         NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
     87         NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
     88         NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
     89         NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
     90         NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
     91         NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
     92 
     93         T_Roi wRoiStart = roiInfo[0] * widthScale;
     94         T_Roi hRoiStart = roiInfo[1] * heightScale;
     95         T_Roi wRoiEnd = roiInfo[2] * widthScale;
     96         T_Roi hRoiEnd = roiInfo[3] * heightScale;
     97 
     98         T_Roi roiWidth = std::max(static_cast<float>(wRoiEnd - wRoiStart), 1.0f);
     99         T_Roi roiHeight = std::max(static_cast<float>(hRoiEnd - hRoiStart), 1.0f);
    100         T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
    101         T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
    102 
    103         // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
    104         uint32_t wSamplingRatio = widthSamplingRatio > 0 ? widthSamplingRatio
    105                                                          : std::ceil(static_cast<float>(wStepSize));
    106         uint32_t hSamplingRatio = heightSamplingRatio > 0
    107                                           ? heightSamplingRatio
    108                                           : std::ceil(static_cast<float>(hStepSize));
    109         int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
    110         T_Roi wBinSize = wStepSize / static_cast<T_Roi>(wSamplingRatio);
    111         T_Roi hBinSize = hStepSize / static_cast<T_Roi>(hSamplingRatio);
    112 
    113         const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
    114         for (uint32_t i = 0; i < outHeight; i++) {
    115             for (uint32_t j = 0; j < outWidth; j++) {
    116                 T_Roi wStart = wStepSize * j + wRoiStart;
    117                 T_Roi wEnd = wStepSize * (j + 1) + wRoiStart;
    118                 T_Roi hStart = hStepSize * i + hRoiStart;
    119                 T_Roi hEnd = hStepSize * (i + 1) + hRoiStart;
    120 
    121                 // initialize output to zero
    122                 for (uint32_t k = 0; k < inDepth; k++) outPtr[k] = 0;
    123 
    124                 // calculate the sum of the sampling points
    125                 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
    126                     for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
    127                         T_Roi y = hStart + hBinSize / 2 + hBinSize * yInd;
    128                         T_Roi x = wStart + wBinSize / 2 + wBinSize * xInd;
    129 
    130                         // bilinear interpolation of point (x,y)
    131                         // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
    132                         uint32_t x1 = std::floor(static_cast<float>(x));
    133                         uint32_t y1 = std::floor(static_cast<float>(y));
    134                         uint32_t x2 = x1 + 1, y2 = y1 + 1;
    135                         T_Roi dx1 = x - static_cast<T_Roi>(x1);
    136                         T_Roi dy1 = y - static_cast<T_Roi>(y1);
    137 
    138                         // dealing with out of bound samples
    139                         if (x1 >= inWidth - 1) {
    140                             x1 = x2 = inWidth - 1;
    141                             dx1 = 0;
    142                         }
    143                         if (y1 >= inHeight - 1) {
    144                             y1 = y2 = inHeight - 1;
    145                             dy1 = 0;
    146                         }
    147 
    148                         T_Roi dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
    149                         T_Roi ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
    150                         uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
    151                                               y1 * inWidth * inDepth + x2 * inDepth,
    152                                               y2 * inWidth * inDepth + x1 * inDepth,
    153                                               y2 * inWidth * inDepth + x2 * inDepth};
    154 
    155                         for (uint32_t k = 0; k < inDepth; k++) {
    156                             T_Input interpolation = 0;
    157                             for (uint32_t c = 0; c < 4; c++) {
    158                                 interpolation += ws[c] * batchBase[offsets[c] + k];
    159                             }
    160                             outPtr[k] += interpolation;
    161                         }
    162                     }
    163                 }
    164 
    165                 // take average
    166                 for (uint32_t k = 0; k < inDepth; k++)
    167                     outPtr[k] /= static_cast<T_Input>(numSamplingPoints);
    168                 outPtr += inDepth;
    169             }
    170         }
    171     }
    172     return true;
    173 }
    174 
    175 template <>
    176 inline bool roiAlignNhwc<uint8_t, uint16_t>(const uint8_t* inputData, const Shape& inputShape,
    177                                             const uint16_t* roiData, const Shape& roiShape,
    178                                             const int32_t* batchSplitData,
    179                                             const Shape& batchSplitShape, float heightStride,
    180                                             float widthStride, int32_t heightSamplingRatio,
    181                                             int32_t widthSamplingRatio, uint8_t* outputData,
    182                                             const Shape& outputShape) {
    183     NNTRACE_TRANS("RoiAlignQuant8");
    184 
    185     constexpr float wScale = 1.0f / 255.0f;
    186     constexpr uint32_t kRoiDim = 4;
    187     const float heightScale = 1.0f / heightStride;
    188     const float widthScale = 1.0f / widthStride;
    189 
    190     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
    191     uint32_t inHeight = getSizeOfDimension(inputShape, 1);
    192     uint32_t inWidth = getSizeOfDimension(inputShape, 2);
    193     uint32_t inDepth = getSizeOfDimension(inputShape, 3);
    194     uint32_t outHeight = getSizeOfDimension(outputShape, 1);
    195     uint32_t outWidth = getSizeOfDimension(outputShape, 2);
    196     uint32_t numRois = getSizeOfDimension(roiShape, 0);
    197     uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
    198 
    199     uint8_t* outPtr = outputData;
    200     const uint16_t* roiDataEnd = roiData + numRois * roiInfoLength;
    201     uint32_t roiIndex = 0;
    202     for (const uint16_t* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
    203         uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
    204         float wRoiStart = static_cast<float>(roiInfo[0]) * widthScale * 0.125f;
    205         float hRoiStart = static_cast<float>(roiInfo[1]) * heightScale * 0.125f;
    206         float wRoiEnd = static_cast<float>(roiInfo[2]) * widthScale * 0.125f;
    207         float hRoiEnd = static_cast<float>(roiInfo[3]) * heightScale * 0.125f;
    208 
    209         // Check for malformed data
    210         // 1. invalid batch id
    211         // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
    212         // 3. Invalid region: x2 < x1 || y2 < y1
    213         NN_RET_CHECK_GE(batchId, 0);
    214         NN_RET_CHECK_LT(batchId, numBatches);
    215         NN_RET_CHECK(wRoiStart <= inWidth);
    216         NN_RET_CHECK(hRoiStart <= inHeight);
    217         NN_RET_CHECK(wRoiEnd <= inWidth);
    218         NN_RET_CHECK(hRoiEnd <= inHeight);
    219         NN_RET_CHECK_LE(wRoiStart, wRoiEnd);
    220         NN_RET_CHECK_LE(hRoiStart, hRoiEnd);
    221 
    222         float roiWidth = std::max(wRoiEnd - wRoiStart, 1.0f);
    223         float roiHeight = std::max(hRoiEnd - hRoiStart, 1.0f);
    224         float wStepSize = roiWidth / static_cast<float>(outWidth);
    225         float hStepSize = roiHeight / static_cast<float>(outHeight);
    226 
    227         // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
    228         uint32_t wSamplingRatio =
    229                 widthSamplingRatio > 0 ? widthSamplingRatio : std::ceil(wStepSize);
    230         uint32_t hSamplingRatio =
    231                 heightSamplingRatio > 0 ? heightSamplingRatio : std::ceil(hStepSize);
    232         int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
    233         float wBinSize = wStepSize / static_cast<float>(wSamplingRatio);
    234         float hBinSize = hStepSize / static_cast<float>(hSamplingRatio);
    235 
    236         float realMultiplier = inputShape.scale * wScale / outputShape.scale / numSamplingPoints;
    237         int32_t outputMultiplier = 0;
    238         int32_t outputShift = 0;
    239         if (!QuantizeMultiplierSmallerThanOne(realMultiplier, &outputMultiplier, &outputShift)) {
    240             return false;
    241         }
    242 
    243         const uint8_t* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
    244         for (uint32_t i = 0; i < outHeight; i++) {
    245             for (uint32_t j = 0; j < outWidth; j++) {
    246                 float wStart = wStepSize * j + wRoiStart;
    247                 float wEnd = wStepSize * (j + 1) + wRoiStart;
    248                 float hStart = hStepSize * i + hRoiStart;
    249                 float hEnd = hStepSize * (i + 1) + hRoiStart;
    250 
    251                 std::vector<int32_t> outTemp(inDepth, 0);
    252                 // calculate the sum of the sampling points
    253                 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
    254                     for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
    255                         float y = hStart + hBinSize / 2 + hBinSize * yInd;
    256                         float x = wStart + wBinSize / 2 + wBinSize * xInd;
    257 
    258                         // bilinear interpolation of point (x,y)
    259                         // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
    260                         uint32_t x1 = std::floor(x), y1 = std::floor(y);
    261                         uint32_t x2 = x1 + 1, y2 = y1 + 1;
    262                         float dx1 = x - static_cast<float>(x1);
    263                         float dy1 = y - static_cast<float>(y1);
    264 
    265                         // dealing with out of bound samples
    266                         if (x1 >= inWidth - 1) {
    267                             x1 = x2 = inWidth - 1;
    268                             dx1 = 0;
    269                         }
    270                         if (y1 >= inHeight - 1) {
    271                             y1 = y2 = inHeight - 1;
    272                             dy1 = 0;
    273                         }
    274 
    275                         float dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
    276                         float ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
    277                         uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
    278                                               y1 * inWidth * inDepth + x2 * inDepth,
    279                                               y2 * inWidth * inDepth + x1 * inDepth,
    280                                               y2 * inWidth * inDepth + x2 * inDepth};
    281 
    282                         for (uint32_t k = 0; k < inDepth; k++) {
    283                             int32_t interpolation = 0;
    284                             for (uint32_t c = 0; c < 4; c++) {
    285                                 int32_t wQuant = static_cast<int32_t>(std::round(ws[c] / wScale));
    286                                 interpolation +=
    287                                         wQuant * (static_cast<int32_t>(batchBase[offsets[c] + k]) -
    288                                                   inputShape.offset);
    289                             }
    290                             outTemp[k] += interpolation;
    291                         }
    292                     }
    293                 }
    294 
    295                 // take average and cast to output quantization
    296                 for (uint32_t k = 0; k < inDepth; k++) {
    297                     int32_t raw_out = tflite::MultiplyByQuantizedMultiplier(
    298                                               outTemp[k], outputMultiplier, -outputShift) +
    299                                       outputShape.offset;
    300                     int32_t clamped_out = std::min(255, std::max(0, raw_out));
    301                     outPtr[k] = static_cast<uint8_t>(clamped_out);
    302                 }
    303                 outPtr += inDepth;
    304             }
    305         }
    306     }
    307     return true;
    308 }
    309 
    310 template <typename T_Input, typename T_Roi>
    311 inline bool roiAlign(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
    312                      const Shape& roiShape, const int32_t* batchSplitData,
    313                      const Shape& batchSplitShape, float heightStride, float widthStride,
    314                      int32_t heightSamplingRatio, int32_t widthSamplingRatio, bool useNchw,
    315                      T_Input* outputData, const Shape& outputShape) {
    316     InputWithLayout<T_Input> input(useNchw);
    317     OutputWithLayout<T_Input> output(useNchw);
    318     NN_RET_CHECK(input.initialize(inputData, inputShape));
    319     NN_RET_CHECK(output.initialize(outputData, outputShape));
    320     NN_RET_CHECK(roiAlignNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
    321                               batchSplitData, batchSplitShape, heightStride, widthStride,
    322                               heightSamplingRatio, widthSamplingRatio, output.getNhwcBuffer(),
    323                               output.getNhwcShape()));
    324     NN_RET_CHECK(output.commit());
    325     return true;
    326 }
    327 
    328 }  // namespace
    329 
    330 bool validate(const IOperationValidationContext* context) {
    331     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    332     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    333     std::vector<OperandType> inExpectedTypes;
    334     auto inputType = context->getInputType(kInputTensor);
    335     if (inputType == OperandType::TENSOR_FLOAT32) {
    336         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    337                            OperandType::TENSOR_INT32,   OperandType::INT32,
    338                            OperandType::INT32,          OperandType::FLOAT32,
    339                            OperandType::FLOAT32,        OperandType::INT32,
    340                            OperandType::INT32,          OperandType::BOOL};
    341     } else if (inputType == OperandType::TENSOR_FLOAT16) {
    342         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    343                            OperandType::TENSOR_INT32,   OperandType::INT32,
    344                            OperandType::INT32,          OperandType::FLOAT16,
    345                            OperandType::FLOAT16,        OperandType::INT32,
    346                            OperandType::INT32,          OperandType::BOOL};
    347     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    348         inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
    349                            OperandType::TENSOR_QUANT16_ASYMM,
    350                            OperandType::TENSOR_INT32,
    351                            OperandType::INT32,
    352                            OperandType::INT32,
    353                            OperandType::FLOAT32,
    354                            OperandType::FLOAT32,
    355                            OperandType::INT32,
    356                            OperandType::INT32,
    357                            OperandType::BOOL};
    358     } else {
    359         LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
    360         return false;
    361     }
    362     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
    363     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
    364     return validateHalVersion(context, HalVersion::V1_2);
    365 }
    366 
    367 bool prepare(IOperationExecutionContext* context) {
    368     bool useNchw = context->getInputValue<bool>(kLayoutScalar);
    369     Shape input = context->getInputShape(kInputTensor);
    370     Shape roiShape = context->getInputShape(kRoiTensor);
    371     Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
    372     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
    373     NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
    374 
    375     uint32_t numBatches = getSizeOfDimension(input, 0);
    376     uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
    377     uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
    378     uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
    379     uint32_t numRois = getSizeOfDimension(roiShape, 0);
    380     // Every dimension must be positive except for numRois.
    381     NN_RET_CHECK_GT(numBatches, 0);
    382     NN_RET_CHECK_GT(inHeight, 0);
    383     NN_RET_CHECK_GT(inWidth, 0);
    384     NN_RET_CHECK_GT(inDepth, 0);
    385     NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4);
    386     NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
    387 
    388     int32_t outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
    389     int32_t outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
    390     int32_t heightSamplingRatio = context->getInputValue<int32_t>(kHeightSamplingRatioScalar);
    391     int32_t widthSamplingRatio = context->getInputValue<int32_t>(kWidthSamplingRatioScalar);
    392     float heightScale, widthScale;
    393     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
    394         heightScale = context->getInputValue<_Float16>(kHeightStrideSalar);
    395         widthScale = context->getInputValue<_Float16>(kWidthStrideScalar);
    396     } else {
    397         heightScale = context->getInputValue<float>(kHeightStrideSalar);
    398         widthScale = context->getInputValue<float>(kWidthStrideScalar);
    399     }
    400     NN_RET_CHECK_GT(outputHeight, 0);
    401     NN_RET_CHECK_GT(outputWidth, 0);
    402     NN_RET_CHECK_GT(heightScale, 0);
    403     NN_RET_CHECK_GT(widthScale, 0);
    404     // Sampling ratio can set to 0 for adaptive value.
    405     NN_RET_CHECK_GE(heightSamplingRatio, 0);
    406     NN_RET_CHECK_GE(widthSamplingRatio, 0);
    407 
    408     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
    409         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
    410         NN_RET_CHECK_EQ(roiShape.offset, 0);
    411     }
    412 
    413     Shape output = context->getOutputShape(kOutputTensor);
    414     output.type = input.type;
    415     if (useNchw) {
    416         output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
    417                              static_cast<uint32_t>(outputWidth)};
    418     } else {
    419         output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
    420                              static_cast<uint32_t>(outputWidth), inDepth};
    421     }
    422     return context->setOutputShape(kOutputTensor, output);
    423 }
    424 
    425 bool execute(IOperationExecutionContext* context) {
    426     // Bypass execution in the case of zero-sized input.
    427     if (getNumberOfElements(context->getInputShape(kRoiTensor)) == 0) return true;
    428     switch (context->getInputType(kInputTensor)) {
    429         case OperandType::TENSOR_FLOAT16:
    430             return roiAlign(context->getInputBuffer<_Float16>(kInputTensor),
    431                             context->getInputShape(kInputTensor),
    432                             context->getInputBuffer<_Float16>(kRoiTensor),
    433                             context->getInputShape(kRoiTensor),
    434                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
    435                             context->getInputShape(kBatchSplitTensor),
    436                             context->getInputValue<_Float16>(kHeightStrideSalar),
    437                             context->getInputValue<_Float16>(kWidthStrideScalar),
    438                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
    439                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
    440                             context->getInputValue<bool>(kLayoutScalar),
    441                             context->getOutputBuffer<_Float16>(kOutputTensor),
    442                             context->getOutputShape(kOutputTensor));
    443         case OperandType::TENSOR_FLOAT32:
    444             return roiAlign(context->getInputBuffer<float>(kInputTensor),
    445                             context->getInputShape(kInputTensor),
    446                             context->getInputBuffer<float>(kRoiTensor),
    447                             context->getInputShape(kRoiTensor),
    448                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
    449                             context->getInputShape(kBatchSplitTensor),
    450                             context->getInputValue<float>(kHeightStrideSalar),
    451                             context->getInputValue<float>(kWidthStrideScalar),
    452                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
    453                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
    454                             context->getInputValue<bool>(kLayoutScalar),
    455                             context->getOutputBuffer<float>(kOutputTensor),
    456                             context->getOutputShape(kOutputTensor));
    457         case OperandType::TENSOR_QUANT8_ASYMM:
    458             return roiAlign(context->getInputBuffer<uint8_t>(kInputTensor),
    459                             context->getInputShape(kInputTensor),
    460                             context->getInputBuffer<uint16_t>(kRoiTensor),
    461                             context->getInputShape(kRoiTensor),
    462                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
    463                             context->getInputShape(kBatchSplitTensor),
    464                             context->getInputValue<float>(kHeightStrideSalar),
    465                             context->getInputValue<float>(kWidthStrideScalar),
    466                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
    467                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
    468                             context->getInputValue<bool>(kLayoutScalar),
    469                             context->getOutputBuffer<uint8_t>(kOutputTensor),
    470                             context->getOutputShape(kOutputTensor));
    471         default:
    472             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
    473     }
    474 }
    475 
    476 }  // namespace roi_align
    477 
    478 NN_REGISTER_OPERATION(ROI_ALIGN, roi_align::kOperationName, roi_align::validate, roi_align::prepare,
    479                       roi_align::execute, .allowZeroSizedInput = true);
    480 
    481 }  // namespace nn
    482 }  // namespace android
    483