1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "HalInterfaces.h" 20 #include "IndexedShapeWrapper.h" 21 #include "OperationResolver.h" 22 #include "OperationsUtils.h" 23 24 namespace android { 25 namespace nn { 26 namespace select_op { 27 28 constexpr uint32_t kNumInputs = 3; 29 constexpr uint32_t kInputCondition = 0; 30 constexpr uint32_t kInputTensor1 = 1; 31 constexpr uint32_t kInputTensor2 = 2; 32 33 constexpr uint32_t kNumOutputs = 1; 34 constexpr uint32_t kOutputTensor = 0; 35 36 namespace { 37 38 template <typename T> 39 bool compute(const bool8* conditionData, const Shape& conditionShape, const T* aData, 40 const Shape& aShape, const T* bData, const Shape& bShape, T* outputData, 41 const Shape& outputShape) { 42 // The code assumes that condition has the same shape as all other tensors. 43 // This should be checked during preparation stage. 44 uint32_t size = getNumberOfElements(conditionShape); 45 for (uint32_t i = 0; i < size; ++i) { 46 T a = aData[i]; 47 T b = bData[i]; 48 if (aShape.type == OperandType::TENSOR_QUANT8_ASYMM) { 49 a = requantize(a, aShape, outputShape); 50 b = requantize(b, bShape, outputShape); 51 } 52 outputData[i] = conditionData[i] ? a : b; 53 } 54 return true; 55 } 56 57 template <typename T> 58 bool executeTyped(IOperationExecutionContext* context) { 59 return compute<T>( 60 context->getInputBuffer<bool8>(kInputCondition), 61 context->getInputShape(kInputCondition), context->getInputBuffer<T>(kInputTensor1), 62 context->getInputShape(kInputTensor1), context->getInputBuffer<T>(kInputTensor2), 63 context->getInputShape(kInputTensor2), context->getOutputBuffer<T>(kOutputTensor), 64 context->getOutputShape(kOutputTensor)); 65 } 66 67 } // namespace 68 69 bool validate(const IOperationValidationContext* context) { 70 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 71 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 72 OperandType inputType = context->getInputType(kInputTensor1); 73 NN_RET_CHECK( 74 inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || 75 inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) 76 << "Unsupported input operand type for select op: " << toString(inputType); 77 NN_RET_CHECK(validateInputTypes(context, {OperandType::TENSOR_BOOL8, inputType, inputType})); 78 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 79 return validateHalVersion(context, HalVersion::V1_2); 80 } 81 82 bool prepare(IOperationExecutionContext* context) { 83 Shape inputCondition = context->getInputShape(kInputCondition); 84 Shape input1 = context->getInputShape(kInputTensor1); 85 if (inputCondition.dimensions.size() != input1.dimensions.size()) { 86 LOG(ERROR) << "Condition and input tensor dimensions are not equal"; 87 return false; 88 } 89 for (int i = 0; i < inputCondition.dimensions.size(); ++i) { 90 if (inputCondition.dimensions[i] != input1.dimensions[i]) { 91 LOG(ERROR) << "Condition and input tensor dimensions are not equal"; 92 return false; 93 } 94 } 95 96 Shape input2 = context->getInputShape(kInputTensor2); 97 NN_RET_CHECK(SameShape(input1, input2)); 98 99 Shape output = context->getOutputShape(kOutputTensor); 100 NN_RET_CHECK(SetShape(input1, &output)); 101 return context->setOutputShape(kOutputTensor, output); 102 } 103 104 bool execute(IOperationExecutionContext* context) { 105 switch (context->getInputType(kInputTensor1)) { 106 case OperandType::TENSOR_FLOAT16: 107 return executeTyped<_Float16>(context); 108 case OperandType::TENSOR_FLOAT32: 109 return executeTyped<float>(context); 110 case OperandType::TENSOR_INT32: 111 return executeTyped<int32_t>(context); 112 case OperandType::TENSOR_QUANT8_ASYMM: 113 return executeTyped<uint8_t>(context); 114 default: 115 NN_RET_CHECK_FAIL() << "Unsupported tensor type for SELECT op."; 116 } 117 } 118 119 } // namespace select_op 120 121 NN_REGISTER_OPERATION(SELECT, "SELECT", select_op::validate, select_op::prepare, 122 select_op::execute); 123 124 } // namespace nn 125 } // namespace android 126