1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "CpuOperationUtils.h" 18 #include "OperationResolver.h" 19 20 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h" 21 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" 22 23 #include "Tracing.h" 24 25 namespace android { 26 namespace nn { 27 namespace fully_connected { 28 29 constexpr char kOperationName[] = "FULLY_CONNECTED"; 30 31 constexpr uint32_t kNumInputs = 4; 32 constexpr uint32_t kInputTensor = 0; 33 constexpr uint32_t kWeightsTensor = 1; 34 constexpr uint32_t kBiasTensor = 2; 35 constexpr uint32_t kActivationScalar = 3; 36 37 constexpr uint32_t kNumOutputs = 1; 38 constexpr uint32_t kOutputTensor = 0; 39 40 namespace { 41 42 // executionMutex is used to protect concurrent access of non-threadsafe resources 43 // like gemmlowp::GemmContext. 44 // std::mutex is safe for pthreads on Android. 45 static std::mutex executionMutex; 46 47 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape, 48 const float* weightsData, const Shape& weightsShape, 49 const float* biasData, const Shape& biasShape, int32_t activation, 50 float* outputData, const Shape& outputShape) { 51 NNTRACE_TRANS("fullyConnectedFloat32"); 52 float output_activation_min, output_activation_max; 53 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max); 54 55 // b/80425683, optimized implementation produces incorrect results when the 56 // number of input elements is the squre of batch_size. 57 uint32_t batch_size = getSizeOfDimension(outputShape, 0); 58 uint32_t input_n_elements = getNumberOfElements(inputShape); 59 if (batch_size * batch_size == input_n_elements) { 60 NNTRACE_COMP_SWITCH("reference_ops::FullyConnected"); 61 tflite::reference_ops::FullyConnected(inputData, convertShapeToDims(inputShape), 62 weightsData, convertShapeToDims(weightsShape), 63 biasData, convertShapeToDims(biasShape), 64 output_activation_min, output_activation_max, 65 outputData, convertShapeToDims(outputShape)); 66 } else { 67 NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected"); 68 tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape), 69 weightsData, convertShapeToDims(weightsShape), 70 biasData, convertShapeToDims(biasShape), 71 output_activation_min, output_activation_max, 72 outputData, convertShapeToDims(outputShape)); 73 } 74 return true; 75 } 76 77 bool fullyConnectedFloat16(const _Float16* inputData, const Shape& inputShape, 78 const _Float16* weightsData, const Shape& weightsShape, 79 const _Float16* biasData, const Shape& biasShape, int32_t activation, 80 _Float16* outputData, const Shape& outputShape) { 81 NNTRACE_TRANS("fullyConnectedFloat16"); 82 std::vector<float> inputDataFloat32(getNumberOfElements(inputShape)); 83 convertFloat16ToFloat32(inputData, &inputDataFloat32); 84 std::vector<float> weightsDataFloat32(getNumberOfElements(weightsShape)); 85 convertFloat16ToFloat32(weightsData, &weightsDataFloat32); 86 std::vector<float> biasDataFloat32(getNumberOfElements(biasShape)); 87 convertFloat16ToFloat32(biasData, &biasDataFloat32); 88 89 std::vector<float> outputDataFloat32(getNumberOfElements(outputShape)); 90 fullyConnectedFloat32(inputDataFloat32.data(), inputShape, weightsDataFloat32.data(), 91 weightsShape, biasDataFloat32.data(), biasShape, activation, 92 outputDataFloat32.data(), outputShape); 93 convertFloat32ToFloat16(outputDataFloat32, outputData); 94 95 return true; 96 } 97 98 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape, 99 const uint8_t* weightsData, const Shape& weightsShape, 100 const int32_t* biasData, const Shape& biasShape, int32_t activation, 101 uint8_t* outputData, const Shape& outputShape) { 102 NNTRACE_TRANS("fullyConnectedQuant8"); 103 int32_t inputOffset = -inputShape.offset; 104 int32_t weightsOffset = -weightsShape.offset; 105 int32_t outputOffset = outputShape.offset; 106 107 double realMultiplier = 0.0; 108 int32_t outputMultiplier = 0; 109 int32_t outputShift = 0; 110 int32_t outputActivationMin = 0; 111 int32_t outputActivationMax = 0; 112 113 NN_RET_CHECK(GetQuantizedConvolutionMultipler(inputShape, weightsShape, biasShape, outputShape, 114 &realMultiplier)); 115 int exponent; 116 NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &exponent)); 117 outputShift = -exponent; 118 CalculateActivationRangeUint8(activation, outputShape, &outputActivationMin, 119 &outputActivationMax); 120 121 static gemmlowp::GemmContext gemmContext; 122 123 // Prevent concurrent executions that access gemmContext. 124 std::unique_lock<std::mutex> lock(executionMutex); 125 // Alow gemmlowp automatically decide how many threads to use. 126 gemmContext.set_max_num_threads(0); 127 128 NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected"); 129 tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape), inputOffset, 130 weightsData, convertShapeToDims(weightsShape), 131 weightsOffset, biasData, convertShapeToDims(biasShape), 132 outputOffset, outputMultiplier, outputShift, 133 outputActivationMin, outputActivationMax, outputData, 134 convertShapeToDims(outputShape), &gemmContext); 135 136 return true; 137 } 138 139 } // namespace 140 141 bool validate(const IOperationValidationContext* context) { 142 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 143 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 144 auto inputType = context->getInputType(kInputTensor); 145 std::vector<OperandType> inExpectedTypes; 146 std::vector<OperandType> outExpectedTypes; 147 if (inputType == OperandType::TENSOR_FLOAT32) { 148 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0)); 149 inExpectedTypes = { 150 OperandType::TENSOR_FLOAT32, 151 OperandType::TENSOR_FLOAT32, 152 OperandType::TENSOR_FLOAT32, 153 OperandType::INT32, 154 }; 155 } else if (inputType == OperandType::TENSOR_FLOAT16) { 156 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); 157 inExpectedTypes = { 158 OperandType::TENSOR_FLOAT16, 159 OperandType::TENSOR_FLOAT16, 160 OperandType::TENSOR_FLOAT16, 161 OperandType::INT32, 162 }; 163 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { 164 // NeuralNetworks.h specifies that ANEURALNETWORKS_FULLY_CONNECTED's output must 165 // meet "outputScale > inputScale * weightsScale" for the operand type 166 // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM before API level 29. 167 const float inputScale = context->getInputShape(kInputTensor).scale; 168 const float weightsScale = context->getInputShape(kWeightsTensor).scale; 169 const float outputScale = context->getOutputShape(kOutputTensor).scale; 170 bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale); 171 172 if (!meetsQuantizedScaleConstraintBeforeV1_2) { 173 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); 174 } else { 175 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0)); 176 } 177 178 inExpectedTypes = { 179 OperandType::TENSOR_QUANT8_ASYMM, 180 OperandType::TENSOR_QUANT8_ASYMM, 181 OperandType::TENSOR_INT32, 182 OperandType::INT32, 183 }; 184 } else { 185 NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName; 186 return false; 187 } 188 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); 189 NN_RET_CHECK(validateOutputTypes(context, {inputType})); 190 return true; 191 } 192 193 bool prepare(IOperationExecutionContext* context) { 194 Shape input = context->getInputShape(kInputTensor); 195 Shape weights = context->getInputShape(kWeightsTensor); 196 Shape bias = context->getInputShape(kBiasTensor); 197 198 // Check all the parameters of tensor match within themselves and match the 199 // input configuration. 200 NN_RET_CHECK(input.type == weights.type); 201 if (input.type == OperandType::TENSOR_QUANT8_ASYMM) { 202 NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32); 203 } else { 204 NN_RET_CHECK(input.type == bias.type); 205 } 206 // The Tensorflow fully connected layer specification says that input should 207 // be of at least rank 2, so we check. Tflite doesn't check. 208 NN_RET_CHECK_GE(getNumberOfDimensions(input), 2); 209 NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2); 210 uint32_t input_n_elements = getNumberOfElements(input); 211 uint32_t num_units = getSizeOfDimension(weights, 0); 212 uint32_t input_size = getSizeOfDimension(weights, 1); 213 // Only batch_size can be 0. 214 NN_RET_CHECK_GT(num_units, 0); 215 NN_RET_CHECK_GT(input_size, 0); 216 uint32_t batch_size = input_n_elements / input_size; 217 NN_RET_CHECK_EQ(getSizeOfDimension(bias, 0), num_units); 218 NN_RET_CHECK_EQ(input_size * batch_size, input_n_elements); 219 220 Shape output = context->getOutputShape(kOutputTensor); 221 output.type = input.type; 222 output.dimensions = {batch_size, num_units}; 223 return context->setOutputShape(kOutputTensor, output); 224 } 225 226 bool execute(IOperationExecutionContext* context) { 227 // Bypass execution in the case of zero-sized input. 228 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 229 switch (context->getInputType(kInputTensor)) { 230 case OperandType::TENSOR_FLOAT32: 231 return fullyConnectedFloat32(context->getInputBuffer<float>(kInputTensor), 232 context->getInputShape(kInputTensor), 233 context->getInputBuffer<float>(kWeightsTensor), 234 context->getInputShape(kWeightsTensor), 235 context->getInputBuffer<float>(kBiasTensor), 236 context->getInputShape(kBiasTensor), 237 context->getInputValue<int32_t>(kActivationScalar), 238 context->getOutputBuffer<float>(kOutputTensor), 239 context->getOutputShape(kOutputTensor)); 240 case OperandType::TENSOR_FLOAT16: 241 return fullyConnectedFloat16(context->getInputBuffer<_Float16>(kInputTensor), 242 context->getInputShape(kInputTensor), 243 context->getInputBuffer<_Float16>(kWeightsTensor), 244 context->getInputShape(kWeightsTensor), 245 context->getInputBuffer<_Float16>(kBiasTensor), 246 context->getInputShape(kBiasTensor), 247 context->getInputValue<int32_t>(kActivationScalar), 248 context->getOutputBuffer<_Float16>(kOutputTensor), 249 context->getOutputShape(kOutputTensor)); 250 case OperandType::TENSOR_QUANT8_ASYMM: 251 return fullyConnectedQuant8(context->getInputBuffer<uint8_t>(kInputTensor), 252 context->getInputShape(kInputTensor), 253 context->getInputBuffer<uint8_t>(kWeightsTensor), 254 context->getInputShape(kWeightsTensor), 255 context->getInputBuffer<int32_t>(kBiasTensor), 256 context->getInputShape(kBiasTensor), 257 context->getInputValue<int32_t>(kActivationScalar), 258 context->getOutputBuffer<uint8_t>(kOutputTensor), 259 context->getOutputShape(kOutputTensor)); 260 default: 261 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName; 262 } 263 } 264 265 } // namespace fully_connected 266 267 NN_REGISTER_OPERATION(FULLY_CONNECTED, fully_connected::kOperationName, fully_connected::validate, 268 fully_connected::prepare, fully_connected::execute, 269 .allowZeroSizedInput = true); 270 271 } // namespace nn 272 } // namespace android 273