1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "CpuOperationUtils.h" 18 #include "OperationResolver.h" 19 20 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h" 21 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" 22 23 #include "Tracing.h" 24 25 namespace android { 26 namespace nn { 27 namespace transpose { 28 29 constexpr char kOperationName[] = "TRANSPOSE"; 30 31 constexpr uint32_t kNumInputs = 2; 32 constexpr uint32_t kInputTensor = 0; 33 constexpr uint32_t kPermTensor = 1; 34 35 constexpr uint32_t kNumOutputs = 1; 36 constexpr uint32_t kOutputTensor = 0; 37 38 namespace { 39 40 template <typename T> 41 bool transposeGeneric(const T* inputData, const Shape& inputShape, const int32_t* perm, 42 const Shape& permShape, T* outputData, const Shape& outputShape) { 43 NNTRACE_TRANS("transposeGeneric"); 44 // Reverse the permuted axes and convert to 4D due to the way Dims are 45 // constructed. 46 const int32_t kOutputDimensionNum = 4; 47 48 // permData can be NO_VALUE representing a regular 2D matrix transpose 49 int32_t permSize = perm == nullptr ? 2 : static_cast<int32_t>(getSizeOfDimension(permShape, 0)); 50 int32_t perm_tmp[2] = {1, 0}; 51 if (perm == nullptr) { 52 perm = perm_tmp; 53 } 54 int32_t reversed_perm[kOutputDimensionNum]; 55 for (int32_t output_k = 0, input_k = permSize - 1; output_k < permSize; ++output_k, --input_k) { 56 reversed_perm[output_k] = permSize - perm[input_k] - 1; 57 } 58 for (int32_t k = permSize; k < kOutputDimensionNum; ++k) { 59 reversed_perm[k] = k; 60 } 61 NNTRACE_COMP_SWITCH("reference_ops::Transpose"); 62 tflite::reference_ops::Transpose(inputData, convertShapeToDims(inputShape), outputData, 63 convertShapeToDims(outputShape), reversed_perm); 64 return true; 65 } 66 67 } // namespace 68 69 bool validate(const IOperationValidationContext* context) { 70 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 71 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 72 73 const OperandType inputType = context->getInputType(kInputTensor); 74 if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { 75 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_1)); 76 } else if (inputType == OperandType::TENSOR_FLOAT16) { 77 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); 78 } else { 79 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; 80 } 81 return validateInputTypes(context, {inputType, OperandType::TENSOR_INT32}) && 82 validateOutputTypes(context, {inputType}); 83 } 84 85 bool prepare(IOperationExecutionContext* context) { 86 // Only the permutation tensor can be omitted. 87 NN_RET_CHECK(!context->isOmittedInput(kInputTensor)); 88 NN_RET_CHECK(!context->isOmittedOutput(kOutputTensor)); 89 90 const Shape& input = context->getInputShape(kInputTensor); 91 uint32_t numInputDims = getNumberOfDimensions(input); 92 Shape output = context->getOutputShape(kOutputTensor); 93 output.type = input.type; 94 output.offset = input.offset; 95 output.scale = input.scale; 96 97 // permData can be NO_VALUE representing a regular 2D matrix transpose 98 if (context->isOmittedInput(kPermTensor)) { 99 NN_RET_CHECK_EQ(numInputDims, 2); 100 output.dimensions = {getSizeOfDimension(input, 1), getSizeOfDimension(input, 0)}; 101 } else { 102 const Shape& permShape = context->getInputShape(kPermTensor); 103 const int32_t* permData = context->getInputBuffer<int32_t>(kPermTensor); 104 105 // Transpose op only supports 1D-4D input arrays. 106 NN_RET_CHECK_LE(numInputDims, 4); 107 108 // perm need to be provided as a 1-D int32 tensor. 109 NN_RET_CHECK(permShape.type == OperandType::TENSOR_INT32); 110 NN_RET_CHECK_EQ(getNumberOfDimensions(permShape), 1); 111 NN_RET_CHECK_EQ(numInputDims, getSizeOfDimension(permShape, 0)); 112 113 std::vector<uint32_t> outDims(numInputDims); 114 for (int32_t idx = 0; idx < static_cast<int32_t>(numInputDims); ++idx) { 115 NN_RET_CHECK(permData[idx] >= 0 && permData[idx] < static_cast<int32_t>(numInputDims)); 116 outDims[idx] = getSizeOfDimension(input, permData[idx]); 117 } 118 output.dimensions = outDims; 119 } 120 return context->setOutputShape(kOutputTensor, output); 121 } 122 123 bool execute(IOperationExecutionContext* context) { 124 // Bypass execution in the case of zero-sized input. 125 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 126 127 switch (context->getInputType(kInputTensor)) { 128 case OperandType::TENSOR_FLOAT32: 129 return transposeGeneric(context->getInputBuffer<float>(kInputTensor), 130 context->getInputShape(kInputTensor), 131 context->getInputBuffer<int32_t>(kPermTensor), 132 context->getInputShape(kPermTensor), 133 context->getOutputBuffer<float>(kOutputTensor), 134 context->getOutputShape(kOutputTensor)); 135 case OperandType::TENSOR_FLOAT16: 136 return transposeGeneric(context->getInputBuffer<_Float16>(kInputTensor), 137 context->getInputShape(kInputTensor), 138 context->getInputBuffer<int32_t>(kPermTensor), 139 context->getInputShape(kPermTensor), 140 context->getOutputBuffer<_Float16>(kOutputTensor), 141 context->getOutputShape(kOutputTensor)); 142 case OperandType::TENSOR_QUANT8_ASYMM: 143 return transposeGeneric(context->getInputBuffer<uint8_t>(kInputTensor), 144 context->getInputShape(kInputTensor), 145 context->getInputBuffer<int32_t>(kPermTensor), 146 context->getInputShape(kPermTensor), 147 context->getOutputBuffer<uint8_t>(kOutputTensor), 148 context->getOutputShape(kOutputTensor)); 149 default: 150 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; 151 } 152 } 153 154 } // namespace transpose 155 156 NN_REGISTER_OPERATION(TRANSPOSE, transpose::kOperationName, transpose::validate, transpose::prepare, 157 transpose::execute, .allowOmittedOperand = true, .allowZeroSizedInput = true); 158 159 } // namespace nn 160 } // namespace android 161