1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "eight_bit_int_gemm.h" 16 17 #include <memory> 18 19 // gemmlowp symbols should have hidden visibility. 20 // currently this is ensured in the build system by 21 // passing -finlines-visibility-hidden. TODO: it would be 22 // safer to hardcode it here with some #pragma's. 23 #include "../public/gemmlowp.h" 24 25 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON 26 // code. This code path consists of a number of meta-programmed, automatically 27 // generated GEMM kernels that are suitable for some sizes of input matrices. 28 // Due to the fact that the generated code relies heavily on loop unrolling, 29 // inling and currying of runtime parameters the size of the generated binary 30 // is quite significant (approx. 200kb) which might be prohibitive in 31 // low-memory situations. 32 33 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) 34 #include "../meta/multi_thread_gemm.h" 35 #endif 36 37 namespace gemmlowp { 38 namespace eight_bit_int_gemm { 39 namespace { 40 41 // To be used as template parameter for GlobalLock. 42 // GlobalLock<EightBitIntGemmLockId> is the global lock 43 // on EightBitIntGemm entry points, protecting 44 // EightBitIntGemm's global state. 45 struct EightBitIntGemmLockId; 46 47 // Global state: consists of one global GemmContext instance. 48 GemmContext* global_context; 49 50 GemmContext* GetOrCreateGlobalContext() { 51 if (!global_context) { 52 global_context = new GemmContext; 53 } 54 return global_context; 55 } 56 57 void DestroyGlobalContext() { 58 delete global_context; 59 global_context = nullptr; 60 } 61 62 template <bool transpose_a, bool transpose_b, bool transpose_c> 63 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k, 64 const std::uint8_t* a, std::int32_t a_offset, int lda, 65 const std::uint8_t* b, std::int32_t b_offset, int ldb, 66 std::uint8_t* c, std::int32_t c_offset, 67 std::int32_t c_mult_int, std::int32_t c_shift, int ldc, 68 BitDepthSetting bit_depth) { 69 const int lhs_offset = a_offset; 70 const int rhs_offset = b_offset; 71 const int result_offset = c_offset; 72 const int result_mult_int = c_mult_int; 73 const int result_shift = c_shift; 74 75 static const MapOrder ResultOrder = 76 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; 77 static const MapOrder LhsOrder = 78 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; 79 static const MapOrder RhsOrder = 80 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; 81 82 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda); 83 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb); 84 MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc); 85 86 switch (bit_depth) { 87 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ 88 case BitDepthSetting::BIT_DEPTH_SETTING: \ 89 Gemm<std::uint8_t, BIT_DEPTH_PARAMS>( \ 90 context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \ 91 result_mult_int, result_shift); \ 92 return; 93 GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams) 94 GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams) 95 default: 96 abort(); 97 #undef GEMMLOWP_HANDLE_BIT_DEPTH 98 } 99 } 100 101 template <bool transpose_a, bool transpose_b, bool transpose_c> 102 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k, 103 const std::uint8_t* a, std::int32_t a_offset, 104 int lda, const std::uint8_t* b, 105 std::int32_t b_offset, int ldb, std::int32_t* c, 106 int ldc, BitDepthSetting bit_depth) { 107 const int lhs_offset = a_offset; 108 const int rhs_offset = b_offset; 109 110 static const MapOrder ResultOrder = 111 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; 112 static const MapOrder LhsOrder = 113 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; 114 static const MapOrder RhsOrder = 115 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; 116 117 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda); 118 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb); 119 MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc); 120 121 auto empty_pipeline = std::make_tuple(); 122 123 switch (bit_depth) { 124 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ 125 case BitDepthSetting::BIT_DEPTH_SETTING: \ 126 GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>( \ 127 context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \ 128 return; 129 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams) 130 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams) 131 default: 132 abort(); 133 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32 134 } 135 } 136 137 class Scratch { 138 public: 139 Scratch() : buffer_(), size_(0) {} 140 141 void AssureSize(std::int32_t required_size) { 142 if (size_ >= required_size) { 143 return; 144 } 145 buffer_.reset(new std::uint8_t[required_size]); 146 size_ = required_size; 147 } 148 149 void Clear() { 150 buffer_.reset(nullptr); 151 size_ = 0; 152 } 153 154 std::uint8_t* buffer() { return buffer_.get(); } 155 156 private: 157 std::unique_ptr<std::uint8_t[]> buffer_; 158 std::int32_t size_; 159 }; 160 161 Scratch* global_scratch = nullptr; 162 163 Scratch* GetOrCreateGlobalScratch() { 164 if (global_scratch == nullptr) { 165 global_scratch = new Scratch(); 166 } 167 return global_scratch; 168 } 169 170 void DestroyGlobalScratch() { 171 delete global_scratch; 172 global_scratch = nullptr; 173 } 174 175 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) 176 177 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) { 178 // Is it row major and nicely packed? 179 if (transpose && stride == cols) { 180 return true; 181 } 182 183 // Is it a one row vector? (a vector is both row and column major) 184 if (rows == 1) { 185 return true; 186 } 187 188 return false; 189 } 190 191 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) { 192 // Is it column major and nicely packed? 193 if (!transpose && stride == rows) { 194 return true; 195 } 196 197 // Is it a one column vector? (a vector is both row and column major) 198 if (cols == 1) { 199 return true; 200 } 201 202 return false; 203 } 204 205 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c, 206 int m, int n, int k, int lda, int ldb, int ldc, 207 BitDepthSetting depth_setting) { 208 // Meta fastpath only supports 8bit x 8bit and k up to 2048. 209 if (depth_setting != BitDepthSetting::A8B8 || k > 2048) { 210 return false; 211 } 212 213 // The first operand needs to be a row major matrix or a vector. 214 if (!IsRowMajorOrVector(transpose_a, lda, m, k)) { 215 return false; 216 } 217 218 // The second operand needs to be a column major matrix or a vector. 219 if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) { 220 return false; 221 } 222 223 // The result can either be a row major matrix, a column major matrix or 224 // a vector. 225 if (IsRowMajorOrVector(transpose_c, ldc, m, n)) { 226 return true; 227 } 228 229 if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) { 230 return true; 231 } 232 233 return false; 234 } 235 236 // Assure enough scratch memory is allocated and run the fast path gemm. 237 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs, 238 const std::uint8_t* rhs, int m, int n, int k, 239 std::int32_t lhs_offset, std::int32_t rhs_offset, 240 std::int32_t sum_offset, 241 std::int32_t multiplicative_offset, 242 std::int32_t shift, bool result_transpose, 243 std::int32_t result_stride, std::uint8_t* result) { 244 Scratch* scratch = GetOrCreateGlobalScratch(); 245 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { 246 scratch->AssureSize( 247 meta::gemm_q8_scratch(m, n, k, context->max_num_threads())); 248 meta::multi_thread_gemm_q8( 249 context->workers_pool(), context->max_num_threads(), scratch->buffer(), 250 lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset, 251 multiplicative_offset, shift, result); 252 } else { 253 scratch->AssureSize( 254 meta::gemm_q8_scratch(n, m, k, context->max_num_threads())); 255 meta::multi_thread_gemm_q8( 256 context->workers_pool(), context->max_num_threads(), scratch->buffer(), 257 rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset, 258 multiplicative_offset, shift, result); 259 } 260 } 261 262 // Assure enough scratch memory is allocated and run the 8bit to float fast 263 // path gemm. 264 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs, 265 const std::uint8_t* rhs, int m, int n, int k, 266 std::int32_t lhs_offset, std::int32_t rhs_offset, 267 float result_offset, bool result_transpose, 268 std::int32_t result_stride, float* result) { 269 Scratch* scratch = GetOrCreateGlobalScratch(); 270 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { 271 scratch->AssureSize( 272 meta::gemm_f_scratch(m, n, k, context->max_num_threads())); 273 meta::multi_thread_gemm_f( 274 context->workers_pool(), context->max_num_threads(), scratch->buffer(), 275 lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result); 276 } else { 277 scratch->AssureSize( 278 meta::gemm_f_scratch(n, m, k, context->max_num_threads())); 279 meta::multi_thread_gemm_f( 280 context->workers_pool(), context->max_num_threads(), scratch->buffer(), 281 rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result); 282 } 283 } 284 285 #endif 286 287 } // end anonymous namespace 288 289 // Public interface entry points 290 291 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, 292 int m, int n, int k, const std::uint8_t* a, 293 std::int32_t a_offset, int lda, const std::uint8_t* b, 294 std::int32_t b_offset, int ldb, std::uint8_t* c, 295 std::int32_t c_offset, std::int32_t c_mult_int, 296 std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) { 297 AutoGlobalLock<EightBitIntGemmLockId> lock; 298 GemmContext* context = GetOrCreateGlobalContext(); 299 300 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) 301 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, 302 ldb, ldc, bit_depth)) { 303 MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset, 304 c_mult_int, c_shift, transpose_c, ldc, c); 305 return; 306 } 307 #endif 308 309 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc) \ 310 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ 311 EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b, \ 312 b_offset, ldb, c, c_offset, c_mult_int, \ 313 c_shift, ldc, bit_depth); \ 314 } 315 316 GEMMLOWP_HANDLE_CASE(false, false, false) 317 GEMMLOWP_HANDLE_CASE(false, false, true) 318 GEMMLOWP_HANDLE_CASE(false, true, false) 319 GEMMLOWP_HANDLE_CASE(false, true, true) 320 GEMMLOWP_HANDLE_CASE(true, false, false) 321 GEMMLOWP_HANDLE_CASE(true, false, true) 322 GEMMLOWP_HANDLE_CASE(true, true, false) 323 GEMMLOWP_HANDLE_CASE(true, true, true) 324 325 #undef GEMMLOWP_HANDLE_CASE 326 } 327 328 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, 329 int m, int n, int k, const std::uint8_t* a, 330 std::int32_t a_offset, std::int32_t lda, 331 const std::uint8_t* b, std::int32_t b_offset, 332 std::int32_t ldb, float* c, float c_offset, 333 std::int32_t ldc, BitDepthSetting bit_depth) { 334 AutoGlobalLock<EightBitIntGemmLockId> lock; 335 GemmContext* context = GetOrCreateGlobalContext(); 336 337 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) 338 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, 339 ldb, ldc, bit_depth)) { 340 MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset, 341 transpose_c, ldc, c); 342 return; 343 } 344 #endif 345 346 // TODO(maciekc): implement a float output stage, get rid of scratch memory. 347 Scratch* scratch = GetOrCreateGlobalScratch(); 348 if (transpose_c) { 349 scratch->AssureSize(m * ldc * sizeof(std::int32_t)); 350 } else { 351 scratch->AssureSize(n * ldc * sizeof(std::int32_t)); 352 } 353 std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer()); 354 355 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc) \ 356 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ 357 EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \ 358 b, b_offset, ldb, temp_c, ldc, \ 359 bit_depth); \ 360 } 361 362 GEMMLOWP_HANDLE_INT32_CASE(false, false, false) 363 GEMMLOWP_HANDLE_INT32_CASE(false, false, true) 364 GEMMLOWP_HANDLE_INT32_CASE(false, true, false) 365 GEMMLOWP_HANDLE_INT32_CASE(false, true, true) 366 GEMMLOWP_HANDLE_INT32_CASE(true, false, false) 367 GEMMLOWP_HANDLE_INT32_CASE(true, false, true) 368 GEMMLOWP_HANDLE_INT32_CASE(true, true, false) 369 GEMMLOWP_HANDLE_INT32_CASE(true, true, true) 370 371 #undef GEMMLOWP_HANDLE_INT32_CASE 372 373 if (transpose_c) { 374 // Row major. 375 for (int i = 0; i < m; ++i) { 376 float* dest_row = c + i * ldc; 377 std::int32_t* src_row = temp_c + i * ldc; 378 for (int j = 0; j < n; ++j) { 379 dest_row[j] = static_cast<float>(src_row[j]) * c_offset; 380 } 381 } 382 } else { 383 // Column major. 384 for (int i = 0; i < n; ++i) { 385 float* dest_column = c + i * ldc; 386 std::int32_t* src_column = temp_c + i * ldc; 387 for (int j = 0; j < m; ++j) { 388 dest_column[j] = static_cast<float>(src_column[j]) * c_offset; 389 } 390 } 391 } 392 } 393 394 void SetMaxNumThreads(int n) { 395 AutoGlobalLock<EightBitIntGemmLockId> lock; 396 GemmContext* context = GetOrCreateGlobalContext(); 397 context->set_max_num_threads(n); 398 } 399 400 void FreePersistentResources() { 401 AutoGlobalLock<EightBitIntGemmLockId> lock; 402 DestroyGlobalContext(); 403 DestroyGlobalScratch(); 404 } 405 406 } // namespace eight_bit_int_gemm 407 } // namespace gemmlowp 408