Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "CpuOperationUtils.h"
     18 #include "OperationResolver.h"
     19 #include "OperationsUtils.h"
     20 
     21 #include <cfloat>
     22 #include <cmath>
     23 
     24 #include "Tracing.h"
     25 
     26 namespace android {
     27 namespace nn {
     28 namespace heatmap_max_keypoint {
     29 
     30 constexpr char kOperationName[] = "HEATMAP_MAX_KEYPOINT";
     31 
     32 constexpr uint32_t kNumInputs = 3;
     33 constexpr uint32_t kHeatmapTensor = 0;
     34 constexpr uint32_t kBoxesTensor = 1;
     35 constexpr uint32_t kLayoutScalar = 2;
     36 
     37 constexpr uint32_t kNumOutputs = 2;
     38 constexpr uint32_t kOutputScoreTensor = 0;
     39 constexpr uint32_t kOutputKeypointTensor = 1;
     40 
     41 namespace {
     42 
     43 // This function uses Taylor expansion up to the quatratic term to approximate bicubic
     44 // upscaling result.
     45 // 2nd order Taylor expansion: D(x) = D - b'x + 1/2 * x'Ax
     46 // where D = grid[1][1], Taylor expansion center, the original score,
     47 //       x = delta, the correction on max keypoint position,
     48 //       D(x) = deltaScore, the accuracy score after correction
     49 static void solveForDelta(const float grid[3][3], float* delta, float* deltaScore,
     50                           float fpAtol = 1e-5f, float fpRtol = 1e-5f) {
     51     // b: negative 1st order derivative at center
     52     // A: Hessian matrix at center (2nd order derivative)
     53     float A[2][2], b[2];
     54     b[0] = -(grid[1][2] - grid[1][0]) / 2.0f;
     55     b[1] = -(grid[2][1] - grid[0][1]) / 2.0f;
     56     A[0][0] = grid[1][0] - 2.0f * grid[1][1] + grid[1][2];
     57     A[0][1] = (grid[2][2] - grid[2][0] - grid[0][2] + grid[0][0]) / 4.0f;
     58     A[1][0] = A[0][1];
     59     A[1][1] = grid[0][1] - 2.0f * grid[1][1] + grid[2][1];
     60 
     61     // solve Ax=b, where x=delta -> delta = inv(A) * b
     62     float crossProd1 = A[0][0] * A[1][1], crossProd2 = A[0][1] * A[1][0];
     63     float detA = crossProd1 - crossProd2;
     64     // check if A is invertible
     65     if (std::abs(detA) < (fpAtol + fpRtol * crossProd1)) return;
     66     delta[0] = (A[1][1] * b[0] - A[0][1] * b[1]) / detA;
     67     delta[1] = (A[0][0] * b[1] - A[1][0] * b[0]) / detA;
     68 
     69     // clip out of range delta, i.e. delta > 3/2
     70     if (std::abs(delta[0]) > 1.5f || std::abs(delta[1]) > 1.5f) {
     71         float scale = 1.5f / std::max(std::abs(delta[0]), std::abs(delta[1]));
     72         delta[0] *= scale;
     73         delta[1] *= scale;
     74     }
     75 
     76     *deltaScore = grid[1][1] - b[0] * delta[0] - b[1] * delta[1] +
     77                   ((A[0][0] * delta[0] + A[0][1] * delta[1]) * delta[0] +
     78                    (A[1][0] * delta[0] + A[1][1] * delta[1]) * delta[1]) /
     79                           2.0f;
     80 }
     81 
     82 inline bool heatmapMaxKeypointFloat32Nhwc(const float* heatmap, const Shape& heatmapShape,
     83                                           const float* boxes, const Shape& boxesShape,
     84                                           float* outputScoreData, const Shape& outputScoreShape,
     85                                           float* outputKeypointData,
     86                                           const Shape& outputKeypointShape, float fpAtol,
     87                                           float fpRtol) {
     88     NNTRACE_TRANS("HeatmapMaxKeypoint");
     89 
     90     uint32_t numBoxes = getSizeOfDimension(heatmapShape, 0);
     91     uint32_t heatmapSize = getSizeOfDimension(heatmapShape, 1);
     92     uint32_t numKeypoints = getSizeOfDimension(heatmapShape, 3);
     93     uint32_t boxInfoLength = getSizeOfDimension(boxesShape, 1);
     94 
     95     const float* heatmapBase = heatmap;
     96     const float* boxInfoBase = boxes;
     97     float* outputScoreBase = outputScoreData;
     98     float* outputKeypointBase = outputKeypointData;
     99     for (uint32_t i = 0; i < numBoxes; i++) {
    100         NN_RET_CHECK_LE(boxInfoBase[0], boxInfoBase[2]);
    101         NN_RET_CHECK_LE(boxInfoBase[1], boxInfoBase[3]);
    102         for (uint32_t j = 0; j < numKeypoints; j++) {
    103             // find max score and its index
    104             uint32_t maxIndex = 0;
    105             float maxScore = -FLT_MAX;
    106             for (uint32_t k = 0; k < heatmapSize * heatmapSize; k++) {
    107                 float val = heatmapBase[k * numKeypoints + j];
    108                 if (maxScore < val) {
    109                     maxScore = val;
    110                     maxIndex = k;
    111                 }
    112             }
    113 
    114             uint32_t maxIndexWidth = maxIndex % heatmapSize;
    115             uint32_t maxIndexHeight = maxIndex / heatmapSize;
    116 
    117             // get local 3x3 grid
    118             float localGrid[3][3];
    119             for (int32_t dh = -1; dh <= 1; dh++) {
    120                 for (int32_t dw = -1; dw <= 1; dw++) {
    121                     // cast uint32_t to int32_t
    122                     int32_t h = static_cast<int32_t>(maxIndexHeight) + dh;
    123                     int32_t w = static_cast<int32_t>(maxIndexWidth) + dw;
    124 
    125                     // use mirroring for out of bound indexing
    126                     // need to ensure heatmapSize >= 2
    127                     h = h < 0 ? 1 : (h >= heatmapSize ? heatmapSize - 2 : h);
    128                     w = w < 0 ? 1 : (w >= heatmapSize ? heatmapSize - 2 : w);
    129 
    130                     uint32_t heatmapIndex = static_cast<uint32_t>(h) * heatmapSize * numKeypoints +
    131                                             static_cast<uint32_t>(w) * numKeypoints + j;
    132                     localGrid[dh + 1][dw + 1] = heatmapBase[heatmapIndex];
    133                 }
    134             }
    135 
    136             float delta[2] = {0.0f, 0.0f}, deltaScore = maxScore;
    137             solveForDelta(localGrid, delta, &deltaScore, fpAtol, fpRtol);
    138 
    139             float wRoiStart = boxInfoBase[0];
    140             float hRoiStart = boxInfoBase[1];
    141             float wRoiEnd = boxInfoBase[2];
    142             float hRoiEnd = boxInfoBase[3];
    143             float roiWidth = wRoiEnd - wRoiStart;
    144             float roiHeight = hRoiEnd - hRoiStart;
    145             float wRelativePos = (static_cast<float>(maxIndexWidth) + delta[0] + 0.5f) /
    146                                  static_cast<float>(heatmapSize);
    147             float hRelativePos = (static_cast<float>(maxIndexHeight) + delta[1] + 0.5f) /
    148                                  static_cast<float>(heatmapSize);
    149             *outputScoreBase++ = deltaScore;
    150             outputKeypointBase[0] = wRelativePos * roiWidth + wRoiStart;
    151             outputKeypointBase[1] = hRelativePos * roiHeight + hRoiStart;
    152             outputKeypointBase += 2;
    153         }
    154         boxInfoBase += boxInfoLength;
    155         heatmapBase += heatmapSize * heatmapSize * numKeypoints;
    156     }
    157 
    158     return true;
    159 }
    160 
    161 inline bool heatmapMaxKeypointFloat32(const float* heatmap, const Shape& heatmapShape,
    162                                       const float* boxes, const Shape& boxesShape, bool layout,
    163                                       float* outputScoreData, const Shape& outputScoreShape,
    164                                       float* outputKeypointData, const Shape& outputKeypointShape,
    165                                       float fpAtol, float fpRtol) {
    166     std::vector<float> heatmap_nhwc;
    167     Shape heatmapShape_nhwc;
    168     if (layout) {
    169         NN_RET_CHECK(convertNchwToNhwc(heatmap, heatmapShape, &heatmap_nhwc, &heatmapShape_nhwc));
    170     }
    171     const float* heatmap_tmp = layout ? heatmap_nhwc.data() : heatmap;
    172     const Shape& heatmapShape_tmp = layout ? heatmapShape_nhwc : heatmapShape;
    173     return heatmapMaxKeypointFloat32Nhwc(heatmap_tmp, heatmapShape_tmp, boxes, boxesShape,
    174                                          outputScoreData, outputScoreShape, outputKeypointData,
    175                                          outputKeypointShape, fpAtol, fpRtol);
    176 }
    177 
    178 inline bool heatmapMaxKeypointQuant(const uint8_t* heatmap, const Shape& heatmapShape,
    179                                     const uint16_t* boxes, const Shape& boxesShape, bool layout,
    180                                     uint8_t* outputScoreData, const Shape& outputScoreShape,
    181                                     uint16_t* outputKeypointData, const Shape& outputKeypointShape,
    182                                     float fpAtol, float fpRtol) {
    183     std::vector<float> heatmap_float32(getNumberOfElements(heatmapShape));
    184     convertQuantToFloat32(heatmap, heatmapShape.scale, heatmapShape.offset, &heatmap_float32);
    185     std::vector<float> boxes_float32(getNumberOfElements(boxesShape));
    186     convertQuantToFloat32(boxes, boxesShape.scale, boxesShape.offset, &boxes_float32);
    187     std::vector<float> outputScore_float32(getNumberOfElements(outputScoreShape));
    188     std::vector<float> outputKeypoint_float32(getNumberOfElements(outputKeypointShape));
    189     NN_RET_CHECK(heatmapMaxKeypointFloat32(
    190             heatmap_float32.data(), heatmapShape, boxes_float32.data(), boxesShape, layout,
    191             outputScore_float32.data(), outputScoreShape, outputKeypoint_float32.data(),
    192             outputKeypointShape, fpAtol, fpRtol));
    193     convertFloat32ToQuant(outputScore_float32, outputScoreShape.scale, outputScoreShape.offset,
    194                           outputScoreData);
    195     convertFloat32ToQuant(outputKeypoint_float32, outputKeypointShape.scale,
    196                           outputKeypointShape.offset, outputKeypointData);
    197     return true;
    198 }
    199 
    200 }  // namespace
    201 
    202 bool validate(const IOperationValidationContext* context) {
    203     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    204     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    205     std::vector<OperandType> inExpectedTypes;
    206     std::vector<OperandType> outExpectedTypes;
    207     auto inputType = context->getInputType(kHeatmapTensor);
    208     if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_FLOAT16) {
    209         inExpectedTypes = {inputType, inputType, OperandType::BOOL};
    210         outExpectedTypes = {inputType, inputType};
    211     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    212         inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT16_ASYMM,
    213                            OperandType::BOOL};
    214         outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT16_ASYMM};
    215     } else {
    216         LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
    217         return false;
    218     }
    219     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
    220     NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
    221     return validateHalVersion(context, HalVersion::V1_2);
    222 }
    223 
    224 bool prepare(IOperationExecutionContext* context) {
    225     bool layout = context->getInputValue<bool>(kLayoutScalar);
    226     Shape heatmapShape = context->getInputShape(kHeatmapTensor);
    227     Shape boxesShape = context->getInputShape(kBoxesTensor);
    228     NN_RET_CHECK_EQ(getNumberOfDimensions(heatmapShape), 4);
    229     NN_RET_CHECK_EQ(getNumberOfDimensions(boxesShape), 2);
    230 
    231     uint32_t numBoxes = getSizeOfDimension(heatmapShape, 0);
    232     uint32_t heatmapSize = getSizeOfDimension(heatmapShape, 2);
    233     uint32_t numKeypoints = getSizeOfDimension(heatmapShape, layout ? 1 : 3);
    234     uint32_t boxInfoLength = getSizeOfDimension(boxesShape, 1);
    235     NN_RET_CHECK_EQ(getSizeOfDimension(heatmapShape, layout ? 3 : 1), heatmapSize);
    236     NN_RET_CHECK_GE(heatmapSize, 2);
    237     NN_RET_CHECK_EQ(getSizeOfDimension(boxesShape, 0), numBoxes);
    238     NN_RET_CHECK_EQ(boxInfoLength, 4);
    239 
    240     if (heatmapShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
    241         NN_RET_CHECK_EQ(boxesShape.scale, 0.125f);
    242         NN_RET_CHECK_EQ(boxesShape.offset, 0);
    243     }
    244 
    245     Shape outputScore = context->getOutputShape(kOutputScoreTensor);
    246     outputScore.type = heatmapShape.type;
    247     outputScore.dimensions = {numBoxes, numKeypoints};
    248     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScore));
    249 
    250     Shape outputKeypoint = context->getOutputShape(kOutputKeypointTensor);
    251     outputKeypoint.type = boxesShape.type;
    252     outputKeypoint.dimensions = {numBoxes, numKeypoints, 2};
    253     outputKeypoint.offset = 0;
    254     outputKeypoint.scale = 0.125f;
    255     NN_RET_CHECK(context->setOutputShape(kOutputKeypointTensor, outputKeypoint));
    256     return true;
    257 }
    258 
    259 bool execute(IOperationExecutionContext* context) {
    260     bool layout = context->getInputValue<bool>(kLayoutScalar);
    261     switch (context->getInputType(kHeatmapTensor)) {
    262         case OperandType::TENSOR_FLOAT16: {
    263             const auto heatmap = context->getInputBuffer<_Float16>(kHeatmapTensor);
    264             const auto heatmapShape = context->getInputShape(kHeatmapTensor);
    265             const auto boxes = context->getInputBuffer<_Float16>(kBoxesTensor);
    266             const auto boxesShape = context->getInputShape(kBoxesTensor);
    267             auto outputScoreData = context->getOutputBuffer<_Float16>(kOutputScoreTensor);
    268             const auto outputScoreShape = context->getOutputShape(kOutputScoreTensor);
    269             auto outputKeypointData = context->getOutputBuffer<_Float16>(kOutputKeypointTensor);
    270             const auto outputKeypointShape = context->getOutputShape(kOutputKeypointTensor);
    271             std::vector<float> heatmap_float32(getNumberOfElements(heatmapShape));
    272             convertFloat16ToFloat32(heatmap, &heatmap_float32);
    273             std::vector<float> boxes_float32(getNumberOfElements(boxesShape));
    274             convertFloat16ToFloat32(boxes, &boxes_float32);
    275             std::vector<float> outputScore_float32(getNumberOfElements(outputScoreShape));
    276             std::vector<float> outputKeypoint_float32(getNumberOfElements(outputKeypointShape));
    277             NN_RET_CHECK(heatmapMaxKeypointFloat32(
    278                     heatmap_float32.data(), heatmapShape, boxes_float32.data(), boxesShape, layout,
    279                     outputScore_float32.data(), outputScoreShape, outputKeypoint_float32.data(),
    280                     outputKeypointShape, 1e-3f, 1e-3f));
    281             convertFloat32ToFloat16(outputScore_float32, outputScoreData);
    282             convertFloat32ToFloat16(outputKeypoint_float32, outputKeypointData);
    283             return true;
    284         }
    285         case OperandType::TENSOR_FLOAT32: {
    286             return heatmapMaxKeypointFloat32(context->getInputBuffer<float>(kHeatmapTensor),
    287                                              context->getInputShape(kHeatmapTensor),
    288                                              context->getInputBuffer<float>(kBoxesTensor),
    289                                              context->getInputShape(kBoxesTensor), layout,
    290                                              context->getOutputBuffer<float>(kOutputScoreTensor),
    291                                              context->getOutputShape(kOutputScoreTensor),
    292                                              context->getOutputBuffer<float>(kOutputKeypointTensor),
    293                                              context->getOutputShape(kOutputKeypointTensor), 1e-5f,
    294                                              1e-5f);
    295         }
    296         case OperandType::TENSOR_QUANT8_ASYMM: {
    297             return heatmapMaxKeypointQuant(
    298                     context->getInputBuffer<uint8_t>(kHeatmapTensor),
    299                     context->getInputShape(kHeatmapTensor),
    300                     context->getInputBuffer<uint16_t>(kBoxesTensor),
    301                     context->getInputShape(kBoxesTensor), layout,
    302                     context->getOutputBuffer<uint8_t>(kOutputScoreTensor),
    303                     context->getOutputShape(kOutputScoreTensor),
    304                     context->getOutputBuffer<uint16_t>(kOutputKeypointTensor),
    305                     context->getOutputShape(kOutputKeypointTensor), 1e-5f, 1e-5f);
    306         }
    307         default:
    308             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
    309     }
    310 }
    311 
    312 }  // namespace heatmap_max_keypoint
    313 
    314 NN_REGISTER_OPERATION(HEATMAP_MAX_KEYPOINT, heatmap_max_keypoint::kOperationName,
    315                       heatmap_max_keypoint::validate, heatmap_max_keypoint::prepare,
    316                       heatmap_max_keypoint::execute);
    317 
    318 }  // namespace nn
    319 }  // namespace android
    320