1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "CpuOperationUtils.h" 18 #include "OperationResolver.h" 19 #include "OperationsUtils.h" 20 21 #include <cfloat> 22 #include <cmath> 23 24 #include "Tracing.h" 25 26 namespace android { 27 namespace nn { 28 namespace roi_pooling { 29 30 constexpr char kOperationName[] = "ROI_POOLING"; 31 32 constexpr uint32_t kNumInputs = 8; 33 constexpr uint32_t kInputTensor = 0; 34 constexpr uint32_t kRoiTensor = 1; 35 constexpr uint32_t kBatchSplitTensor = 2; 36 constexpr uint32_t kOutputHeightScalar = 3; 37 constexpr uint32_t kOutputWidthScalar = 4; 38 constexpr uint32_t kHeightStrideSalar = 5; 39 constexpr uint32_t kWidthStrideScalar = 6; 40 constexpr uint32_t kLayoutScalar = 7; 41 42 constexpr uint32_t kNumOutputs = 1; 43 constexpr uint32_t kOutputTensor = 0; 44 45 namespace { 46 47 template <typename T_Input, typename T_Roi> 48 inline bool roiPoolingNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData, 49 const Shape& roiShape, const int32_t* batchSplitData, 50 const Shape& batchSplitShape, float heightStride, float widthStride, 51 T_Input* outputData, const Shape& outputShape) { 52 NNTRACE_TRANS("RoiPooling"); 53 54 const uint32_t kRoiDim = 4; 55 const T_Roi heightScale = 1.0f / heightStride; 56 const T_Roi widthScale = 1.0f / widthStride; 57 58 uint32_t numBatches = getSizeOfDimension(inputShape, 0); 59 uint32_t inHeight = getSizeOfDimension(inputShape, 1); 60 uint32_t inWidth = getSizeOfDimension(inputShape, 2); 61 uint32_t inDepth = getSizeOfDimension(inputShape, 3); 62 uint32_t outHeight = getSizeOfDimension(outputShape, 1); 63 uint32_t outWidth = getSizeOfDimension(outputShape, 2); 64 uint32_t numRois = getSizeOfDimension(roiShape, 0); 65 uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1); 66 67 T_Input* outPtr = outputData; 68 const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength; 69 uint32_t roiIndex = 0; 70 for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) { 71 uint32_t batchId = batchSplitData[roiIndex]; 72 // Check for malformed data 73 // 1. invalid batch id 74 // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight 75 // 3. Invalid region: x2 < x1 || y2 < y1 76 NN_RET_CHECK_GE(batchId, 0); 77 NN_RET_CHECK_LT(batchId, numBatches); 78 NN_RET_CHECK(roiInfo[0] >= 0); 79 NN_RET_CHECK(roiInfo[1] >= 0); 80 NN_RET_CHECK(roiInfo[2] >= 0); 81 NN_RET_CHECK(roiInfo[3] >= 0); 82 NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth); 83 NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight); 84 NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth); 85 NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight); 86 NN_RET_CHECK(roiInfo[0] <= roiInfo[2]); 87 NN_RET_CHECK(roiInfo[1] <= roiInfo[3]); 88 89 int32_t wRoiStart = std::round(static_cast<float>(roiInfo[0] * widthScale)); 90 int32_t hRoiStart = std::round(static_cast<float>(roiInfo[1] * heightScale)); 91 int32_t wRoiEnd = std::round(static_cast<float>(roiInfo[2] * widthScale)); 92 int32_t hRoiEnd = std::round(static_cast<float>(roiInfo[3] * heightScale)); 93 94 // Rois with width/height < 1 are considered malformed and are forced to be 1 95 T_Roi roiWidth = static_cast<T_Roi>(std::max(wRoiEnd - wRoiStart + 1, 1)); 96 T_Roi roiHeight = static_cast<T_Roi>(std::max(hRoiEnd - hRoiStart + 1, 1)); 97 T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth); 98 T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight); 99 100 const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth; 101 for (uint32_t i = 0; i < outHeight; i++) { 102 for (uint32_t j = 0; j < outWidth; j++) { 103 // Take floor on start, ceil on end, start included, end excluded, i.e. [start, end) 104 // end is guaranteed to larger than start by at least 1 105 uint32_t wStart = std::floor(static_cast<float>(wStepSize * j + wRoiStart)); 106 uint32_t wEnd = std::ceil(static_cast<float>(wStepSize * (j + 1) + wRoiStart)); 107 uint32_t hStart = std::floor(static_cast<float>(hStepSize * i + hRoiStart)); 108 uint32_t hEnd = std::ceil(static_cast<float>(hStepSize * (i + 1) + hRoiStart)); 109 110 wStart = std::min(wStart, inWidth); 111 wEnd = std::min(wEnd, inWidth); 112 hStart = std::min(hStart, inHeight); 113 hEnd = std::min(hEnd, inHeight); 114 115 for (uint32_t k = 0; k < inDepth; k++) { 116 T_Input maxValue = static_cast<T_Input>(inputShape.offset); 117 bool first = true; 118 for (uint32_t h = hStart; h < hEnd; h++) { 119 for (uint32_t w = wStart; w < wEnd; w++) { 120 T_Input inputValue = batchBase[h * inWidth * inDepth + w * inDepth + k]; 121 if (first || inputValue > maxValue) { 122 maxValue = inputValue; 123 first = false; 124 } 125 } 126 } 127 outPtr[k] = maxValue; 128 } 129 outPtr += inDepth; 130 } 131 } 132 } 133 return true; 134 } 135 136 template <typename T_Input, typename T_Roi> 137 inline bool roiPooling(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData, 138 const Shape& roiShape, const int32_t* batchSplitData, 139 const Shape& batchSplitShape, float heightStride, float widthStride, 140 bool useNchw, T_Input* outputData, const Shape& outputShape) { 141 InputWithLayout<T_Input> input(useNchw); 142 OutputWithLayout<T_Input> output(useNchw); 143 NN_RET_CHECK(input.initialize(inputData, inputShape)); 144 NN_RET_CHECK(output.initialize(outputData, outputShape)); 145 NN_RET_CHECK(roiPoolingNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape, 146 batchSplitData, batchSplitShape, heightStride, widthStride, 147 output.getNhwcBuffer(), output.getNhwcShape())); 148 NN_RET_CHECK(output.commit()); 149 return true; 150 } 151 152 template <> 153 inline bool roiPooling<uint8_t, uint16_t>(const uint8_t* inputData, const Shape& inputShape, 154 const uint16_t* roiData, const Shape& roiShape, 155 const int32_t* batchSplitData, 156 const Shape& batchSplitShape, float heightStride, 157 float widthStride, bool useNchw, uint8_t* outputData, 158 const Shape& outputShape) { 159 std::vector<float> roi_float32(getNumberOfElements(roiShape)); 160 convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32); 161 NN_RET_CHECK(roiPooling(inputData, inputShape, roi_float32.data(), roiShape, batchSplitData, 162 batchSplitShape, heightStride, widthStride, useNchw, outputData, 163 outputShape)); 164 return true; 165 } 166 167 } // namespace 168 169 bool validate(const IOperationValidationContext* context) { 170 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 171 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 172 std::vector<OperandType> inExpectedTypes; 173 auto inputType = context->getInputType(kInputTensor); 174 if (inputType == OperandType::TENSOR_FLOAT32) { 175 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 176 OperandType::TENSOR_INT32, OperandType::INT32, 177 OperandType::INT32, OperandType::FLOAT32, 178 OperandType::FLOAT32, OperandType::BOOL}; 179 } else if (inputType == OperandType::TENSOR_FLOAT16) { 180 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 181 OperandType::TENSOR_INT32, OperandType::INT32, 182 OperandType::INT32, OperandType::FLOAT16, 183 OperandType::FLOAT16, OperandType::BOOL}; 184 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { 185 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, 186 OperandType::TENSOR_QUANT16_ASYMM, 187 OperandType::TENSOR_INT32, 188 OperandType::INT32, 189 OperandType::INT32, 190 OperandType::FLOAT32, 191 OperandType::FLOAT32, 192 OperandType::BOOL}; 193 } else { 194 LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; 195 return false; 196 } 197 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); 198 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 199 return validateHalVersion(context, HalVersion::V1_2); 200 } 201 202 bool prepare(IOperationExecutionContext* context) { 203 bool useNchw = context->getInputValue<bool>(kLayoutScalar); 204 Shape input = context->getInputShape(kInputTensor); 205 Shape roiShape = context->getInputShape(kRoiTensor); 206 Shape batchSplitShape = context->getInputShape(kBatchSplitTensor); 207 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4); 208 NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2); 209 210 uint32_t numBatches = getSizeOfDimension(input, 0); 211 uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1); 212 uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2); 213 uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3); 214 uint32_t numRois = getSizeOfDimension(roiShape, 0); 215 NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4); 216 NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois); 217 218 auto outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar); 219 auto outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar); 220 float heightStride, widthStride; 221 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) { 222 heightStride = context->getInputValue<_Float16>(kHeightStrideSalar); 223 widthStride = context->getInputValue<_Float16>(kWidthStrideScalar); 224 } else { 225 heightStride = context->getInputValue<float>(kHeightStrideSalar); 226 widthStride = context->getInputValue<float>(kWidthStrideScalar); 227 } 228 NN_RET_CHECK_GT(outputHeight, 0); 229 NN_RET_CHECK_GT(outputWidth, 0); 230 NN_RET_CHECK_GT(heightStride, 0); 231 NN_RET_CHECK_GT(widthStride, 0); 232 233 if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) { 234 NN_RET_CHECK_EQ(roiShape.scale, 0.125f); 235 NN_RET_CHECK_EQ(roiShape.offset, 0); 236 } 237 238 Shape output = input; 239 if (useNchw) { 240 output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight), 241 static_cast<uint32_t>(outputWidth)}; 242 } else { 243 output.dimensions = {numRois, static_cast<uint32_t>(outputHeight), 244 static_cast<uint32_t>(outputWidth), inDepth}; 245 } 246 return context->setOutputShape(kOutputTensor, output); 247 } 248 249 bool execute(IOperationExecutionContext* context) { 250 switch (context->getInputType(kInputTensor)) { 251 case OperandType::TENSOR_FLOAT16: 252 return roiPooling(context->getInputBuffer<_Float16>(kInputTensor), 253 context->getInputShape(kInputTensor), 254 context->getInputBuffer<_Float16>(kRoiTensor), 255 context->getInputShape(kRoiTensor), 256 context->getInputBuffer<int32_t>(kBatchSplitTensor), 257 context->getInputShape(kBatchSplitTensor), 258 context->getInputValue<_Float16>(kHeightStrideSalar), 259 context->getInputValue<_Float16>(kWidthStrideScalar), 260 context->getInputValue<bool>(kLayoutScalar), 261 context->getOutputBuffer<_Float16>(kOutputTensor), 262 context->getOutputShape(kOutputTensor)); 263 case OperandType::TENSOR_FLOAT32: 264 return roiPooling(context->getInputBuffer<float>(kInputTensor), 265 context->getInputShape(kInputTensor), 266 context->getInputBuffer<float>(kRoiTensor), 267 context->getInputShape(kRoiTensor), 268 context->getInputBuffer<int32_t>(kBatchSplitTensor), 269 context->getInputShape(kBatchSplitTensor), 270 context->getInputValue<float>(kHeightStrideSalar), 271 context->getInputValue<float>(kWidthStrideScalar), 272 context->getInputValue<bool>(kLayoutScalar), 273 context->getOutputBuffer<float>(kOutputTensor), 274 context->getOutputShape(kOutputTensor)); 275 case OperandType::TENSOR_QUANT8_ASYMM: 276 return roiPooling(context->getInputBuffer<uint8_t>(kInputTensor), 277 context->getInputShape(kInputTensor), 278 context->getInputBuffer<uint16_t>(kRoiTensor), 279 context->getInputShape(kRoiTensor), 280 context->getInputBuffer<int32_t>(kBatchSplitTensor), 281 context->getInputShape(kBatchSplitTensor), 282 context->getInputValue<float>(kHeightStrideSalar), 283 context->getInputValue<float>(kWidthStrideScalar), 284 context->getInputValue<bool>(kLayoutScalar), 285 context->getOutputBuffer<uint8_t>(kOutputTensor), 286 context->getOutputShape(kOutputTensor)); 287 default: 288 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; 289 } 290 } 291 292 } // namespace roi_pooling 293 294 NN_REGISTER_OPERATION(ROI_POOLING, roi_pooling::kOperationName, roi_pooling::validate, 295 roi_pooling::prepare, roi_pooling::execute); 296 297 } // namespace nn 298 } // namespace android 299