1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Contains the implementation of the operations. 18 19 #define LOG_TAG "Operations" 20 21 #include "CpuOperationUtils.h" 22 #include "OperationResolver.h" 23 24 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h" 25 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h" 26 27 #include "Tracing.h" 28 29 #include <algorithm> 30 31 namespace android { 32 namespace nn { 33 namespace broadcast { 34 35 constexpr uint32_t kNumInputs = 3; 36 constexpr uint32_t kInputTensor1 = 0; 37 constexpr uint32_t kInputTensor2 = 1; 38 constexpr uint32_t kActivationScalar = 2; 39 40 constexpr uint32_t kNumOutputs = 1; 41 constexpr uint32_t kOutputTensor = 0; 42 43 namespace { 44 45 #define ANDROID_NN_MACRO_DISPATCH(macro) \ 46 switch (activation) { \ 47 case (int32_t)FusedActivationFunc::NONE: \ 48 macro(kNone); \ 49 break; \ 50 case (int32_t)FusedActivationFunc::RELU: \ 51 macro(kRelu); \ 52 break; \ 53 case (int32_t)FusedActivationFunc::RELU1: \ 54 macro(kRelu1); \ 55 break; \ 56 case (int32_t)FusedActivationFunc::RELU6: \ 57 macro(kRelu6); \ 58 break; \ 59 default: \ 60 LOG(ERROR) << "Unsupported fused activation function type"; \ 61 return false; \ 62 } 63 64 using binaryFunctionFloat32 = std::function<bool( 65 const float* in1, const Shape& shape1, const float* in2, const Shape& shape2, 66 int32_t activation, float* out, const Shape& shapeOut)>; 67 68 bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, 69 const Shape& shape2, int32_t activation, _Float16* out, 70 const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) { 71 std::vector<float> in1_float32(getNumberOfElements(shape1)); 72 convertFloat16ToFloat32(in1, &in1_float32); 73 std::vector<float> in2_float32(getNumberOfElements(shape2)); 74 convertFloat16ToFloat32(in2, &in2_float32); 75 std::vector<float> out_float32(getNumberOfElements(shapeOut)); 76 77 operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation, 78 out_float32.data(), shapeOut); 79 convertFloat32ToFloat16(out_float32, out); 80 81 return true; 82 } 83 84 bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2, 85 int32_t activation, float* out, const Shape& shapeOut) { 86 NNTRACE_TRANS("addFloat32"); 87 bool needBroadcast = !SameShape(shape1, shape2); 88 if (needBroadcast) { 89 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd"); 90 #define ANDROID_NN_BROADCAST_ADD(activation) \ 91 tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \ 92 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \ 93 convertShapeToDims(shapeOut)) 94 95 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD) 96 #undef ANDROID_NN_BROADCAST_ADD 97 } else { 98 NNTRACE_COMP_SWITCH("optimized_ops::Add"); 99 #define ANDROID_NN_ADD(activation) \ 100 tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \ 101 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \ 102 convertShapeToDims(shapeOut)) 103 104 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD) 105 #undef ANDROID_NN_ADD 106 } 107 108 return true; 109 } 110 111 bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2, 112 int32_t activation, _Float16* out, const Shape& shapeOut) { 113 NNTRACE_TRANS("addFloat16"); 114 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32); 115 } 116 117 bool addQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2, 118 int32_t activation, uint8_t* out, const Shape& shapeOut) { 119 NNTRACE_TRANS("addQuant8"); 120 bool needBroadcast = !SameShape(shape1, shape2); 121 122 const int32_t input1_offset = -shape1.offset; 123 const int32_t input2_offset = -shape2.offset; 124 const int32_t output_offset = shapeOut.offset; 125 const int left_shift = 20; 126 const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale); 127 const double real_input1_multiplier = shape1.scale / twice_max_input_scale; 128 const double real_input2_multiplier = shape2.scale / twice_max_input_scale; 129 const double real_output_multiplier = 130 twice_max_input_scale / ((1 << left_shift) * shapeOut.scale); 131 132 int32_t input1_multiplier; 133 int32_t input1_shift; 134 if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, 135 &input1_shift)) { 136 return false; 137 } 138 int32_t input2_multiplier; 139 int32_t input2_shift; 140 if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, 141 &input2_shift)) { 142 return false; 143 } 144 int32_t output_multiplier; 145 int32_t output_shift; 146 if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, 147 &output_shift)) { 148 return false; 149 } 150 int32_t output_activation_min; 151 int32_t output_activation_max; 152 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min, 153 &output_activation_max); 154 155 if (needBroadcast) { 156 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd"); 157 #define ANDROID_NN_BROADCAST_ADD(activation) \ 158 tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \ 159 left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier, \ 160 input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier, \ 161 input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \ 162 output_activation_max, out, convertShapeToDims(shapeOut)) 163 164 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD) 165 #undef ANDROID_NN_BROADCAST_ADD 166 } else { 167 NNTRACE_COMP_SWITCH("optimized_ops::Add"); 168 #define ANDROID_NN_NORMAL_ADD(activation) \ 169 tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \ 170 left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier, \ 171 input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier, \ 172 input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \ 173 output_activation_max, out, convertShapeToDims(shapeOut)) 174 175 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_NORMAL_ADD) 176 #undef ANDROID_NN_NORMAL_ADD 177 } 178 179 return true; 180 } 181 182 bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2, 183 int32_t activation, float* out, const Shape& shapeOut) { 184 NNTRACE_TRANS("mulFloat32"); 185 bool needBroadcast = !SameShape(shape1, shape2); 186 187 if (needBroadcast) { 188 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul"); 189 #define ANDROID_NN_BROADCAST_MUL(activation) \ 190 tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \ 191 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \ 192 convertShapeToDims(shapeOut)) 193 194 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL) 195 #undef ANDROID_NN_BROADCAST_MUL 196 } else { 197 float output_activation_min, output_activation_max; 198 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max); 199 200 NNTRACE_COMP_SWITCH("optimized_ops::Mul"); 201 tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), 202 output_activation_min, output_activation_max, out, 203 convertShapeToDims(shapeOut)); 204 } 205 206 return true; 207 } 208 209 bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2, 210 int32_t activation, _Float16* out, const Shape& shapeOut) { 211 NNTRACE_TRANS("mulFloat16"); 212 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32); 213 } 214 215 bool mulQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2, 216 int32_t activation, uint8_t* out, const Shape& shapeOut) { 217 NNTRACE_TRANS("mulQuant8"); 218 const int32_t input1_offset = -shape1.offset; 219 const int32_t input2_offset = -shape2.offset; 220 const int32_t output_offset = shapeOut.offset; 221 const double input_product_scale = shape1.scale * shape2.scale; 222 const double real_multiplier = input_product_scale / shapeOut.scale; 223 int32 output_multiplier; 224 int output_shift; 225 if (!QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, &output_shift)) { 226 return false; 227 } 228 int32_t output_activation_min; 229 int32_t output_activation_max; 230 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min, 231 &output_activation_max); 232 233 // Use BROADCAST version to handle the normal case. 234 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul"); 235 tflite::optimized_ops::BroadcastMul(in1, convertShapeToDims(shape1), input1_offset, in2, 236 convertShapeToDims(shape2), input2_offset, output_offset, 237 output_multiplier, output_shift, output_activation_min, 238 output_activation_max, out, convertShapeToDims(shapeOut)); 239 240 return true; 241 } 242 243 bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2, 244 int32_t activation, float* out, const Shape& shapeOut) { 245 NNTRACE_TRANS("subFloat32"); 246 NNTRACE_COMP_SWITCH("optimized_ops::Sub"); 247 tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), 248 out, convertShapeToDims(shapeOut)); 249 250 // TFLite does not apply activation to broadcast sub. 251 float output_activation_min, output_activation_max; 252 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max); 253 uint32_t numOutputElements = getNumberOfElements(shapeOut); 254 for (uint32_t i = 0; i < numOutputElements; i++) { 255 out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max); 256 } 257 return true; 258 } 259 260 bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2, 261 int32_t activation, _Float16* out, const Shape& shapeOut) { 262 NNTRACE_TRANS("subFloat16"); 263 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32); 264 } 265 266 bool subQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2, 267 int32_t activation, uint8_t* out, const Shape& shapeOut) { 268 NNTRACE_TRANS("subQuant8"); 269 270 const int32_t input1_offset = -shape1.offset; 271 const int32_t input2_offset = -shape2.offset; 272 const int32_t output_offset = shapeOut.offset; 273 const int left_shift = 20; 274 const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale); 275 const double real_input1_multiplier = shape1.scale / twice_max_input_scale; 276 const double real_input2_multiplier = shape2.scale / twice_max_input_scale; 277 const double real_output_multiplier = 278 twice_max_input_scale / ((1 << left_shift) * shapeOut.scale); 279 280 int32_t input1_multiplier; 281 int32_t input1_shift; 282 if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, 283 &input1_shift)) { 284 return false; 285 } 286 int32_t input2_multiplier; 287 int32_t input2_shift; 288 if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, 289 &input2_shift)) { 290 return false; 291 } 292 input2_multiplier *= -1; 293 int32_t output_multiplier; 294 int32_t output_shift; 295 if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, 296 &output_shift)) { 297 return false; 298 } 299 int32_t output_activation_min; 300 int32_t output_activation_max; 301 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min, 302 &output_activation_max); 303 304 // We are using tflite::optimized_ops::BroadcastAdd unconditionally here 305 // because tflite::optimized_ops::Add fails to pass some of the 306 // sub_quantized_different_scales tests. 307 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd"); 308 #define ANDROID_NN_BROADCAST_ADD(activation) \ 309 tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \ 310 left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier, \ 311 input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier, \ 312 input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \ 313 output_activation_max, out, convertShapeToDims(shapeOut)) 314 315 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD) 316 #undef ANDROID_NN_BROADCAST_ADD 317 318 return true; 319 } 320 321 bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2, 322 int32_t activation, float* out, const Shape& shapeOut) { 323 NNTRACE_TRANS("divFloat32"); 324 float output_activation_min, output_activation_max; 325 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max); 326 327 bool needBroadcast = !SameShape(shape1, shape2); 328 if (needBroadcast) { 329 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv"); 330 tflite::optimized_ops::BroadcastDiv( 331 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), 332 output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut)); 333 } else { 334 NNTRACE_COMP_SWITCH("optimized_ops::Div"); 335 tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), 336 output_activation_min, output_activation_max, out, 337 convertShapeToDims(shapeOut)); 338 } 339 return true; 340 } 341 342 bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2, 343 int32_t activation, _Float16* out, const Shape& shapeOut) { 344 NNTRACE_TRANS("divFloat16"); 345 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32); 346 } 347 348 } // namespace 349 350 bool validate(OperationType opType, const IOperationValidationContext* context) { 351 const HalVersion opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB) 352 ? HalVersion::V1_1 353 : HalVersion::V1_0; 354 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 355 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 356 auto inputType = context->getInputType(kInputTensor1); 357 if (inputType == OperandType::TENSOR_FLOAT32) { 358 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt))); 359 } else if (inputType == OperandType::TENSOR_FLOAT16) { 360 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt))); 361 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { 362 if (opType == OperationType::SUB) { 363 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt))); 364 } else if (opType == OperationType::DIV) { 365 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV"; 366 } else if (opType == OperationType::MUL) { 367 Shape output = context->getOutputShape(kOutputTensor); 368 Shape input1 = context->getInputShape(kInputTensor1); 369 Shape input2 = context->getInputShape(kInputTensor2); 370 NN_RET_CHECK_GT(output.scale, input1.scale * input2.scale); 371 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt))); 372 } else { 373 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt))); 374 } 375 } else { 376 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType); 377 } 378 return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) && 379 validateOutputTypes(context, {inputType}); 380 } 381 382 bool prepare(IOperationExecutionContext* context) { 383 Shape input1 = context->getInputShape(kInputTensor1); 384 Shape input2 = context->getInputShape(kInputTensor2); 385 Shape output = context->getOutputShape(kOutputTensor); 386 NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4); 387 NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4); 388 NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output)); 389 return context->setOutputShape(kOutputTensor, output); 390 } 391 392 bool executeAdd(IOperationExecutionContext* context) { 393 // Bypass execution in the case of zero-sized input. 394 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 395 switch (context->getInputType(kInputTensor1)) { 396 case OperandType::TENSOR_FLOAT16: 397 return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1), 398 context->getInputShape(kInputTensor1), 399 context->getInputBuffer<_Float16>(kInputTensor2), 400 context->getInputShape(kInputTensor2), 401 context->getInputValue<int32_t>(kActivationScalar), 402 context->getOutputBuffer<_Float16>(kOutputTensor), 403 context->getOutputShape(kOutputTensor)); 404 case OperandType::TENSOR_FLOAT32: 405 return addFloat32(context->getInputBuffer<float>(kInputTensor1), 406 context->getInputShape(kInputTensor1), 407 context->getInputBuffer<float>(kInputTensor2), 408 context->getInputShape(kInputTensor2), 409 context->getInputValue<int32_t>(kActivationScalar), 410 context->getOutputBuffer<float>(kOutputTensor), 411 context->getOutputShape(kOutputTensor)); 412 case OperandType::TENSOR_QUANT8_ASYMM: 413 return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1), 414 context->getInputShape(kInputTensor1), 415 context->getInputBuffer<uint8_t>(kInputTensor2), 416 context->getInputShape(kInputTensor2), 417 context->getInputValue<int32_t>(kActivationScalar), 418 context->getOutputBuffer<uint8_t>(kOutputTensor), 419 context->getOutputShape(kOutputTensor)); 420 default: 421 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD"; 422 } 423 } 424 425 bool executeMul(IOperationExecutionContext* context) { 426 // Bypass execution in the case of zero-sized input. 427 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 428 switch (context->getInputType(kInputTensor1)) { 429 case OperandType::TENSOR_FLOAT16: 430 return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1), 431 context->getInputShape(kInputTensor1), 432 context->getInputBuffer<_Float16>(kInputTensor2), 433 context->getInputShape(kInputTensor2), 434 context->getInputValue<int32_t>(kActivationScalar), 435 context->getOutputBuffer<_Float16>(kOutputTensor), 436 context->getOutputShape(kOutputTensor)); 437 case OperandType::TENSOR_FLOAT32: 438 return mulFloat32(context->getInputBuffer<float>(kInputTensor1), 439 context->getInputShape(kInputTensor1), 440 context->getInputBuffer<float>(kInputTensor2), 441 context->getInputShape(kInputTensor2), 442 context->getInputValue<int32_t>(kActivationScalar), 443 context->getOutputBuffer<float>(kOutputTensor), 444 context->getOutputShape(kOutputTensor)); 445 case OperandType::TENSOR_QUANT8_ASYMM: 446 return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1), 447 context->getInputShape(kInputTensor1), 448 context->getInputBuffer<uint8_t>(kInputTensor2), 449 context->getInputShape(kInputTensor2), 450 context->getInputValue<int32_t>(kActivationScalar), 451 context->getOutputBuffer<uint8_t>(kOutputTensor), 452 context->getOutputShape(kOutputTensor)); 453 default: 454 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL"; 455 } 456 } 457 458 bool executeSub(IOperationExecutionContext* context) { 459 // Bypass execution in the case of zero-sized input. 460 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 461 switch (context->getInputType(kInputTensor1)) { 462 case OperandType::TENSOR_FLOAT16: 463 return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1), 464 context->getInputShape(kInputTensor1), 465 context->getInputBuffer<_Float16>(kInputTensor2), 466 context->getInputShape(kInputTensor2), 467 context->getInputValue<int32_t>(kActivationScalar), 468 context->getOutputBuffer<_Float16>(kOutputTensor), 469 context->getOutputShape(kOutputTensor)); 470 case OperandType::TENSOR_FLOAT32: 471 return subFloat32(context->getInputBuffer<float>(kInputTensor1), 472 context->getInputShape(kInputTensor1), 473 context->getInputBuffer<float>(kInputTensor2), 474 context->getInputShape(kInputTensor2), 475 context->getInputValue<int32_t>(kActivationScalar), 476 context->getOutputBuffer<float>(kOutputTensor), 477 context->getOutputShape(kOutputTensor)); 478 case OperandType::TENSOR_QUANT8_ASYMM: 479 return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1), 480 context->getInputShape(kInputTensor1), 481 context->getInputBuffer<uint8_t>(kInputTensor2), 482 context->getInputShape(kInputTensor2), 483 context->getInputValue<int32_t>(kActivationScalar), 484 context->getOutputBuffer<uint8_t>(kOutputTensor), 485 context->getOutputShape(kOutputTensor)); 486 default: 487 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB"; 488 } 489 } 490 491 bool executeDiv(IOperationExecutionContext* context) { 492 // Bypass execution in the case of zero-sized input. 493 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 494 switch (context->getInputType(kInputTensor1)) { 495 case OperandType::TENSOR_FLOAT16: 496 return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1), 497 context->getInputShape(kInputTensor1), 498 context->getInputBuffer<_Float16>(kInputTensor2), 499 context->getInputShape(kInputTensor2), 500 context->getInputValue<int32_t>(kActivationScalar), 501 context->getOutputBuffer<_Float16>(kOutputTensor), 502 context->getOutputShape(kOutputTensor)); 503 case OperandType::TENSOR_FLOAT32: 504 return divFloat32(context->getInputBuffer<float>(kInputTensor1), 505 context->getInputShape(kInputTensor1), 506 context->getInputBuffer<float>(kInputTensor2), 507 context->getInputShape(kInputTensor2), 508 context->getInputValue<int32_t>(kActivationScalar), 509 context->getOutputBuffer<float>(kOutputTensor), 510 context->getOutputShape(kOutputTensor)); 511 default: 512 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV"; 513 } 514 } 515 516 } // namespace broadcast 517 518 using std::placeholders::_1; 519 NN_REGISTER_OPERATION(ADD, "ADD", std::bind(broadcast::validate, OperationType::ADD, _1), 520 broadcast::prepare, broadcast::executeAdd, .allowZeroSizedInput = true); 521 NN_REGISTER_OPERATION(MUL, "MUL", std::bind(broadcast::validate, OperationType::MUL, _1), 522 broadcast::prepare, broadcast::executeMul, .allowZeroSizedInput = true); 523 NN_REGISTER_OPERATION(SUB, "SUB", std::bind(broadcast::validate, OperationType::SUB, _1), 524 broadcast::prepare, broadcast::executeSub, .allowZeroSizedInput = true); 525 NN_REGISTER_OPERATION(DIV, "DIV", std::bind(broadcast::validate, OperationType::DIV, _1), 526 broadcast::prepare, broadcast::executeDiv, .allowZeroSizedInput = true); 527 528 } // namespace nn 529 } // namespace android 530