Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "ActivationFunctor.h"
     18 #include "CpuOperationUtils.h"
     19 #include "OperationResolver.h"
     20 
     21 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
     22 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
     23 
     24 #include "Tracing.h"
     25 
     26 namespace android {
     27 namespace nn {
     28 
     29 namespace activation {
     30 
     31 constexpr uint32_t kNumInputs = 1;
     32 constexpr uint32_t kInputTensor = 0;
     33 
     34 constexpr uint32_t kNumOutputs = 1;
     35 constexpr uint32_t kOutputTensor = 0;
     36 
     37 namespace {
     38 
     39 template <typename T>
     40 bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape,
     41                float reluMin = 0.f, float reluMax = std::numeric_limits<float>::max()) {
     42     NNTRACE_COMP("reluX");
     43     int numElements = getNumberOfElements(inputShape);
     44     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
     45         *outputData = static_cast<T>(
     46                 std::min(std::max(reluMin, static_cast<float>(*inputData)), reluMax));
     47     }
     48     return true;
     49 }
     50 template bool reluFloat<float>(const float* inputData, const Shape& inputShape, float* outputData,
     51                                const Shape& outputShape, float reluMin, float reluMax);
     52 template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
     53                                   _Float16* outputData, const Shape& outputShape, float reluMin,
     54                                   float reluMax);
     55 
     56 template <typename T>
     57 bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData,
     58                 const Shape& outputShape) {
     59     return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f);
     60 }
     61 template bool relu1Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
     62                                 const Shape& outputShape);
     63 template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
     64                                    _Float16* outputData, const Shape& outputShape);
     65 
     66 template <typename T>
     67 bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData,
     68                 const Shape& outputShape) {
     69     return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f);
     70 }
     71 template bool relu6Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
     72                                 const Shape& outputShape);
     73 template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
     74                                    _Float16* outputData, const Shape& outputShape);
     75 
     76 bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData,
     77                  const Shape& outputShape) {
     78     NNTRACE_COMP("tanhFloat16");
     79     int numElements = getNumberOfElements(inputShape);
     80     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
     81         *outputData = static_cast<_Float16>(std::tanh(static_cast<float>(*inputData)));
     82     }
     83     return true;
     84 }
     85 
     86 bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData,
     87                  const Shape& outputShape) {
     88     NNTRACE_COMP("tanhFloat32");
     89     int numElements = getNumberOfElements(inputShape);
     90     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
     91         *outputData = std::tanh(*inputData);
     92     }
     93     return true;
     94 }
     95 
     96 template <typename T>
     97 bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData,
     98                    const Shape& outputShape) {
     99     NNTRACE_COMP("logisticFloat");
    100     int numElements = getNumberOfElements(inputShape);
    101     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
    102         *outputData = static_cast<T>(1.f / (1.f + std::exp(static_cast<float>(-*inputData))));
    103     }
    104     return true;
    105 }
    106 template bool logisticFloat<float>(const float* inputData, const Shape& inputShape,
    107                                    float* outputData, const Shape& outputShape);
    108 template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
    109                                       _Float16* outputData, const Shape& outputShape);
    110 
    111 #define ANDROID_NN_RELUX_QUANT8(activation)                                           \
    112     int numElements = getNumberOfElements(inputShape);                                \
    113     int32_t output_activation_min = 0;                                                \
    114     int32_t output_activation_max = 0;                                                \
    115                                                                                       \
    116     CalculateActivationRangeUint8(activation, inputShape, &output_activation_min,     \
    117                                   &output_activation_max);                            \
    118                                                                                       \
    119     for (int i = 0; i < numElements; i++, inputData++, outputData++) {                \
    120         *outputData = std::min((uint8_t)output_activation_max,                        \
    121                                std::max((uint8_t)output_activation_min, *inputData)); \
    122     }
    123 
    124 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
    125                 const Shape& outputShape) {
    126     NNTRACE_COMP("reluQuant8");
    127     ANDROID_NN_RELUX_QUANT8(kActivationRelu)
    128     return true;
    129 }
    130 
    131 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
    132                  const Shape& outputShape) {
    133     NNTRACE_COMP("relu1Quant8");
    134     ANDROID_NN_RELUX_QUANT8(kActivationRelu1)
    135     return true;
    136 }
    137 
    138 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
    139                  const Shape& outputShape) {
    140     NNTRACE_COMP("relu6Quant8");
    141     ANDROID_NN_RELUX_QUANT8(kActivationRelu6)
    142     return true;
    143 }
    144 
    145 #undef ANDROID_NN_RELUX_QUANT8
    146 
    147 bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
    148                 const Shape& outputShape) {
    149     NNTRACE_TRANS("tanhQuant8");
    150     if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) {
    151         LOG(ERROR) << "incorrect scale or offset for TANH output";
    152         return false;
    153     }
    154 
    155     int numElements = getNumberOfElements(inputShape);
    156     static constexpr int kInputIntegerBits = 4;
    157 
    158     const double input_real_multiplier =
    159             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
    160 
    161     int32_t input_multiplier = 0;
    162     int32_t input_left_shift = 0;
    163     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
    164                                           &input_left_shift)) {
    165         return false;
    166     }
    167     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
    168 
    169     NNTRACE_COMP_SWITCH("optimized_ops::Tanh");
    170     tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset,
    171                                 input_range_radius, input_multiplier, input_left_shift, outputData,
    172                                 convertShapeToTflshape(outputShape));
    173 
    174     return true;
    175 }
    176 
    177 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
    178                     const Shape& outputShape) {
    179     NNTRACE_TRANS("logisticQuant8");
    180     if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
    181         LOG(ERROR) << "incorrect scale / offset for output";
    182         return false;
    183     }
    184 
    185     int numElements = getNumberOfElements(inputShape);
    186     static constexpr int kInputIntegerBits = 4;
    187 
    188     const double input_real_multiplier =
    189             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
    190 
    191     int32_t input_multiplier = 0;
    192     int32_t input_left_shift = 0;
    193     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
    194                                           &input_left_shift)) {
    195         return false;
    196     }
    197     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
    198 
    199     NNTRACE_COMP_SWITCH("optimized_ops::Logistic");
    200     tflite::optimized_ops::Logistic(
    201             inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius,
    202             input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape));
    203 
    204     return true;
    205 }
    206 
    207 }  // namespace
    208 
    209 bool validate(OperationType opType, const IOperationValidationContext* context) {
    210     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    211     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    212     auto inputType = context->getInputType(kInputTensor);
    213     if (inputType == OperandType::TENSOR_FLOAT32) {
    214         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
    215     } else if (inputType == OperandType::TENSOR_FLOAT16) {
    216         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    217     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    218         if (opType == OperationType::TANH) {
    219             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    220         } else {
    221             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
    222         }
    223     } else {
    224         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
    225     }
    226     return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
    227 }
    228 
    229 bool prepare(OperationType opType, IOperationExecutionContext* context) {
    230     Shape input = context->getInputShape(kInputTensor);
    231     NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
    232     Shape output = input;
    233     if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
    234         switch (opType) {
    235             case OperationType::RELU:
    236             case OperationType::RELU1:
    237             case OperationType::RELU6:
    238                 break;
    239             case OperationType::LOGISTIC:
    240                 output.scale = 1.f / 256;
    241                 output.offset = 0;
    242                 break;
    243             case OperationType::TANH:
    244                 output.scale = 1.f / 128;
    245                 output.offset = 128;
    246                 break;
    247             default:
    248                 NN_RET_CHECK_FAIL() << "Unsupported operation type";
    249         }
    250     }
    251     return context->setOutputShape(kOutputTensor, output);
    252 }
    253 
    254 bool executeRelu(IOperationExecutionContext* context) {
    255     // Bypass execution in the case of zero-sized input.
    256     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    257     switch (context->getInputType(kInputTensor)) {
    258         case OperandType::TENSOR_FLOAT16:
    259             return reluFloat(context->getInputBuffer<_Float16>(kInputTensor),
    260                              context->getInputShape(kInputTensor),
    261                              context->getOutputBuffer<_Float16>(kOutputTensor),
    262                              context->getOutputShape(kOutputTensor));
    263         case OperandType::TENSOR_FLOAT32:
    264             return reluFloat(context->getInputBuffer<float>(kInputTensor),
    265                              context->getInputShape(kInputTensor),
    266                              context->getOutputBuffer<float>(kOutputTensor),
    267                              context->getOutputShape(kOutputTensor));
    268         case OperandType::TENSOR_QUANT8_ASYMM:
    269             return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
    270                               context->getInputShape(kInputTensor),
    271                               context->getOutputBuffer<uint8_t>(kOutputTensor),
    272                               context->getOutputShape(kOutputTensor));
    273         default:
    274             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU";
    275     }
    276 }
    277 
    278 bool executeRelu1(IOperationExecutionContext* context) {
    279     // Bypass execution in the case of zero-sized input.
    280     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    281     switch (context->getInputType(kInputTensor)) {
    282         case OperandType::TENSOR_FLOAT16:
    283             return relu1Float(context->getInputBuffer<_Float16>(kInputTensor),
    284                               context->getInputShape(kInputTensor),
    285                               context->getOutputBuffer<_Float16>(kOutputTensor),
    286                               context->getOutputShape(kOutputTensor));
    287         case OperandType::TENSOR_FLOAT32:
    288             return relu1Float(context->getInputBuffer<float>(kInputTensor),
    289                               context->getInputShape(kInputTensor),
    290                               context->getOutputBuffer<float>(kOutputTensor),
    291                               context->getOutputShape(kOutputTensor));
    292         case OperandType::TENSOR_QUANT8_ASYMM:
    293             return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
    294                                context->getInputShape(kInputTensor),
    295                                context->getOutputBuffer<uint8_t>(kOutputTensor),
    296                                context->getOutputShape(kOutputTensor));
    297         default:
    298             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1";
    299     }
    300 }
    301 
    302 bool executeRelu6(IOperationExecutionContext* context) {
    303     // Bypass execution in the case of zero-sized input.
    304     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    305     switch (context->getInputType(kInputTensor)) {
    306         case OperandType::TENSOR_FLOAT16:
    307             return relu6Float(context->getInputBuffer<_Float16>(kInputTensor),
    308                               context->getInputShape(kInputTensor),
    309                               context->getOutputBuffer<_Float16>(kOutputTensor),
    310                               context->getOutputShape(kOutputTensor));
    311         case OperandType::TENSOR_FLOAT32:
    312             return relu6Float(context->getInputBuffer<float>(kInputTensor),
    313                               context->getInputShape(kInputTensor),
    314                               context->getOutputBuffer<float>(kOutputTensor),
    315                               context->getOutputShape(kOutputTensor));
    316         case OperandType::TENSOR_QUANT8_ASYMM:
    317             return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
    318                                context->getInputShape(kInputTensor),
    319                                context->getOutputBuffer<uint8_t>(kOutputTensor),
    320                                context->getOutputShape(kOutputTensor));
    321         default:
    322             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6";
    323     }
    324 }
    325 
    326 bool executeLogistic(IOperationExecutionContext* context) {
    327     // Bypass execution in the case of zero-sized input.
    328     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    329     switch (context->getInputType(kInputTensor)) {
    330         case OperandType::TENSOR_FLOAT16:
    331             return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor),
    332                                  context->getInputShape(kInputTensor),
    333                                  context->getOutputBuffer<_Float16>(kOutputTensor),
    334                                  context->getOutputShape(kOutputTensor));
    335         case OperandType::TENSOR_FLOAT32:
    336             return logisticFloat(context->getInputBuffer<float>(kInputTensor),
    337                                  context->getInputShape(kInputTensor),
    338                                  context->getOutputBuffer<float>(kOutputTensor),
    339                                  context->getOutputShape(kOutputTensor));
    340         case OperandType::TENSOR_QUANT8_ASYMM:
    341             return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
    342                                   context->getInputShape(kInputTensor),
    343                                   context->getOutputBuffer<uint8_t>(kOutputTensor),
    344                                   context->getOutputShape(kOutputTensor));
    345         default:
    346             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
    347     }
    348 }
    349 
    350 bool executeTanh(IOperationExecutionContext* context) {
    351     // Bypass execution in the case of zero-sized input.
    352     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    353     switch (context->getInputType(kInputTensor)) {
    354         case OperandType::TENSOR_FLOAT16:
    355             return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor),
    356                                context->getInputShape(kInputTensor),
    357                                context->getOutputBuffer<_Float16>(kOutputTensor),
    358                                context->getOutputShape(kOutputTensor));
    359         case OperandType::TENSOR_FLOAT32:
    360             return tanhFloat32(context->getInputBuffer<float>(kInputTensor),
    361                                context->getInputShape(kInputTensor),
    362                                context->getOutputBuffer<float>(kOutputTensor),
    363                                context->getOutputShape(kOutputTensor));
    364         case OperandType::TENSOR_QUANT8_ASYMM:
    365             return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
    366                               context->getInputShape(kInputTensor),
    367                               context->getOutputBuffer<uint8_t>(kOutputTensor),
    368                               context->getOutputShape(kOutputTensor));
    369         default:
    370             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
    371     }
    372 }
    373 
    374 }  // namespace activation
    375 
    376 using std::placeholders::_1;
    377 NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1),
    378                       std::bind(activation::prepare, OperationType::RELU, _1),
    379                       activation::executeRelu, .allowZeroSizedInput = true);
    380 NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1),
    381                       std::bind(activation::prepare, OperationType::RELU1, _1),
    382                       activation::executeRelu1, .allowZeroSizedInput = true);
    383 NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1),
    384                       std::bind(activation::prepare, OperationType::RELU6, _1),
    385                       activation::executeRelu6, .allowZeroSizedInput = true);
    386 NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC",
    387                       std::bind(activation::validate, OperationType::LOGISTIC, _1),
    388                       std::bind(activation::prepare, OperationType::LOGISTIC, _1),
    389                       activation::executeLogistic, .allowZeroSizedInput = true);
    390 NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1),
    391                       std::bind(activation::prepare, OperationType::TANH, _1),
    392                       activation::executeTanh, .allowZeroSizedInput = true);
    393 
    394 }  // namespace nn
    395 }  // namespace android
    396