1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "CpuOperationUtils.h" 18 #include "OperationResolver.h" 19 #include "OperationsUtils.h" 20 21 #include <cfloat> 22 #include <cmath> 23 24 #include "Tracing.h" 25 #include "tensorflow/lite/kernels/internal/common.h" 26 27 namespace android { 28 namespace nn { 29 namespace roi_align { 30 31 constexpr char kOperationName[] = "ROI_ALIGN"; 32 33 constexpr uint32_t kNumInputs = 10; 34 constexpr uint32_t kInputTensor = 0; 35 constexpr uint32_t kRoiTensor = 1; 36 constexpr uint32_t kBatchSplitTensor = 2; 37 constexpr uint32_t kOutputHeightScalar = 3; 38 constexpr uint32_t kOutputWidthScalar = 4; 39 constexpr uint32_t kHeightStrideSalar = 5; 40 constexpr uint32_t kWidthStrideScalar = 6; 41 constexpr uint32_t kHeightSamplingRatioScalar = 7; 42 constexpr uint32_t kWidthSamplingRatioScalar = 8; 43 constexpr uint32_t kLayoutScalar = 9; 44 45 constexpr uint32_t kNumOutputs = 1; 46 constexpr uint32_t kOutputTensor = 0; 47 48 namespace { 49 50 template <typename T_Input, typename T_Roi> 51 inline bool roiAlignNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData, 52 const Shape& roiShape, const int32_t* batchSplitData, 53 const Shape& batchSplitShape, float heightStride, float widthStride, 54 int32_t heightSamplingRatio, int32_t widthSamplingRatio, 55 T_Input* outputData, const Shape& outputShape) { 56 NNTRACE_TRANS("RoiAlign"); 57 58 const uint32_t kRoiDim = 4; 59 const T_Roi heightScale = 1.0f / heightStride; 60 const T_Roi widthScale = 1.0f / widthStride; 61 62 uint32_t numBatches = getSizeOfDimension(inputShape, 0); 63 uint32_t inHeight = getSizeOfDimension(inputShape, 1); 64 uint32_t inWidth = getSizeOfDimension(inputShape, 2); 65 uint32_t inDepth = getSizeOfDimension(inputShape, 3); 66 uint32_t outHeight = getSizeOfDimension(outputShape, 1); 67 uint32_t outWidth = getSizeOfDimension(outputShape, 2); 68 uint32_t numRois = getSizeOfDimension(roiShape, 0); 69 uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1); 70 71 T_Input* outPtr = outputData; 72 const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength; 73 uint32_t roiIndex = 0; 74 for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) { 75 uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]); 76 // Check for malformed data 77 // 1. invalid batch id 78 // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight 79 // 3. Invalid region: x2 < x1 || y2 < y1 80 NN_RET_CHECK_GE(batchId, 0); 81 NN_RET_CHECK_LT(batchId, numBatches); 82 NN_RET_CHECK(roiInfo[0] >= 0); 83 NN_RET_CHECK(roiInfo[1] >= 0); 84 NN_RET_CHECK(roiInfo[2] >= 0); 85 NN_RET_CHECK(roiInfo[3] >= 0); 86 NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth); 87 NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight); 88 NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth); 89 NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight); 90 NN_RET_CHECK(roiInfo[0] <= roiInfo[2]); 91 NN_RET_CHECK(roiInfo[1] <= roiInfo[3]); 92 93 T_Roi wRoiStart = roiInfo[0] * widthScale; 94 T_Roi hRoiStart = roiInfo[1] * heightScale; 95 T_Roi wRoiEnd = roiInfo[2] * widthScale; 96 T_Roi hRoiEnd = roiInfo[3] * heightScale; 97 98 T_Roi roiWidth = std::max(static_cast<float>(wRoiEnd - wRoiStart), 1.0f); 99 T_Roi roiHeight = std::max(static_cast<float>(hRoiEnd - hRoiStart), 1.0f); 100 T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth); 101 T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight); 102 103 // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height 104 uint32_t wSamplingRatio = widthSamplingRatio > 0 ? widthSamplingRatio 105 : std::ceil(static_cast<float>(wStepSize)); 106 uint32_t hSamplingRatio = heightSamplingRatio > 0 107 ? heightSamplingRatio 108 : std::ceil(static_cast<float>(hStepSize)); 109 int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio; 110 T_Roi wBinSize = wStepSize / static_cast<T_Roi>(wSamplingRatio); 111 T_Roi hBinSize = hStepSize / static_cast<T_Roi>(hSamplingRatio); 112 113 const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth; 114 for (uint32_t i = 0; i < outHeight; i++) { 115 for (uint32_t j = 0; j < outWidth; j++) { 116 T_Roi wStart = wStepSize * j + wRoiStart; 117 T_Roi wEnd = wStepSize * (j + 1) + wRoiStart; 118 T_Roi hStart = hStepSize * i + hRoiStart; 119 T_Roi hEnd = hStepSize * (i + 1) + hRoiStart; 120 121 // initialize output to zero 122 for (uint32_t k = 0; k < inDepth; k++) outPtr[k] = 0; 123 124 // calculate the sum of the sampling points 125 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) { 126 for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) { 127 T_Roi y = hStart + hBinSize / 2 + hBinSize * yInd; 128 T_Roi x = wStart + wBinSize / 2 + wBinSize * xInd; 129 130 // bilinear interpolation of point (x,y) 131 // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)] 132 uint32_t x1 = std::floor(static_cast<float>(x)); 133 uint32_t y1 = std::floor(static_cast<float>(y)); 134 uint32_t x2 = x1 + 1, y2 = y1 + 1; 135 T_Roi dx1 = x - static_cast<T_Roi>(x1); 136 T_Roi dy1 = y - static_cast<T_Roi>(y1); 137 138 // dealing with out of bound samples 139 if (x1 >= inWidth - 1) { 140 x1 = x2 = inWidth - 1; 141 dx1 = 0; 142 } 143 if (y1 >= inHeight - 1) { 144 y1 = y2 = inHeight - 1; 145 dy1 = 0; 146 } 147 148 T_Roi dx2 = 1.0f - dx1, dy2 = 1.0f - dy1; 149 T_Roi ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1}; 150 uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth, 151 y1 * inWidth * inDepth + x2 * inDepth, 152 y2 * inWidth * inDepth + x1 * inDepth, 153 y2 * inWidth * inDepth + x2 * inDepth}; 154 155 for (uint32_t k = 0; k < inDepth; k++) { 156 T_Input interpolation = 0; 157 for (uint32_t c = 0; c < 4; c++) { 158 interpolation += ws[c] * batchBase[offsets[c] + k]; 159 } 160 outPtr[k] += interpolation; 161 } 162 } 163 } 164 165 // take average 166 for (uint32_t k = 0; k < inDepth; k++) 167 outPtr[k] /= static_cast<T_Input>(numSamplingPoints); 168 outPtr += inDepth; 169 } 170 } 171 } 172 return true; 173 } 174 175 template <> 176 inline bool roiAlignNhwc<uint8_t, uint16_t>(const uint8_t* inputData, const Shape& inputShape, 177 const uint16_t* roiData, const Shape& roiShape, 178 const int32_t* batchSplitData, 179 const Shape& batchSplitShape, float heightStride, 180 float widthStride, int32_t heightSamplingRatio, 181 int32_t widthSamplingRatio, uint8_t* outputData, 182 const Shape& outputShape) { 183 NNTRACE_TRANS("RoiAlignQuant8"); 184 185 constexpr float wScale = 1.0f / 255.0f; 186 constexpr uint32_t kRoiDim = 4; 187 const float heightScale = 1.0f / heightStride; 188 const float widthScale = 1.0f / widthStride; 189 190 uint32_t numBatches = getSizeOfDimension(inputShape, 0); 191 uint32_t inHeight = getSizeOfDimension(inputShape, 1); 192 uint32_t inWidth = getSizeOfDimension(inputShape, 2); 193 uint32_t inDepth = getSizeOfDimension(inputShape, 3); 194 uint32_t outHeight = getSizeOfDimension(outputShape, 1); 195 uint32_t outWidth = getSizeOfDimension(outputShape, 2); 196 uint32_t numRois = getSizeOfDimension(roiShape, 0); 197 uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1); 198 199 uint8_t* outPtr = outputData; 200 const uint16_t* roiDataEnd = roiData + numRois * roiInfoLength; 201 uint32_t roiIndex = 0; 202 for (const uint16_t* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) { 203 uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]); 204 float wRoiStart = static_cast<float>(roiInfo[0]) * widthScale * 0.125f; 205 float hRoiStart = static_cast<float>(roiInfo[1]) * heightScale * 0.125f; 206 float wRoiEnd = static_cast<float>(roiInfo[2]) * widthScale * 0.125f; 207 float hRoiEnd = static_cast<float>(roiInfo[3]) * heightScale * 0.125f; 208 209 // Check for malformed data 210 // 1. invalid batch id 211 // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight 212 // 3. Invalid region: x2 < x1 || y2 < y1 213 NN_RET_CHECK_GE(batchId, 0); 214 NN_RET_CHECK_LT(batchId, numBatches); 215 NN_RET_CHECK(wRoiStart <= inWidth); 216 NN_RET_CHECK(hRoiStart <= inHeight); 217 NN_RET_CHECK(wRoiEnd <= inWidth); 218 NN_RET_CHECK(hRoiEnd <= inHeight); 219 NN_RET_CHECK_LE(wRoiStart, wRoiEnd); 220 NN_RET_CHECK_LE(hRoiStart, hRoiEnd); 221 222 float roiWidth = std::max(wRoiEnd - wRoiStart, 1.0f); 223 float roiHeight = std::max(hRoiEnd - hRoiStart, 1.0f); 224 float wStepSize = roiWidth / static_cast<float>(outWidth); 225 float hStepSize = roiHeight / static_cast<float>(outHeight); 226 227 // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height 228 uint32_t wSamplingRatio = 229 widthSamplingRatio > 0 ? widthSamplingRatio : std::ceil(wStepSize); 230 uint32_t hSamplingRatio = 231 heightSamplingRatio > 0 ? heightSamplingRatio : std::ceil(hStepSize); 232 int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio; 233 float wBinSize = wStepSize / static_cast<float>(wSamplingRatio); 234 float hBinSize = hStepSize / static_cast<float>(hSamplingRatio); 235 236 float realMultiplier = inputShape.scale * wScale / outputShape.scale / numSamplingPoints; 237 int32_t outputMultiplier = 0; 238 int32_t outputShift = 0; 239 if (!QuantizeMultiplierSmallerThanOne(realMultiplier, &outputMultiplier, &outputShift)) { 240 return false; 241 } 242 243 const uint8_t* batchBase = inputData + batchId * inHeight * inWidth * inDepth; 244 for (uint32_t i = 0; i < outHeight; i++) { 245 for (uint32_t j = 0; j < outWidth; j++) { 246 float wStart = wStepSize * j + wRoiStart; 247 float wEnd = wStepSize * (j + 1) + wRoiStart; 248 float hStart = hStepSize * i + hRoiStart; 249 float hEnd = hStepSize * (i + 1) + hRoiStart; 250 251 std::vector<int32_t> outTemp(inDepth, 0); 252 // calculate the sum of the sampling points 253 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) { 254 for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) { 255 float y = hStart + hBinSize / 2 + hBinSize * yInd; 256 float x = wStart + wBinSize / 2 + wBinSize * xInd; 257 258 // bilinear interpolation of point (x,y) 259 // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)] 260 uint32_t x1 = std::floor(x), y1 = std::floor(y); 261 uint32_t x2 = x1 + 1, y2 = y1 + 1; 262 float dx1 = x - static_cast<float>(x1); 263 float dy1 = y - static_cast<float>(y1); 264 265 // dealing with out of bound samples 266 if (x1 >= inWidth - 1) { 267 x1 = x2 = inWidth - 1; 268 dx1 = 0; 269 } 270 if (y1 >= inHeight - 1) { 271 y1 = y2 = inHeight - 1; 272 dy1 = 0; 273 } 274 275 float dx2 = 1.0f - dx1, dy2 = 1.0f - dy1; 276 float ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1}; 277 uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth, 278 y1 * inWidth * inDepth + x2 * inDepth, 279 y2 * inWidth * inDepth + x1 * inDepth, 280 y2 * inWidth * inDepth + x2 * inDepth}; 281 282 for (uint32_t k = 0; k < inDepth; k++) { 283 int32_t interpolation = 0; 284 for (uint32_t c = 0; c < 4; c++) { 285 int32_t wQuant = static_cast<int32_t>(std::round(ws[c] / wScale)); 286 interpolation += 287 wQuant * (static_cast<int32_t>(batchBase[offsets[c] + k]) - 288 inputShape.offset); 289 } 290 outTemp[k] += interpolation; 291 } 292 } 293 } 294 295 // take average and cast to output quantization 296 for (uint32_t k = 0; k < inDepth; k++) { 297 int32_t raw_out = tflite::MultiplyByQuantizedMultiplier( 298 outTemp[k], outputMultiplier, -outputShift) + 299 outputShape.offset; 300 int32_t clamped_out = std::min(255, std::max(0, raw_out)); 301 outPtr[k] = static_cast<uint8_t>(clamped_out); 302 } 303 outPtr += inDepth; 304 } 305 } 306 } 307 return true; 308 } 309 310 template <typename T_Input, typename T_Roi> 311 inline bool roiAlign(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData, 312 const Shape& roiShape, const int32_t* batchSplitData, 313 const Shape& batchSplitShape, float heightStride, float widthStride, 314 int32_t heightSamplingRatio, int32_t widthSamplingRatio, bool useNchw, 315 T_Input* outputData, const Shape& outputShape) { 316 InputWithLayout<T_Input> input(useNchw); 317 OutputWithLayout<T_Input> output(useNchw); 318 NN_RET_CHECK(input.initialize(inputData, inputShape)); 319 NN_RET_CHECK(output.initialize(outputData, outputShape)); 320 NN_RET_CHECK(roiAlignNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape, 321 batchSplitData, batchSplitShape, heightStride, widthStride, 322 heightSamplingRatio, widthSamplingRatio, output.getNhwcBuffer(), 323 output.getNhwcShape())); 324 NN_RET_CHECK(output.commit()); 325 return true; 326 } 327 328 } // namespace 329 330 bool validate(const IOperationValidationContext* context) { 331 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 332 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 333 std::vector<OperandType> inExpectedTypes; 334 auto inputType = context->getInputType(kInputTensor); 335 if (inputType == OperandType::TENSOR_FLOAT32) { 336 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 337 OperandType::TENSOR_INT32, OperandType::INT32, 338 OperandType::INT32, OperandType::FLOAT32, 339 OperandType::FLOAT32, OperandType::INT32, 340 OperandType::INT32, OperandType::BOOL}; 341 } else if (inputType == OperandType::TENSOR_FLOAT16) { 342 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 343 OperandType::TENSOR_INT32, OperandType::INT32, 344 OperandType::INT32, OperandType::FLOAT16, 345 OperandType::FLOAT16, OperandType::INT32, 346 OperandType::INT32, OperandType::BOOL}; 347 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { 348 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, 349 OperandType::TENSOR_QUANT16_ASYMM, 350 OperandType::TENSOR_INT32, 351 OperandType::INT32, 352 OperandType::INT32, 353 OperandType::FLOAT32, 354 OperandType::FLOAT32, 355 OperandType::INT32, 356 OperandType::INT32, 357 OperandType::BOOL}; 358 } else { 359 LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; 360 return false; 361 } 362 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); 363 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 364 return validateHalVersion(context, HalVersion::V1_2); 365 } 366 367 bool prepare(IOperationExecutionContext* context) { 368 bool useNchw = context->getInputValue<bool>(kLayoutScalar); 369 Shape input = context->getInputShape(kInputTensor); 370 Shape roiShape = context->getInputShape(kRoiTensor); 371 Shape batchSplitShape = context->getInputShape(kBatchSplitTensor); 372 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4); 373 NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2); 374 375 uint32_t numBatches = getSizeOfDimension(input, 0); 376 uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1); 377 uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2); 378 uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3); 379 uint32_t numRois = getSizeOfDimension(roiShape, 0); 380 // Every dimension must be positive except for numRois. 381 NN_RET_CHECK_GT(numBatches, 0); 382 NN_RET_CHECK_GT(inHeight, 0); 383 NN_RET_CHECK_GT(inWidth, 0); 384 NN_RET_CHECK_GT(inDepth, 0); 385 NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4); 386 NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois); 387 388 int32_t outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar); 389 int32_t outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar); 390 int32_t heightSamplingRatio = context->getInputValue<int32_t>(kHeightSamplingRatioScalar); 391 int32_t widthSamplingRatio = context->getInputValue<int32_t>(kWidthSamplingRatioScalar); 392 float heightScale, widthScale; 393 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) { 394 heightScale = context->getInputValue<_Float16>(kHeightStrideSalar); 395 widthScale = context->getInputValue<_Float16>(kWidthStrideScalar); 396 } else { 397 heightScale = context->getInputValue<float>(kHeightStrideSalar); 398 widthScale = context->getInputValue<float>(kWidthStrideScalar); 399 } 400 NN_RET_CHECK_GT(outputHeight, 0); 401 NN_RET_CHECK_GT(outputWidth, 0); 402 NN_RET_CHECK_GT(heightScale, 0); 403 NN_RET_CHECK_GT(widthScale, 0); 404 // Sampling ratio can set to 0 for adaptive value. 405 NN_RET_CHECK_GE(heightSamplingRatio, 0); 406 NN_RET_CHECK_GE(widthSamplingRatio, 0); 407 408 if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) { 409 NN_RET_CHECK_EQ(roiShape.scale, 0.125f); 410 NN_RET_CHECK_EQ(roiShape.offset, 0); 411 } 412 413 Shape output = context->getOutputShape(kOutputTensor); 414 output.type = input.type; 415 if (useNchw) { 416 output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight), 417 static_cast<uint32_t>(outputWidth)}; 418 } else { 419 output.dimensions = {numRois, static_cast<uint32_t>(outputHeight), 420 static_cast<uint32_t>(outputWidth), inDepth}; 421 } 422 return context->setOutputShape(kOutputTensor, output); 423 } 424 425 bool execute(IOperationExecutionContext* context) { 426 // Bypass execution in the case of zero-sized input. 427 if (getNumberOfElements(context->getInputShape(kRoiTensor)) == 0) return true; 428 switch (context->getInputType(kInputTensor)) { 429 case OperandType::TENSOR_FLOAT16: 430 return roiAlign(context->getInputBuffer<_Float16>(kInputTensor), 431 context->getInputShape(kInputTensor), 432 context->getInputBuffer<_Float16>(kRoiTensor), 433 context->getInputShape(kRoiTensor), 434 context->getInputBuffer<int32_t>(kBatchSplitTensor), 435 context->getInputShape(kBatchSplitTensor), 436 context->getInputValue<_Float16>(kHeightStrideSalar), 437 context->getInputValue<_Float16>(kWidthStrideScalar), 438 context->getInputValue<int32_t>(kHeightSamplingRatioScalar), 439 context->getInputValue<int32_t>(kWidthSamplingRatioScalar), 440 context->getInputValue<bool>(kLayoutScalar), 441 context->getOutputBuffer<_Float16>(kOutputTensor), 442 context->getOutputShape(kOutputTensor)); 443 case OperandType::TENSOR_FLOAT32: 444 return roiAlign(context->getInputBuffer<float>(kInputTensor), 445 context->getInputShape(kInputTensor), 446 context->getInputBuffer<float>(kRoiTensor), 447 context->getInputShape(kRoiTensor), 448 context->getInputBuffer<int32_t>(kBatchSplitTensor), 449 context->getInputShape(kBatchSplitTensor), 450 context->getInputValue<float>(kHeightStrideSalar), 451 context->getInputValue<float>(kWidthStrideScalar), 452 context->getInputValue<int32_t>(kHeightSamplingRatioScalar), 453 context->getInputValue<int32_t>(kWidthSamplingRatioScalar), 454 context->getInputValue<bool>(kLayoutScalar), 455 context->getOutputBuffer<float>(kOutputTensor), 456 context->getOutputShape(kOutputTensor)); 457 case OperandType::TENSOR_QUANT8_ASYMM: 458 return roiAlign(context->getInputBuffer<uint8_t>(kInputTensor), 459 context->getInputShape(kInputTensor), 460 context->getInputBuffer<uint16_t>(kRoiTensor), 461 context->getInputShape(kRoiTensor), 462 context->getInputBuffer<int32_t>(kBatchSplitTensor), 463 context->getInputShape(kBatchSplitTensor), 464 context->getInputValue<float>(kHeightStrideSalar), 465 context->getInputValue<float>(kWidthStrideScalar), 466 context->getInputValue<int32_t>(kHeightSamplingRatioScalar), 467 context->getInputValue<int32_t>(kWidthSamplingRatioScalar), 468 context->getInputValue<bool>(kLayoutScalar), 469 context->getOutputBuffer<uint8_t>(kOutputTensor), 470 context->getOutputShape(kOutputTensor)); 471 default: 472 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; 473 } 474 } 475 476 } // namespace roi_align 477 478 NN_REGISTER_OPERATION(ROI_ALIGN, roi_align::kOperationName, roi_align::validate, roi_align::prepare, 479 roi_align::execute, .allowZeroSizedInput = true); 480 481 } // namespace nn 482 } // namespace android 483