Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_H
     18 #define ANDROID_ML_NN_COMMON_OPERATIONS_H
     19 
     20 #include "operations/EmbeddingLookup.h"
     21 #include "operations/HashtableLookup.h"
     22 #include "operations/LSHProjection.h"
     23 #include "operations/LSTM.h"
     24 #include "operations/RNN.h"
     25 #include "operations/SVDF.h"
     26 
     27 #include <stddef.h>
     28 
     29 #include <cstdint>
     30 #include <vector>
     31 
     32 namespace android {
     33 namespace nn {
     34 
     35 struct Shape;
     36 
     37 bool addFloat32(const float* in1, const Shape& shape1,
     38                 const float* in2, const Shape& shape2,
     39                 int32_t activation,
     40                 float* out, const Shape& shapeOut);
     41 bool addQuant8(const uint8_t* in1, const Shape& shape1,
     42                const uint8_t* in2, const Shape& shape2,
     43                int32_t activation,
     44                uint8_t* out, const Shape& shapeOut);
     45 
     46 bool mulFloat32(const float* in1, const Shape& shape1,
     47                 const float* in2, const Shape& shape2,
     48                 int32_t activation,
     49                 float* out, const Shape& shapeOut);
     50 bool mulQuant8(const uint8_t* in1, const Shape& shape1,
     51                const uint8_t* in2, const Shape& shape2,
     52                int32_t activation,
     53                uint8_t* out, const Shape& shapeOut);
     54 
     55 bool floorFloat32(const float* inputData,
     56                   float* outputData,
     57                   const Shape& shape);
     58 
     59 bool dequantizeQuant8ToFloat32(const uint8_t* inputData,
     60                                float* outputData,
     61                                const Shape& shape);
     62 
     63 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape,
     64                           const float* filterData, const Shape& filterShape,
     65                           const float* biasData, const Shape& biasShape,
     66                           int32_t padding_left, int32_t padding_right,
     67                           int32_t padding_top, int32_t padding_bottom,
     68                           int32_t stride_width, int32_t stride_height,
     69                           int32_t depth_multiplier, int32_t activation,
     70                           float* outputData, const Shape& outputShape);
     71 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape,
     72                          const uint8_t* filterData, const Shape& filterShape,
     73                          const int32_t* biasData, const Shape& biasShape,
     74                          int32_t padding_left, int32_t padding_right,
     75                          int32_t padding_top, int32_t padding_bottom,
     76                          int32_t stride_width, int32_t stride_height,
     77                          int32_t depth_multiplier, int32_t activation,
     78                          uint8_t* outputData, const Shape& outputShape);
     79 
     80 bool convFloat32(const float* inputData, const Shape& inputShape,
     81                  const float* filterData, const Shape& filterShape,
     82                  const float* biasData, const Shape& biasShape,
     83                  int32_t padding_left, int32_t padding_right,
     84                  int32_t padding_top, int32_t padding_bottom,
     85                  int32_t stride_width, int32_t stride_height,
     86                  int32_t activation,
     87                  float* outputData, const Shape& outputShape);
     88 bool convQuant8(const uint8_t* inputData, const Shape& inputShape,
     89                 const uint8_t* filterData, const Shape& filterShape,
     90                 const int32_t* biasData, const Shape& biasShape,
     91                 int32_t padding_left, int32_t padding_right,
     92                 int32_t padding_top, int32_t padding_bottom,
     93                 int32_t stride_width, int32_t stride_height,
     94                 int32_t activation,
     95                 uint8_t* outputData, const Shape& outputShape);
     96 
     97 bool averagePoolFloat32(const float* inputData, const Shape& inputShape,
     98                         int32_t padding_left, int32_t padding_right,
     99                         int32_t padding_top, int32_t padding_bottom,
    100                         int32_t stride_width, int32_t stride_height,
    101                         int32_t filter_width, int32_t filter_height, int32_t activation,
    102                         float* outputData, const Shape& outputShape);
    103 bool averagePoolQuant8(const uint8_t* inputData, const Shape& inputShape,
    104                        int32_t padding_left, int32_t padding_right,
    105                        int32_t padding_top, int32_t padding_bottom,
    106                        int32_t stride_width, int32_t stride_height,
    107                        int32_t filter_width, int32_t filter_height, int32_t activation,
    108                        uint8_t* outputData, const Shape& outputShape);
    109 bool l2PoolFloat32(const float* inputData, const Shape& inputShape,
    110                    int32_t padding_left, int32_t padding_right,
    111                    int32_t padding_top, int32_t padding_bottom,
    112                    int32_t stride_width, int32_t stride_height,
    113                    int32_t filter_width, int32_t filter_height, int32_t activation,
    114                    float* outputData, const Shape& outputShape);
    115 bool maxPoolFloat32(const float* inputData, const Shape& inputShape,
    116                     int32_t padding_left, int32_t padding_right,
    117                     int32_t padding_top, int32_t padding_bottom,
    118                     int32_t stride_width, int32_t stride_height,
    119                     int32_t filter_width, int32_t filter_height, int32_t activation,
    120                     float* outputData, const Shape& outputShape);
    121 bool maxPoolQuant8(const uint8_t* inputData, const Shape& inputShape,
    122                    int32_t padding_left, int32_t padding_right,
    123                    int32_t padding_top, int32_t padding_bottom,
    124                    int32_t stride_width, int32_t stride_height,
    125                    int32_t filter_width, int32_t filter_height, int32_t activation,
    126                    uint8_t* outputData, const Shape& outputShape);
    127 
    128 bool reluFloat32(const float* inputData, const Shape& inputShape,
    129                  float* outputData, const Shape& outputShape);
    130 bool relu1Float32(const float* inputData, const Shape& inputShape,
    131                   float* outputData, const Shape& outputShape);
    132 bool relu6Float32(const float* inputData, const Shape& inputShape,
    133                   float* outputData, const Shape& outputShape);
    134 bool tanhFloat32(const float* inputData, const Shape& inputShape,
    135                  float* outputData, const Shape& outputShape);
    136 bool logisticFloat32(const float* inputData, const Shape& inputShape,
    137                      float* outputData, const Shape& outputShape);
    138 bool softmaxFloat32(const float* inputData, const Shape& inputShape,
    139                     const float beta,
    140                     float* outputData, const Shape& outputShape);
    141 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape,
    142                 uint8_t* outputData, const Shape& outputShape);
    143 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape,
    144                  uint8_t* outputData, const Shape& outputShape);
    145 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape,
    146                  uint8_t* outputData, const Shape& outputShape);
    147 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape,
    148                     uint8_t* outputData, const Shape& outputShape);
    149 bool softmaxQuant8(const uint8_t* inputData, const Shape& inputShape,
    150                    const float beta,
    151                    uint8_t* outputData, const Shape& outputShape);
    152 
    153 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
    154                            const float* weights, const Shape& weightsShape,
    155                            const float* biasData, const Shape& biasShape,
    156                            int32_t activation,
    157                            float* outputData, const Shape& outputShape);
    158 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
    159                           const uint8_t* weights, const Shape& weightsShape,
    160                           const int32_t* biasData, const Shape& biasShape,
    161                           int32_t activation,
    162                           uint8_t* outputData, const Shape& outputShape);
    163 
    164 bool concatenationFloat32(const std::vector<const float*>& inputDataPtrs,
    165                           const std::vector<Shape>& inputShapes, int32_t axis,
    166                           float* outputData, const Shape& outputShape);
    167 bool concatenationQuant8(const std::vector<const uint8_t*>& inputDataPtrs,
    168                          const std::vector<Shape>& inputShapes, int32_t axis,
    169                          uint8_t* outputData, const Shape& outputShape);
    170 
    171 bool l2normFloat32(const float* inputData, const Shape& inputShape,
    172                    float* outputData, const Shape& outputShape);
    173 bool l2normQuant8(const uint8_t* inputData, const Shape& inputShape,
    174                   uint8_t* outputData, const Shape& outputShape);
    175 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape,
    176                               int32_t radius, float bias, float alpha, float beta,
    177                               float* outputData, const Shape& outputShape);
    178 
    179 bool reshapeGeneric(const void* inputData, const Shape& inputShape,
    180                     void* outputData, const Shape& outputShape);
    181 
    182 bool resizeBilinearFloat32(const float* inputData,
    183                            const Shape& inputShape,
    184                            float* outputData,
    185                            const Shape& outputShape);
    186 
    187 bool depthToSpaceGeneric(const uint8_t* inputData, const Shape& inputShape,
    188                          int32_t blockSize,
    189                          uint8_t* outputData, const Shape& outputShape);
    190 
    191 bool spaceToDepthGeneric(const uint8_t* inputData, const Shape& inputShape,
    192                          int32_t blockSize,
    193                          uint8_t* outputData, const Shape& outputShape);
    194 
    195 bool padGeneric(const uint8_t* inputData, const Shape& inputShape,
    196                 const int32_t* paddings,
    197                 uint8_t* outputData, const Shape& outputShape);
    198 
    199 bool batchToSpaceGeneric(const uint8_t* inputData, const Shape& inputShape,
    200                          const int32_t* blockSize,
    201                          uint8_t* outputData, const Shape& outputShape);
    202 
    203 bool spaceToBatchGeneric(const uint8_t* inputData, const Shape& inputShape,
    204                          const int32_t* blockSize,
    205                          const int32_t* padding, const Shape& paddingShape,
    206                          uint8_t* outputData, const Shape& outputShape);
    207 
    208 bool subFloat32(const float* in1, const Shape& shape1,
    209                 const float* in2, const Shape& shape2,
    210                 int32_t activation,
    211                 float* out, const Shape& shapeOut);
    212 
    213 bool squeezeGeneric(const void* inputData, const Shape& inputShape,
    214                     void* outputData, const Shape& outputShape);
    215 
    216 bool divFloat32(const float* in1, const Shape& shape1,
    217                 const float* in2, const Shape& shape2,
    218                 int32_t activation,
    219                 float* out, const Shape& shapeOut);
    220 
    221 bool transposeGeneric(const uint8_t* inputData, const Shape& inputShape,
    222                       const int32_t* perm, const Shape& permShape,
    223                       uint8_t* outputData, const Shape& outputShape);
    224 
    225 bool meanGeneric(const uint8_t* inputData, const Shape& inputShape,
    226                  const int32_t* axis, const Shape& axisShape, bool keepDims,
    227                  uint8_t* outputData, const Shape& outputShape);
    228 
    229 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape,
    230                          const int32_t* beginData, const int32_t* endData,
    231                          const int32_t* stridesData,
    232                          int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask,
    233                          uint8_t* outputData, const Shape& outputShape);
    234 } // namespace nn
    235 } // namespace android
    236 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_H
    237