Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_UTILS_H
     18 #define ANDROID_ML_NN_COMMON_OPERATIONS_UTILS_H
     19 
     20 #include "Utils.h"
     21 
     22 #include <cstdint>
     23 #include <vector>
     24 
     25 // Macro to check if the input parameters for operation are valid or not.
     26 #define NN_CHECK(v)                                                     \
     27   do {                                                                  \
     28     if (!(v)) {                                                         \
     29       LOG(ERROR) << "NN_CHECK failed: "  << #v << "'\n";                \
     30       return false;                                                     \
     31     }                                                                   \
     32   } while(0);
     33 
     34 #define NN_CHECK_EQ(actual, expected)           \
     35   NN_CHECK((actual) == (expected))
     36 
     37 #define NN_OPS_CHECK NN_CHECK
     38 
     39 namespace android {
     40 namespace nn {
     41 
     42 enum PaddingScheme {
     43     kPaddingUnknown = 0,
     44     kPaddingSame = 1,
     45     kPaddingValid = 2,
     46 };
     47 
     48 // The type and dimensions of an operand.
     49 struct Shape {
     50     OperandType type;
     51     std::vector<uint32_t> dimensions;
     52     float scale;
     53     int32_t offset;
     54 };
     55 
     56 // Verifies that the two shapes are the same.
     57 bool SameShape(const Shape& in1, const Shape& in2);
     58 
     59 // Sets out to the same shape as in.
     60 bool SetShape(const Shape& in, Shape* out);
     61 
     62 // Return the total number of elements, i.e. all the dimensions multiplied
     63 // together. For a scalar, returns one.
     64 uint32_t getNumberOfElements(const Shape& shape);
     65 
     66 uint32_t getNumberOfDimensions(const Shape& shape);
     67 
     68 uint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx);
     69 
     70 inline uint32_t computeOutSize(uint32_t imageSize, uint32_t filterSize, uint32_t stride,
     71                                uint32_t paddingHead, uint32_t paddingTail) {
     72     return (imageSize - filterSize + stride + paddingHead + paddingTail) / stride;
     73 }
     74 
     75 __wur
     76 bool QuantizeMultiplierSmallerThanOne(double double_multiplier,
     77                                       int32_t* quantized_multiplier,
     78                                       int32_t* right_shift);
     79 
     80 __wur
     81 bool QuantizeMultiplierGreaterThanOne(double double_multiplier,
     82                                       int32_t* quantized_multiplier,
     83                                       int* left_shift);
     84 
     85 __wur
     86 bool GetQuantizedConvolutionMultipler(const Shape& inputShape,
     87                                       const Shape& filterShape,
     88                                       const Shape& biasShape,
     89                                       const Shape& outputShape,
     90                                       float* multiplier);
     91 
     92 void CalculateActivationRangeUint8(int32_t activation,
     93                                    const Shape& outputShape,
     94                                    int32_t* act_min,
     95                                    int32_t* act_max);
     96 
     97 int32_t CalculateInputRadius(int input_integer_bits, int input_left_shift);
     98 
     99 inline void calculateExplicitPadding(int32_t in_size, int32_t stride,
    100                                      int32_t filter_size, int32_t padding_implicit,
    101                                      int32_t* padding_head, int32_t* padding_tail) {
    102     *padding_head = 0;
    103     *padding_tail = 0;
    104 
    105     if (padding_implicit == kPaddingSame) {
    106         int32_t out_size = (in_size + stride - 1) / stride;
    107         int32_t tmp = (out_size - 1) * stride + filter_size;
    108         if (tmp > in_size) {
    109             *padding_head = (tmp - in_size) / 2;
    110             *padding_tail = (tmp - in_size) - *padding_head;
    111         }
    112     }
    113 }
    114 
    115 inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight,
    116                                       int32_t strideWidth, int32_t strideHeight,
    117                                       int32_t filterWidth, int32_t filterHeight,
    118                                       int32_t paddingLeft, int32_t paddingRight,
    119                                       int32_t paddingTop, int32_t paddingBottom) {
    120     if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && paddingBottom == 0) {
    121         return kPaddingValid;
    122     }
    123 
    124     int32_t expectedPaddingLeft, expectedPaddingRight;
    125     int32_t expectedPaddingTop, expectedPaddingBottom;
    126 
    127     calculateExplicitPadding(inWidth, strideWidth, filterWidth, kPaddingSame,
    128                              &expectedPaddingLeft, &expectedPaddingRight);
    129     calculateExplicitPadding(inHeight, strideHeight, filterHeight, kPaddingSame,
    130                              &expectedPaddingTop, &expectedPaddingBottom);
    131     if (expectedPaddingLeft == paddingLeft && expectedPaddingRight == paddingRight &&
    132         expectedPaddingTop == paddingTop && expectedPaddingBottom == paddingBottom) {
    133         return kPaddingSame;
    134     } else {
    135         return kPaddingUnknown;
    136     }
    137 }
    138 
    139 // Preparation functions for the corresponding ops
    140 bool addMulPrepare(const Shape& in1, const Shape& in2, Shape* out1);
    141 
    142 bool floorPrepare(const Shape& input, Shape* output);
    143 
    144 bool dequantizePrepare(const Shape& input, Shape* output);
    145 
    146 bool depthwiseConvPrepare(const Shape& input,
    147                           const Shape& filter,
    148                           const Shape& bias,
    149                           int32_t padding_left, int32_t padding_right,
    150                           int32_t padding_top, int32_t padding_bottom,
    151                           int32_t stride_width, int32_t stride_height,
    152                           Shape* output);
    153 
    154 bool convPrepare(const Shape& input,
    155                  const Shape& filter,
    156                  const Shape& bias,
    157                  int32_t padding_left, int32_t padding_right,
    158                  int32_t padding_top, int32_t padding_bottom,
    159                  int32_t stride_width, int32_t stride_height,
    160                  Shape* output);
    161 
    162 bool genericPoolingPrepare(const Shape& input,
    163                            int32_t padding_left, int32_t padding_right,
    164                            int32_t padding_top, int32_t padding_bottom,
    165                            int32_t stride_width, int32_t stride_height,
    166                            int32_t filter_width, int32_t filter_height,
    167                            Shape* output);
    168 
    169 bool genericActivationPrepare(const Shape& input, Shape* output);
    170 
    171 bool fullyConnectedPrepare(const Shape& input,
    172                            const Shape& weights,
    173                            const Shape& bias,
    174                            Shape* output);
    175 
    176 bool concatenationPrepare(const std::vector<Shape>& inputShapes,
    177                           int32_t axis,
    178                           Shape* output);
    179 
    180 bool genericNormalizationPrepare(const Shape& input, Shape* output);
    181 
    182 bool reshapePrepare(const Shape& input,
    183                     const int32_t* targetDims,
    184                     const int32_t targetDimsSize,
    185                     Shape* output);
    186 
    187 bool resizeBilinearPrepare(const Shape& input,
    188                            int32_t height,
    189                            int32_t width,
    190                            Shape* output);
    191 
    192 bool depthToSpacePrepare(const Shape& input,
    193                          int32_t blockSize,
    194                          Shape* output);
    195 
    196 bool spaceToDepthPrepare(const Shape& input,
    197                          int32_t blockSize,
    198                          Shape* output);
    199 
    200 bool embeddingLookupPrepare(const Shape &valueShape,
    201                             const Shape &lookupShape,
    202                             Shape *outputShape);
    203 
    204 bool hashtableLookupPrepare(const Shape &lookupShape,
    205                             const Shape &keyShape,
    206                             const Shape &valueShape,
    207                             Shape *outputShape,
    208                             Shape *hitShape);
    209 
    210 #define ANDROID_NN_MACRO_DISPATCH(macro)                                    \
    211     switch (activation) {                                                   \
    212         case (int32_t) FusedActivationFunc::NONE:                           \
    213             macro(kNone);                                                   \
    214             break;                                                          \
    215         case (int32_t) FusedActivationFunc::RELU:                           \
    216             macro(kRelu);                                                   \
    217             break;                                                          \
    218         case (int32_t) FusedActivationFunc::RELU1:                          \
    219             macro(kRelu1);                                                  \
    220             break;                                                          \
    221         case (int32_t) FusedActivationFunc::RELU6:                          \
    222             macro(kRelu6);                                                  \
    223             break;                                                          \
    224         default:                                                            \
    225             LOG(ERROR) << "Unsupported fused activation function type";     \
    226             return false;                                                   \
    227     }
    228 
    229 } // namespace nn
    230 } // namespace android
    231 
    232 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_UTILS_H
    233