Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "CpuOperationUtils.h"
     18 #include "OperationResolver.h"
     19 
     20 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
     21 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
     22 
     23 #include "Tracing.h"
     24 
     25 namespace android {
     26 namespace nn {
     27 namespace fully_connected {
     28 
     29 constexpr char kOperationName[] = "FULLY_CONNECTED";
     30 
     31 constexpr uint32_t kNumInputs = 4;
     32 constexpr uint32_t kInputTensor = 0;
     33 constexpr uint32_t kWeightsTensor = 1;
     34 constexpr uint32_t kBiasTensor = 2;
     35 constexpr uint32_t kActivationScalar = 3;
     36 
     37 constexpr uint32_t kNumOutputs = 1;
     38 constexpr uint32_t kOutputTensor = 0;
     39 
     40 namespace {
     41 
     42 // executionMutex is used to protect concurrent access of non-threadsafe resources
     43 // like gemmlowp::GemmContext.
     44 // std::mutex is safe for pthreads on Android.
     45 static std::mutex executionMutex;
     46 
     47 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
     48                            const float* weightsData, const Shape& weightsShape,
     49                            const float* biasData, const Shape& biasShape, int32_t activation,
     50                            float* outputData, const Shape& outputShape) {
     51     NNTRACE_TRANS("fullyConnectedFloat32");
     52     float output_activation_min, output_activation_max;
     53     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
     54 
     55     // b/80425683, optimized implementation produces incorrect results when the
     56     // number of input elements is the squre of batch_size.
     57     uint32_t batch_size = getSizeOfDimension(outputShape, 0);
     58     uint32_t input_n_elements = getNumberOfElements(inputShape);
     59     if (batch_size * batch_size == input_n_elements) {
     60         NNTRACE_COMP_SWITCH("reference_ops::FullyConnected");
     61         tflite::reference_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
     62                                               weightsData, convertShapeToDims(weightsShape),
     63                                               biasData, convertShapeToDims(biasShape),
     64                                               output_activation_min, output_activation_max,
     65                                               outputData, convertShapeToDims(outputShape));
     66     } else {
     67         NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
     68         tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
     69                                               weightsData, convertShapeToDims(weightsShape),
     70                                               biasData, convertShapeToDims(biasShape),
     71                                               output_activation_min, output_activation_max,
     72                                               outputData, convertShapeToDims(outputShape));
     73     }
     74     return true;
     75 }
     76 
     77 bool fullyConnectedFloat16(const _Float16* inputData, const Shape& inputShape,
     78                            const _Float16* weightsData, const Shape& weightsShape,
     79                            const _Float16* biasData, const Shape& biasShape, int32_t activation,
     80                            _Float16* outputData, const Shape& outputShape) {
     81     NNTRACE_TRANS("fullyConnectedFloat16");
     82     std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
     83     convertFloat16ToFloat32(inputData, &inputDataFloat32);
     84     std::vector<float> weightsDataFloat32(getNumberOfElements(weightsShape));
     85     convertFloat16ToFloat32(weightsData, &weightsDataFloat32);
     86     std::vector<float> biasDataFloat32(getNumberOfElements(biasShape));
     87     convertFloat16ToFloat32(biasData, &biasDataFloat32);
     88 
     89     std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
     90     fullyConnectedFloat32(inputDataFloat32.data(), inputShape, weightsDataFloat32.data(),
     91                           weightsShape, biasDataFloat32.data(), biasShape, activation,
     92                           outputDataFloat32.data(), outputShape);
     93     convertFloat32ToFloat16(outputDataFloat32, outputData);
     94 
     95     return true;
     96 }
     97 
     98 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
     99                           const uint8_t* weightsData, const Shape& weightsShape,
    100                           const int32_t* biasData, const Shape& biasShape, int32_t activation,
    101                           uint8_t* outputData, const Shape& outputShape) {
    102     NNTRACE_TRANS("fullyConnectedQuant8");
    103     int32_t inputOffset = -inputShape.offset;
    104     int32_t weightsOffset = -weightsShape.offset;
    105     int32_t outputOffset = outputShape.offset;
    106 
    107     double realMultiplier = 0.0;
    108     int32_t outputMultiplier = 0;
    109     int32_t outputShift = 0;
    110     int32_t outputActivationMin = 0;
    111     int32_t outputActivationMax = 0;
    112 
    113     NN_RET_CHECK(GetQuantizedConvolutionMultipler(inputShape, weightsShape, biasShape, outputShape,
    114                                                   &realMultiplier));
    115     int exponent;
    116     NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &exponent));
    117     outputShift = -exponent;
    118     CalculateActivationRangeUint8(activation, outputShape, &outputActivationMin,
    119                                   &outputActivationMax);
    120 
    121     static gemmlowp::GemmContext gemmContext;
    122 
    123     // Prevent concurrent executions that access gemmContext.
    124     std::unique_lock<std::mutex> lock(executionMutex);
    125     // Alow gemmlowp automatically decide how many threads to use.
    126     gemmContext.set_max_num_threads(0);
    127 
    128     NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
    129     tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape), inputOffset,
    130                                           weightsData, convertShapeToDims(weightsShape),
    131                                           weightsOffset, biasData, convertShapeToDims(biasShape),
    132                                           outputOffset, outputMultiplier, outputShift,
    133                                           outputActivationMin, outputActivationMax, outputData,
    134                                           convertShapeToDims(outputShape), &gemmContext);
    135 
    136     return true;
    137 }
    138 
    139 }  // namespace
    140 
    141 bool validate(const IOperationValidationContext* context) {
    142     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    143     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    144     auto inputType = context->getInputType(kInputTensor);
    145     std::vector<OperandType> inExpectedTypes;
    146     std::vector<OperandType> outExpectedTypes;
    147     if (inputType == OperandType::TENSOR_FLOAT32) {
    148         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
    149         inExpectedTypes = {
    150                 OperandType::TENSOR_FLOAT32,
    151                 OperandType::TENSOR_FLOAT32,
    152                 OperandType::TENSOR_FLOAT32,
    153                 OperandType::INT32,
    154         };
    155     } else if (inputType == OperandType::TENSOR_FLOAT16) {
    156         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    157         inExpectedTypes = {
    158                 OperandType::TENSOR_FLOAT16,
    159                 OperandType::TENSOR_FLOAT16,
    160                 OperandType::TENSOR_FLOAT16,
    161                 OperandType::INT32,
    162         };
    163     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
    164         // NeuralNetworks.h specifies that ANEURALNETWORKS_FULLY_CONNECTED's output must
    165         // meet "outputScale > inputScale * weightsScale" for the operand type
    166         // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM before API level 29.
    167         const float inputScale = context->getInputShape(kInputTensor).scale;
    168         const float weightsScale = context->getInputShape(kWeightsTensor).scale;
    169         const float outputScale = context->getOutputShape(kOutputTensor).scale;
    170         bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale);
    171 
    172         if (!meetsQuantizedScaleConstraintBeforeV1_2) {
    173             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
    174         } else {
    175             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
    176         }
    177 
    178         inExpectedTypes = {
    179                 OperandType::TENSOR_QUANT8_ASYMM,
    180                 OperandType::TENSOR_QUANT8_ASYMM,
    181                 OperandType::TENSOR_INT32,
    182                 OperandType::INT32,
    183         };
    184     } else {
    185         NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName;
    186         return false;
    187     }
    188     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
    189     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
    190     return true;
    191 }
    192 
    193 bool prepare(IOperationExecutionContext* context) {
    194     Shape input = context->getInputShape(kInputTensor);
    195     Shape weights = context->getInputShape(kWeightsTensor);
    196     Shape bias = context->getInputShape(kBiasTensor);
    197 
    198     // Check all the parameters of tensor match within themselves and match the
    199     // input configuration.
    200     NN_RET_CHECK(input.type == weights.type);
    201     if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
    202         NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32);
    203     } else {
    204         NN_RET_CHECK(input.type == bias.type);
    205     }
    206     // The Tensorflow fully connected layer specification says that input should
    207     // be of at least rank 2, so we check. Tflite doesn't check.
    208     NN_RET_CHECK_GE(getNumberOfDimensions(input), 2);
    209     NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
    210     uint32_t input_n_elements = getNumberOfElements(input);
    211     uint32_t num_units = getSizeOfDimension(weights, 0);
    212     uint32_t input_size = getSizeOfDimension(weights, 1);
    213     // Only batch_size can be 0.
    214     NN_RET_CHECK_GT(num_units, 0);
    215     NN_RET_CHECK_GT(input_size, 0);
    216     uint32_t batch_size = input_n_elements / input_size;
    217     NN_RET_CHECK_EQ(getSizeOfDimension(bias, 0), num_units);
    218     NN_RET_CHECK_EQ(input_size * batch_size, input_n_elements);
    219 
    220     Shape output = context->getOutputShape(kOutputTensor);
    221     output.type = input.type;
    222     output.dimensions = {batch_size, num_units};
    223     return context->setOutputShape(kOutputTensor, output);
    224 }
    225 
    226 bool execute(IOperationExecutionContext* context) {
    227     // Bypass execution in the case of zero-sized input.
    228     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
    229     switch (context->getInputType(kInputTensor)) {
    230         case OperandType::TENSOR_FLOAT32:
    231             return fullyConnectedFloat32(context->getInputBuffer<float>(kInputTensor),
    232                                          context->getInputShape(kInputTensor),
    233                                          context->getInputBuffer<float>(kWeightsTensor),
    234                                          context->getInputShape(kWeightsTensor),
    235                                          context->getInputBuffer<float>(kBiasTensor),
    236                                          context->getInputShape(kBiasTensor),
    237                                          context->getInputValue<int32_t>(kActivationScalar),
    238                                          context->getOutputBuffer<float>(kOutputTensor),
    239                                          context->getOutputShape(kOutputTensor));
    240         case OperandType::TENSOR_FLOAT16:
    241             return fullyConnectedFloat16(context->getInputBuffer<_Float16>(kInputTensor),
    242                                          context->getInputShape(kInputTensor),
    243                                          context->getInputBuffer<_Float16>(kWeightsTensor),
    244                                          context->getInputShape(kWeightsTensor),
    245                                          context->getInputBuffer<_Float16>(kBiasTensor),
    246                                          context->getInputShape(kBiasTensor),
    247                                          context->getInputValue<int32_t>(kActivationScalar),
    248                                          context->getOutputBuffer<_Float16>(kOutputTensor),
    249                                          context->getOutputShape(kOutputTensor));
    250         case OperandType::TENSOR_QUANT8_ASYMM:
    251             return fullyConnectedQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
    252                                         context->getInputShape(kInputTensor),
    253                                         context->getInputBuffer<uint8_t>(kWeightsTensor),
    254                                         context->getInputShape(kWeightsTensor),
    255                                         context->getInputBuffer<int32_t>(kBiasTensor),
    256                                         context->getInputShape(kBiasTensor),
    257                                         context->getInputValue<int32_t>(kActivationScalar),
    258                                         context->getOutputBuffer<uint8_t>(kOutputTensor),
    259                                         context->getOutputShape(kOutputTensor));
    260         default:
    261             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
    262     }
    263 }
    264 
    265 }  // namespace fully_connected
    266 
    267 NN_REGISTER_OPERATION(FULLY_CONNECTED, fully_connected::kOperationName, fully_connected::validate,
    268                       fully_connected::prepare, fully_connected::execute,
    269                       .allowZeroSizedInput = true);
    270 
    271 }  // namespace nn
    272 }  // namespace android
    273