Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2019 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "CpuOperationUtils.h"
     20 #include "HalInterfaces.h"
     21 #include "OperationResolver.h"
     22 #include "Tracing.h"
     23 
     24 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     25 
     26 #include <functional>
     27 #include <vector>
     28 
     29 namespace android {
     30 namespace nn {
     31 
     32 namespace resize_image {
     33 
     34 constexpr uint32_t kNumInputs = 4;
     35 constexpr uint32_t kInputTensor = 0;
     36 // The following two scalars represent output shape if INT32, scale if floating point.
     37 constexpr uint32_t kOutputWidthParamScalar = 1;
     38 constexpr uint32_t kOutputHeightParamScalar = 2;
     39 constexpr uint32_t kLayoutScalar = 3;
     40 
     41 constexpr uint32_t kNumOutputs = 1;
     42 constexpr uint32_t kOutputTensor = 0;
     43 
     44 namespace {
     45 
     46 template <typename T>
     47 bool resizeImageOpNhwc(OperationType opType, const T* inputData, const Shape& inputShape,
     48                        T* outputData, const Shape& outputShape) {
     49     NNTRACE_TRANS("resizeImageOpNhwc");
     50     int32_t height = static_cast<int32_t>(getSizeOfDimension(outputShape, 1));
     51     int32_t width = static_cast<int32_t>(getSizeOfDimension(outputShape, 2));
     52     // We have to fake a tensor here, to satisfy tflite implementation.
     53     int32_t outDimData[2] = {height, width};
     54     Shape outDimShape;
     55     outDimShape.dimensions = {2};
     56 
     57     if (opType == OperationType::RESIZE_BILINEAR) {
     58         NNTRACE_COMP_SWITCH("optimized_ops::ResizeBilinear");
     59         tflite::reference_ops::ResizeBilinear({.align_corners = false},
     60                                               convertShapeToTflshape(inputShape), inputData,
     61                                               convertShapeToTflshape(outDimShape), outDimData,
     62                                               convertShapeToTflshape(outputShape), outputData);
     63     } else if (opType == OperationType::RESIZE_NEAREST_NEIGHBOR) {
     64         // Align corners = true is not supported.
     65         NNTRACE_COMP_SWITCH("optimized_ops::ResizeNearestNeighbor");
     66         tflite::reference_ops::ResizeNearestNeighbor(
     67                 {.align_corners = false}, convertShapeToTflshape(inputShape), inputData,
     68                 convertShapeToTflshape(outDimShape), outDimData,
     69                 convertShapeToTflshape(outputShape), outputData);
     70     }
     71     return true;
     72 }
     73 
     74 template <>
     75 bool resizeImageOpNhwc<_Float16>(OperationType opType, const _Float16* inputData,
     76                                  const Shape& inputShape, _Float16* outputData,
     77                                  const Shape& outputShape) {
     78     NNTRACE_TRANS("resizeImageOpNhwcFloat16");
     79     std::vector<float> inputData_float32(getNumberOfElements(inputShape));
     80     convertFloat16ToFloat32(inputData, &inputData_float32);
     81     std::vector<float> outputData_float32(getNumberOfElements(outputShape));
     82     NN_RET_CHECK(resizeImageOpNhwc(opType, inputData_float32.data(), inputShape,
     83                                    outputData_float32.data(), outputShape));
     84     convertFloat32ToFloat16(outputData_float32, outputData);
     85     return true;
     86 }
     87 
     88 template <typename T>
     89 bool resizeImageOp(OperationType opType, const T* inputData, const Shape& inputShape, bool useNchw,
     90                    T* outputData, const Shape& outputShape) {
     91     InputWithLayout<T> input(useNchw);
     92     OutputWithLayout<T> output(useNchw);
     93     NN_RET_CHECK(input.initialize(inputData, inputShape));
     94     NN_RET_CHECK(output.initialize(outputData, outputShape));
     95     NN_RET_CHECK(resizeImageOpNhwc(opType, input.getNhwcBuffer(), input.getNhwcShape(),
     96                                    output.getNhwcBuffer(), output.getNhwcShape()));
     97     NN_RET_CHECK(output.commit());
     98     return true;
     99 }
    100 
    101 }  // namespace
    102 
    103 bool validate(OperationType opType, const IOperationValidationContext* context) {
    104     if (opType == OperationType::RESIZE_BILINEAR) {
    105         NN_RET_CHECK(context->getNumInputs() == kNumInputs ||
    106                      context->getNumInputs() == kNumInputs - 1);
    107     } else if (opType == OperationType::RESIZE_NEAREST_NEIGHBOR) {
    108         NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    109     } else {
    110         NN_RET_CHECK_FAIL() << "Unsupported operation " << getOperationName(opType);
    111     }
    112     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    113     auto inputType = context->getInputType(kInputTensor);
    114     auto scalarType = context->getInputType(kOutputHeightParamScalar);
    115     std::vector<OperandType> inExpectedTypes = {inputType, scalarType, scalarType};
    116     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
    117                  inputType == OperandType::TENSOR_FLOAT32 ||
    118                  inputType == OperandType::TENSOR_QUANT8_ASYMM)
    119             << "Unsupported tensor type for operation " << getOperationName(opType);
    120     if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    121         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    122     }
    123     if (scalarType != OperandType::INT32) {
    124         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    125         if (inputType == OperandType::TENSOR_FLOAT32) {
    126             NN_RET_CHECK(scalarType == OperandType::FLOAT32);
    127         } else if (inputType == OperandType::TENSOR_FLOAT16) {
    128             NN_RET_CHECK(scalarType == OperandType::FLOAT16);
    129         } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    130             NN_RET_CHECK(scalarType == OperandType::FLOAT32);
    131         }
    132     }
    133     if (context->getNumInputs() == kNumInputs) {
    134         inExpectedTypes.push_back(OperandType::BOOL);
    135         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    136     } else {
    137         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
    138     }
    139     return validateInputTypes(context, inExpectedTypes) &&
    140            validateOutputTypes(context, {inputType});
    141 }
    142 
    143 bool prepare(OperationType opType, IOperationExecutionContext* context) {
    144     Shape input = context->getInputShape(kInputTensor);
    145     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
    146     bool useNchw = false;
    147     if (context->getNumInputs() > kLayoutScalar) {
    148         useNchw = context->getInputValue<bool>(kLayoutScalar);
    149     }
    150 
    151     // Only batches can be zero.
    152     uint32_t batches = getSizeOfDimension(input, 0);
    153     uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
    154     uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
    155     uint32_t channels = getSizeOfDimension(input, useNchw ? 1 : 3);
    156     NN_RET_CHECK_GT(inHeight, 0);
    157     NN_RET_CHECK_GT(inWidth, 0);
    158     NN_RET_CHECK_GT(channels, 0);
    159 
    160     int32_t height, width;
    161     auto scalarType = context->getInputType(kOutputHeightParamScalar);
    162     if (scalarType == OperandType::INT32) {
    163         height = context->getInputValue<int32_t>(kOutputHeightParamScalar);
    164         width = context->getInputValue<int32_t>(kOutputWidthParamScalar);
    165     } else if (scalarType == OperandType::FLOAT32) {
    166         height = std::floor(static_cast<float>(inHeight) *
    167                             context->getInputValue<float>(kOutputHeightParamScalar));
    168         width = std::floor(static_cast<float>(inWidth) *
    169                            context->getInputValue<float>(kOutputWidthParamScalar));
    170     } else if (scalarType == OperandType::FLOAT16) {
    171         height = std::floor(
    172                 static_cast<float>(inHeight) *
    173                 static_cast<float>(context->getInputValue<_Float16>(kOutputHeightParamScalar)));
    174         width = std::floor(
    175                 static_cast<float>(inWidth) *
    176                 static_cast<float>(context->getInputValue<_Float16>(kOutputWidthParamScalar)));
    177     } else {
    178         NN_RET_CHECK_FAIL() << "Unsupported scalar type for operation " << getOperationName(opType);
    179     }
    180     NN_RET_CHECK_GT(height, 0);
    181     NN_RET_CHECK_GT(width, 0);
    182 
    183     Shape output = input;
    184     if (useNchw) {
    185         output.dimensions = {batches, channels, (uint32_t)height, (uint32_t)width};
    186     } else {
    187         output.dimensions = {batches, (uint32_t)height, (uint32_t)width, channels};
    188     }
    189     return context->setOutputShape(kOutputTensor, output);
    190 }
    191 
    192 bool execute(OperationType opType, IOperationExecutionContext* context) {
    193     // Bypass execution in the case of zero-sized input.
    194     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    195     bool useNchw = false;
    196     if (context->getNumInputs() > kLayoutScalar) {
    197         useNchw = context->getInputValue<bool>(kLayoutScalar);
    198     }
    199     switch (context->getInputType(kInputTensor)) {
    200         case OperandType::TENSOR_FLOAT16:
    201             return resizeImageOp(opType, context->getInputBuffer<_Float16>(kInputTensor),
    202                                  context->getInputShape(kInputTensor), useNchw,
    203                                  context->getOutputBuffer<_Float16>(kOutputTensor),
    204                                  context->getOutputShape(kOutputTensor));
    205         case OperandType::TENSOR_FLOAT32:
    206             return resizeImageOp(opType, context->getInputBuffer<float>(kInputTensor),
    207                                  context->getInputShape(kInputTensor), useNchw,
    208                                  context->getOutputBuffer<float>(kOutputTensor),
    209                                  context->getOutputShape(kOutputTensor));
    210         case OperandType::TENSOR_QUANT8_ASYMM:
    211             return resizeImageOp(opType, context->getInputBuffer<uint8_t>(kInputTensor),
    212                                  context->getInputShape(kInputTensor), useNchw,
    213                                  context->getOutputBuffer<uint8_t>(kOutputTensor),
    214                                  context->getOutputShape(kOutputTensor));
    215         default:
    216             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation "
    217                                 << getOperationName(opType);
    218     }
    219 }
    220 
    221 }  // namespace resize_image
    222 
    223 using std::placeholders::_1;
    224 
    225 NN_REGISTER_OPERATION(RESIZE_BILINEAR, "RESIZE_BILINEAR",
    226                       std::bind(resize_image::validate, OperationType::RESIZE_BILINEAR, _1),
    227                       std::bind(resize_image::prepare, OperationType::RESIZE_BILINEAR, _1),
    228                       std::bind(resize_image::execute, OperationType::RESIZE_BILINEAR, _1),
    229                       .allowZeroSizedInput = true);
    230 
    231 NN_REGISTER_OPERATION(RESIZE_NEAREST_NEIGHBOR, "RESIZE_NEAREST_NEIGHBOR",
    232                       std::bind(resize_image::validate, OperationType::RESIZE_NEAREST_NEIGHBOR, _1),
    233                       std::bind(resize_image::prepare, OperationType::RESIZE_NEAREST_NEIGHBOR, _1),
    234                       std::bind(resize_image::execute, OperationType::RESIZE_NEAREST_NEIGHBOR, _1),
    235                       .allowZeroSizedInput = true);
    236 
    237 }  // namespace nn
    238 }  // namespace android
    239