1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_H 18 #define ANDROID_ML_NN_COMMON_OPERATIONS_H 19 20 #include "operations/BidirectionalSequenceLSTM.h" 21 #include "operations/Cast.h" 22 #include "operations/EmbeddingLookup.h" 23 #include "operations/ExpandDims.h" 24 #include "operations/HashtableLookup.h" 25 #include "operations/LSHProjection.h" 26 #include "operations/LSTM.h" 27 #include "operations/MaximumMinimum.h" 28 #include "operations/Multinomial.h" 29 #include "operations/Pow.h" 30 #include "operations/QuantizedLSTM.h" 31 #include "operations/RNN.h" 32 #include "operations/SVDF.h" 33 #include "operations/Tile.h" 34 #include "operations/TopK_V2.h" 35 36 #include <stddef.h> 37 38 #include <cstdint> 39 #include <vector> 40 41 namespace android { 42 namespace nn { 43 44 struct Shape; 45 46 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape); 47 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape); 48 49 bool depthwiseConvFloat16(const _Float16* inputData, const Shape& inputShape, 50 const _Float16* filterData, const Shape& filterShape, 51 const _Float16* biasData, const Shape& biasShape, int32_t paddingLeft, 52 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 53 int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor, 54 int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation, 55 _Float16* outputData, const Shape& outputShape); 56 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData, 57 const Shape& filterShape, const float* biasData, const Shape& biasShape, 58 int32_t paddingLeft, int32_t paddingRight, int32_t paddingTop, 59 int32_t paddingBottom, int32_t strideWidth, int32_t strideHeight, 60 int32_t dilationWidthFactor, int32_t dilationHeightFactor, 61 int32_t depthMultiplier, int32_t activation, float* outputData, 62 const Shape& outputShape); 63 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape, 64 const uint8_t* filterData, const Shape& filterShape, 65 const int32_t* biasData, const Shape& biasShape, int32_t paddingLeft, 66 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 67 int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor, 68 int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation, 69 uint8_t* outputData, const Shape& outputShape); 70 bool depthwiseConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape, 71 const int8_t* filterData, const Shape& filterShape, 72 const float* filterScales, const int32_t* biasData, 73 const Shape& biasShape, int32_t paddingLeft, 74 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 75 int32_t strideWidth, int32_t strideHeight, 76 int32_t dilationWidthFactor, int32_t dilationHeightFactor, 77 int32_t depthMultiplier, int32_t activation, uint8_t* outputData, 78 const Shape& outputShape); 79 80 bool localResponseNormFloat16(const _Float16* inputData, const Shape& inputShape, int32_t radius, 81 float bias, float alpha, float beta, int32_t axis, 82 _Float16* outputData, const Shape& outputShape); 83 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape, int32_t radius, 84 float bias, float alpha, float beta, int32_t axis, float* outputData, 85 const Shape& outputShape); 86 87 bool copyData(const void* inputData, const Shape& inputShape, void* outputData, 88 const Shape& outputShape); 89 90 template <typename T> 91 bool depthToSpaceGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize, 92 T* outputData, const Shape& outputShape); 93 template <typename T> 94 bool spaceToDepthGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize, 95 T* outputData, const Shape& outputShape); 96 97 template <typename T> 98 bool padGeneric(const T* inputData, const Shape& inputShape, const int32_t* paddings, T pad_value, 99 T* outputData, const Shape& outputShape); 100 101 template <typename T> 102 bool batchToSpaceGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize, 103 T* outputData, const Shape& outputShape); 104 105 template <typename T> 106 bool spaceToBatchGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize, 107 const int32_t* padding, const Shape& paddingShape, T* outputData, 108 const Shape& outputShape); 109 110 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis, 111 const Shape& axisShape, bool keepDims, _Float16* outputData, 112 const Shape& outputShape); 113 template <typename T, typename U> 114 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, 115 bool keepDims, T* outputData, const Shape& outputShape); 116 117 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape, 118 const int32_t* beginData, const int32_t* endData, 119 const int32_t* stridesData, int32_t beginMask, int32_t endMask, 120 int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape); 121 122 bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t axis, 123 bool isArgMin, uint8_t* outputData, const Shape& outputShape); 124 125 bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis, 126 const std::vector<_Float16*>* outputDataPtrs, 127 const std::vector<Shape>& outputShapes); 128 129 bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis, 130 const std::vector<float*>* outputDataPtrs, 131 const std::vector<Shape>& outputShapes); 132 133 bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis, 134 const std::vector<int32_t*>* outputDataPtrs, 135 const std::vector<Shape>& outputShapes); 136 137 bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis, 138 const std::vector<uint8_t*>* outputDataPtrs, 139 const std::vector<Shape>& outputShapes); 140 141 bool groupedConvFloat16(const _Float16* inputData, const Shape& inputShape, 142 const _Float16* filterData, const Shape& filterShape, 143 const _Float16* biasData, const Shape& biasShape, int32_t numGroups, 144 int32_t padding_left, int32_t padding_right, int32_t padding_top, 145 int32_t padding_bottom, int32_t stride_width, int32_t stride_height, 146 int32_t activation, _Float16* outputData, const Shape& outputShape); 147 148 bool groupedConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData, 149 const Shape& filterShape, const float* biasData, const Shape& biasShape, 150 int32_t numGroups, int32_t padding_left, int32_t padding_right, 151 int32_t padding_top, int32_t padding_bottom, int32_t stride_width, 152 int32_t stride_height, int32_t activation, float* outputData, 153 const Shape& outputShape); 154 155 bool groupedConvQuant8(const uint8_t* inputData, const Shape& inputShape, const uint8_t* filterData, 156 const Shape& filterShape, const int32_t* biasData, const Shape& biasShape, 157 int32_t numGroups, int32_t padding_left, int32_t padding_right, 158 int32_t padding_top, int32_t padding_bottom, int32_t stride_width, 159 int32_t stride_height, int32_t activation, uint8_t* outputData, 160 const Shape& outputShape); 161 162 bool groupedConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape, 163 const int8_t* filterData, const Shape& filterShape, 164 const float* filterScales, const int32_t* biasData, 165 const Shape& biasShape, int32_t padding_left, 166 int32_t padding_right, int32_t padding_top, int32_t padding_bottom, 167 int32_t stride_width, int32_t stride_height, int32_t numGroups, 168 int32_t activation, uint8_t* outputData, const Shape& outputShape); 169 170 bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups, 171 int32_t axis, uint8_t* outputData, const Shape& outputShape); 172 } // namespace nn 173 } // namespace android 174 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_H 175