Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "CpuOperationUtils.h"
     18 #include "OperationResolver.h"
     19 
     20 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
     21 
     22 #include "Tracing.h"
     23 
     24 namespace android {
     25 namespace nn {
     26 namespace pooling {
     27 
     28 constexpr uint32_t kInputTensor = 0;
     29 
     30 constexpr uint32_t kNumOutputs = 1;
     31 constexpr uint32_t kOutputTensor = 0;
     32 
     33 namespace {
     34 
     35 struct PoolingParam {
     36     int32_t padding_left, padding_right;
     37     int32_t padding_top, padding_bottom;
     38     int32_t stride_width, stride_height;
     39     int32_t filter_width, filter_height;
     40     int32_t activation;
     41     bool useNchw = false;
     42 
     43     bool initialize(const IOperationExecutionContext* context) {
     44         uint32_t inCount = context->getNumInputs();
     45         int32_t padding_implicit = 0;
     46         if (inCount >= 10) {
     47             padding_left = context->getInputValue<int32_t>(1);
     48             padding_right = context->getInputValue<int32_t>(2);
     49             padding_top = context->getInputValue<int32_t>(3);
     50             padding_bottom = context->getInputValue<int32_t>(4);
     51             stride_width = context->getInputValue<int32_t>(5);
     52             stride_height = context->getInputValue<int32_t>(6);
     53             filter_width = context->getInputValue<int32_t>(7);
     54             filter_height = context->getInputValue<int32_t>(8);
     55             activation = context->getInputValue<int32_t>(9);
     56             if (inCount == 11) {
     57                 useNchw = context->getInputValue<bool>(10);
     58             }
     59         } else {
     60             padding_implicit = context->getInputValue<int32_t>(1);
     61             stride_width = context->getInputValue<int32_t>(2);
     62             stride_height = context->getInputValue<int32_t>(3);
     63             filter_width = context->getInputValue<int32_t>(4);
     64             filter_height = context->getInputValue<int32_t>(5);
     65             activation = context->getInputValue<int32_t>(6);
     66             if (inCount == 8) {
     67                 useNchw = context->getInputValue<bool>(7);
     68             }
     69         }
     70         if (inCount <= 8) {
     71             Shape inputShape = context->getInputShape(kInputTensor);
     72             int32_t input_height = getSizeOfDimension(inputShape, useNchw ? 2 : 1);
     73             int32_t input_width = getSizeOfDimension(inputShape, useNchw ? 3 : 2);
     74             calculateExplicitPadding(input_width, stride_width, filter_width, padding_implicit,
     75                                      &padding_left, &padding_right);
     76             calculateExplicitPadding(input_height, stride_height, filter_height, padding_implicit,
     77                                      &padding_top, &padding_bottom);
     78         }
     79         NN_RET_CHECK_GE(padding_left, 0);
     80         NN_RET_CHECK_GE(padding_right, 0);
     81         NN_RET_CHECK_GE(padding_top, 0);
     82         NN_RET_CHECK_GE(padding_bottom, 0);
     83         NN_RET_CHECK_GT(stride_width, 0);
     84         NN_RET_CHECK_GT(stride_height, 0);
     85         NN_RET_CHECK_GT(filter_width, 0);
     86         NN_RET_CHECK_GT(filter_height, 0);
     87         NN_RET_CHECK_GE(activation, 0);
     88         NN_RET_CHECK_GT(filter_width, padding_left);
     89         NN_RET_CHECK_GT(filter_width, padding_right);
     90         NN_RET_CHECK_GT(filter_height, padding_top);
     91         NN_RET_CHECK_GT(filter_height, padding_bottom);
     92         return true;
     93     }
     94 
     95     tflite::PoolParams toTfliteParam(const Shape& output) const {
     96         tflite::PoolParams params = {
     97                 .stride_height = stride_height,
     98                 .stride_width = stride_width,
     99                 .filter_height = filter_height,
    100                 .filter_width = filter_width,
    101                 .padding_values = {.height = static_cast<int16_t>(padding_top),
    102                                    .width = static_cast<int16_t>(padding_left)}};
    103         if (output.type == OperandType::TENSOR_QUANT8_ASYMM) {
    104             int32_t output_activation_min = 0;
    105             int32_t output_activation_max = 0;
    106             CalculateActivationRangeUint8(activation, output, &output_activation_min,
    107                                           &output_activation_max);
    108             params.quantized_activation_min = output_activation_min;
    109             params.quantized_activation_max = output_activation_max;
    110         } else {
    111             float output_activation_min, output_activation_max;
    112             CalculateActivationRangeFloat(activation, &output_activation_min,
    113                                           &output_activation_max);
    114             params.float_activation_min = output_activation_min;
    115             params.float_activation_max = output_activation_max;
    116         }
    117         return params;
    118     }
    119 };
    120 
    121 bool averagePoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
    122                      float* outputData, const Shape& outputShape) {
    123     NNTRACE_TRANS("averagePoolFloat32");
    124     auto op_params = param.toTfliteParam(outputShape);
    125     NNTRACE_COMP_SWITCH("optimized_ops::AveragePool");
    126     tflite::optimized_ops::AveragePool(op_params, convertShapeToTflshape(inputShape), inputData,
    127                                        convertShapeToTflshape(outputShape), outputData);
    128     return true;
    129 }
    130 
    131 bool averagePoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
    132                      _Float16* outputData, const Shape& outputShape) {
    133     NNTRACE_TRANS("averagePoolFloat16");
    134     std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
    135     std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
    136 
    137     convertFloat16ToFloat32(inputData, &inputDataFloat32);
    138     averagePoolNhwc(inputDataFloat32.data(), inputShape, param, outputDataFloat32.data(),
    139                     outputShape);
    140     convertFloat32ToFloat16(outputDataFloat32, outputData);
    141     return true;
    142 }
    143 
    144 bool averagePoolNhwc(const uint8_t* inputData, const Shape& inputShape, const PoolingParam& param,
    145                      uint8_t* outputData, const Shape& outputShape) {
    146     NNTRACE_TRANS("averagePoolQuant8");
    147     auto op_params = param.toTfliteParam(outputShape);
    148     NNTRACE_COMP_SWITCH("optimized_ops::AveragePool");
    149     tflite::optimized_ops::AveragePool(op_params, convertShapeToTflshape(inputShape), inputData,
    150                                        convertShapeToTflshape(outputShape), outputData);
    151     return true;
    152 }
    153 
    154 bool l2PoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
    155                 float* outputData, const Shape& outputShape) {
    156     NNTRACE_TRANS("l2PoolFloat32");
    157     auto op_params = param.toTfliteParam(outputShape);
    158     NNTRACE_COMP_SWITCH("optimized_ops::L2Pool");
    159     tflite::optimized_ops::L2Pool(op_params, convertShapeToTflshape(inputShape), inputData,
    160                                   convertShapeToTflshape(outputShape), outputData);
    161     return true;
    162 }
    163 
    164 bool l2PoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
    165                 _Float16* outputData, const Shape& outputShape) {
    166     NNTRACE_TRANS("l2PoolFloat16");
    167     std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
    168     std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
    169 
    170     convertFloat16ToFloat32(inputData, &inputDataFloat32);
    171     l2PoolNhwc(inputDataFloat32.data(), inputShape, param, outputDataFloat32.data(), outputShape);
    172     convertFloat32ToFloat16(outputDataFloat32, outputData);
    173     return true;
    174 }
    175 
    176 bool maxPoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
    177                  float* outputData, const Shape& outputShape) {
    178     NNTRACE_TRANS("maxPoolFloat32");
    179     auto op_params = param.toTfliteParam(outputShape);
    180     NNTRACE_COMP_SWITCH("optimized_ops::MaxPool");
    181     tflite::optimized_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
    182                                    convertShapeToTflshape(outputShape), outputData);
    183     return true;
    184 }
    185 
    186 bool maxPoolNhwc(const uint8_t* inputData, const Shape& inputShape, const PoolingParam& param,
    187                  uint8_t* outputData, const Shape& outputShape) {
    188     NNTRACE_TRANS("maxPoolQuant8");
    189     auto op_params = param.toTfliteParam(outputShape);
    190     NNTRACE_COMP_SWITCH("optimized_ops::MaxPool");
    191     tflite::optimized_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
    192                                    convertShapeToTflshape(outputShape), outputData);
    193     return true;
    194 }
    195 
    196 bool maxPoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
    197                  _Float16* outputData, const Shape& outputShape) {
    198     NNTRACE_TRANS("maxPoolFloat16");
    199     std::vector<float> inputData_float32(getNumberOfElements(inputShape));
    200     std::vector<float> outputData_float32(getNumberOfElements(outputShape));
    201 
    202     convertFloat16ToFloat32(inputData, &inputData_float32);
    203     maxPoolNhwc(inputData_float32.data(), inputShape, param, outputData_float32.data(),
    204                 outputShape);
    205     convertFloat32ToFloat16(outputData_float32, outputData);
    206     return true;
    207 }
    208 
    209 template <typename T>
    210 bool averagePool(const T* inputData, const Shape& inputShape, const PoolingParam& param,
    211                  T* outputData, const Shape& outputShape) {
    212     InputWithLayout<T> input(param.useNchw);
    213     OutputWithLayout<T> output(param.useNchw);
    214     NN_RET_CHECK(input.initialize(inputData, inputShape));
    215     NN_RET_CHECK(output.initialize(outputData, outputShape));
    216     NN_RET_CHECK(averagePoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
    217                                  output.getNhwcBuffer(), output.getNhwcShape()));
    218     NN_RET_CHECK(output.commit());
    219     return true;
    220 }
    221 
    222 template <typename T>
    223 bool l2Pool(const T* inputData, const Shape& inputShape, const PoolingParam& param, T* outputData,
    224             const Shape& outputShape) {
    225     InputWithLayout<T> input(param.useNchw);
    226     OutputWithLayout<T> output(param.useNchw);
    227     NN_RET_CHECK(input.initialize(inputData, inputShape));
    228     NN_RET_CHECK(output.initialize(outputData, outputShape));
    229     NN_RET_CHECK(l2PoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
    230                             output.getNhwcBuffer(), output.getNhwcShape()));
    231     NN_RET_CHECK(output.commit());
    232     return true;
    233 }
    234 
    235 template <typename T>
    236 bool maxPool(const T* inputData, const Shape& inputShape, const PoolingParam& param, T* outputData,
    237              const Shape& outputShape) {
    238     InputWithLayout<T> input(param.useNchw);
    239     OutputWithLayout<T> output(param.useNchw);
    240     NN_RET_CHECK(input.initialize(inputData, inputShape));
    241     NN_RET_CHECK(output.initialize(outputData, outputShape));
    242     NN_RET_CHECK(maxPoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
    243                              output.getNhwcBuffer(), output.getNhwcShape()));
    244     NN_RET_CHECK(output.commit());
    245     return true;
    246 }
    247 
    248 }  // namespace
    249 
    250 bool validate(OperationType opType, const IOperationValidationContext* context) {
    251     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    252     auto inputCount = context->getNumInputs();
    253     NN_RET_CHECK(inputCount == 11 || inputCount == 10 || inputCount == 8 || inputCount == 7);
    254     auto inputType = context->getInputType(kInputTensor);
    255     std::vector<OperandType> inExpectedTypes;
    256     if (inputType == OperandType::TENSOR_FLOAT32) {
    257         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
    258         inExpectedTypes = {
    259                 inputType,          OperandType::INT32, OperandType::INT32, OperandType::INT32,
    260                 OperandType::INT32, OperandType::INT32, OperandType::INT32,
    261         };
    262     } else if (inputType == OperandType::TENSOR_FLOAT16) {
    263         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    264         inExpectedTypes = {
    265                 OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::INT32,
    266                 OperandType::INT32,          OperandType::INT32, OperandType::INT32,
    267                 OperandType::INT32,
    268         };
    269     } else if (opType != OperationType::L2_POOL_2D &&
    270                inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    271         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
    272         inExpectedTypes = {
    273                 OperandType::TENSOR_QUANT8_ASYMM,
    274                 OperandType::INT32,
    275                 OperandType::INT32,
    276                 OperandType::INT32,
    277                 OperandType::INT32,
    278                 OperandType::INT32,
    279                 OperandType::INT32,
    280         };
    281     } else {
    282         NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation "
    283                             << getOperationName(opType);
    284     }
    285 
    286     if (inputCount >= 10) {
    287         std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
    288         inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
    289                                explicitScalarTypes.end());
    290     }
    291     if (inputCount == 11 || inputCount == 8) {
    292         inExpectedTypes.push_back(OperandType::BOOL);
    293         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    294     } else {
    295         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
    296     }
    297     return validateInputTypes(context, inExpectedTypes) &&
    298            validateOutputTypes(context, {inputType});
    299 }
    300 
    301 bool prepare(IOperationExecutionContext* context) {
    302     Shape input = context->getInputShape(kInputTensor);
    303     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
    304 
    305     PoolingParam param;
    306     NN_RET_CHECK(param.initialize(context));
    307 
    308     // Only batches can be zero.
    309     uint32_t batches = getSizeOfDimension(input, 0);
    310     uint32_t height = getSizeOfDimension(input, param.useNchw ? 2 : 1);
    311     uint32_t width = getSizeOfDimension(input, param.useNchw ? 3 : 2);
    312     uint32_t channels = getSizeOfDimension(input, param.useNchw ? 1 : 3);
    313     NN_RET_CHECK_GT(height, 0);
    314     NN_RET_CHECK_GT(width, 0);
    315     NN_RET_CHECK_GT(channels, 0);
    316 
    317     uint32_t outWidth = computeOutSize(width, param.filter_width, param.stride_width,
    318                                        param.padding_left, param.padding_right);
    319     uint32_t outHeight = computeOutSize(height, param.filter_height, param.stride_height,
    320                                         param.padding_top, param.padding_bottom);
    321 
    322     Shape output = input;
    323     if (param.useNchw) {
    324         output.dimensions = {batches, channels, outHeight, outWidth};
    325     } else {
    326         output.dimensions = {batches, outHeight, outWidth, channels};
    327     }
    328     return context->setOutputShape(kOutputTensor, output);
    329 }
    330 
    331 #define POOLING_DISPATCH_INPUT_TYPE(name, type, cppType)              \
    332     case OperandType::type:                                           \
    333         return name(context->getInputBuffer<cppType>(kInputTensor),   \
    334                     context->getInputShape(kInputTensor), param,      \
    335                     context->getOutputBuffer<cppType>(kOutputTensor), \
    336                     context->getOutputShape(kOutputTensor))
    337 
    338 bool executeAveragePool(IOperationExecutionContext* context) {
    339     // Bypass execution in the case of zero-sized input.
    340     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    341     PoolingParam param;
    342     NN_RET_CHECK(param.initialize(context));
    343     switch (context->getInputType(kInputTensor)) {
    344         POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_FLOAT32, float);
    345         POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_FLOAT16, _Float16);
    346         POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_QUANT8_ASYMM, uint8_t);
    347         default:
    348             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation AVERAGE_POOL_2D";
    349     }
    350 }
    351 
    352 bool executeL2Pool(IOperationExecutionContext* context) {
    353     // Bypass execution in the case of zero-sized input.
    354     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    355     PoolingParam param;
    356     NN_RET_CHECK(param.initialize(context));
    357     switch (context->getInputType(kInputTensor)) {
    358         POOLING_DISPATCH_INPUT_TYPE(l2Pool, TENSOR_FLOAT32, float);
    359         POOLING_DISPATCH_INPUT_TYPE(l2Pool, TENSOR_FLOAT16, _Float16);
    360         default:
    361             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation L2_POOL_2D";
    362     }
    363 }
    364 
    365 bool executeMaxPool(IOperationExecutionContext* context) {
    366     // Bypass execution in the case of zero-sized input.
    367     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    368     PoolingParam param;
    369     NN_RET_CHECK(param.initialize(context));
    370     switch (context->getInputType(kInputTensor)) {
    371         POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_FLOAT32, float);
    372         POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_FLOAT16, _Float16);
    373         POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_QUANT8_ASYMM, uint8_t);
    374         default:
    375             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MAX_POOL_2D";
    376     }
    377 }
    378 
    379 #undef POOLING_DISPATCH_INPUT_TYPE
    380 
    381 }  // namespace pooling
    382 
    383 using std::placeholders::_1;
    384 NN_REGISTER_OPERATION(AVERAGE_POOL_2D, "AVERAGE_POOL_2D",
    385                       std::bind(pooling::validate, OperationType::AVERAGE_POOL_2D, _1),
    386                       pooling::prepare, pooling::executeAveragePool, .allowZeroSizedInput = true);
    387 NN_REGISTER_OPERATION(L2_POOL_2D, "L2_POOL_2D",
    388                       std::bind(pooling::validate, OperationType::L2_POOL_2D, _1), pooling::prepare,
    389                       pooling::executeL2Pool, .allowZeroSizedInput = true);
    390 NN_REGISTER_OPERATION(MAX_POOL_2D, "MAX_POOL_2D",
    391                       std::bind(pooling::validate, OperationType::MAX_POOL_2D, _1),
    392                       pooling::prepare, pooling::executeMaxPool, .allowZeroSizedInput = true);
    393 
    394 }  // namespace nn
    395 }  // namespace android
    396