1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "CpuOperationUtils.h" 18 #include "OperationResolver.h" 19 #include "Operations.h" 20 21 #include "Utils.h" 22 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h" 23 24 #include "Tracing.h" 25 26 namespace android { 27 namespace nn { 28 namespace conv_2d { 29 30 constexpr char kOperationName[] = "CONV_2D"; 31 32 constexpr uint32_t kInputTensor = 0; 33 constexpr uint32_t kFilterTensor = 1; 34 constexpr uint32_t kBiasTensor = 2; 35 36 constexpr uint32_t kNumOutputs = 1; 37 constexpr uint32_t kOutputTensor = 0; 38 39 namespace { 40 41 // If possible we will use this static buffer for the tensor. 42 constexpr size_t kStaticBufferSize = 1605632; 43 char static_scratch_buffer[kStaticBufferSize]; 44 45 // executionMutex is used to protect concurrent access of the static_scratch_buffer 46 // and other non-threadsafe resources like gemmlowp::GemmContext. 47 // std::mutex is safe for pthreads on Android. 48 std::mutex executionMutex; 49 50 struct Conv2dParam { 51 int32_t padding_left, padding_right; 52 int32_t padding_top, padding_bottom; 53 int32_t stride_width, stride_height; 54 int32_t dilation_width_factor = 1, dilation_height_factor = 1; 55 int32_t activation; 56 bool useNchw = false; 57 58 bool initialize(const IOperationExecutionContext* context) { 59 uint32_t inCount = context->getNumInputs(); 60 int32_t padding_implicit = 0; 61 bool useImplicitPadding = false; 62 if ((inCount >= 8 && context->getInputType(7) == OperandType::BOOL) || inCount == 7) { 63 padding_implicit = context->getInputValue<int32_t>(3); 64 stride_width = context->getInputValue<int32_t>(4); 65 stride_height = context->getInputValue<int32_t>(5); 66 activation = context->getInputValue<int32_t>(6); 67 if (inCount >= 8) { 68 useNchw = context->getInputValue<bool>(7); 69 } 70 if (inCount == 10) { 71 dilation_width_factor = context->getInputValue<int32_t>(8); 72 dilation_height_factor = context->getInputValue<int32_t>(9); 73 } 74 useImplicitPadding = true; 75 } else if (inCount >= 10 && context->getInputType(7) == OperandType::INT32) { 76 padding_left = context->getInputValue<int32_t>(3); 77 padding_right = context->getInputValue<int32_t>(4); 78 padding_top = context->getInputValue<int32_t>(5); 79 padding_bottom = context->getInputValue<int32_t>(6); 80 stride_width = context->getInputValue<int32_t>(7); 81 stride_height = context->getInputValue<int32_t>(8); 82 activation = context->getInputValue<int32_t>(9); 83 if (inCount >= 11) { 84 useNchw = context->getInputValue<bool>(10); 85 } 86 if (inCount == 13) { 87 dilation_width_factor = context->getInputValue<int32_t>(11); 88 dilation_height_factor = context->getInputValue<int32_t>(12); 89 } 90 } else { 91 NN_RET_CHECK_FAIL() << "Unsupported input spec for operation " << kOperationName; 92 } 93 if (useImplicitPadding) { 94 Shape inputShape = context->getInputShape(kInputTensor); 95 Shape filterShape = context->getInputShape(kFilterTensor); 96 int32_t input_width = getSizeOfDimension(inputShape, useNchw ? 3 : 2); 97 int32_t input_height = getSizeOfDimension(inputShape, useNchw ? 2 : 1); 98 int32_t filter_width = getSizeOfDimension(filterShape, 2); 99 int32_t filter_height = getSizeOfDimension(filterShape, 1); 100 calculateExplicitPadding(input_width, stride_width, dilation_width_factor, filter_width, 101 padding_implicit, &padding_left, &padding_right); 102 calculateExplicitPadding(input_height, stride_height, dilation_height_factor, 103 filter_height, padding_implicit, &padding_top, 104 &padding_bottom); 105 } 106 NN_RET_CHECK_GE(padding_left, 0); 107 NN_RET_CHECK_GE(padding_right, 0); 108 NN_RET_CHECK_GE(padding_top, 0); 109 NN_RET_CHECK_GE(padding_bottom, 0); 110 NN_RET_CHECK_GT(stride_width, 0); 111 NN_RET_CHECK_GT(stride_height, 0); 112 NN_RET_CHECK_GT(dilation_width_factor, 0); 113 NN_RET_CHECK_GT(dilation_height_factor, 0); 114 NN_RET_CHECK_GE(activation, 0); 115 return true; 116 } 117 }; 118 119 #define ANDROID_NN_CONV_PARAMETERS(Type) \ 120 uint32_t height = getSizeOfDimension(inputShape, 1); \ 121 uint32_t width = getSizeOfDimension(inputShape, 2); \ 122 uint32_t filterHeight = getSizeOfDimension(filterShape, 1); \ 123 uint32_t filterWidth = getSizeOfDimension(filterShape, 2); \ 124 uint32_t outHeight = getSizeOfDimension(outputShape, 1); \ 125 uint32_t outWidth = getSizeOfDimension(outputShape, 2); \ 126 uint32_t inDepth = getSizeOfDimension(inputShape, 3); \ 127 \ 128 uint32_t paddingHeight = (uint32_t)padding_top; \ 129 uint32_t paddingWidth = (uint32_t)padding_left; \ 130 \ 131 tflite::Dims<4> im2colDim; \ 132 im2colDim.sizes[3] = (int)getSizeOfDimension(outputShape, 0); \ 133 im2colDim.sizes[2] = (int)getSizeOfDimension(outputShape, 1); \ 134 im2colDim.sizes[1] = (int)getSizeOfDimension(outputShape, 2); \ 135 im2colDim.sizes[0] = (int)inDepth * filterHeight * filterWidth; \ 136 \ 137 im2colDim.strides[0] = 1; \ 138 for (int i=1; i<4; i++) { \ 139 im2colDim.strides[i] = im2colDim.strides[i-1] * im2colDim.sizes[i-1]; \ 140 } \ 141 \ 142 Type* im2colData = nullptr; \ 143 uint64_t im2colByteSize = sizeof(Type); \ 144 std::unique_ptr<Type[]> im2colGuard; \ 145 for (int i=0; i<4; i++) { \ 146 im2colByteSize *= im2colDim.sizes[i]; \ 147 } \ 148 /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */ \ 149 if (im2colByteSize >= 0x7fffffff) { \ 150 LOG(ERROR) << "Conv size is too large, not enough memory"; \ 151 return false; \ 152 } \ 153 if (im2colByteSize <= kStaticBufferSize) { \ 154 im2colData = reinterpret_cast<Type *>(static_scratch_buffer); \ 155 } else { \ 156 im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)]; \ 157 if (im2colData == nullptr) { \ 158 LOG(ERROR) << "Conv size is too large, not enough memory"; \ 159 return false; \ 160 } \ 161 im2colGuard.reset(im2colData); \ 162 } 163 164 bool convNhwc(const float* inputData, const Shape& inputShape, const float* filterData, 165 const Shape& filterShape, const float* biasData, const Shape& biasShape, 166 int32_t padding_left, int32_t padding_right, int32_t padding_top, 167 int32_t padding_bottom, int32_t stride_width, int32_t stride_height, 168 int32_t dilation_width_factor, int32_t dilation_height_factor, int32_t activation, 169 float* outputData, const Shape& outputShape) { 170 NNTRACE_TRANS("convFloat32"); 171 172 ANDROID_NN_CONV_PARAMETERS(float) 173 174 float output_activation_min, output_activation_max; 175 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max); 176 177 // Prevent concurrent executions that may access the scratch buffer. 178 std::unique_lock<std::mutex> lock(executionMutex); 179 NNTRACE_COMP_SWITCH("optimized_ops::Conv"); 180 tflite::optimized_ops::Conv(inputData, convertShapeToDims(inputShape), filterData, 181 convertShapeToDims(filterShape), biasData, 182 convertShapeToDims(biasShape), stride_width, stride_height, 183 dilation_width_factor, dilation_height_factor, paddingWidth, 184 paddingHeight, output_activation_min, output_activation_max, 185 outputData, convertShapeToDims(outputShape), im2colData, im2colDim); 186 return true; 187 } 188 189 bool convNhwc(const uint8_t* inputData, const Shape& inputShape, const uint8_t* filterData, 190 const Shape& filterShape, const int32_t* biasData, const Shape& biasShape, 191 int32_t padding_left, int32_t padding_right, int32_t padding_top, 192 int32_t padding_bottom, int32_t stride_width, int32_t stride_height, 193 int32_t dilation_width_factor, int32_t dilation_height_factor, int32_t activation, 194 uint8_t* outputData, const Shape& outputShape) { 195 NNTRACE_TRANS("convQuant8"); 196 197 ANDROID_NN_CONV_PARAMETERS(uint8_t) 198 199 int32_t inputOffset = -inputShape.offset; 200 int32_t filterOffset = -filterShape.offset; 201 int32_t outputOffset = outputShape.offset; 202 203 double real_multiplier = 0.0; 204 int32_t output_multiplier = 0; 205 int32_t output_shift = 0; 206 int32_t output_activation_min = 0; 207 int32_t output_activation_max = 0; 208 209 NN_RET_CHECK(GetQuantizedConvolutionMultipler(inputShape, filterShape, biasShape, outputShape, 210 &real_multiplier)); 211 int exponent; 212 NN_RET_CHECK(QuantizeMultiplier(real_multiplier, &output_multiplier, &exponent)); 213 output_shift = -exponent; 214 CalculateActivationRangeUint8(activation, outputShape, &output_activation_min, 215 &output_activation_max); 216 217 static gemmlowp::GemmContext gemm_context; 218 219 // Prevent concurrent executions that may access the scratch buffer and 220 // gemm_context. 221 std::unique_lock<std::mutex> lock(executionMutex); 222 // Alow gemmlowp automatically decide how many threads to use. 223 gemm_context.set_max_num_threads(0); 224 225 NNTRACE_COMP_SWITCH("optimized_ops::Conv"); 226 tflite::optimized_ops::Conv( 227 inputData, convertShapeToDims(inputShape), inputOffset, filterData, 228 convertShapeToDims(filterShape), filterOffset, biasData, convertShapeToDims(biasShape), 229 stride_width, stride_height, dilation_width_factor, dilation_height_factor, 230 paddingWidth, paddingHeight, outputOffset, output_multiplier, output_shift, 231 output_activation_min, output_activation_max, outputData, 232 convertShapeToDims(outputShape), im2colData, im2colDim, &gemm_context); 233 return true; 234 } 235 236 bool convNhwc(const _Float16* inputData, const Shape& inputShape, const _Float16* filterData, 237 const Shape& filterShape, const _Float16* biasData, const Shape& biasShape, 238 int32_t padding_left, int32_t padding_right, int32_t padding_top, 239 int32_t padding_bottom, int32_t stride_width, int32_t stride_height, 240 int32_t dilation_width_factor, int32_t dilation_height_factor, int32_t activation, 241 _Float16* outputData, const Shape& outputShape) { 242 NNTRACE_TRANS("convFloat16"); 243 244 std::vector<float> inputData_float32(getNumberOfElements(inputShape)); 245 std::vector<float> filterData_float32(getNumberOfElements(filterShape)); 246 std::vector<float> biasData_float32(getNumberOfElements(biasShape)); 247 std::vector<float> outputData_float32(getNumberOfElements(outputShape)); 248 249 convertFloat16ToFloat32(inputData, &inputData_float32); 250 convertFloat16ToFloat32(filterData, &filterData_float32); 251 convertFloat16ToFloat32(biasData, &biasData_float32); 252 253 convNhwc(inputData_float32.data(), inputShape, filterData_float32.data(), filterShape, 254 biasData_float32.data(), biasShape, padding_left, padding_right, padding_top, 255 padding_bottom, stride_width, stride_height, dilation_width_factor, 256 dilation_height_factor, activation, outputData_float32.data(), outputShape); 257 convertFloat32ToFloat16(outputData_float32, outputData); 258 259 return true; 260 } 261 262 template <typename T_Input, typename T_Filter, typename T_Bias> 263 bool conv(const T_Input* inputData, const Shape& inputShape, const T_Filter* filterData, 264 const Shape& filterShape, const T_Bias* biasData, const Shape& biasShape, 265 int32_t padding_left, int32_t padding_right, int32_t padding_top, int32_t padding_bottom, 266 int32_t stride_width, int32_t stride_height, int32_t dilation_width_factor, 267 int32_t dilation_height_factor, int32_t activation, bool useNchw, T_Input* outputData, 268 const Shape& outputShape) { 269 InputWithLayout<T_Input> input(useNchw); 270 OutputWithLayout<T_Input> output(useNchw); 271 NN_RET_CHECK(input.initialize(inputData, inputShape)); 272 NN_RET_CHECK(output.initialize(outputData, outputShape)); 273 NN_RET_CHECK(convNhwc(input.getNhwcBuffer(), input.getNhwcShape(), filterData, filterShape, 274 biasData, biasShape, padding_left, padding_right, padding_top, 275 padding_bottom, stride_width, stride_height, dilation_width_factor, 276 dilation_height_factor, activation, output.getNhwcBuffer(), 277 output.getNhwcShape())); 278 NN_RET_CHECK(output.commit()); 279 return true; 280 } 281 282 bool convQuant8PerChannelNhwc(const uint8_t* inputData, const Shape& inputShape, 283 const int8_t* filterData, const Shape& filterShape, 284 const float* filterScales, const int32_t* biasData, 285 const Shape& biasShape, int32_t paddingLeft, int32_t paddingRight, 286 int32_t paddingTop, int32_t paddingBottom, int32_t strideWidth, 287 int32_t strideHeight, int32_t dilationWidthFactor, 288 int32_t dilationHeightFactor, int32_t activation, uint8_t* outputData, 289 const Shape& outputShape) { 290 NNTRACE_TRANS("convQuant8PerChannel"); 291 292 uint32_t numBatches = getSizeOfDimension(inputShape, 0); 293 uint32_t inputHeight = getSizeOfDimension(inputShape, 1); 294 uint32_t inputWidth = getSizeOfDimension(inputShape, 2); 295 uint32_t inputDepth = getSizeOfDimension(inputShape, 3); 296 uint32_t filterHeight = getSizeOfDimension(filterShape, 1); 297 uint32_t filterWidth = getSizeOfDimension(filterShape, 2); 298 uint32_t filterDepth = getSizeOfDimension(filterShape, 3); 299 uint32_t outputHeight = getSizeOfDimension(outputShape, 1); 300 uint32_t outputWidth = getSizeOfDimension(outputShape, 2); 301 uint32_t outputDepth = getSizeOfDimension(outputShape, 3); 302 303 int32_t inputOffset = -inputShape.offset; 304 int32_t outputOffset = outputShape.offset; 305 306 auto realMultiplier = std::vector<double>(outputDepth, .0f); 307 auto outputMultiplier = std::vector<int32_t>(outputDepth, 0); 308 auto outputShift = std::vector<int32_t>(outputDepth, .0f); 309 310 for (int i = 0; i < outputDepth; ++i) { 311 Shape filterChannelShape = filterShape; 312 filterChannelShape.scale = filterScales[i]; 313 Shape biasChannelShape = biasShape; 314 biasChannelShape.scale = filterScales[i] * inputShape.scale; 315 NN_RET_CHECK(GetQuantizedConvolutionMultipler( 316 inputShape, filterChannelShape, biasChannelShape, outputShape, &realMultiplier[i])); 317 int exponent; 318 NN_RET_CHECK(QuantizeMultiplier(realMultiplier[i], &outputMultiplier[i], &exponent)); 319 outputShift[i] = -exponent; 320 } 321 322 int32_t output_activation_min = 0, output_activation_max = 0; 323 CalculateActivationRangeUint8(activation, outputShape, &output_activation_min, 324 &output_activation_max); 325 const uint8_t* inputBase = inputData; 326 uint8_t* outPtr = outputData; 327 for (uint32_t b = 0; b < numBatches; b++) { 328 for (uint32_t h = 0; h < outputHeight; h++) { 329 for (uint32_t w = 0; w < outputWidth; w++) { 330 const int8_t* filterBase = filterData; 331 332 for (uint32_t d = 0; d < outputDepth; d++) { 333 int32_t wInputOrigin = static_cast<int32_t>(w) * strideWidth - paddingLeft; 334 int32_t hInputOrigin = static_cast<int32_t>(h) * strideHeight - paddingTop; 335 int32_t sum = 0.0f; 336 337 for (uint32_t i = 0; i < filterHeight; i++) { 338 for (uint32_t j = 0; j < filterWidth; j++) { 339 for (uint32_t k = 0; k < filterDepth; k++) { 340 int32_t hInput = hInputOrigin + 341 dilationHeightFactor * static_cast<int32_t>(i); 342 int32_t wInput = wInputOrigin + 343 dilationWidthFactor * static_cast<int32_t>(j); 344 uint32_t dInput = k; 345 if (hInput >= 0 && hInput < static_cast<int32_t>(inputHeight) && 346 wInput >= 0 && wInput < static_cast<int32_t>(inputWidth)) { 347 uint32_t filterIndex = 348 i * filterWidth * filterDepth + j * filterDepth + k; 349 uint32_t inputIndex = hInput * inputWidth * inputDepth + 350 wInput * inputDepth + dInput; 351 sum += (static_cast<int32_t>(filterBase[filterIndex])) * 352 (static_cast<int32_t>(inputBase[inputIndex]) + 353 inputOffset); 354 } 355 } 356 } 357 } 358 sum += biasData[d]; 359 sum = tflite::MultiplyByQuantizedMultiplier(sum, outputMultiplier[d], 360 -outputShift[d]); 361 sum += outputOffset; 362 sum = std::max(std::min(sum, output_activation_max), output_activation_min); 363 outPtr[d] = static_cast<uint8_t>(sum); 364 filterBase += filterHeight * filterWidth * filterDepth; 365 } 366 outPtr += outputDepth; 367 } 368 } 369 inputBase += inputHeight * inputWidth * inputDepth; 370 } 371 372 return true; 373 } 374 375 bool convQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape, 376 const int8_t* filterData, const Shape& filterShape, 377 const float* filterScales, const int32_t* biasData, 378 const Shape& biasShape, int32_t paddingLeft, int32_t paddingRight, 379 int32_t paddingTop, int32_t paddingBottom, int32_t strideWidth, 380 int32_t strideHeight, int32_t dilationWidthFactor, 381 int32_t dilationHeightFactor, int32_t activation, bool useNchw, 382 uint8_t* outputData, const Shape& outputShape) { 383 InputWithLayout<uint8_t> input(useNchw); 384 OutputWithLayout<uint8_t> output(useNchw); 385 NN_RET_CHECK(input.initialize(inputData, inputShape)); 386 NN_RET_CHECK(output.initialize(outputData, outputShape)); 387 NN_RET_CHECK(convQuant8PerChannelNhwc( 388 input.getNhwcBuffer(), input.getNhwcShape(), filterData, filterShape, filterScales, 389 biasData, biasShape, paddingLeft, paddingRight, paddingTop, paddingBottom, strideWidth, 390 strideHeight, dilationWidthFactor, dilationHeightFactor, activation, 391 output.getNhwcBuffer(), output.getNhwcShape())); 392 NN_RET_CHECK(output.commit()); 393 return true; 394 } 395 396 #undef ANDROID_NN_CONV_PARAMETERS 397 398 } // namespace 399 400 bool validate(const IOperationValidationContext* context) { 401 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 402 auto inputCount = context->getNumInputs(); 403 auto inputType = context->getInputType(kInputTensor); 404 auto filterType = context->getInputType(kFilterTensor); 405 std::vector<OperandType> inExpectedTypes; 406 if (inputType == OperandType::TENSOR_FLOAT32) { 407 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, 408 OperandType::TENSOR_FLOAT32, OperandType::INT32, 409 OperandType::INT32, OperandType::INT32, 410 OperandType::INT32}; 411 } else if (inputType == OperandType::TENSOR_FLOAT16) { 412 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, 413 OperandType::TENSOR_FLOAT16, OperandType::INT32, 414 OperandType::INT32, OperandType::INT32, 415 OperandType::INT32}; 416 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { 417 if (filterType == OperandType::TENSOR_QUANT8_ASYMM || 418 filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) { 419 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, 420 filterType, 421 OperandType::TENSOR_INT32, 422 OperandType::INT32, 423 OperandType::INT32, 424 OperandType::INT32, 425 OperandType::INT32}; 426 427 if (filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) { 428 NN_RET_CHECK_EQ( 429 context->getInputExtraParams(kFilterTensor).channelQuant().channelDim, 0) 430 << "Unsupported filter tensor channel dimension for operation " 431 << kOperationName; 432 } 433 } else { 434 NN_RET_CHECK_FAIL() << "Unsupported filter tensor type for operation " 435 << kOperationName; 436 } 437 } else { 438 NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName; 439 } 440 441 // NeuralNetworks.h specifies that ANEURALNETWORKS_CONV_2D's output must 442 // meet "outputScale > inputScale * filterScale" for the operand type 443 // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM before API level 29. For other 444 // operand types (e.g., ANEURALNETWORKS_TENSOR_FLOAT32), this constraint 445 // does not apply, so by default the constraint is met. 446 bool meetsQuantizedScaleConstraintBeforeV1_2 = true; 447 if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { 448 const float inputScale = context->getInputShape(kInputTensor).scale; 449 const float filterScale = context->getInputShape(kFilterTensor).scale; 450 const float outputScale = context->getInputShape(kOutputTensor).scale; 451 meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * filterScale); 452 } 453 454 bool withExplicitPadding = false; 455 bool withLayout = false; 456 bool withDilation = false; 457 if (inputCount >= 8) { 458 if (context->getInputType(7) == OperandType::INT32 && inputCount >= 10) { 459 std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32); 460 inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(), 461 explicitScalarTypes.end()); 462 withExplicitPadding = true; 463 } 464 int inputOffset = withExplicitPadding ? 3 : 0; 465 if (inputCount >= 8 + inputOffset) { 466 inExpectedTypes.push_back(OperandType::BOOL); 467 withLayout = true; 468 } 469 NN_RET_CHECK_NE(inputCount, 9 + inputOffset) 470 << "Provided only one dilation factor value, two values are requred for operation " 471 << kOperationName; 472 if (inputCount == 10 + inputOffset) { 473 inExpectedTypes.push_back(OperandType::INT32); 474 inExpectedTypes.push_back(OperandType::INT32); 475 withDilation = true; 476 } 477 } 478 479 if (filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || withLayout || withDilation || 480 !meetsQuantizedScaleConstraintBeforeV1_2) { 481 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); 482 } else { 483 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0)); 484 } 485 return validateInputTypes(context, inExpectedTypes) && 486 validateOutputTypes(context, {inputType}); 487 } 488 489 bool prepare(IOperationExecutionContext* context) { 490 Shape input = context->getInputShape(kInputTensor); 491 Shape filter = context->getInputShape(kFilterTensor); 492 Shape bias = context->getInputShape(kBiasTensor); 493 494 if (filter.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) { 495 NN_RET_CHECK(input.type == OperandType::TENSOR_QUANT8_ASYMM); 496 } else { 497 NN_RET_CHECK(input.type == filter.type); 498 } 499 if (input.type == OperandType::TENSOR_QUANT8_ASYMM) { 500 NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32); 501 } else { 502 NN_RET_CHECK(input.type == bias.type); 503 } 504 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4); 505 NN_RET_CHECK_EQ(getNumberOfDimensions(filter), 4); 506 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1); 507 508 Conv2dParam param; 509 NN_RET_CHECK(param.initialize(context)); 510 511 uint32_t batches = getSizeOfDimension(input, 0); 512 uint32_t height = getSizeOfDimension(input, param.useNchw ? 2 : 1); 513 uint32_t width = getSizeOfDimension(input, param.useNchw ? 3 : 2); 514 uint32_t channels_in = getSizeOfDimension(input, param.useNchw ? 1 : 3); 515 uint32_t channels_out = getSizeOfDimension(filter, 0); 516 uint32_t filterHeight = getSizeOfDimension(filter, 1); 517 uint32_t filterWidth = getSizeOfDimension(filter, 2); 518 // Only batches can be zero. 519 NN_RET_CHECK_EQ(channels_in, getSizeOfDimension(filter, 3)); 520 NN_RET_CHECK_EQ(channels_out, getSizeOfDimension(bias, 0)); 521 NN_RET_CHECK_GT(height, 0); 522 NN_RET_CHECK_GT(width, 0); 523 NN_RET_CHECK_GT(channels_in, 0); 524 NN_RET_CHECK_GT(channels_out, 0); 525 526 int32_t effectiveFilterWidth = (filterWidth - 1) * param.dilation_width_factor + 1; 527 int32_t effectiveFilterHeight = (filterHeight - 1) * param.dilation_height_factor + 1; 528 NN_RET_CHECK_GT(effectiveFilterWidth, param.padding_left); 529 NN_RET_CHECK_GT(effectiveFilterWidth, param.padding_right); 530 NN_RET_CHECK_GT(effectiveFilterHeight, param.padding_top); 531 NN_RET_CHECK_GT(effectiveFilterHeight, param.padding_bottom); 532 533 uint32_t outWidth = 534 computeOutSize(width, filterWidth, param.stride_width, param.dilation_width_factor, 535 param.padding_left, param.padding_right); 536 uint32_t outHeight = 537 computeOutSize(height, filterHeight, param.stride_height, param.dilation_height_factor, 538 param.padding_top, param.padding_bottom); 539 540 Shape output = context->getOutputShape(kOutputTensor); 541 output.type = input.type; 542 if (param.useNchw) { 543 output.dimensions = {batches, channels_out, outHeight, outWidth}; 544 } else { 545 output.dimensions = {batches, outHeight, outWidth, channels_out}; 546 } 547 return context->setOutputShape(kOutputTensor, output); 548 } 549 550 bool execute(IOperationExecutionContext* context) { 551 // Bypass execution in the case of zero-sized input. 552 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 553 Conv2dParam param; 554 NN_RET_CHECK(param.initialize(context)); 555 switch (context->getInputType(kInputTensor)) { 556 case OperandType::TENSOR_FLOAT32: 557 return conv(context->getInputBuffer<float>(kInputTensor), 558 context->getInputShape(kInputTensor), 559 context->getInputBuffer<float>(kFilterTensor), 560 context->getInputShape(kFilterTensor), 561 context->getInputBuffer<float>(kBiasTensor), 562 context->getInputShape(kBiasTensor), param.padding_left, 563 param.padding_right, param.padding_top, param.padding_bottom, 564 param.stride_width, param.stride_height, param.dilation_width_factor, 565 param.dilation_height_factor, param.activation, param.useNchw, 566 context->getOutputBuffer<float>(kOutputTensor), 567 context->getOutputShape(kOutputTensor)); 568 case OperandType::TENSOR_FLOAT16: 569 return conv(context->getInputBuffer<_Float16>(kInputTensor), 570 context->getInputShape(kInputTensor), 571 context->getInputBuffer<_Float16>(kFilterTensor), 572 context->getInputShape(kFilterTensor), 573 context->getInputBuffer<_Float16>(kBiasTensor), 574 context->getInputShape(kBiasTensor), param.padding_left, 575 param.padding_right, param.padding_top, param.padding_bottom, 576 param.stride_width, param.stride_height, param.dilation_width_factor, 577 param.dilation_height_factor, param.activation, param.useNchw, 578 context->getOutputBuffer<_Float16>(kOutputTensor), 579 context->getOutputShape(kOutputTensor)); 580 case OperandType::TENSOR_QUANT8_ASYMM: 581 if (context->getInputType(kFilterTensor) == 582 OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) { 583 return convQuant8PerChannel( 584 context->getInputBuffer<uint8_t>(kInputTensor), 585 context->getInputShape(kInputTensor), 586 context->getInputBuffer<int8_t>(kFilterTensor), 587 context->getInputShape(kFilterTensor), 588 context->getInputExtraParams(kFilterTensor).channelQuant().scales.data(), 589 context->getInputBuffer<int32_t>(kBiasTensor), 590 context->getInputShape(kBiasTensor), param.padding_left, 591 param.padding_right, param.padding_top, param.padding_bottom, 592 param.stride_width, param.stride_height, param.dilation_width_factor, 593 param.dilation_height_factor, param.activation, param.useNchw, 594 context->getOutputBuffer<uint8_t>(kOutputTensor), 595 context->getOutputShape(kOutputTensor)); 596 } else if (context->getInputType(kFilterTensor) == OperandType::TENSOR_QUANT8_ASYMM) { 597 return conv(context->getInputBuffer<uint8_t>(kInputTensor), 598 context->getInputShape(kInputTensor), 599 context->getInputBuffer<uint8_t>(kFilterTensor), 600 context->getInputShape(kFilterTensor), 601 context->getInputBuffer<int32_t>(kBiasTensor), 602 context->getInputShape(kBiasTensor), param.padding_left, 603 param.padding_right, param.padding_top, param.padding_bottom, 604 param.stride_width, param.stride_height, param.dilation_width_factor, 605 param.dilation_height_factor, param.activation, param.useNchw, 606 context->getOutputBuffer<uint8_t>(kOutputTensor), 607 context->getOutputShape(kOutputTensor)); 608 } else { 609 NN_RET_CHECK_FAIL() << "Unsupported filter type for operation " << kOperationName; 610 } 611 default: 612 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; 613 } 614 } 615 616 } // namespace conv_2d 617 618 NN_REGISTER_OPERATION(CONV_2D, conv_2d::kOperationName, conv_2d::validate, conv_2d::prepare, 619 conv_2d::execute, .allowZeroSizedInput = true); 620 621 } // namespace nn 622 } // namespace android 623