1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "HalInterfaces.h" 20 #include "OperationResolver.h" 21 #include "OperationsUtils.h" 22 #include "Tracing.h" 23 24 namespace android { 25 namespace nn { 26 namespace channel_shuffle { 27 28 constexpr char kOperationName[] = "CHANNEL_SHUFFLE"; 29 30 constexpr uint32_t kNumInputs = 3; 31 constexpr uint32_t kInputTensor = 0; 32 constexpr uint32_t kNumGroups = 1; 33 constexpr uint32_t kInputAxis = 2; 34 35 constexpr uint32_t kNumOutputs = 1; 36 constexpr uint32_t kOutputTensor = 0; 37 38 template <typename T> 39 inline bool eval(const T* inputData, const Shape& inputShape, int32_t numGroups, int32_t axis, 40 T* outputData) { 41 const uint32_t outerSize = getNumberOfElements(inputShape, 0, axis); 42 const uint32_t axisSize = getSizeOfDimension(inputShape, axis); 43 const uint32_t innerSize = 44 getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape)); 45 const uint32_t groupSize = axisSize / numGroups; 46 for (uint32_t outer = 0; outer < outerSize; ++outer) { 47 for (uint32_t inner = 0; inner < innerSize; ++inner) { 48 const T* inputBase = inputData + outer * axisSize * innerSize + inner; 49 T* outputBase = outputData + outer * axisSize * innerSize + inner; 50 for (uint32_t i = 0; i < groupSize; i++) { 51 for (uint32_t j = 0; j < static_cast<uint32_t>(numGroups); 52 j++, outputBase += innerSize) { 53 *outputBase = inputBase[innerSize * (i + j * groupSize)]; 54 } 55 } 56 } 57 } 58 return true; 59 } 60 61 bool validate(const IOperationValidationContext* context) { 62 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 63 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 64 auto inputType = context->getInputType(kInputTensor); 65 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 || 66 inputType == OperandType::TENSOR_FLOAT32 || 67 inputType == OperandType::TENSOR_QUANT8_ASYMM) 68 << "Unsupported tensor type for operation " << kOperationName; 69 NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::INT32, OperandType::INT32})); 70 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 71 return validateHalVersion(context, HalVersion::V1_2); 72 } 73 74 bool prepare(IOperationExecutionContext* context) { 75 Shape input = context->getInputShape(kInputTensor); 76 int32_t numGroups = context->getInputValue<int32_t>(kNumGroups); 77 int32_t axis = context->getInputValue<int32_t>(kInputAxis); 78 NN_RET_CHECK(handleNegativeAxis(input, &axis)); 79 NN_RET_CHECK(numGroups > 0); 80 NN_RET_CHECK(getSizeOfDimension(input, axis) % numGroups == 0); 81 return context->setOutputShape(kOutputTensor, input); 82 } 83 84 bool execute(IOperationExecutionContext* context) { 85 int32_t numGroups = context->getInputValue<int32_t>(kNumGroups); 86 int32_t axis = context->getInputValue<int32_t>(kInputAxis); 87 NN_RET_CHECK(handleNegativeAxis(context->getInputShape(kInputTensor), &axis)); 88 switch (context->getInputType(kInputTensor)) { 89 case OperandType::TENSOR_FLOAT16: 90 return eval(context->getInputBuffer<_Float16>(kInputTensor), 91 context->getInputShape(kInputTensor), numGroups, axis, 92 context->getOutputBuffer<_Float16>(kOutputTensor)); 93 case OperandType::TENSOR_FLOAT32: 94 return eval(context->getInputBuffer<float>(kInputTensor), 95 context->getInputShape(kInputTensor), numGroups, axis, 96 context->getOutputBuffer<float>(kOutputTensor)); 97 case OperandType::TENSOR_QUANT8_ASYMM: 98 return eval(context->getInputBuffer<uint8_t>(kInputTensor), 99 context->getInputShape(kInputTensor), numGroups, axis, 100 context->getOutputBuffer<uint8_t>(kOutputTensor)); 101 default: 102 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; 103 } 104 } 105 106 } // namespace channel_shuffle 107 108 NN_REGISTER_OPERATION(CHANNEL_SHUFFLE, channel_shuffle::kOperationName, channel_shuffle::validate, 109 channel_shuffle::prepare, channel_shuffle::execute); 110 111 } // namespace nn 112 } // namespace android 113