1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "Operations" 18 19 #include "CpuOperationUtils.h" 20 #include "HalInterfaces.h" 21 #include "OperationResolver.h" 22 #include "Tracing.h" 23 24 #include <cmath> 25 #include <vector> 26 27 namespace android { 28 namespace nn { 29 namespace instance_normalization { 30 31 constexpr char kOperationName[] = "INSTANCE_NORMALIZATION"; 32 33 constexpr uint32_t kNumInputs = 5; 34 constexpr uint32_t kInputTensor = 0; 35 constexpr uint32_t kGammaScalar = 1; 36 constexpr uint32_t kBetaScalar = 2; 37 constexpr uint32_t kEpsilonScalar = 3; 38 constexpr uint32_t kLayoutScalar = 4; 39 40 constexpr uint32_t kNumOutputs = 1; 41 constexpr uint32_t kOutputTensor = 0; 42 43 namespace { 44 45 template <typename T> 46 inline bool instanceNormNhwc(const T* inputData, const Shape& inputShape, T gamma, T beta, 47 T epsilon, T* outputData, const Shape& outputShape) { 48 NNTRACE_TRANS("InstanceNormalizationNhwc"); 49 uint32_t numBatches = getSizeOfDimension(inputShape, 0); 50 uint32_t height = getSizeOfDimension(inputShape, 1); 51 uint32_t width = getSizeOfDimension(inputShape, 2); 52 uint32_t depth = getSizeOfDimension(inputShape, 3); 53 for (uint32_t b = 0; b < numBatches; b++) { 54 for (uint32_t d = 0; d < depth; d++) { 55 uint32_t indexBase = b * height * width * depth + d; 56 T mean = 0, var = 0; 57 for (uint32_t h = 0; h < height; h++) { 58 for (uint32_t w = 0; w < width; w++) { 59 T val = inputData[indexBase + (h * width + w) * depth]; 60 mean += val; 61 var += val * val; 62 } 63 } 64 mean /= static_cast<T>(height * width); 65 var = std::sqrt(static_cast<float>(var / static_cast<T>(height * width)) + epsilon); 66 for (uint32_t h = 0; h < height; h++) { 67 for (uint32_t w = 0; w < width; w++) { 68 uint32_t ind = indexBase + (h * width + w) * depth; 69 outputData[ind] = (inputData[ind] - mean) * gamma / var + beta; 70 } 71 } 72 } 73 } 74 return true; 75 } 76 77 template <typename T> 78 inline bool instanceNorm(const T* inputData, const Shape& inputShape, T gamma, T beta, T epsilon, 79 bool useNchw, T* outputData, const Shape& outputShape) { 80 InputWithLayout<T> input(useNchw); 81 OutputWithLayout<T> output(useNchw); 82 NN_RET_CHECK(input.initialize(inputData, inputShape)); 83 NN_RET_CHECK(output.initialize(outputData, outputShape)); 84 NN_RET_CHECK(instanceNormNhwc(input.getNhwcBuffer(), input.getNhwcShape(), gamma, beta, epsilon, 85 output.getNhwcBuffer(), output.getNhwcShape())); 86 NN_RET_CHECK(output.commit()); 87 return true; 88 } 89 90 } // namespace 91 92 bool validate(const IOperationValidationContext* context) { 93 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 94 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 95 std::vector<OperandType> inExpectedTypes; 96 auto inputType = context->getInputType(kInputTensor); 97 if (inputType == OperandType::TENSOR_FLOAT32) { 98 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::FLOAT32, OperandType::FLOAT32, 99 OperandType::FLOAT32, OperandType::BOOL}; 100 } else if (inputType == OperandType::TENSOR_FLOAT16) { 101 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16, 102 OperandType::FLOAT16, OperandType::BOOL}; 103 } else { 104 LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName; 105 return false; 106 } 107 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); 108 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 109 return validateHalVersion(context, HalVersion::V1_2); 110 } 111 112 bool prepare(IOperationExecutionContext* context) { 113 Shape input = context->getInputShape(kInputTensor); 114 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4); 115 return context->setOutputShape(kOutputTensor, input); 116 } 117 118 bool execute(IOperationExecutionContext* context) { 119 switch (context->getInputType(kInputTensor)) { 120 case OperandType::TENSOR_FLOAT16: 121 return instanceNorm(context->getInputBuffer<_Float16>(kInputTensor), 122 context->getInputShape(kInputTensor), 123 context->getInputValue<_Float16>(kGammaScalar), 124 context->getInputValue<_Float16>(kBetaScalar), 125 context->getInputValue<_Float16>(kEpsilonScalar), 126 context->getInputValue<bool>(kLayoutScalar), 127 context->getOutputBuffer<_Float16>(kOutputTensor), 128 context->getOutputShape(kOutputTensor)); 129 case OperandType::TENSOR_FLOAT32: 130 return instanceNorm(context->getInputBuffer<float>(kInputTensor), 131 context->getInputShape(kInputTensor), 132 context->getInputValue<float>(kGammaScalar), 133 context->getInputValue<float>(kBetaScalar), 134 context->getInputValue<float>(kEpsilonScalar), 135 context->getInputValue<bool>(kLayoutScalar), 136 context->getOutputBuffer<float>(kOutputTensor), 137 context->getOutputShape(kOutputTensor)); 138 default: 139 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; 140 } 141 } 142 143 } // namespace instance_normalization 144 145 NN_REGISTER_OPERATION(INSTANCE_NORMALIZATION, instance_normalization::kOperationName, 146 instance_normalization::validate, instance_normalization::prepare, 147 instance_normalization::execute); 148 149 } // namespace nn 150 } // namespace android 151