1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK 16 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK 17 #endif 18 #include "eight_bit_int_gemm.h" 19 20 #include <memory> 21 22 // gemmlowp symbols should have hidden visibility. 23 // currently this is ensured in the build system by 24 // passing -finlines-visibility-hidden. TODO: it would be 25 // safer to hardcode it here with some #pragma's. 26 #include "../public/gemmlowp.h" 27 28 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON 29 // code. This code path consists of a number of meta-programmed, automatically 30 // generated GEMM kernels that are suitable for some sizes of input matrices. 31 // Due to the fact that the generated code relies heavily on loop unrolling, 32 // inling and currying of runtime parameters the size of the generated binary 33 // is quite significant (approx. 200kb) which might be prohibitive in 34 // low-memory situations. 35 36 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) 37 #include "../meta/legacy_multi_thread_gemm.h" 38 #else 39 40 #if defined(GEMMLOWP_USE_META_FASTPATH) 41 #warning "META fast path turned on without NEON!" 42 #endif 43 44 #endif 45 46 namespace gemmlowp { 47 namespace eight_bit_int_gemm { 48 namespace { 49 50 // To be used as template parameter for GlobalLock. 51 // GlobalLock<EightBitIntGemmLockId> is the global lock 52 // on EightBitIntGemm entry points, protecting 53 // EightBitIntGemm's global state. 54 struct EightBitIntGemmLockId; 55 56 // Global state: consists of one global GemmContext instance. 57 GemmContext* global_context; 58 59 GemmContext* GetOrCreateGlobalContext() { 60 if (!global_context) { 61 global_context = new GemmContext; 62 } 63 return global_context; 64 } 65 66 void DestroyGlobalContext() { 67 delete global_context; 68 global_context = nullptr; 69 } 70 71 template <bool transpose_a, bool transpose_b, bool transpose_c> 72 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k, 73 const std::uint8_t* a, std::int32_t a_offset, int lda, 74 const std::uint8_t* b, std::int32_t b_offset, int ldb, 75 std::uint8_t* c, std::int32_t c_offset, 76 std::int32_t c_mult_int, std::int32_t c_shift, int ldc, 77 BitDepthSetting bit_depth) { 78 const int lhs_offset = a_offset; 79 const int rhs_offset = b_offset; 80 const int result_offset = c_offset; 81 const int result_mult_int = c_mult_int; 82 const int result_shift = c_shift; 83 84 static const MapOrder ResultOrder = 85 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; 86 static const MapOrder LhsOrder = 87 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; 88 static const MapOrder RhsOrder = 89 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; 90 91 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda); 92 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb); 93 MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc); 94 95 switch (bit_depth) { 96 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ 97 case BitDepthSetting::BIT_DEPTH_SETTING: \ 98 Gemm<std::uint8_t, BIT_DEPTH_PARAMS>( \ 99 context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \ 100 result_mult_int, result_shift); \ 101 return; 102 GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams) 103 GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams) 104 default: 105 abort(); 106 #undef GEMMLOWP_HANDLE_BIT_DEPTH 107 } 108 } 109 110 template <bool transpose_a, bool transpose_b, bool transpose_c> 111 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k, 112 const std::uint8_t* a, std::int32_t a_offset, 113 int lda, const std::uint8_t* b, 114 std::int32_t b_offset, int ldb, std::int32_t* c, 115 int ldc, BitDepthSetting bit_depth) { 116 const int lhs_offset = a_offset; 117 const int rhs_offset = b_offset; 118 119 static const MapOrder ResultOrder = 120 transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; 121 static const MapOrder LhsOrder = 122 transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; 123 static const MapOrder RhsOrder = 124 transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; 125 126 MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda); 127 MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb); 128 MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc); 129 130 auto empty_pipeline = std::make_tuple(); 131 132 switch (bit_depth) { 133 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ 134 case BitDepthSetting::BIT_DEPTH_SETTING: \ 135 GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>( \ 136 context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \ 137 return; 138 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams) 139 GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams) 140 default: 141 abort(); 142 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32 143 } 144 } 145 146 class Scratch { 147 public: 148 Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {} 149 150 void AssureSize(std::int32_t required_size) { 151 if (size_ >= required_size) { 152 return; 153 } 154 buffer_.reset(new std::uint8_t[required_size + 32]); 155 buffer_32_ = 156 buffer_.get() + 157 ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32); 158 assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0); 159 size_ = required_size; 160 } 161 162 void Clear() { 163 buffer_.reset(nullptr); 164 buffer_32_ = nullptr; 165 size_ = 0; 166 } 167 168 std::uint8_t* buffer() { return buffer_32_; } 169 170 private: 171 std::unique_ptr<std::uint8_t[]> buffer_; 172 std::uint8_t* buffer_32_; 173 std::int32_t size_; 174 }; 175 176 Scratch* global_scratch = nullptr; 177 178 Scratch* GetOrCreateGlobalScratch() { 179 if (global_scratch == nullptr) { 180 global_scratch = new Scratch(); 181 } 182 return global_scratch; 183 } 184 185 void DestroyGlobalScratch() { 186 delete global_scratch; 187 global_scratch = nullptr; 188 } 189 190 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) 191 192 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) { 193 // Is it row major and nicely packed? 194 if (transpose && stride == cols) { 195 return true; 196 } 197 198 // Is it a one row vector? (a vector is both row and column major) 199 if (rows == 1) { 200 return true; 201 } 202 203 return false; 204 } 205 206 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) { 207 // Is it column major and nicely packed? 208 if (!transpose && stride == rows) { 209 return true; 210 } 211 212 // Is it a one column vector? (a vector is both row and column major) 213 if (cols == 1) { 214 return true; 215 } 216 217 return false; 218 } 219 220 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c, 221 int m, int n, int k, int lda, int ldb, int ldc, 222 BitDepthSetting depth_setting) { 223 // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048. 224 if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) { 225 return false; 226 } 227 228 // The first operand needs to be a row major matrix or a vector. 229 if (!IsRowMajorOrVector(transpose_a, lda, m, k)) { 230 return false; 231 } 232 233 // The second operand needs to be a column major matrix or a vector. 234 if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) { 235 return false; 236 } 237 238 // The result can either be a row major matrix, a column major matrix or 239 // a vector. 240 if (IsRowMajorOrVector(transpose_c, ldc, m, n)) { 241 return true; 242 } 243 244 if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) { 245 return true; 246 } 247 248 return false; 249 } 250 251 // Assure enough scratch memory is allocated and run the fast path gemm. 252 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs, 253 const std::uint8_t* rhs, int m, int n, int k, 254 std::int32_t lhs_offset, std::int32_t rhs_offset, 255 std::int32_t sum_offset, 256 std::int32_t multiplicative_offset, 257 std::int32_t shift, bool result_transpose, 258 std::int32_t result_stride, std::uint8_t* result) { 259 Scratch* scratch = GetOrCreateGlobalScratch(); 260 const std::int32_t max_num_threads = context->max_num_threads(); 261 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { 262 scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads)); 263 meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, 264 scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, 265 rhs_offset, sum_offset, multiplicative_offset, 266 shift, result); 267 } else { 268 scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads)); 269 meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, 270 scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, 271 lhs_offset, sum_offset, multiplicative_offset, 272 shift, result); 273 } 274 } 275 276 // Assure enough scratch memory is allocated and run the 8bit to float fast 277 // path gemm. 278 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs, 279 const std::uint8_t* rhs, int m, int n, int k, 280 std::int32_t lhs_offset, std::int32_t rhs_offset, 281 float result_offset, bool result_transpose, 282 std::int32_t result_stride, float* result) { 283 Scratch* scratch = GetOrCreateGlobalScratch(); 284 const std::int32_t max_num_threads = context->max_num_threads(); 285 if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { 286 scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads)); 287 meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, 288 scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, 289 rhs_offset, result_offset, result); 290 } else { 291 scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads)); 292 meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, 293 scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, 294 lhs_offset, result_offset, result); 295 } 296 } 297 298 #endif 299 300 } // end anonymous namespace 301 302 // Public interface entry points 303 304 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, 305 int m, int n, int k, const std::uint8_t* a, 306 std::int32_t a_offset, int lda, const std::uint8_t* b, 307 std::int32_t b_offset, int ldb, std::uint8_t* c, 308 std::int32_t c_offset, std::int32_t c_mult_int, 309 std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) { 310 ScopedLock sl(GlobalMutexes::EightBitIntGemm()); 311 GemmContext* context = GetOrCreateGlobalContext(); 312 313 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) 314 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, 315 ldb, ldc, bit_depth)) { 316 MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset, 317 c_mult_int, c_shift, transpose_c, ldc, c); 318 return; 319 } 320 #endif 321 322 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc) \ 323 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ 324 EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b, \ 325 b_offset, ldb, c, c_offset, c_mult_int, \ 326 c_shift, ldc, bit_depth); \ 327 } 328 329 GEMMLOWP_HANDLE_CASE(false, false, false) 330 GEMMLOWP_HANDLE_CASE(false, false, true) 331 GEMMLOWP_HANDLE_CASE(false, true, false) 332 GEMMLOWP_HANDLE_CASE(false, true, true) 333 GEMMLOWP_HANDLE_CASE(true, false, false) 334 GEMMLOWP_HANDLE_CASE(true, false, true) 335 GEMMLOWP_HANDLE_CASE(true, true, false) 336 GEMMLOWP_HANDLE_CASE(true, true, true) 337 338 #undef GEMMLOWP_HANDLE_CASE 339 } 340 341 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, 342 int m, int n, int k, const std::uint8_t* a, 343 std::int32_t a_offset, std::int32_t lda, 344 const std::uint8_t* b, std::int32_t b_offset, 345 std::int32_t ldb, float* c, float c_offset, 346 std::int32_t ldc, BitDepthSetting bit_depth) { 347 ScopedLock sl(GlobalMutexes::EightBitIntGemm()); 348 GemmContext* context = GetOrCreateGlobalContext(); 349 350 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) 351 if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, 352 ldb, ldc, bit_depth)) { 353 MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset, 354 transpose_c, ldc, c); 355 return; 356 } 357 #endif 358 359 // TODO(maciekc): implement a float output stage, get rid of scratch memory. 360 Scratch* scratch = GetOrCreateGlobalScratch(); 361 if (transpose_c) { 362 scratch->AssureSize(m * ldc * sizeof(std::int32_t)); 363 } else { 364 scratch->AssureSize(n * ldc * sizeof(std::int32_t)); 365 } 366 std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer()); 367 368 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc) \ 369 if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ 370 EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \ 371 b, b_offset, ldb, temp_c, ldc, \ 372 bit_depth); \ 373 } 374 375 GEMMLOWP_HANDLE_INT32_CASE(false, false, false) 376 GEMMLOWP_HANDLE_INT32_CASE(false, false, true) 377 GEMMLOWP_HANDLE_INT32_CASE(false, true, false) 378 GEMMLOWP_HANDLE_INT32_CASE(false, true, true) 379 GEMMLOWP_HANDLE_INT32_CASE(true, false, false) 380 GEMMLOWP_HANDLE_INT32_CASE(true, false, true) 381 GEMMLOWP_HANDLE_INT32_CASE(true, true, false) 382 GEMMLOWP_HANDLE_INT32_CASE(true, true, true) 383 384 #undef GEMMLOWP_HANDLE_INT32_CASE 385 386 if (transpose_c) { 387 // Row major. 388 for (int i = 0; i < m; ++i) { 389 float* dest_row = c + i * ldc; 390 std::int32_t* src_row = temp_c + i * ldc; 391 for (int j = 0; j < n; ++j) { 392 dest_row[j] = static_cast<float>(src_row[j]) * c_offset; 393 } 394 } 395 } else { 396 // Column major. 397 for (int i = 0; i < n; ++i) { 398 float* dest_column = c + i * ldc; 399 std::int32_t* src_column = temp_c + i * ldc; 400 for (int j = 0; j < m; ++j) { 401 dest_column[j] = static_cast<float>(src_column[j]) * c_offset; 402 } 403 } 404 } 405 } 406 407 void SetMaxNumThreads(int n) { 408 ScopedLock sl(GlobalMutexes::EightBitIntGemm()); 409 GemmContext* context = GetOrCreateGlobalContext(); 410 context->set_max_num_threads(n); 411 } 412 413 void FreePersistentResources() { 414 ScopedLock sl(GlobalMutexes::EightBitIntGemm()); 415 DestroyGlobalContext(); 416 DestroyGlobalScratch(); 417 } 418 419 } // namespace eight_bit_int_gemm 420 } // namespace gemmlowp 421