1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "HalInterfaces.h" 20 #include "IndexedShapeWrapper.h" 21 #include "OperationResolver.h" 22 #include "OperationsUtils.h" 23 24 namespace android { 25 namespace nn { 26 namespace comparisons { 27 28 constexpr uint32_t kNumInputs = 2; 29 constexpr uint32_t kInputTensor1 = 0; 30 constexpr uint32_t kInputTensor2 = 1; 31 32 constexpr uint32_t kNumOutputs = 1; 33 constexpr uint32_t kOutputTensor = 0; 34 35 namespace { 36 37 template <typename DataType, typename ComparisonType> 38 bool compute(const std::function<bool(ComparisonType, ComparisonType)>& func, const DataType* aData, 39 const Shape& aShape, const DataType* bData, const Shape& bShape, bool8* outputData, 40 const Shape& outputShape) { 41 IndexedShapeWrapper aShapeIndexed(aShape); 42 IndexedShapeWrapper bShapeIndexed(bShape); 43 IndexedShapeWrapper outputShapeIndexed(outputShape); 44 std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0); 45 bool lastIndex = false; 46 do { 47 uint32_t outputFlatIndex; 48 NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex)); 49 uint32_t aFlatIndex; 50 NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex)); 51 uint32_t bFlatIndex; 52 NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex)); 53 54 if (aShape.type == OperandType::TENSOR_QUANT8_ASYMM) { 55 const float realA = (aData[aFlatIndex] - aShape.offset) * aShape.scale; 56 const float realB = (bData[bFlatIndex] - bShape.offset) * bShape.scale; 57 outputData[outputFlatIndex] = func(realA, realB); 58 } else { 59 outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]); 60 } 61 62 NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex)); 63 } while (!lastIndex); 64 return true; 65 } 66 67 template <typename DataType, typename ComparisonType> 68 bool executeLessTyped(IOperationExecutionContext* context) { 69 return compute<DataType, ComparisonType>( 70 std::less<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1), 71 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2), 72 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor), 73 context->getOutputShape(kOutputTensor)); 74 } 75 76 template <typename DataType, typename ComparisonType> 77 bool executeLessEqualTyped(IOperationExecutionContext* context) { 78 return compute<DataType, ComparisonType>( 79 std::less_equal<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1), 80 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2), 81 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor), 82 context->getOutputShape(kOutputTensor)); 83 } 84 85 template <typename DataType, typename ComparisonType> 86 bool executeEqualTyped(IOperationExecutionContext* context) { 87 return compute<DataType, ComparisonType>( 88 std::equal_to<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1), 89 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2), 90 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor), 91 context->getOutputShape(kOutputTensor)); 92 } 93 94 template <typename DataType, typename ComparisonType> 95 bool executeNotEqualTyped(IOperationExecutionContext* context) { 96 return compute<DataType, ComparisonType>( 97 std::not_equal_to<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1), 98 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2), 99 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor), 100 context->getOutputShape(kOutputTensor)); 101 } 102 103 template <typename DataType, typename ComparisonType> 104 bool executeGreaterEqualTyped(IOperationExecutionContext* context) { 105 return compute<DataType, ComparisonType>( 106 std::greater_equal<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1), 107 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2), 108 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor), 109 context->getOutputShape(kOutputTensor)); 110 } 111 112 template <typename DataType, typename ComparisonType> 113 bool executeGreaterTyped(IOperationExecutionContext* context) { 114 return compute<DataType, ComparisonType>( 115 std::greater<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1), 116 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2), 117 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor), 118 context->getOutputShape(kOutputTensor)); 119 } 120 121 } // namespace 122 123 bool validate(const IOperationValidationContext* context) { 124 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 125 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 126 OperandType inputType = context->getInputType(kInputTensor1); 127 NN_RET_CHECK( 128 inputType == OperandType::TENSOR_BOOL8 || inputType == OperandType::TENSOR_FLOAT16 || 129 inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || 130 inputType == OperandType::TENSOR_QUANT8_ASYMM) 131 << "Unsupported input operand type for comparison op: " << toString(inputType); 132 NN_RET_CHECK(validateInputTypes(context, {inputType, inputType})); 133 NN_RET_CHECK(validateOutputTypes(context, {OperandType::TENSOR_BOOL8})); 134 return validateHalVersion(context, HalVersion::V1_2); 135 } 136 137 bool prepare(IOperationExecutionContext* context) { 138 Shape input1 = context->getInputShape(kInputTensor1); 139 Shape input2 = context->getInputShape(kInputTensor2); 140 Shape output = context->getOutputShape(kOutputTensor); 141 NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output)); 142 return context->setOutputShape(kOutputTensor, output); 143 } 144 145 bool executeLess(IOperationExecutionContext* context) { 146 switch (context->getInputType(kInputTensor1)) { 147 case OperandType::TENSOR_FLOAT16: 148 return executeLessTyped<_Float16, _Float16>(context); 149 case OperandType::TENSOR_FLOAT32: 150 return executeLessTyped<float, float>(context); 151 case OperandType::TENSOR_INT32: 152 return executeLessTyped<int32_t, int32_t>(context); 153 case OperandType::TENSOR_QUANT8_ASYMM: 154 return executeLessTyped<uint8_t, float>(context); 155 case OperandType::TENSOR_BOOL8: 156 return executeLessTyped<bool8, bool8>(context); 157 default: 158 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison"; 159 } 160 } 161 162 bool executeLessEqual(IOperationExecutionContext* context) { 163 switch (context->getInputType(kInputTensor1)) { 164 case OperandType::TENSOR_FLOAT16: 165 return executeLessEqualTyped<_Float16, _Float16>(context); 166 case OperandType::TENSOR_FLOAT32: 167 return executeLessEqualTyped<float, float>(context); 168 case OperandType::TENSOR_INT32: 169 return executeLessEqualTyped<int32_t, int32_t>(context); 170 case OperandType::TENSOR_QUANT8_ASYMM: 171 return executeLessEqualTyped<uint8_t, float>(context); 172 case OperandType::TENSOR_BOOL8: 173 return executeLessEqualTyped<bool8, bool8>(context); 174 default: 175 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison"; 176 } 177 } 178 179 bool executeEqual(IOperationExecutionContext* context) { 180 switch (context->getInputType(kInputTensor1)) { 181 case OperandType::TENSOR_FLOAT16: 182 return executeEqualTyped<_Float16, _Float16>(context); 183 case OperandType::TENSOR_FLOAT32: 184 return executeEqualTyped<float, float>(context); 185 case OperandType::TENSOR_INT32: 186 return executeEqualTyped<int32_t, int32_t>(context); 187 case OperandType::TENSOR_QUANT8_ASYMM: 188 return executeEqualTyped<uint8_t, float>(context); 189 case OperandType::TENSOR_BOOL8: 190 return executeEqualTyped<bool8, bool8>(context); 191 default: 192 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison"; 193 } 194 } 195 196 bool executeNotEqual(IOperationExecutionContext* context) { 197 switch (context->getInputType(kInputTensor1)) { 198 case OperandType::TENSOR_FLOAT16: 199 return executeNotEqualTyped<_Float16, _Float16>(context); 200 case OperandType::TENSOR_FLOAT32: 201 return executeNotEqualTyped<float, float>(context); 202 case OperandType::TENSOR_INT32: 203 return executeNotEqualTyped<int32_t, int32_t>(context); 204 case OperandType::TENSOR_QUANT8_ASYMM: 205 return executeNotEqualTyped<uint8_t, float>(context); 206 case OperandType::TENSOR_BOOL8: 207 return executeNotEqualTyped<bool8, bool8>(context); 208 default: 209 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison"; 210 } 211 } 212 213 bool executeGreaterEqual(IOperationExecutionContext* context) { 214 switch (context->getInputType(kInputTensor1)) { 215 case OperandType::TENSOR_FLOAT16: 216 return executeGreaterEqualTyped<_Float16, _Float16>(context); 217 case OperandType::TENSOR_FLOAT32: 218 return executeGreaterEqualTyped<float, float>(context); 219 case OperandType::TENSOR_INT32: 220 return executeGreaterEqualTyped<int32_t, int32_t>(context); 221 case OperandType::TENSOR_QUANT8_ASYMM: 222 return executeGreaterEqualTyped<uint8_t, float>(context); 223 case OperandType::TENSOR_BOOL8: 224 return executeGreaterEqualTyped<bool8, bool8>(context); 225 default: 226 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison"; 227 } 228 } 229 230 bool executeGreater(IOperationExecutionContext* context) { 231 switch (context->getInputType(kInputTensor1)) { 232 case OperandType::TENSOR_FLOAT16: 233 return executeGreaterTyped<_Float16, _Float16>(context); 234 case OperandType::TENSOR_FLOAT32: 235 return executeGreaterTyped<float, float>(context); 236 case OperandType::TENSOR_INT32: 237 return executeGreaterTyped<int32_t, int32_t>(context); 238 case OperandType::TENSOR_QUANT8_ASYMM: 239 return executeGreaterTyped<uint8_t, float>(context); 240 case OperandType::TENSOR_BOOL8: 241 return executeGreaterTyped<bool8, bool8>(context); 242 default: 243 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison"; 244 } 245 } 246 247 } // namespace comparisons 248 249 NN_REGISTER_OPERATION(LESS, "LESS", comparisons::validate, comparisons::prepare, 250 comparisons::executeLess); 251 NN_REGISTER_OPERATION(LESS_EQUAL, "LESS_EQUAL", comparisons::validate, comparisons::prepare, 252 comparisons::executeLessEqual); 253 NN_REGISTER_OPERATION(EQUAL, "EQUAL", comparisons::validate, comparisons::prepare, 254 comparisons::executeEqual); 255 NN_REGISTER_OPERATION(NOT_EQUAL, "NOT_EQUAL", comparisons::validate, comparisons::prepare, 256 comparisons::executeNotEqual); 257 NN_REGISTER_OPERATION(GREATER_EQUAL, "GREATER_EQUAL", comparisons::validate, comparisons::prepare, 258 comparisons::executeGreaterEqual); 259 NN_REGISTER_OPERATION(GREATER, "GREATER", comparisons::validate, comparisons::prepare, 260 comparisons::executeGreater); 261 262 } // namespace nn 263 } // namespace android 264