Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Contains the implementation of the operations.
     18 
     19 #define LOG_TAG "Operations"
     20 
     21 #include "Operations.h"
     22 #include "CpuOperationUtils.h"
     23 
     24 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
     25 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
     26 
     27 namespace android {
     28 namespace nn {
     29 
     30 #define ANDROID_NN_MACRO_DISPATCH(macro)                                    \
     31     switch (activation) {                                                   \
     32         case (int32_t) FusedActivationFunc::NONE:                           \
     33             macro(kNone);                                                   \
     34             break;                                                          \
     35         case (int32_t) FusedActivationFunc::RELU:                           \
     36             macro(kRelu);                                                   \
     37             break;                                                          \
     38         case (int32_t) FusedActivationFunc::RELU1:                          \
     39             macro(kRelu1);                                                  \
     40             break;                                                          \
     41         case (int32_t) FusedActivationFunc::RELU6:                          \
     42             macro(kRelu6);                                                  \
     43             break;                                                          \
     44         default:                                                            \
     45             LOG(ERROR) << "Unsupported fused activation function type";     \
     46             return false;                                                   \
     47     }
     48 
     49 bool addFloat32(const float* in1, const Shape& shape1,
     50                 const float* in2, const Shape& shape2,
     51                 int32_t activation,
     52                 float* out, const Shape& shapeOut) {
     53     bool needBroadcast = !SameShape(shape1, shape2);
     54 
     55     if (needBroadcast) {
     56         #define ANDROID_NN_BROADCAST_ADD(activation)                                              \
     57             tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
     58                     in1, convertShapeToDims(shape1),                                              \
     59                     in2, convertShapeToDims(shape2),                                              \
     60                     out, convertShapeToDims(shapeOut))
     61 
     62         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
     63         #undef ANDROID_NN_BROADCAST_ADD
     64     } else {
     65         float output_activation_min, output_activation_max;
     66         CalculateActivationRangeFloat(activation, &output_activation_min,
     67                                       &output_activation_max);
     68 
     69         tflite::optimized_ops::Add(
     70                 in1, convertShapeToDims(shape1),
     71                 in2, convertShapeToDims(shape2),
     72                 output_activation_min, output_activation_max,
     73                 out, convertShapeToDims(shapeOut));
     74     }
     75 
     76     return true;
     77 }
     78 
     79 bool addQuant8(const uint8_t* in1, const Shape& shape1,
     80                const uint8_t* in2, const Shape& shape2,
     81                int32_t activation,
     82                uint8_t* out, const Shape& shapeOut) {
     83     bool needBroadcast = !SameShape(shape1, shape2);
     84 
     85     const int32_t input1_offset = -shape1.offset;
     86     const int32_t input2_offset = -shape2.offset;
     87     const int32_t output_offset = shapeOut.offset;
     88     const int left_shift = 20;
     89     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
     90     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
     91     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
     92     const double real_output_multiplier =
     93             twice_max_input_scale /
     94             ((1 << left_shift) * shapeOut.scale);
     95 
     96     int32_t input1_multiplier;
     97     int32_t input1_shift;
     98     if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier,
     99                                           &input1_multiplier, &input1_shift)) {
    100         return false;
    101     }
    102     int32_t input2_multiplier;
    103     int32_t input2_shift;
    104     if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier,
    105                                           &input2_multiplier, &input2_shift)) {
    106         return false;
    107     }
    108     int32_t output_multiplier;
    109     int32_t output_shift;
    110     if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier,
    111                                           &output_multiplier, &output_shift)) {
    112         return false;
    113     }
    114     int32_t output_activation_min;
    115     int32_t output_activation_max;
    116     CalculateActivationRangeUint8(activation, shapeOut,
    117                                   &output_activation_min,
    118                                   &output_activation_max);
    119 
    120     if (needBroadcast) {
    121         tflite::optimized_ops::BroadcastAdd(
    122                 left_shift,
    123                 in1, convertShapeToDims(shape1),
    124                 input1_offset, input1_multiplier, input1_shift,
    125                 in2, convertShapeToDims(shape2),
    126                 input2_offset, input2_multiplier, input2_shift,
    127                 output_offset, output_multiplier, output_shift,
    128                 output_activation_min, output_activation_max,
    129                 out, convertShapeToDims(shapeOut));
    130     } else {
    131         #define ANDROID_NN_NORMAL_ADD(activation)                                        \
    132             tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \
    133                     left_shift,                                                          \
    134                     in1, convertShapeToDims(shape1),                                     \
    135                     input1_offset, input1_multiplier, input1_shift,                      \
    136                     in2, convertShapeToDims(shape2),                                     \
    137                     input2_offset, input2_multiplier, input2_shift,                      \
    138                     output_offset, output_multiplier, output_shift,                      \
    139                     output_activation_min, output_activation_max,                        \
    140                     out, convertShapeToDims(shapeOut))
    141 
    142         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_NORMAL_ADD)
    143         #undef ANDROID_NN_NORMAL_ADD
    144     }
    145 
    146     return true;
    147 }
    148 
    149 bool mulFloat32(const float* in1, const Shape& shape1,
    150                 const float* in2, const Shape& shape2,
    151                 int32_t activation,
    152                 float* out, const Shape& shapeOut) {
    153     bool needBroadcast = !SameShape(shape1, shape2);
    154 
    155     if (needBroadcast) {
    156     #define ANDROID_NN_BROADCAST_MUL(activation)                                              \
    157         tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
    158                 in1, convertShapeToDims(shape1),                                              \
    159                 in2, convertShapeToDims(shape2),                                              \
    160                 out, convertShapeToDims(shapeOut))
    161 
    162         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
    163         #undef ANDROID_NN_BROADCAST_MUL
    164     } else {
    165         float output_activation_min, output_activation_max;
    166         CalculateActivationRangeFloat(activation, &output_activation_min,
    167                                       &output_activation_max);
    168 
    169         tflite::optimized_ops::Mul(
    170                 in1, convertShapeToDims(shape1),
    171                 in2, convertShapeToDims(shape2),
    172                 output_activation_min, output_activation_max,
    173                 out, convertShapeToDims(shapeOut));
    174     }
    175 
    176     return true;
    177 }
    178 
    179 bool mulQuant8(const uint8_t* in1, const Shape& shape1,
    180                const uint8_t* in2, const Shape& shape2,
    181                int32_t activation,
    182                uint8_t* out, const Shape& shapeOut) {
    183     const int32_t input1_offset = -shape1.offset;
    184     const int32_t input2_offset = -shape2.offset;
    185     const int32_t output_offset = shapeOut.offset;
    186     const double input_product_scale = shape1.scale * shape2.scale;
    187     const double real_multiplier = input_product_scale / shapeOut.scale;
    188     int32 output_multiplier;
    189     int output_shift;
    190     if (!QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
    191                                           &output_shift)) {
    192         return false;
    193     }
    194     int32_t output_activation_min;
    195     int32_t output_activation_max;
    196     CalculateActivationRangeUint8(activation, shapeOut,
    197                                   &output_activation_min,
    198                                   &output_activation_max);
    199 
    200     // Use BROADCAST version to handle the normal case.
    201     tflite::optimized_ops::BroadcastMul(
    202                 in1, convertShapeToDims(shape1), input1_offset,
    203                 in2, convertShapeToDims(shape2), input2_offset,
    204                 output_offset, output_multiplier, output_shift,
    205                 output_activation_min, output_activation_max,
    206                 out, convertShapeToDims(shapeOut));
    207 
    208     return true;
    209 }
    210 
    211 bool floorFloat32(const float* inputData,
    212                   float* outputData,
    213                   const Shape& shape) {
    214     tflite::Dims<4> dim = convertShapeToDims(shape);
    215     tflite::optimized_ops::Floor(inputData, dim, outputData, dim);
    216     return true;
    217 }
    218 
    219 bool dequantizeQuant8ToFloat32(const uint8_t* inputData,
    220                                float* outputData,
    221                                const Shape& shape) {
    222     tflite::Dims<4> dim = convertShapeToDims(shape);
    223     tflite::optimized_ops::Dequantize(inputData, dim,
    224                                       shape.offset, shape.scale,
    225                                       outputData, dim);
    226     return true;
    227 }
    228 
    229 bool subFloat32(const float* in1, const Shape& shape1,
    230                 const float* in2, const Shape& shape2,
    231                 int32_t activation,
    232                 float* out, const Shape& shapeOut) {
    233     float output_activation_min, output_activation_max;
    234     CalculateActivationRangeFloat(activation, &output_activation_min,
    235                                   &output_activation_max);
    236 
    237     bool needBroadcast = !SameShape(shape1, shape2);
    238     if (needBroadcast) {
    239         tflite::optimized_ops::BroadcastSub(
    240                 in1, convertShapeToDims(shape1),
    241                 in2, convertShapeToDims(shape2),
    242                 output_activation_min, output_activation_max,
    243                 out, convertShapeToDims(shapeOut));
    244     } else {
    245         tflite::optimized_ops::Sub(
    246                 in1, convertShapeToDims(shape1),
    247                 in2, convertShapeToDims(shape2),
    248                 output_activation_min, output_activation_max,
    249                 out, convertShapeToDims(shapeOut));
    250     }
    251     return true;
    252 }
    253 
    254 bool divFloat32(const float* in1, const Shape& shape1,
    255                 const float* in2, const Shape& shape2,
    256                 int32_t activation,
    257                 float* out, const Shape& shapeOut) {
    258     float output_activation_min, output_activation_max;
    259     CalculateActivationRangeFloat(activation, &output_activation_min,
    260                                   &output_activation_max);
    261 
    262     bool needBroadcast = !SameShape(shape1, shape2);
    263     if (needBroadcast) {
    264         tflite::optimized_ops::BroadcastDiv(
    265                 in1, convertShapeToDims(shape1),
    266                 in2, convertShapeToDims(shape2),
    267                 output_activation_min, output_activation_max,
    268                 out, convertShapeToDims(shapeOut));
    269     } else {
    270         tflite::optimized_ops::Div(
    271                 in1, convertShapeToDims(shape1),
    272                 in2, convertShapeToDims(shape2),
    273                 output_activation_min, output_activation_max,
    274                 out, convertShapeToDims(shapeOut));
    275     }
    276     return true;
    277 }
    278 
    279 bool meanGeneric(const uint8_t* inputData, const Shape& inputShape,
    280                  const int32_t* axis, const Shape& axisShape, bool keepDims,
    281                  uint8_t* outputData, const Shape& outputShape) {
    282     // Creates a temp index to iterate through input data.
    283     int32_t* scratchBuffer = new int32_t[getNumberOfDimensions(inputShape)];
    284 
    285     // Creates a temp tensor to store resolved axis given input data.
    286     int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0));
    287     int32_t* resolvedAxis = new int32_t[axisSize];
    288 
    289     bool result = true;
    290     if (inputShape.type == OperandType::TENSOR_FLOAT32) {
    291         tflite::reference_ops::Mean<float>(
    292                 const_cast<float*>(reinterpret_cast<const float*>(inputData)),
    293                 reinterpret_cast<const int*>(inputShape.dimensions.data()),
    294                 getNumberOfDimensions(inputShape),
    295                 reinterpret_cast<float*>(outputData),
    296                 reinterpret_cast<const int*>(outputShape.dimensions.data()),
    297                 getNumberOfDimensions(outputShape),
    298                 axis, axisSize, keepDims, scratchBuffer, resolvedAxis);
    299     } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
    300         tflite::reference_ops::Mean<uint8_t>(
    301                 const_cast<uint8_t*>(inputData),
    302                 reinterpret_cast<const int*>(inputShape.dimensions.data()),
    303                 getNumberOfDimensions(inputShape),
    304                 outputData,
    305                 reinterpret_cast<const int*>(outputShape.dimensions.data()),
    306                 getNumberOfDimensions(outputShape),
    307                 axis, axisSize, keepDims, scratchBuffer, resolvedAxis);
    308     } else {
    309         LOG(ERROR) << "Unsupported data type";
    310         result = false;
    311     }
    312     delete[] scratchBuffer;
    313     delete[] resolvedAxis;
    314     return result;
    315 }
    316 } // namespace nn
    317 } // namespace android
    318