1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "ActivationFunctor.h" 18 #include "CpuOperationUtils.h" 19 #include "OperationResolver.h" 20 21 #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h" 22 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" 23 24 #include "Tracing.h" 25 26 namespace android { 27 namespace nn { 28 29 namespace activation { 30 31 constexpr uint32_t kNumInputs = 1; 32 constexpr uint32_t kInputTensor = 0; 33 34 constexpr uint32_t kNumOutputs = 1; 35 constexpr uint32_t kOutputTensor = 0; 36 37 namespace { 38 39 template <typename T> 40 bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape, 41 float reluMin = 0.f, float reluMax = std::numeric_limits<float>::max()) { 42 NNTRACE_COMP("reluX"); 43 int numElements = getNumberOfElements(inputShape); 44 for (int i = 0; i < numElements; i++, inputData++, outputData++) { 45 *outputData = static_cast<T>( 46 std::min(std::max(reluMin, static_cast<float>(*inputData)), reluMax)); 47 } 48 return true; 49 } 50 template bool reluFloat<float>(const float* inputData, const Shape& inputShape, float* outputData, 51 const Shape& outputShape, float reluMin, float reluMax); 52 template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape, 53 _Float16* outputData, const Shape& outputShape, float reluMin, 54 float reluMax); 55 56 template <typename T> 57 bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData, 58 const Shape& outputShape) { 59 return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f); 60 } 61 template bool relu1Float<float>(const float* inputData, const Shape& inputShape, float* outputData, 62 const Shape& outputShape); 63 template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape, 64 _Float16* outputData, const Shape& outputShape); 65 66 template <typename T> 67 bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData, 68 const Shape& outputShape) { 69 return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f); 70 } 71 template bool relu6Float<float>(const float* inputData, const Shape& inputShape, float* outputData, 72 const Shape& outputShape); 73 template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape, 74 _Float16* outputData, const Shape& outputShape); 75 76 bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData, 77 const Shape& outputShape) { 78 NNTRACE_COMP("tanhFloat16"); 79 int numElements = getNumberOfElements(inputShape); 80 for (int i = 0; i < numElements; i++, inputData++, outputData++) { 81 *outputData = static_cast<_Float16>(std::tanh(static_cast<float>(*inputData))); 82 } 83 return true; 84 } 85 86 bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData, 87 const Shape& outputShape) { 88 NNTRACE_COMP("tanhFloat32"); 89 int numElements = getNumberOfElements(inputShape); 90 for (int i = 0; i < numElements; i++, inputData++, outputData++) { 91 *outputData = std::tanh(*inputData); 92 } 93 return true; 94 } 95 96 template <typename T> 97 bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData, 98 const Shape& outputShape) { 99 NNTRACE_COMP("logisticFloat"); 100 int numElements = getNumberOfElements(inputShape); 101 for (int i = 0; i < numElements; i++, inputData++, outputData++) { 102 *outputData = static_cast<T>(1.f / (1.f + std::exp(static_cast<float>(-*inputData)))); 103 } 104 return true; 105 } 106 template bool logisticFloat<float>(const float* inputData, const Shape& inputShape, 107 float* outputData, const Shape& outputShape); 108 template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape, 109 _Float16* outputData, const Shape& outputShape); 110 111 #define ANDROID_NN_RELUX_QUANT8(activation) \ 112 int numElements = getNumberOfElements(inputShape); \ 113 int32_t output_activation_min = 0; \ 114 int32_t output_activation_max = 0; \ 115 \ 116 CalculateActivationRangeUint8(activation, inputShape, &output_activation_min, \ 117 &output_activation_max); \ 118 \ 119 for (int i = 0; i < numElements; i++, inputData++, outputData++) { \ 120 *outputData = std::min((uint8_t)output_activation_max, \ 121 std::max((uint8_t)output_activation_min, *inputData)); \ 122 } 123 124 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, 125 const Shape& outputShape) { 126 NNTRACE_COMP("reluQuant8"); 127 ANDROID_NN_RELUX_QUANT8(kActivationRelu) 128 return true; 129 } 130 131 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, 132 const Shape& outputShape) { 133 NNTRACE_COMP("relu1Quant8"); 134 ANDROID_NN_RELUX_QUANT8(kActivationRelu1) 135 return true; 136 } 137 138 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, 139 const Shape& outputShape) { 140 NNTRACE_COMP("relu6Quant8"); 141 ANDROID_NN_RELUX_QUANT8(kActivationRelu6) 142 return true; 143 } 144 145 #undef ANDROID_NN_RELUX_QUANT8 146 147 bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, 148 const Shape& outputShape) { 149 NNTRACE_TRANS("tanhQuant8"); 150 if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) { 151 LOG(ERROR) << "incorrect scale or offset for TANH output"; 152 return false; 153 } 154 155 int numElements = getNumberOfElements(inputShape); 156 static constexpr int kInputIntegerBits = 4; 157 158 const double input_real_multiplier = 159 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits)); 160 161 int32_t input_multiplier = 0; 162 int32_t input_left_shift = 0; 163 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier, 164 &input_left_shift)) { 165 return false; 166 } 167 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift); 168 169 NNTRACE_COMP_SWITCH("optimized_ops::Tanh"); 170 tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset, 171 input_range_radius, input_multiplier, input_left_shift, outputData, 172 convertShapeToTflshape(outputShape)); 173 174 return true; 175 } 176 177 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData, 178 const Shape& outputShape) { 179 NNTRACE_TRANS("logisticQuant8"); 180 if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) { 181 LOG(ERROR) << "incorrect scale / offset for output"; 182 return false; 183 } 184 185 int numElements = getNumberOfElements(inputShape); 186 static constexpr int kInputIntegerBits = 4; 187 188 const double input_real_multiplier = 189 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits)); 190 191 int32_t input_multiplier = 0; 192 int32_t input_left_shift = 0; 193 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier, 194 &input_left_shift)) { 195 return false; 196 } 197 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift); 198 199 NNTRACE_COMP_SWITCH("optimized_ops::Logistic"); 200 tflite::optimized_ops::Logistic( 201 inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius, 202 input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape)); 203 204 return true; 205 } 206 207 } // namespace 208 209 bool validate(OperationType opType, const IOperationValidationContext* context) { 210 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); 211 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); 212 auto inputType = context->getInputType(kInputTensor); 213 if (inputType == OperandType::TENSOR_FLOAT32) { 214 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0)); 215 } else if (inputType == OperandType::TENSOR_FLOAT16) { 216 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); 217 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { 218 if (opType == OperationType::TANH) { 219 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); 220 } else { 221 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0)); 222 } 223 } else { 224 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType); 225 } 226 return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); 227 } 228 229 bool prepare(OperationType opType, IOperationExecutionContext* context) { 230 Shape input = context->getInputShape(kInputTensor); 231 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); 232 Shape output = input; 233 if (input.type == OperandType::TENSOR_QUANT8_ASYMM) { 234 switch (opType) { 235 case OperationType::RELU: 236 case OperationType::RELU1: 237 case OperationType::RELU6: 238 break; 239 case OperationType::LOGISTIC: 240 output.scale = 1.f / 256; 241 output.offset = 0; 242 break; 243 case OperationType::TANH: 244 output.scale = 1.f / 128; 245 output.offset = 128; 246 break; 247 default: 248 NN_RET_CHECK_FAIL() << "Unsupported operation type"; 249 } 250 } 251 return context->setOutputShape(kOutputTensor, output); 252 } 253 254 bool executeRelu(IOperationExecutionContext* context) { 255 // Bypass execution in the case of zero-sized input. 256 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 257 switch (context->getInputType(kInputTensor)) { 258 case OperandType::TENSOR_FLOAT16: 259 return reluFloat(context->getInputBuffer<_Float16>(kInputTensor), 260 context->getInputShape(kInputTensor), 261 context->getOutputBuffer<_Float16>(kOutputTensor), 262 context->getOutputShape(kOutputTensor)); 263 case OperandType::TENSOR_FLOAT32: 264 return reluFloat(context->getInputBuffer<float>(kInputTensor), 265 context->getInputShape(kInputTensor), 266 context->getOutputBuffer<float>(kOutputTensor), 267 context->getOutputShape(kOutputTensor)); 268 case OperandType::TENSOR_QUANT8_ASYMM: 269 return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor), 270 context->getInputShape(kInputTensor), 271 context->getOutputBuffer<uint8_t>(kOutputTensor), 272 context->getOutputShape(kOutputTensor)); 273 default: 274 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU"; 275 } 276 } 277 278 bool executeRelu1(IOperationExecutionContext* context) { 279 // Bypass execution in the case of zero-sized input. 280 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 281 switch (context->getInputType(kInputTensor)) { 282 case OperandType::TENSOR_FLOAT16: 283 return relu1Float(context->getInputBuffer<_Float16>(kInputTensor), 284 context->getInputShape(kInputTensor), 285 context->getOutputBuffer<_Float16>(kOutputTensor), 286 context->getOutputShape(kOutputTensor)); 287 case OperandType::TENSOR_FLOAT32: 288 return relu1Float(context->getInputBuffer<float>(kInputTensor), 289 context->getInputShape(kInputTensor), 290 context->getOutputBuffer<float>(kOutputTensor), 291 context->getOutputShape(kOutputTensor)); 292 case OperandType::TENSOR_QUANT8_ASYMM: 293 return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor), 294 context->getInputShape(kInputTensor), 295 context->getOutputBuffer<uint8_t>(kOutputTensor), 296 context->getOutputShape(kOutputTensor)); 297 default: 298 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1"; 299 } 300 } 301 302 bool executeRelu6(IOperationExecutionContext* context) { 303 // Bypass execution in the case of zero-sized input. 304 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 305 switch (context->getInputType(kInputTensor)) { 306 case OperandType::TENSOR_FLOAT16: 307 return relu6Float(context->getInputBuffer<_Float16>(kInputTensor), 308 context->getInputShape(kInputTensor), 309 context->getOutputBuffer<_Float16>(kOutputTensor), 310 context->getOutputShape(kOutputTensor)); 311 case OperandType::TENSOR_FLOAT32: 312 return relu6Float(context->getInputBuffer<float>(kInputTensor), 313 context->getInputShape(kInputTensor), 314 context->getOutputBuffer<float>(kOutputTensor), 315 context->getOutputShape(kOutputTensor)); 316 case OperandType::TENSOR_QUANT8_ASYMM: 317 return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor), 318 context->getInputShape(kInputTensor), 319 context->getOutputBuffer<uint8_t>(kOutputTensor), 320 context->getOutputShape(kOutputTensor)); 321 default: 322 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6"; 323 } 324 } 325 326 bool executeLogistic(IOperationExecutionContext* context) { 327 // Bypass execution in the case of zero-sized input. 328 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 329 switch (context->getInputType(kInputTensor)) { 330 case OperandType::TENSOR_FLOAT16: 331 return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor), 332 context->getInputShape(kInputTensor), 333 context->getOutputBuffer<_Float16>(kOutputTensor), 334 context->getOutputShape(kOutputTensor)); 335 case OperandType::TENSOR_FLOAT32: 336 return logisticFloat(context->getInputBuffer<float>(kInputTensor), 337 context->getInputShape(kInputTensor), 338 context->getOutputBuffer<float>(kOutputTensor), 339 context->getOutputShape(kOutputTensor)); 340 case OperandType::TENSOR_QUANT8_ASYMM: 341 return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor), 342 context->getInputShape(kInputTensor), 343 context->getOutputBuffer<uint8_t>(kOutputTensor), 344 context->getOutputShape(kOutputTensor)); 345 default: 346 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC"; 347 } 348 } 349 350 bool executeTanh(IOperationExecutionContext* context) { 351 // Bypass execution in the case of zero-sized input. 352 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; 353 switch (context->getInputType(kInputTensor)) { 354 case OperandType::TENSOR_FLOAT16: 355 return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor), 356 context->getInputShape(kInputTensor), 357 context->getOutputBuffer<_Float16>(kOutputTensor), 358 context->getOutputShape(kOutputTensor)); 359 case OperandType::TENSOR_FLOAT32: 360 return tanhFloat32(context->getInputBuffer<float>(kInputTensor), 361 context->getInputShape(kInputTensor), 362 context->getOutputBuffer<float>(kOutputTensor), 363 context->getOutputShape(kOutputTensor)); 364 case OperandType::TENSOR_QUANT8_ASYMM: 365 return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor), 366 context->getInputShape(kInputTensor), 367 context->getOutputBuffer<uint8_t>(kOutputTensor), 368 context->getOutputShape(kOutputTensor)); 369 default: 370 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH"; 371 } 372 } 373 374 } // namespace activation 375 376 using std::placeholders::_1; 377 NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1), 378 std::bind(activation::prepare, OperationType::RELU, _1), 379 activation::executeRelu, .allowZeroSizedInput = true); 380 NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1), 381 std::bind(activation::prepare, OperationType::RELU1, _1), 382 activation::executeRelu1, .allowZeroSizedInput = true); 383 NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1), 384 std::bind(activation::prepare, OperationType::RELU6, _1), 385 activation::executeRelu6, .allowZeroSizedInput = true); 386 NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC", 387 std::bind(activation::validate, OperationType::LOGISTIC, _1), 388 std::bind(activation::prepare, OperationType::LOGISTIC, _1), 389 activation::executeLogistic, .allowZeroSizedInput = true); 390 NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1), 391 std::bind(activation::prepare, OperationType::TANH, _1), 392 activation::executeTanh, .allowZeroSizedInput = true); 393 394 } // namespace nn 395 } // namespace android 396