Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "CpuOperationUtils.h"
     18 #include "OperationResolver.h"
     19 #include "OperationsUtils.h"
     20 
     21 #include <cfloat>
     22 #include <cmath>
     23 
     24 #include "Tracing.h"
     25 
     26 namespace android {
     27 namespace nn {
     28 namespace roi_pooling {
     29 
     30 constexpr char kOperationName[] = "ROI_POOLING";
     31 
     32 constexpr uint32_t kNumInputs = 8;
     33 constexpr uint32_t kInputTensor = 0;
     34 constexpr uint32_t kRoiTensor = 1;
     35 constexpr uint32_t kBatchSplitTensor = 2;
     36 constexpr uint32_t kOutputHeightScalar = 3;
     37 constexpr uint32_t kOutputWidthScalar = 4;
     38 constexpr uint32_t kHeightStrideSalar = 5;
     39 constexpr uint32_t kWidthStrideScalar = 6;
     40 constexpr uint32_t kLayoutScalar = 7;
     41 
     42 constexpr uint32_t kNumOutputs = 1;
     43 constexpr uint32_t kOutputTensor = 0;
     44 
     45 namespace {
     46 
     47 template <typename T_Input, typename T_Roi>
     48 inline bool roiPoolingNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
     49                            const Shape& roiShape, const int32_t* batchSplitData,
     50                            const Shape& batchSplitShape, float heightStride, float widthStride,
     51                            T_Input* outputData, const Shape& outputShape) {
     52     NNTRACE_TRANS("RoiPooling");
     53 
     54     const uint32_t kRoiDim = 4;
     55     const T_Roi heightScale = 1.0f / heightStride;
     56     const T_Roi widthScale = 1.0f / widthStride;
     57 
     58     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
     59     uint32_t inHeight = getSizeOfDimension(inputShape, 1);
     60     uint32_t inWidth = getSizeOfDimension(inputShape, 2);
     61     uint32_t inDepth = getSizeOfDimension(inputShape, 3);
     62     uint32_t outHeight = getSizeOfDimension(outputShape, 1);
     63     uint32_t outWidth = getSizeOfDimension(outputShape, 2);
     64     uint32_t numRois = getSizeOfDimension(roiShape, 0);
     65     uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
     66 
     67     T_Input* outPtr = outputData;
     68     const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
     69     uint32_t roiIndex = 0;
     70     for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
     71         uint32_t batchId = batchSplitData[roiIndex];
     72         // Check for malformed data
     73         // 1. invalid batch id
     74         // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
     75         // 3. Invalid region: x2 < x1 || y2 < y1
     76         NN_RET_CHECK_GE(batchId, 0);
     77         NN_RET_CHECK_LT(batchId, numBatches);
     78         NN_RET_CHECK(roiInfo[0] >= 0);
     79         NN_RET_CHECK(roiInfo[1] >= 0);
     80         NN_RET_CHECK(roiInfo[2] >= 0);
     81         NN_RET_CHECK(roiInfo[3] >= 0);
     82         NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
     83         NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
     84         NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
     85         NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
     86         NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
     87         NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
     88 
     89         int32_t wRoiStart = std::round(static_cast<float>(roiInfo[0] * widthScale));
     90         int32_t hRoiStart = std::round(static_cast<float>(roiInfo[1] * heightScale));
     91         int32_t wRoiEnd = std::round(static_cast<float>(roiInfo[2] * widthScale));
     92         int32_t hRoiEnd = std::round(static_cast<float>(roiInfo[3] * heightScale));
     93 
     94         // Rois with width/height < 1 are considered malformed and are forced to be 1
     95         T_Roi roiWidth = static_cast<T_Roi>(std::max(wRoiEnd - wRoiStart + 1, 1));
     96         T_Roi roiHeight = static_cast<T_Roi>(std::max(hRoiEnd - hRoiStart + 1, 1));
     97         T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
     98         T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
     99 
    100         const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
    101         for (uint32_t i = 0; i < outHeight; i++) {
    102             for (uint32_t j = 0; j < outWidth; j++) {
    103                 // Take floor on start, ceil on end, start included, end excluded, i.e. [start, end)
    104                 // end is guaranteed to larger than start by at least 1
    105                 uint32_t wStart = std::floor(static_cast<float>(wStepSize * j + wRoiStart));
    106                 uint32_t wEnd = std::ceil(static_cast<float>(wStepSize * (j + 1) + wRoiStart));
    107                 uint32_t hStart = std::floor(static_cast<float>(hStepSize * i + hRoiStart));
    108                 uint32_t hEnd = std::ceil(static_cast<float>(hStepSize * (i + 1) + hRoiStart));
    109 
    110                 wStart = std::min(wStart, inWidth);
    111                 wEnd = std::min(wEnd, inWidth);
    112                 hStart = std::min(hStart, inHeight);
    113                 hEnd = std::min(hEnd, inHeight);
    114 
    115                 for (uint32_t k = 0; k < inDepth; k++) {
    116                     T_Input maxValue = static_cast<T_Input>(inputShape.offset);
    117                     bool first = true;
    118                     for (uint32_t h = hStart; h < hEnd; h++) {
    119                         for (uint32_t w = wStart; w < wEnd; w++) {
    120                             T_Input inputValue = batchBase[h * inWidth * inDepth + w * inDepth + k];
    121                             if (first || inputValue > maxValue) {
    122                                 maxValue = inputValue;
    123                                 first = false;
    124                             }
    125                         }
    126                     }
    127                     outPtr[k] = maxValue;
    128                 }
    129                 outPtr += inDepth;
    130             }
    131         }
    132     }
    133     return true;
    134 }
    135 
    136 template <typename T_Input, typename T_Roi>
    137 inline bool roiPooling(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
    138                        const Shape& roiShape, const int32_t* batchSplitData,
    139                        const Shape& batchSplitShape, float heightStride, float widthStride,
    140                        bool useNchw, T_Input* outputData, const Shape& outputShape) {
    141     InputWithLayout<T_Input> input(useNchw);
    142     OutputWithLayout<T_Input> output(useNchw);
    143     NN_RET_CHECK(input.initialize(inputData, inputShape));
    144     NN_RET_CHECK(output.initialize(outputData, outputShape));
    145     NN_RET_CHECK(roiPoolingNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
    146                                 batchSplitData, batchSplitShape, heightStride, widthStride,
    147                                 output.getNhwcBuffer(), output.getNhwcShape()));
    148     NN_RET_CHECK(output.commit());
    149     return true;
    150 }
    151 
    152 template <>
    153 inline bool roiPooling<uint8_t, uint16_t>(const uint8_t* inputData, const Shape& inputShape,
    154                                           const uint16_t* roiData, const Shape& roiShape,
    155                                           const int32_t* batchSplitData,
    156                                           const Shape& batchSplitShape, float heightStride,
    157                                           float widthStride, bool useNchw, uint8_t* outputData,
    158                                           const Shape& outputShape) {
    159     std::vector<float> roi_float32(getNumberOfElements(roiShape));
    160     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
    161     NN_RET_CHECK(roiPooling(inputData, inputShape, roi_float32.data(), roiShape, batchSplitData,
    162                             batchSplitShape, heightStride, widthStride, useNchw, outputData,
    163                             outputShape));
    164     return true;
    165 }
    166 
    167 }  // namespace
    168 
    169 bool validate(const IOperationValidationContext* context) {
    170     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    171     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    172     std::vector<OperandType> inExpectedTypes;
    173     auto inputType = context->getInputType(kInputTensor);
    174     if (inputType == OperandType::TENSOR_FLOAT32) {
    175         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
    176                            OperandType::TENSOR_INT32,   OperandType::INT32,
    177                            OperandType::INT32,          OperandType::FLOAT32,
    178                            OperandType::FLOAT32,        OperandType::BOOL};
    179     } else if (inputType == OperandType::TENSOR_FLOAT16) {
    180         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
    181                            OperandType::TENSOR_INT32,   OperandType::INT32,
    182                            OperandType::INT32,          OperandType::FLOAT16,
    183                            OperandType::FLOAT16,        OperandType::BOOL};
    184     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    185         inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
    186                            OperandType::TENSOR_QUANT16_ASYMM,
    187                            OperandType::TENSOR_INT32,
    188                            OperandType::INT32,
    189                            OperandType::INT32,
    190                            OperandType::FLOAT32,
    191                            OperandType::FLOAT32,
    192                            OperandType::BOOL};
    193     } else {
    194         LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
    195         return false;
    196     }
    197     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
    198     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
    199     return validateHalVersion(context, HalVersion::V1_2);
    200 }
    201 
    202 bool prepare(IOperationExecutionContext* context) {
    203     bool useNchw = context->getInputValue<bool>(kLayoutScalar);
    204     Shape input = context->getInputShape(kInputTensor);
    205     Shape roiShape = context->getInputShape(kRoiTensor);
    206     Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
    207     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
    208     NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
    209 
    210     uint32_t numBatches = getSizeOfDimension(input, 0);
    211     uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
    212     uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
    213     uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
    214     uint32_t numRois = getSizeOfDimension(roiShape, 0);
    215     NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4);
    216     NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
    217 
    218     auto outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
    219     auto outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
    220     float heightStride, widthStride;
    221     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
    222         heightStride = context->getInputValue<_Float16>(kHeightStrideSalar);
    223         widthStride = context->getInputValue<_Float16>(kWidthStrideScalar);
    224     } else {
    225         heightStride = context->getInputValue<float>(kHeightStrideSalar);
    226         widthStride = context->getInputValue<float>(kWidthStrideScalar);
    227     }
    228     NN_RET_CHECK_GT(outputHeight, 0);
    229     NN_RET_CHECK_GT(outputWidth, 0);
    230     NN_RET_CHECK_GT(heightStride, 0);
    231     NN_RET_CHECK_GT(widthStride, 0);
    232 
    233     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
    234         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
    235         NN_RET_CHECK_EQ(roiShape.offset, 0);
    236     }
    237 
    238     Shape output = input;
    239     if (useNchw) {
    240         output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
    241                              static_cast<uint32_t>(outputWidth)};
    242     } else {
    243         output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
    244                              static_cast<uint32_t>(outputWidth), inDepth};
    245     }
    246     return context->setOutputShape(kOutputTensor, output);
    247 }
    248 
    249 bool execute(IOperationExecutionContext* context) {
    250     switch (context->getInputType(kInputTensor)) {
    251         case OperandType::TENSOR_FLOAT16:
    252             return roiPooling(context->getInputBuffer<_Float16>(kInputTensor),
    253                               context->getInputShape(kInputTensor),
    254                               context->getInputBuffer<_Float16>(kRoiTensor),
    255                               context->getInputShape(kRoiTensor),
    256                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
    257                               context->getInputShape(kBatchSplitTensor),
    258                               context->getInputValue<_Float16>(kHeightStrideSalar),
    259                               context->getInputValue<_Float16>(kWidthStrideScalar),
    260                               context->getInputValue<bool>(kLayoutScalar),
    261                               context->getOutputBuffer<_Float16>(kOutputTensor),
    262                               context->getOutputShape(kOutputTensor));
    263         case OperandType::TENSOR_FLOAT32:
    264             return roiPooling(context->getInputBuffer<float>(kInputTensor),
    265                               context->getInputShape(kInputTensor),
    266                               context->getInputBuffer<float>(kRoiTensor),
    267                               context->getInputShape(kRoiTensor),
    268                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
    269                               context->getInputShape(kBatchSplitTensor),
    270                               context->getInputValue<float>(kHeightStrideSalar),
    271                               context->getInputValue<float>(kWidthStrideScalar),
    272                               context->getInputValue<bool>(kLayoutScalar),
    273                               context->getOutputBuffer<float>(kOutputTensor),
    274                               context->getOutputShape(kOutputTensor));
    275         case OperandType::TENSOR_QUANT8_ASYMM:
    276             return roiPooling(context->getInputBuffer<uint8_t>(kInputTensor),
    277                               context->getInputShape(kInputTensor),
    278                               context->getInputBuffer<uint16_t>(kRoiTensor),
    279                               context->getInputShape(kRoiTensor),
    280                               context->getInputBuffer<int32_t>(kBatchSplitTensor),
    281                               context->getInputShape(kBatchSplitTensor),
    282                               context->getInputValue<float>(kHeightStrideSalar),
    283                               context->getInputValue<float>(kWidthStrideScalar),
    284                               context->getInputValue<bool>(kLayoutScalar),
    285                               context->getOutputBuffer<uint8_t>(kOutputTensor),
    286                               context->getOutputShape(kOutputTensor));
    287         default:
    288             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
    289     }
    290 }
    291 
    292 }  // namespace roi_pooling
    293 
    294 NN_REGISTER_OPERATION(ROI_POOLING, roi_pooling::kOperationName, roi_pooling::validate,
    295                       roi_pooling::prepare, roi_pooling::execute);
    296 
    297 }  // namespace nn
    298 }  // namespace android
    299