1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_H 18 #define ANDROID_ML_NN_COMMON_OPERATIONS_H 19 20 #include "operations/EmbeddingLookup.h" 21 #include "operations/HashtableLookup.h" 22 #include "operations/LSHProjection.h" 23 #include "operations/LSTM.h" 24 #include "operations/RNN.h" 25 #include "operations/SVDF.h" 26 27 #include <stddef.h> 28 29 #include <cstdint> 30 #include <vector> 31 32 namespace android { 33 namespace nn { 34 35 struct Shape; 36 37 bool addFloat32(const float* in1, const Shape& shape1, 38 const float* in2, const Shape& shape2, 39 int32_t activation, 40 float* out, const Shape& shapeOut); 41 bool addQuant8(const uint8_t* in1, const Shape& shape1, 42 const uint8_t* in2, const Shape& shape2, 43 int32_t activation, 44 uint8_t* out, const Shape& shapeOut); 45 46 bool mulFloat32(const float* in1, const Shape& shape1, 47 const float* in2, const Shape& shape2, 48 int32_t activation, 49 float* out, const Shape& shapeOut); 50 bool mulQuant8(const uint8_t* in1, const Shape& shape1, 51 const uint8_t* in2, const Shape& shape2, 52 int32_t activation, 53 uint8_t* out, const Shape& shapeOut); 54 55 bool floorFloat32(const float* inputData, 56 float* outputData, 57 const Shape& shape); 58 59 bool dequantizeQuant8ToFloat32(const uint8_t* inputData, 60 float* outputData, 61 const Shape& shape); 62 63 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape, 64 const float* filterData, const Shape& filterShape, 65 const float* biasData, const Shape& biasShape, 66 int32_t padding_left, int32_t padding_right, 67 int32_t padding_top, int32_t padding_bottom, 68 int32_t stride_width, int32_t stride_height, 69 int32_t depth_multiplier, int32_t activation, 70 float* outputData, const Shape& outputShape); 71 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape, 72 const uint8_t* filterData, const Shape& filterShape, 73 const int32_t* biasData, const Shape& biasShape, 74 int32_t padding_left, int32_t padding_right, 75 int32_t padding_top, int32_t padding_bottom, 76 int32_t stride_width, int32_t stride_height, 77 int32_t depth_multiplier, int32_t activation, 78 uint8_t* outputData, const Shape& outputShape); 79 80 bool convFloat32(const float* inputData, const Shape& inputShape, 81 const float* filterData, const Shape& filterShape, 82 const float* biasData, const Shape& biasShape, 83 int32_t padding_left, int32_t padding_right, 84 int32_t padding_top, int32_t padding_bottom, 85 int32_t stride_width, int32_t stride_height, 86 int32_t activation, 87 float* outputData, const Shape& outputShape); 88 bool convQuant8(const uint8_t* inputData, const Shape& inputShape, 89 const uint8_t* filterData, const Shape& filterShape, 90 const int32_t* biasData, const Shape& biasShape, 91 int32_t padding_left, int32_t padding_right, 92 int32_t padding_top, int32_t padding_bottom, 93 int32_t stride_width, int32_t stride_height, 94 int32_t activation, 95 uint8_t* outputData, const Shape& outputShape); 96 97 bool averagePoolFloat32(const float* inputData, const Shape& inputShape, 98 int32_t padding_left, int32_t padding_right, 99 int32_t padding_top, int32_t padding_bottom, 100 int32_t stride_width, int32_t stride_height, 101 int32_t filter_width, int32_t filter_height, int32_t activation, 102 float* outputData, const Shape& outputShape); 103 bool averagePoolQuant8(const uint8_t* inputData, const Shape& inputShape, 104 int32_t padding_left, int32_t padding_right, 105 int32_t padding_top, int32_t padding_bottom, 106 int32_t stride_width, int32_t stride_height, 107 int32_t filter_width, int32_t filter_height, int32_t activation, 108 uint8_t* outputData, const Shape& outputShape); 109 bool l2PoolFloat32(const float* inputData, const Shape& inputShape, 110 int32_t padding_left, int32_t padding_right, 111 int32_t padding_top, int32_t padding_bottom, 112 int32_t stride_width, int32_t stride_height, 113 int32_t filter_width, int32_t filter_height, int32_t activation, 114 float* outputData, const Shape& outputShape); 115 bool maxPoolFloat32(const float* inputData, const Shape& inputShape, 116 int32_t padding_left, int32_t padding_right, 117 int32_t padding_top, int32_t padding_bottom, 118 int32_t stride_width, int32_t stride_height, 119 int32_t filter_width, int32_t filter_height, int32_t activation, 120 float* outputData, const Shape& outputShape); 121 bool maxPoolQuant8(const uint8_t* inputData, const Shape& inputShape, 122 int32_t padding_left, int32_t padding_right, 123 int32_t padding_top, int32_t padding_bottom, 124 int32_t stride_width, int32_t stride_height, 125 int32_t filter_width, int32_t filter_height, int32_t activation, 126 uint8_t* outputData, const Shape& outputShape); 127 128 bool reluFloat32(const float* inputData, const Shape& inputShape, 129 float* outputData, const Shape& outputShape); 130 bool relu1Float32(const float* inputData, const Shape& inputShape, 131 float* outputData, const Shape& outputShape); 132 bool relu6Float32(const float* inputData, const Shape& inputShape, 133 float* outputData, const Shape& outputShape); 134 bool tanhFloat32(const float* inputData, const Shape& inputShape, 135 float* outputData, const Shape& outputShape); 136 bool logisticFloat32(const float* inputData, const Shape& inputShape, 137 float* outputData, const Shape& outputShape); 138 bool softmaxFloat32(const float* inputData, const Shape& inputShape, 139 const float beta, 140 float* outputData, const Shape& outputShape); 141 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, 142 uint8_t* outputData, const Shape& outputShape); 143 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, 144 uint8_t* outputData, const Shape& outputShape); 145 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, 146 uint8_t* outputData, const Shape& outputShape); 147 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, 148 uint8_t* outputData, const Shape& outputShape); 149 bool softmaxQuant8(const uint8_t* inputData, const Shape& inputShape, 150 const float beta, 151 uint8_t* outputData, const Shape& outputShape); 152 153 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape, 154 const float* weights, const Shape& weightsShape, 155 const float* biasData, const Shape& biasShape, 156 int32_t activation, 157 float* outputData, const Shape& outputShape); 158 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape, 159 const uint8_t* weights, const Shape& weightsShape, 160 const int32_t* biasData, const Shape& biasShape, 161 int32_t activation, 162 uint8_t* outputData, const Shape& outputShape); 163 164 bool concatenationFloat32(const std::vector<const float*>& inputDataPtrs, 165 const std::vector<Shape>& inputShapes, int32_t axis, 166 float* outputData, const Shape& outputShape); 167 bool concatenationQuant8(const std::vector<const uint8_t*>& inputDataPtrs, 168 const std::vector<Shape>& inputShapes, int32_t axis, 169 uint8_t* outputData, const Shape& outputShape); 170 171 bool l2normFloat32(const float* inputData, const Shape& inputShape, 172 float* outputData, const Shape& outputShape); 173 bool l2normQuant8(const uint8_t* inputData, const Shape& inputShape, 174 uint8_t* outputData, const Shape& outputShape); 175 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape, 176 int32_t radius, float bias, float alpha, float beta, 177 float* outputData, const Shape& outputShape); 178 179 bool reshapeGeneric(const void* inputData, const Shape& inputShape, 180 void* outputData, const Shape& outputShape); 181 182 bool resizeBilinearFloat32(const float* inputData, 183 const Shape& inputShape, 184 float* outputData, 185 const Shape& outputShape); 186 187 bool depthToSpaceGeneric(const uint8_t* inputData, const Shape& inputShape, 188 int32_t blockSize, 189 uint8_t* outputData, const Shape& outputShape); 190 191 bool spaceToDepthGeneric(const uint8_t* inputData, const Shape& inputShape, 192 int32_t blockSize, 193 uint8_t* outputData, const Shape& outputShape); 194 195 bool padGeneric(const uint8_t* inputData, const Shape& inputShape, 196 const int32_t* paddings, 197 uint8_t* outputData, const Shape& outputShape); 198 199 bool batchToSpaceGeneric(const uint8_t* inputData, const Shape& inputShape, 200 const int32_t* blockSize, 201 uint8_t* outputData, const Shape& outputShape); 202 203 bool spaceToBatchGeneric(const uint8_t* inputData, const Shape& inputShape, 204 const int32_t* blockSize, 205 const int32_t* padding, const Shape& paddingShape, 206 uint8_t* outputData, const Shape& outputShape); 207 208 bool subFloat32(const float* in1, const Shape& shape1, 209 const float* in2, const Shape& shape2, 210 int32_t activation, 211 float* out, const Shape& shapeOut); 212 213 bool squeezeGeneric(const void* inputData, const Shape& inputShape, 214 void* outputData, const Shape& outputShape); 215 216 bool divFloat32(const float* in1, const Shape& shape1, 217 const float* in2, const Shape& shape2, 218 int32_t activation, 219 float* out, const Shape& shapeOut); 220 221 bool transposeGeneric(const uint8_t* inputData, const Shape& inputShape, 222 const int32_t* perm, const Shape& permShape, 223 uint8_t* outputData, const Shape& outputShape); 224 225 bool meanGeneric(const uint8_t* inputData, const Shape& inputShape, 226 const int32_t* axis, const Shape& axisShape, bool keepDims, 227 uint8_t* outputData, const Shape& outputShape); 228 229 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape, 230 const int32_t* beginData, const int32_t* endData, 231 const int32_t* stridesData, 232 int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask, 233 uint8_t* outputData, const Shape& outputShape); 234 } // namespace nn 235 } // namespace android 236 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_H 237