Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Contains the implementation of the operations.
     18 
     19 #define LOG_TAG "Operations"
     20 
     21 #include "CpuOperationUtils.h"
     22 #include "OperationResolver.h"
     23 
     24 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
     25 #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
     26 
     27 #include "Tracing.h"
     28 
     29 #include <algorithm>
     30 
     31 namespace android {
     32 namespace nn {
     33 namespace broadcast {
     34 
     35 constexpr uint32_t kNumInputs = 3;
     36 constexpr uint32_t kInputTensor1 = 0;
     37 constexpr uint32_t kInputTensor2 = 1;
     38 constexpr uint32_t kActivationScalar = 2;
     39 
     40 constexpr uint32_t kNumOutputs = 1;
     41 constexpr uint32_t kOutputTensor = 0;
     42 
     43 namespace {
     44 
     45 #define ANDROID_NN_MACRO_DISPATCH(macro)                                \
     46     switch (activation) {                                               \
     47         case (int32_t)FusedActivationFunc::NONE:                        \
     48             macro(kNone);                                               \
     49             break;                                                      \
     50         case (int32_t)FusedActivationFunc::RELU:                        \
     51             macro(kRelu);                                               \
     52             break;                                                      \
     53         case (int32_t)FusedActivationFunc::RELU1:                       \
     54             macro(kRelu1);                                              \
     55             break;                                                      \
     56         case (int32_t)FusedActivationFunc::RELU6:                       \
     57             macro(kRelu6);                                              \
     58             break;                                                      \
     59         default:                                                        \
     60             LOG(ERROR) << "Unsupported fused activation function type"; \
     61             return false;                                               \
     62     }
     63 
     64 using binaryFunctionFloat32 = std::function<bool(
     65         const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
     66         int32_t activation, float* out, const Shape& shapeOut)>;
     67 
     68 bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
     69                             const Shape& shape2, int32_t activation, _Float16* out,
     70                             const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
     71     std::vector<float> in1_float32(getNumberOfElements(shape1));
     72     convertFloat16ToFloat32(in1, &in1_float32);
     73     std::vector<float> in2_float32(getNumberOfElements(shape2));
     74     convertFloat16ToFloat32(in2, &in2_float32);
     75     std::vector<float> out_float32(getNumberOfElements(shapeOut));
     76 
     77     operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
     78                      out_float32.data(), shapeOut);
     79     convertFloat32ToFloat16(out_float32, out);
     80 
     81     return true;
     82 }
     83 
     84 bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
     85                 int32_t activation, float* out, const Shape& shapeOut) {
     86     NNTRACE_TRANS("addFloat32");
     87     bool needBroadcast = !SameShape(shape1, shape2);
     88     if (needBroadcast) {
     89         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
     90 #define ANDROID_NN_BROADCAST_ADD(activation)                                              \
     91     tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
     92             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
     93             convertShapeToDims(shapeOut))
     94 
     95         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
     96 #undef ANDROID_NN_BROADCAST_ADD
     97     } else {
     98         NNTRACE_COMP_SWITCH("optimized_ops::Add");
     99 #define ANDROID_NN_ADD(activation)                                                 \
    100     tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(   \
    101             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
    102             convertShapeToDims(shapeOut))
    103 
    104         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
    105 #undef ANDROID_NN_ADD
    106     }
    107 
    108     return true;
    109 }
    110 
    111 bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
    112                 int32_t activation, _Float16* out, const Shape& shapeOut) {
    113     NNTRACE_TRANS("addFloat16");
    114     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
    115 }
    116 
    117 bool addQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
    118                int32_t activation, uint8_t* out, const Shape& shapeOut) {
    119     NNTRACE_TRANS("addQuant8");
    120     bool needBroadcast = !SameShape(shape1, shape2);
    121 
    122     const int32_t input1_offset = -shape1.offset;
    123     const int32_t input2_offset = -shape2.offset;
    124     const int32_t output_offset = shapeOut.offset;
    125     const int left_shift = 20;
    126     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
    127     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
    128     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
    129     const double real_output_multiplier =
    130             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
    131 
    132     int32_t input1_multiplier;
    133     int32_t input1_shift;
    134     if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
    135                                           &input1_shift)) {
    136         return false;
    137     }
    138     int32_t input2_multiplier;
    139     int32_t input2_shift;
    140     if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
    141                                           &input2_shift)) {
    142         return false;
    143     }
    144     int32_t output_multiplier;
    145     int32_t output_shift;
    146     if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
    147                                           &output_shift)) {
    148         return false;
    149     }
    150     int32_t output_activation_min;
    151     int32_t output_activation_max;
    152     CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
    153                                   &output_activation_max);
    154 
    155     if (needBroadcast) {
    156         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
    157 #define ANDROID_NN_BROADCAST_ADD(activation)                                                     \
    158     tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>(        \
    159             left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier,       \
    160             input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier,     \
    161             input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
    162             output_activation_max, out, convertShapeToDims(shapeOut))
    163 
    164         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
    165 #undef ANDROID_NN_BROADCAST_ADD
    166     } else {
    167         NNTRACE_COMP_SWITCH("optimized_ops::Add");
    168 #define ANDROID_NN_NORMAL_ADD(activation)                                                        \
    169     tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(                 \
    170             left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier,       \
    171             input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier,     \
    172             input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
    173             output_activation_max, out, convertShapeToDims(shapeOut))
    174 
    175         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_NORMAL_ADD)
    176 #undef ANDROID_NN_NORMAL_ADD
    177     }
    178 
    179     return true;
    180 }
    181 
    182 bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
    183                 int32_t activation, float* out, const Shape& shapeOut) {
    184     NNTRACE_TRANS("mulFloat32");
    185     bool needBroadcast = !SameShape(shape1, shape2);
    186 
    187     if (needBroadcast) {
    188         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
    189 #define ANDROID_NN_BROADCAST_MUL(activation)                                              \
    190     tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
    191             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
    192             convertShapeToDims(shapeOut))
    193 
    194         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
    195 #undef ANDROID_NN_BROADCAST_MUL
    196     } else {
    197         float output_activation_min, output_activation_max;
    198         CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
    199 
    200         NNTRACE_COMP_SWITCH("optimized_ops::Mul");
    201         tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
    202                                    output_activation_min, output_activation_max, out,
    203                                    convertShapeToDims(shapeOut));
    204     }
    205 
    206     return true;
    207 }
    208 
    209 bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
    210                 int32_t activation, _Float16* out, const Shape& shapeOut) {
    211     NNTRACE_TRANS("mulFloat16");
    212     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
    213 }
    214 
    215 bool mulQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
    216                int32_t activation, uint8_t* out, const Shape& shapeOut) {
    217     NNTRACE_TRANS("mulQuant8");
    218     const int32_t input1_offset = -shape1.offset;
    219     const int32_t input2_offset = -shape2.offset;
    220     const int32_t output_offset = shapeOut.offset;
    221     const double input_product_scale = shape1.scale * shape2.scale;
    222     const double real_multiplier = input_product_scale / shapeOut.scale;
    223     int32 output_multiplier;
    224     int output_shift;
    225     if (!QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, &output_shift)) {
    226         return false;
    227     }
    228     int32_t output_activation_min;
    229     int32_t output_activation_max;
    230     CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
    231                                   &output_activation_max);
    232 
    233     // Use BROADCAST version to handle the normal case.
    234     NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
    235     tflite::optimized_ops::BroadcastMul(in1, convertShapeToDims(shape1), input1_offset, in2,
    236                                         convertShapeToDims(shape2), input2_offset, output_offset,
    237                                         output_multiplier, output_shift, output_activation_min,
    238                                         output_activation_max, out, convertShapeToDims(shapeOut));
    239 
    240     return true;
    241 }
    242 
    243 bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
    244                 int32_t activation, float* out, const Shape& shapeOut) {
    245     NNTRACE_TRANS("subFloat32");
    246     NNTRACE_COMP_SWITCH("optimized_ops::Sub");
    247     tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
    248                                out, convertShapeToDims(shapeOut));
    249 
    250     // TFLite does not apply activation to broadcast sub.
    251     float output_activation_min, output_activation_max;
    252     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
    253     uint32_t numOutputElements = getNumberOfElements(shapeOut);
    254     for (uint32_t i = 0; i < numOutputElements; i++) {
    255         out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
    256     }
    257     return true;
    258 }
    259 
    260 bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
    261                 int32_t activation, _Float16* out, const Shape& shapeOut) {
    262     NNTRACE_TRANS("subFloat16");
    263     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
    264 }
    265 
    266 bool subQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
    267                int32_t activation, uint8_t* out, const Shape& shapeOut) {
    268     NNTRACE_TRANS("subQuant8");
    269 
    270     const int32_t input1_offset = -shape1.offset;
    271     const int32_t input2_offset = -shape2.offset;
    272     const int32_t output_offset = shapeOut.offset;
    273     const int left_shift = 20;
    274     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
    275     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
    276     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
    277     const double real_output_multiplier =
    278             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
    279 
    280     int32_t input1_multiplier;
    281     int32_t input1_shift;
    282     if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
    283                                           &input1_shift)) {
    284         return false;
    285     }
    286     int32_t input2_multiplier;
    287     int32_t input2_shift;
    288     if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
    289                                           &input2_shift)) {
    290         return false;
    291     }
    292     input2_multiplier *= -1;
    293     int32_t output_multiplier;
    294     int32_t output_shift;
    295     if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
    296                                           &output_shift)) {
    297         return false;
    298     }
    299     int32_t output_activation_min;
    300     int32_t output_activation_max;
    301     CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
    302                                   &output_activation_max);
    303 
    304     // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
    305     // because tflite::optimized_ops::Add fails to pass some of the
    306     // sub_quantized_different_scales tests.
    307     NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
    308 #define ANDROID_NN_BROADCAST_ADD(activation)                                                     \
    309     tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>(        \
    310             left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier,       \
    311             input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier,     \
    312             input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
    313             output_activation_max, out, convertShapeToDims(shapeOut))
    314 
    315     ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
    316 #undef ANDROID_NN_BROADCAST_ADD
    317 
    318     return true;
    319 }
    320 
    321 bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
    322                 int32_t activation, float* out, const Shape& shapeOut) {
    323     NNTRACE_TRANS("divFloat32");
    324     float output_activation_min, output_activation_max;
    325     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
    326 
    327     bool needBroadcast = !SameShape(shape1, shape2);
    328     if (needBroadcast) {
    329         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
    330         tflite::optimized_ops::BroadcastDiv(
    331                 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
    332                 output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
    333     } else {
    334         NNTRACE_COMP_SWITCH("optimized_ops::Div");
    335         tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
    336                                    output_activation_min, output_activation_max, out,
    337                                    convertShapeToDims(shapeOut));
    338     }
    339     return true;
    340 }
    341 
    342 bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
    343                 int32_t activation, _Float16* out, const Shape& shapeOut) {
    344     NNTRACE_TRANS("divFloat16");
    345     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
    346 }
    347 
    348 }  // namespace
    349 
    350 bool validate(OperationType opType, const IOperationValidationContext* context) {
    351     const HalVersion opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB)
    352                                               ? HalVersion::V1_1
    353                                               : HalVersion::V1_0;
    354     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    355     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    356     auto inputType = context->getInputType(kInputTensor1);
    357     if (inputType == OperandType::TENSOR_FLOAT32) {
    358         NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
    359     } else if (inputType == OperandType::TENSOR_FLOAT16) {
    360         NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
    361     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    362         if (opType == OperationType::SUB) {
    363             NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
    364         } else if (opType == OperationType::DIV) {
    365             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
    366         } else if (opType == OperationType::MUL) {
    367             Shape output = context->getOutputShape(kOutputTensor);
    368             Shape input1 = context->getInputShape(kInputTensor1);
    369             Shape input2 = context->getInputShape(kInputTensor2);
    370             NN_RET_CHECK_GT(output.scale, input1.scale * input2.scale);
    371             NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
    372         } else {
    373             NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
    374         }
    375     } else {
    376         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
    377     }
    378     return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) &&
    379            validateOutputTypes(context, {inputType});
    380 }
    381 
    382 bool prepare(IOperationExecutionContext* context) {
    383     Shape input1 = context->getInputShape(kInputTensor1);
    384     Shape input2 = context->getInputShape(kInputTensor2);
    385     Shape output = context->getOutputShape(kOutputTensor);
    386     NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
    387     NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
    388     NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
    389     return context->setOutputShape(kOutputTensor, output);
    390 }
    391 
    392 bool executeAdd(IOperationExecutionContext* context) {
    393     // Bypass execution in the case of zero-sized input.
    394     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    395     switch (context->getInputType(kInputTensor1)) {
    396         case OperandType::TENSOR_FLOAT16:
    397             return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
    398                               context->getInputShape(kInputTensor1),
    399                               context->getInputBuffer<_Float16>(kInputTensor2),
    400                               context->getInputShape(kInputTensor2),
    401                               context->getInputValue<int32_t>(kActivationScalar),
    402                               context->getOutputBuffer<_Float16>(kOutputTensor),
    403                               context->getOutputShape(kOutputTensor));
    404         case OperandType::TENSOR_FLOAT32:
    405             return addFloat32(context->getInputBuffer<float>(kInputTensor1),
    406                               context->getInputShape(kInputTensor1),
    407                               context->getInputBuffer<float>(kInputTensor2),
    408                               context->getInputShape(kInputTensor2),
    409                               context->getInputValue<int32_t>(kActivationScalar),
    410                               context->getOutputBuffer<float>(kOutputTensor),
    411                               context->getOutputShape(kOutputTensor));
    412         case OperandType::TENSOR_QUANT8_ASYMM:
    413             return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
    414                              context->getInputShape(kInputTensor1),
    415                              context->getInputBuffer<uint8_t>(kInputTensor2),
    416                              context->getInputShape(kInputTensor2),
    417                              context->getInputValue<int32_t>(kActivationScalar),
    418                              context->getOutputBuffer<uint8_t>(kOutputTensor),
    419                              context->getOutputShape(kOutputTensor));
    420         default:
    421             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
    422     }
    423 }
    424 
    425 bool executeMul(IOperationExecutionContext* context) {
    426     // Bypass execution in the case of zero-sized input.
    427     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    428     switch (context->getInputType(kInputTensor1)) {
    429         case OperandType::TENSOR_FLOAT16:
    430             return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
    431                               context->getInputShape(kInputTensor1),
    432                               context->getInputBuffer<_Float16>(kInputTensor2),
    433                               context->getInputShape(kInputTensor2),
    434                               context->getInputValue<int32_t>(kActivationScalar),
    435                               context->getOutputBuffer<_Float16>(kOutputTensor),
    436                               context->getOutputShape(kOutputTensor));
    437         case OperandType::TENSOR_FLOAT32:
    438             return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
    439                               context->getInputShape(kInputTensor1),
    440                               context->getInputBuffer<float>(kInputTensor2),
    441                               context->getInputShape(kInputTensor2),
    442                               context->getInputValue<int32_t>(kActivationScalar),
    443                               context->getOutputBuffer<float>(kOutputTensor),
    444                               context->getOutputShape(kOutputTensor));
    445         case OperandType::TENSOR_QUANT8_ASYMM:
    446             return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
    447                              context->getInputShape(kInputTensor1),
    448                              context->getInputBuffer<uint8_t>(kInputTensor2),
    449                              context->getInputShape(kInputTensor2),
    450                              context->getInputValue<int32_t>(kActivationScalar),
    451                              context->getOutputBuffer<uint8_t>(kOutputTensor),
    452                              context->getOutputShape(kOutputTensor));
    453         default:
    454             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
    455     }
    456 }
    457 
    458 bool executeSub(IOperationExecutionContext* context) {
    459     // Bypass execution in the case of zero-sized input.
    460     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    461     switch (context->getInputType(kInputTensor1)) {
    462         case OperandType::TENSOR_FLOAT16:
    463             return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
    464                               context->getInputShape(kInputTensor1),
    465                               context->getInputBuffer<_Float16>(kInputTensor2),
    466                               context->getInputShape(kInputTensor2),
    467                               context->getInputValue<int32_t>(kActivationScalar),
    468                               context->getOutputBuffer<_Float16>(kOutputTensor),
    469                               context->getOutputShape(kOutputTensor));
    470         case OperandType::TENSOR_FLOAT32:
    471             return subFloat32(context->getInputBuffer<float>(kInputTensor1),
    472                               context->getInputShape(kInputTensor1),
    473                               context->getInputBuffer<float>(kInputTensor2),
    474                               context->getInputShape(kInputTensor2),
    475                               context->getInputValue<int32_t>(kActivationScalar),
    476                               context->getOutputBuffer<float>(kOutputTensor),
    477                               context->getOutputShape(kOutputTensor));
    478         case OperandType::TENSOR_QUANT8_ASYMM:
    479             return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
    480                              context->getInputShape(kInputTensor1),
    481                              context->getInputBuffer<uint8_t>(kInputTensor2),
    482                              context->getInputShape(kInputTensor2),
    483                              context->getInputValue<int32_t>(kActivationScalar),
    484                              context->getOutputBuffer<uint8_t>(kOutputTensor),
    485                              context->getOutputShape(kOutputTensor));
    486         default:
    487             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
    488     }
    489 }
    490 
    491 bool executeDiv(IOperationExecutionContext* context) {
    492     // Bypass execution in the case of zero-sized input.
    493     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    494     switch (context->getInputType(kInputTensor1)) {
    495         case OperandType::TENSOR_FLOAT16:
    496             return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
    497                               context->getInputShape(kInputTensor1),
    498                               context->getInputBuffer<_Float16>(kInputTensor2),
    499                               context->getInputShape(kInputTensor2),
    500                               context->getInputValue<int32_t>(kActivationScalar),
    501                               context->getOutputBuffer<_Float16>(kOutputTensor),
    502                               context->getOutputShape(kOutputTensor));
    503         case OperandType::TENSOR_FLOAT32:
    504             return divFloat32(context->getInputBuffer<float>(kInputTensor1),
    505                               context->getInputShape(kInputTensor1),
    506                               context->getInputBuffer<float>(kInputTensor2),
    507                               context->getInputShape(kInputTensor2),
    508                               context->getInputValue<int32_t>(kActivationScalar),
    509                               context->getOutputBuffer<float>(kOutputTensor),
    510                               context->getOutputShape(kOutputTensor));
    511         default:
    512             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
    513     }
    514 }
    515 
    516 }  // namespace broadcast
    517 
    518 using std::placeholders::_1;
    519 NN_REGISTER_OPERATION(ADD, "ADD", std::bind(broadcast::validate, OperationType::ADD, _1),
    520                       broadcast::prepare, broadcast::executeAdd, .allowZeroSizedInput = true);
    521 NN_REGISTER_OPERATION(MUL, "MUL", std::bind(broadcast::validate, OperationType::MUL, _1),
    522                       broadcast::prepare, broadcast::executeMul, .allowZeroSizedInput = true);
    523 NN_REGISTER_OPERATION(SUB, "SUB", std::bind(broadcast::validate, OperationType::SUB, _1),
    524                       broadcast::prepare, broadcast::executeSub, .allowZeroSizedInput = true);
    525 NN_REGISTER_OPERATION(DIV, "DIV", std::bind(broadcast::validate, OperationType::DIV, _1),
    526                       broadcast::prepare, broadcast::executeDiv, .allowZeroSizedInput = true);
    527 
    528 }  // namespace nn
    529 }  // namespace android
    530