Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_H
     18 #define ANDROID_ML_NN_COMMON_OPERATIONS_H
     19 
     20 #include "operations/BidirectionalSequenceLSTM.h"
     21 #include "operations/Cast.h"
     22 #include "operations/EmbeddingLookup.h"
     23 #include "operations/ExpandDims.h"
     24 #include "operations/HashtableLookup.h"
     25 #include "operations/LSHProjection.h"
     26 #include "operations/LSTM.h"
     27 #include "operations/MaximumMinimum.h"
     28 #include "operations/Multinomial.h"
     29 #include "operations/Pow.h"
     30 #include "operations/QuantizedLSTM.h"
     31 #include "operations/RNN.h"
     32 #include "operations/SVDF.h"
     33 #include "operations/Tile.h"
     34 #include "operations/TopK_V2.h"
     35 
     36 #include <stddef.h>
     37 
     38 #include <cstdint>
     39 #include <vector>
     40 
     41 namespace android {
     42 namespace nn {
     43 
     44 struct Shape;
     45 
     46 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape);
     47 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape);
     48 
     49 bool depthwiseConvFloat16(const _Float16* inputData, const Shape& inputShape,
     50                           const _Float16* filterData, const Shape& filterShape,
     51                           const _Float16* biasData, const Shape& biasShape, int32_t paddingLeft,
     52                           int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
     53                           int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor,
     54                           int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation,
     55                           _Float16* outputData, const Shape& outputShape);
     56 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData,
     57                           const Shape& filterShape, const float* biasData, const Shape& biasShape,
     58                           int32_t paddingLeft, int32_t paddingRight, int32_t paddingTop,
     59                           int32_t paddingBottom, int32_t strideWidth, int32_t strideHeight,
     60                           int32_t dilationWidthFactor, int32_t dilationHeightFactor,
     61                           int32_t depthMultiplier, int32_t activation, float* outputData,
     62                           const Shape& outputShape);
     63 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape,
     64                          const uint8_t* filterData, const Shape& filterShape,
     65                          const int32_t* biasData, const Shape& biasShape, int32_t paddingLeft,
     66                          int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
     67                          int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor,
     68                          int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation,
     69                          uint8_t* outputData, const Shape& outputShape);
     70 bool depthwiseConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape,
     71                                    const int8_t* filterData, const Shape& filterShape,
     72                                    const float* filterScales, const int32_t* biasData,
     73                                    const Shape& biasShape, int32_t paddingLeft,
     74                                    int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
     75                                    int32_t strideWidth, int32_t strideHeight,
     76                                    int32_t dilationWidthFactor, int32_t dilationHeightFactor,
     77                                    int32_t depthMultiplier, int32_t activation, uint8_t* outputData,
     78                                    const Shape& outputShape);
     79 
     80 bool localResponseNormFloat16(const _Float16* inputData, const Shape& inputShape, int32_t radius,
     81                               float bias, float alpha, float beta, int32_t axis,
     82                               _Float16* outputData, const Shape& outputShape);
     83 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape, int32_t radius,
     84                               float bias, float alpha, float beta, int32_t axis, float* outputData,
     85                               const Shape& outputShape);
     86 
     87 bool copyData(const void* inputData, const Shape& inputShape, void* outputData,
     88               const Shape& outputShape);
     89 
     90 template <typename T>
     91 bool depthToSpaceGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
     92                          T* outputData, const Shape& outputShape);
     93 template <typename T>
     94 bool spaceToDepthGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
     95                          T* outputData, const Shape& outputShape);
     96 
     97 template <typename T>
     98 bool padGeneric(const T* inputData, const Shape& inputShape, const int32_t* paddings, T pad_value,
     99                 T* outputData, const Shape& outputShape);
    100 
    101 template <typename T>
    102 bool batchToSpaceGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
    103                          T* outputData, const Shape& outputShape);
    104 
    105 template <typename T>
    106 bool spaceToBatchGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
    107                          const int32_t* padding, const Shape& paddingShape, T* outputData,
    108                          const Shape& outputShape);
    109 
    110 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis,
    111                  const Shape& axisShape, bool keepDims, _Float16* outputData,
    112                  const Shape& outputShape);
    113 template <typename T, typename U>
    114 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape,
    115                  bool keepDims, T* outputData, const Shape& outputShape);
    116 
    117 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape,
    118                          const int32_t* beginData, const int32_t* endData,
    119                          const int32_t* stridesData, int32_t beginMask, int32_t endMask,
    120                          int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape);
    121 
    122 bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t axis,
    123                       bool isArgMin, uint8_t* outputData, const Shape& outputShape);
    124 
    125 bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis,
    126                   const std::vector<_Float16*>* outputDataPtrs,
    127                   const std::vector<Shape>& outputShapes);
    128 
    129 bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis,
    130                   const std::vector<float*>* outputDataPtrs,
    131                   const std::vector<Shape>& outputShapes);
    132 
    133 bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis,
    134                 const std::vector<int32_t*>* outputDataPtrs,
    135                 const std::vector<Shape>& outputShapes);
    136 
    137 bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis,
    138                  const std::vector<uint8_t*>* outputDataPtrs,
    139                  const std::vector<Shape>& outputShapes);
    140 
    141 bool groupedConvFloat16(const _Float16* inputData, const Shape& inputShape,
    142                         const _Float16* filterData, const Shape& filterShape,
    143                         const _Float16* biasData, const Shape& biasShape, int32_t numGroups,
    144                         int32_t padding_left, int32_t padding_right, int32_t padding_top,
    145                         int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
    146                         int32_t activation, _Float16* outputData, const Shape& outputShape);
    147 
    148 bool groupedConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData,
    149                         const Shape& filterShape, const float* biasData, const Shape& biasShape,
    150                         int32_t numGroups, int32_t padding_left, int32_t padding_right,
    151                         int32_t padding_top, int32_t padding_bottom, int32_t stride_width,
    152                         int32_t stride_height, int32_t activation, float* outputData,
    153                         const Shape& outputShape);
    154 
    155 bool groupedConvQuant8(const uint8_t* inputData, const Shape& inputShape, const uint8_t* filterData,
    156                        const Shape& filterShape, const int32_t* biasData, const Shape& biasShape,
    157                        int32_t numGroups, int32_t padding_left, int32_t padding_right,
    158                        int32_t padding_top, int32_t padding_bottom, int32_t stride_width,
    159                        int32_t stride_height, int32_t activation, uint8_t* outputData,
    160                        const Shape& outputShape);
    161 
    162 bool groupedConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape,
    163                                  const int8_t* filterData, const Shape& filterShape,
    164                                  const float* filterScales, const int32_t* biasData,
    165                                  const Shape& biasShape, int32_t padding_left,
    166                                  int32_t padding_right, int32_t padding_top, int32_t padding_bottom,
    167                                  int32_t stride_width, int32_t stride_height, int32_t numGroups,
    168                                  int32_t activation, uint8_t* outputData, const Shape& outputShape);
    169 
    170 bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups,
    171                            int32_t axis, uint8_t* outputData, const Shape& outputShape);
    172 }  // namespace nn
    173 }  // namespace android
    174 #endif  // ANDROID_ML_NN_COMMON_OPERATIONS_H
    175