Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "Operations"
     18 
     19 #include "CpuOperationUtils.h"
     20 #include "HalInterfaces.h"
     21 #include "OperationResolver.h"
     22 #include "Tracing.h"
     23 
     24 #include <cmath>
     25 #include <vector>
     26 
     27 namespace android {
     28 namespace nn {
     29 namespace instance_normalization {
     30 
     31 constexpr char kOperationName[] = "INSTANCE_NORMALIZATION";
     32 
     33 constexpr uint32_t kNumInputs = 5;
     34 constexpr uint32_t kInputTensor = 0;
     35 constexpr uint32_t kGammaScalar = 1;
     36 constexpr uint32_t kBetaScalar = 2;
     37 constexpr uint32_t kEpsilonScalar = 3;
     38 constexpr uint32_t kLayoutScalar = 4;
     39 
     40 constexpr uint32_t kNumOutputs = 1;
     41 constexpr uint32_t kOutputTensor = 0;
     42 
     43 namespace {
     44 
     45 template <typename T>
     46 inline bool instanceNormNhwc(const T* inputData, const Shape& inputShape, T gamma, T beta,
     47                              T epsilon, T* outputData, const Shape& outputShape) {
     48     NNTRACE_TRANS("InstanceNormalizationNhwc");
     49     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
     50     uint32_t height = getSizeOfDimension(inputShape, 1);
     51     uint32_t width = getSizeOfDimension(inputShape, 2);
     52     uint32_t depth = getSizeOfDimension(inputShape, 3);
     53     for (uint32_t b = 0; b < numBatches; b++) {
     54         for (uint32_t d = 0; d < depth; d++) {
     55             uint32_t indexBase = b * height * width * depth + d;
     56             T mean = 0, var = 0;
     57             for (uint32_t h = 0; h < height; h++) {
     58                 for (uint32_t w = 0; w < width; w++) {
     59                     T val = inputData[indexBase + (h * width + w) * depth];
     60                     mean += val;
     61                     var += val * val;
     62                 }
     63             }
     64             mean /= static_cast<T>(height * width);
     65             var = std::sqrt(static_cast<float>(var / static_cast<T>(height * width)) + epsilon);
     66             for (uint32_t h = 0; h < height; h++) {
     67                 for (uint32_t w = 0; w < width; w++) {
     68                     uint32_t ind = indexBase + (h * width + w) * depth;
     69                     outputData[ind] = (inputData[ind] - mean) * gamma / var + beta;
     70                 }
     71             }
     72         }
     73     }
     74     return true;
     75 }
     76 
     77 template <typename T>
     78 inline bool instanceNorm(const T* inputData, const Shape& inputShape, T gamma, T beta, T epsilon,
     79                          bool useNchw, T* outputData, const Shape& outputShape) {
     80     InputWithLayout<T> input(useNchw);
     81     OutputWithLayout<T> output(useNchw);
     82     NN_RET_CHECK(input.initialize(inputData, inputShape));
     83     NN_RET_CHECK(output.initialize(outputData, outputShape));
     84     NN_RET_CHECK(instanceNormNhwc(input.getNhwcBuffer(), input.getNhwcShape(), gamma, beta, epsilon,
     85                                   output.getNhwcBuffer(), output.getNhwcShape()));
     86     NN_RET_CHECK(output.commit());
     87     return true;
     88 }
     89 
     90 }  // namespace
     91 
     92 bool validate(const IOperationValidationContext* context) {
     93     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
     94     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
     95     std::vector<OperandType> inExpectedTypes;
     96     auto inputType = context->getInputType(kInputTensor);
     97     if (inputType == OperandType::TENSOR_FLOAT32) {
     98         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::FLOAT32, OperandType::FLOAT32,
     99                            OperandType::FLOAT32, OperandType::BOOL};
    100     } else if (inputType == OperandType::TENSOR_FLOAT16) {
    101         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16,
    102                            OperandType::FLOAT16, OperandType::BOOL};
    103     } else {
    104         LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
    105         return false;
    106     }
    107     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
    108     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
    109     return validateHalVersion(context, HalVersion::V1_2);
    110 }
    111 
    112 bool prepare(IOperationExecutionContext* context) {
    113     Shape input = context->getInputShape(kInputTensor);
    114     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
    115     return context->setOutputShape(kOutputTensor, input);
    116 }
    117 
    118 bool execute(IOperationExecutionContext* context) {
    119     switch (context->getInputType(kInputTensor)) {
    120         case OperandType::TENSOR_FLOAT16:
    121             return instanceNorm(context->getInputBuffer<_Float16>(kInputTensor),
    122                                 context->getInputShape(kInputTensor),
    123                                 context->getInputValue<_Float16>(kGammaScalar),
    124                                 context->getInputValue<_Float16>(kBetaScalar),
    125                                 context->getInputValue<_Float16>(kEpsilonScalar),
    126                                 context->getInputValue<bool>(kLayoutScalar),
    127                                 context->getOutputBuffer<_Float16>(kOutputTensor),
    128                                 context->getOutputShape(kOutputTensor));
    129         case OperandType::TENSOR_FLOAT32:
    130             return instanceNorm(context->getInputBuffer<float>(kInputTensor),
    131                                 context->getInputShape(kInputTensor),
    132                                 context->getInputValue<float>(kGammaScalar),
    133                                 context->getInputValue<float>(kBetaScalar),
    134                                 context->getInputValue<float>(kEpsilonScalar),
    135                                 context->getInputValue<bool>(kLayoutScalar),
    136                                 context->getOutputBuffer<float>(kOutputTensor),
    137                                 context->getOutputShape(kOutputTensor));
    138         default:
    139             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
    140     }
    141 }
    142 
    143 }  // namespace instance_normalization
    144 
    145 NN_REGISTER_OPERATION(INSTANCE_NORMALIZATION, instance_normalization::kOperationName,
    146                       instance_normalization::validate, instance_normalization::prepare,
    147                       instance_normalization::execute);
    148 
    149 }  // namespace nn
    150 }  // namespace android
    151