Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "HalInterfaces.h"
     20 #include "IndexedShapeWrapper.h"
     21 #include "OperationResolver.h"
     22 #include "OperationsUtils.h"
     23 
     24 namespace android {
     25 namespace nn {
     26 namespace comparisons {
     27 
     28 constexpr uint32_t kNumInputs = 2;
     29 constexpr uint32_t kInputTensor1 = 0;
     30 constexpr uint32_t kInputTensor2 = 1;
     31 
     32 constexpr uint32_t kNumOutputs = 1;
     33 constexpr uint32_t kOutputTensor = 0;
     34 
     35 namespace {
     36 
     37 template <typename DataType, typename ComparisonType>
     38 bool compute(const std::function<bool(ComparisonType, ComparisonType)>& func, const DataType* aData,
     39              const Shape& aShape, const DataType* bData, const Shape& bShape, bool8* outputData,
     40              const Shape& outputShape) {
     41     IndexedShapeWrapper aShapeIndexed(aShape);
     42     IndexedShapeWrapper bShapeIndexed(bShape);
     43     IndexedShapeWrapper outputShapeIndexed(outputShape);
     44     std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
     45     bool lastIndex = false;
     46     do {
     47         uint32_t outputFlatIndex;
     48         NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
     49         uint32_t aFlatIndex;
     50         NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
     51         uint32_t bFlatIndex;
     52         NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
     53 
     54         if (aShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
     55             const float realA = (aData[aFlatIndex] - aShape.offset) * aShape.scale;
     56             const float realB = (bData[bFlatIndex] - bShape.offset) * bShape.scale;
     57             outputData[outputFlatIndex] = func(realA, realB);
     58         } else {
     59             outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
     60         }
     61 
     62         NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
     63     } while (!lastIndex);
     64     return true;
     65 }
     66 
     67 template <typename DataType, typename ComparisonType>
     68 bool executeLessTyped(IOperationExecutionContext* context) {
     69     return compute<DataType, ComparisonType>(
     70             std::less<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
     71             context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
     72             context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
     73             context->getOutputShape(kOutputTensor));
     74 }
     75 
     76 template <typename DataType, typename ComparisonType>
     77 bool executeLessEqualTyped(IOperationExecutionContext* context) {
     78     return compute<DataType, ComparisonType>(
     79             std::less_equal<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
     80             context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
     81             context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
     82             context->getOutputShape(kOutputTensor));
     83 }
     84 
     85 template <typename DataType, typename ComparisonType>
     86 bool executeEqualTyped(IOperationExecutionContext* context) {
     87     return compute<DataType, ComparisonType>(
     88             std::equal_to<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
     89             context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
     90             context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
     91             context->getOutputShape(kOutputTensor));
     92 }
     93 
     94 template <typename DataType, typename ComparisonType>
     95 bool executeNotEqualTyped(IOperationExecutionContext* context) {
     96     return compute<DataType, ComparisonType>(
     97             std::not_equal_to<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
     98             context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
     99             context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
    100             context->getOutputShape(kOutputTensor));
    101 }
    102 
    103 template <typename DataType, typename ComparisonType>
    104 bool executeGreaterEqualTyped(IOperationExecutionContext* context) {
    105     return compute<DataType, ComparisonType>(
    106             std::greater_equal<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
    107             context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
    108             context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
    109             context->getOutputShape(kOutputTensor));
    110 }
    111 
    112 template <typename DataType, typename ComparisonType>
    113 bool executeGreaterTyped(IOperationExecutionContext* context) {
    114     return compute<DataType, ComparisonType>(
    115             std::greater<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
    116             context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
    117             context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
    118             context->getOutputShape(kOutputTensor));
    119 }
    120 
    121 }  // namespace
    122 
    123 bool validate(const IOperationValidationContext* context) {
    124     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    125     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    126     OperandType inputType = context->getInputType(kInputTensor1);
    127     NN_RET_CHECK(
    128             inputType == OperandType::TENSOR_BOOL8 || inputType == OperandType::TENSOR_FLOAT16 ||
    129             inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 ||
    130             inputType == OperandType::TENSOR_QUANT8_ASYMM)
    131             << "Unsupported input operand type for comparison op: " << toString(inputType);
    132     NN_RET_CHECK(validateInputTypes(context, {inputType, inputType}));
    133     NN_RET_CHECK(validateOutputTypes(context, {OperandType::TENSOR_BOOL8}));
    134     return validateHalVersion(context, HalVersion::V1_2);
    135 }
    136 
    137 bool prepare(IOperationExecutionContext* context) {
    138     Shape input1 = context->getInputShape(kInputTensor1);
    139     Shape input2 = context->getInputShape(kInputTensor2);
    140     Shape output = context->getOutputShape(kOutputTensor);
    141     NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
    142     return context->setOutputShape(kOutputTensor, output);
    143 }
    144 
    145 bool executeLess(IOperationExecutionContext* context) {
    146     switch (context->getInputType(kInputTensor1)) {
    147         case OperandType::TENSOR_FLOAT16:
    148             return executeLessTyped<_Float16, _Float16>(context);
    149         case OperandType::TENSOR_FLOAT32:
    150             return executeLessTyped<float, float>(context);
    151         case OperandType::TENSOR_INT32:
    152             return executeLessTyped<int32_t, int32_t>(context);
    153         case OperandType::TENSOR_QUANT8_ASYMM:
    154             return executeLessTyped<uint8_t, float>(context);
    155         case OperandType::TENSOR_BOOL8:
    156             return executeLessTyped<bool8, bool8>(context);
    157         default:
    158             NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
    159     }
    160 }
    161 
    162 bool executeLessEqual(IOperationExecutionContext* context) {
    163     switch (context->getInputType(kInputTensor1)) {
    164         case OperandType::TENSOR_FLOAT16:
    165             return executeLessEqualTyped<_Float16, _Float16>(context);
    166         case OperandType::TENSOR_FLOAT32:
    167             return executeLessEqualTyped<float, float>(context);
    168         case OperandType::TENSOR_INT32:
    169             return executeLessEqualTyped<int32_t, int32_t>(context);
    170         case OperandType::TENSOR_QUANT8_ASYMM:
    171             return executeLessEqualTyped<uint8_t, float>(context);
    172         case OperandType::TENSOR_BOOL8:
    173             return executeLessEqualTyped<bool8, bool8>(context);
    174         default:
    175             NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
    176     }
    177 }
    178 
    179 bool executeEqual(IOperationExecutionContext* context) {
    180     switch (context->getInputType(kInputTensor1)) {
    181         case OperandType::TENSOR_FLOAT16:
    182             return executeEqualTyped<_Float16, _Float16>(context);
    183         case OperandType::TENSOR_FLOAT32:
    184             return executeEqualTyped<float, float>(context);
    185         case OperandType::TENSOR_INT32:
    186             return executeEqualTyped<int32_t, int32_t>(context);
    187         case OperandType::TENSOR_QUANT8_ASYMM:
    188             return executeEqualTyped<uint8_t, float>(context);
    189         case OperandType::TENSOR_BOOL8:
    190             return executeEqualTyped<bool8, bool8>(context);
    191         default:
    192             NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
    193     }
    194 }
    195 
    196 bool executeNotEqual(IOperationExecutionContext* context) {
    197     switch (context->getInputType(kInputTensor1)) {
    198         case OperandType::TENSOR_FLOAT16:
    199             return executeNotEqualTyped<_Float16, _Float16>(context);
    200         case OperandType::TENSOR_FLOAT32:
    201             return executeNotEqualTyped<float, float>(context);
    202         case OperandType::TENSOR_INT32:
    203             return executeNotEqualTyped<int32_t, int32_t>(context);
    204         case OperandType::TENSOR_QUANT8_ASYMM:
    205             return executeNotEqualTyped<uint8_t, float>(context);
    206         case OperandType::TENSOR_BOOL8:
    207             return executeNotEqualTyped<bool8, bool8>(context);
    208         default:
    209             NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
    210     }
    211 }
    212 
    213 bool executeGreaterEqual(IOperationExecutionContext* context) {
    214     switch (context->getInputType(kInputTensor1)) {
    215         case OperandType::TENSOR_FLOAT16:
    216             return executeGreaterEqualTyped<_Float16, _Float16>(context);
    217         case OperandType::TENSOR_FLOAT32:
    218             return executeGreaterEqualTyped<float, float>(context);
    219         case OperandType::TENSOR_INT32:
    220             return executeGreaterEqualTyped<int32_t, int32_t>(context);
    221         case OperandType::TENSOR_QUANT8_ASYMM:
    222             return executeGreaterEqualTyped<uint8_t, float>(context);
    223         case OperandType::TENSOR_BOOL8:
    224             return executeGreaterEqualTyped<bool8, bool8>(context);
    225         default:
    226             NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
    227     }
    228 }
    229 
    230 bool executeGreater(IOperationExecutionContext* context) {
    231     switch (context->getInputType(kInputTensor1)) {
    232         case OperandType::TENSOR_FLOAT16:
    233             return executeGreaterTyped<_Float16, _Float16>(context);
    234         case OperandType::TENSOR_FLOAT32:
    235             return executeGreaterTyped<float, float>(context);
    236         case OperandType::TENSOR_INT32:
    237             return executeGreaterTyped<int32_t, int32_t>(context);
    238         case OperandType::TENSOR_QUANT8_ASYMM:
    239             return executeGreaterTyped<uint8_t, float>(context);
    240         case OperandType::TENSOR_BOOL8:
    241             return executeGreaterTyped<bool8, bool8>(context);
    242         default:
    243             NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
    244     }
    245 }
    246 
    247 }  // namespace comparisons
    248 
    249 NN_REGISTER_OPERATION(LESS, "LESS", comparisons::validate, comparisons::prepare,
    250                       comparisons::executeLess);
    251 NN_REGISTER_OPERATION(LESS_EQUAL, "LESS_EQUAL", comparisons::validate, comparisons::prepare,
    252                       comparisons::executeLessEqual);
    253 NN_REGISTER_OPERATION(EQUAL, "EQUAL", comparisons::validate, comparisons::prepare,
    254                       comparisons::executeEqual);
    255 NN_REGISTER_OPERATION(NOT_EQUAL, "NOT_EQUAL", comparisons::validate, comparisons::prepare,
    256                       comparisons::executeNotEqual);
    257 NN_REGISTER_OPERATION(GREATER_EQUAL, "GREATER_EQUAL", comparisons::validate, comparisons::prepare,
    258                       comparisons::executeGreaterEqual);
    259 NN_REGISTER_OPERATION(GREATER, "GREATER", comparisons::validate, comparisons::prepare,
    260                       comparisons::executeGreater);
    261 
    262 }  // namespace nn
    263 }  // namespace android
    264