Home | History | Annotate | Download | only in eight_bit_int_gemm
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
     16 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
     17 #endif
     18 #include "eight_bit_int_gemm.h"
     19 
     20 #include <memory>
     21 
     22 // gemmlowp symbols should have hidden visibility.
     23 // currently this is ensured in the build system by
     24 // passing -finlines-visibility-hidden. TODO: it would be
     25 // safer to hardcode it here with some #pragma's.
     26 #include "../public/gemmlowp.h"
     27 
     28 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
     29 // code. This code path consists of a number of meta-programmed, automatically
     30 // generated GEMM kernels that are suitable for some sizes of input matrices.
     31 // Due to the fact that the generated code relies heavily on loop unrolling,
     32 // inling and currying of runtime parameters the size of the generated binary
     33 // is quite significant (approx. 200kb) which might be prohibitive in
     34 // low-memory situations.
     35 
     36 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
     37 #include "../meta/legacy_multi_thread_gemm.h"
     38 #else
     39 
     40 #if defined(GEMMLOWP_USE_META_FASTPATH)
     41 #warning "META fast path turned on without NEON!"
     42 #endif
     43 
     44 #endif
     45 
     46 namespace gemmlowp {
     47 namespace eight_bit_int_gemm {
     48 namespace {
     49 
     50 // To be used as template parameter for GlobalLock.
     51 // GlobalLock<EightBitIntGemmLockId> is the global lock
     52 // on EightBitIntGemm entry points, protecting
     53 // EightBitIntGemm's global state.
     54 struct EightBitIntGemmLockId;
     55 
     56 // Global state: consists of one global GemmContext instance.
     57 GemmContext* global_context;
     58 
     59 GemmContext* GetOrCreateGlobalContext() {
     60   if (!global_context) {
     61     global_context = new GemmContext;
     62   }
     63   return global_context;
     64 }
     65 
     66 void DestroyGlobalContext() {
     67   delete global_context;
     68   global_context = nullptr;
     69 }
     70 
     71 template <bool transpose_a, bool transpose_b, bool transpose_c>
     72 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
     73                          const std::uint8_t* a, std::int32_t a_offset, int lda,
     74                          const std::uint8_t* b, std::int32_t b_offset, int ldb,
     75                          std::uint8_t* c, std::int32_t c_offset,
     76                          std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
     77                          BitDepthSetting bit_depth) {
     78   const int lhs_offset = a_offset;
     79   const int rhs_offset = b_offset;
     80   const int result_offset = c_offset;
     81   const int result_mult_int = c_mult_int;
     82   const int result_shift = c_shift;
     83 
     84   static const MapOrder ResultOrder =
     85       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
     86   static const MapOrder LhsOrder =
     87       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
     88   static const MapOrder RhsOrder =
     89       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
     90 
     91   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
     92   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
     93   MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
     94 
     95   switch (bit_depth) {
     96 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS)     \
     97   case BitDepthSetting::BIT_DEPTH_SETTING:                                 \
     98     Gemm<std::uint8_t, BIT_DEPTH_PARAMS>(                                  \
     99         context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
    100         result_mult_int, result_shift);                                    \
    101     return;
    102     GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
    103     GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
    104     default:
    105       abort();
    106 #undef GEMMLOWP_HANDLE_BIT_DEPTH
    107   }
    108 }
    109 
    110 template <bool transpose_a, bool transpose_b, bool transpose_c>
    111 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
    112                               const std::uint8_t* a, std::int32_t a_offset,
    113                               int lda, const std::uint8_t* b,
    114                               std::int32_t b_offset, int ldb, std::int32_t* c,
    115                               int ldc, BitDepthSetting bit_depth) {
    116   const int lhs_offset = a_offset;
    117   const int rhs_offset = b_offset;
    118 
    119   static const MapOrder ResultOrder =
    120       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
    121   static const MapOrder LhsOrder =
    122       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
    123   static const MapOrder RhsOrder =
    124       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
    125 
    126   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
    127   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
    128   MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
    129 
    130   auto empty_pipeline = std::make_tuple();
    131 
    132   switch (bit_depth) {
    133 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
    134   case BitDepthSetting::BIT_DEPTH_SETTING:                                   \
    135     GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>(    \
    136         context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
    137     return;
    138     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
    139     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
    140     default:
    141       abort();
    142 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
    143   }
    144 }
    145 
    146 class Scratch {
    147  public:
    148   Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {}
    149 
    150   void AssureSize(std::int32_t required_size) {
    151     if (size_ >= required_size) {
    152       return;
    153     }
    154     buffer_.reset(new std::uint8_t[required_size + 32]);
    155     buffer_32_ =
    156         buffer_.get() +
    157         ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32);
    158     assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0);
    159     size_ = required_size;
    160   }
    161 
    162   void Clear() {
    163     buffer_.reset(nullptr);
    164     buffer_32_ = nullptr;
    165     size_ = 0;
    166   }
    167 
    168   std::uint8_t* buffer() { return buffer_32_; }
    169 
    170  private:
    171   std::unique_ptr<std::uint8_t[]> buffer_;
    172   std::uint8_t* buffer_32_;
    173   std::int32_t size_;
    174 };
    175 
    176 Scratch* global_scratch = nullptr;
    177 
    178 Scratch* GetOrCreateGlobalScratch() {
    179   if (global_scratch == nullptr) {
    180     global_scratch = new Scratch();
    181   }
    182   return global_scratch;
    183 }
    184 
    185 void DestroyGlobalScratch() {
    186   delete global_scratch;
    187   global_scratch = nullptr;
    188 }
    189 
    190 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
    191 
    192 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
    193   // Is it row major and nicely packed?
    194   if (transpose && stride == cols) {
    195     return true;
    196   }
    197 
    198   // Is it a one row vector? (a vector is both row and column major)
    199   if (rows == 1) {
    200     return true;
    201   }
    202 
    203   return false;
    204 }
    205 
    206 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
    207   // Is it column major and nicely packed?
    208   if (!transpose && stride == rows) {
    209     return true;
    210   }
    211 
    212   // Is it a one column vector? (a vector is both row and column major)
    213   if (cols == 1) {
    214     return true;
    215   }
    216 
    217   return false;
    218 }
    219 
    220 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
    221                            int m, int n, int k, int lda, int ldb, int ldc,
    222                            BitDepthSetting depth_setting) {
    223   // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048.
    224   if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) {
    225     return false;
    226   }
    227 
    228   // The first operand needs to be a row major matrix or a vector.
    229   if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
    230     return false;
    231   }
    232 
    233   // The second operand needs to be a column major matrix or a vector.
    234   if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
    235     return false;
    236   }
    237 
    238   // The result can either be a row major matrix, a column major matrix or
    239   // a vector.
    240   if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
    241     return true;
    242   }
    243 
    244   if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
    245     return true;
    246   }
    247 
    248   return false;
    249 }
    250 
    251 // Assure enough scratch memory is allocated and run the fast path gemm.
    252 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
    253                            const std::uint8_t* rhs, int m, int n, int k,
    254                            std::int32_t lhs_offset, std::int32_t rhs_offset,
    255                            std::int32_t sum_offset,
    256                            std::int32_t multiplicative_offset,
    257                            std::int32_t shift, bool result_transpose,
    258                            std::int32_t result_stride, std::uint8_t* result) {
    259   Scratch* scratch = GetOrCreateGlobalScratch();
    260   const std::int32_t max_num_threads = context->max_num_threads();
    261   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
    262     scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads));
    263     meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
    264                                scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
    265                                rhs_offset, sum_offset, multiplicative_offset,
    266                                shift, result);
    267   } else {
    268     scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads));
    269     meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
    270                                scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
    271                                lhs_offset, sum_offset, multiplicative_offset,
    272                                shift, result);
    273   }
    274 }
    275 
    276 // Assure enough scratch memory is allocated and run the 8bit to float fast
    277 // path gemm.
    278 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
    279                    const std::uint8_t* rhs, int m, int n, int k,
    280                    std::int32_t lhs_offset, std::int32_t rhs_offset,
    281                    float result_offset, bool result_transpose,
    282                    std::int32_t result_stride, float* result) {
    283   Scratch* scratch = GetOrCreateGlobalScratch();
    284   const std::int32_t max_num_threads = context->max_num_threads();
    285   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
    286     scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads));
    287     meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
    288                               scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
    289                               rhs_offset, result_offset, result);
    290   } else {
    291     scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads));
    292     meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
    293                               scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
    294                               lhs_offset, result_offset, result);
    295   }
    296 }
    297 
    298 #endif
    299 
    300 }  // end anonymous namespace
    301 
    302 // Public interface entry points
    303 
    304 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
    305                      int m, int n, int k, const std::uint8_t* a,
    306                      std::int32_t a_offset, int lda, const std::uint8_t* b,
    307                      std::int32_t b_offset, int ldb, std::uint8_t* c,
    308                      std::int32_t c_offset, std::int32_t c_mult_int,
    309                      std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
    310   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
    311   GemmContext* context = GetOrCreateGlobalContext();
    312 
    313 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
    314   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
    315                             ldb, ldc, bit_depth)) {
    316     MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
    317                           c_mult_int, c_shift, transpose_c, ldc, c);
    318     return;
    319   }
    320 #endif
    321 
    322 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc)                                    \
    323   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {        \
    324     EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b,  \
    325                                     b_offset, ldb, c, c_offset, c_mult_int, \
    326                                     c_shift, ldc, bit_depth);               \
    327   }
    328 
    329   GEMMLOWP_HANDLE_CASE(false, false, false)
    330   GEMMLOWP_HANDLE_CASE(false, false, true)
    331   GEMMLOWP_HANDLE_CASE(false, true, false)
    332   GEMMLOWP_HANDLE_CASE(false, true, true)
    333   GEMMLOWP_HANDLE_CASE(true, false, false)
    334   GEMMLOWP_HANDLE_CASE(true, false, true)
    335   GEMMLOWP_HANDLE_CASE(true, true, false)
    336   GEMMLOWP_HANDLE_CASE(true, true, true)
    337 
    338 #undef GEMMLOWP_HANDLE_CASE
    339 }
    340 
    341 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
    342                      int m, int n, int k, const std::uint8_t* a,
    343                      std::int32_t a_offset, std::int32_t lda,
    344                      const std::uint8_t* b, std::int32_t b_offset,
    345                      std::int32_t ldb, float* c, float c_offset,
    346                      std::int32_t ldc, BitDepthSetting bit_depth) {
    347   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
    348   GemmContext* context = GetOrCreateGlobalContext();
    349 
    350 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
    351   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
    352                             ldb, ldc, bit_depth)) {
    353     MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
    354                   transpose_c, ldc, c);
    355     return;
    356   }
    357 #endif
    358 
    359   // TODO(maciekc): implement a float output stage, get rid of scratch memory.
    360   Scratch* scratch = GetOrCreateGlobalScratch();
    361   if (transpose_c) {
    362     scratch->AssureSize(m * ldc * sizeof(std::int32_t));
    363   } else {
    364     scratch->AssureSize(n * ldc * sizeof(std::int32_t));
    365   }
    366   std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
    367 
    368 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc)                               \
    369   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {         \
    370     EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
    371                                          b, b_offset, ldb, temp_c, ldc,      \
    372                                          bit_depth);                         \
    373   }
    374 
    375   GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
    376   GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
    377   GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
    378   GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
    379   GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
    380   GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
    381   GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
    382   GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
    383 
    384 #undef GEMMLOWP_HANDLE_INT32_CASE
    385 
    386   if (transpose_c) {
    387     // Row major.
    388     for (int i = 0; i < m; ++i) {
    389       float* dest_row = c + i * ldc;
    390       std::int32_t* src_row = temp_c + i * ldc;
    391       for (int j = 0; j < n; ++j) {
    392         dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
    393       }
    394     }
    395   } else {
    396     // Column major.
    397     for (int i = 0; i < n; ++i) {
    398       float* dest_column = c + i * ldc;
    399       std::int32_t* src_column = temp_c + i * ldc;
    400       for (int j = 0; j < m; ++j) {
    401         dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
    402       }
    403     }
    404   }
    405 }
    406 
    407 void SetMaxNumThreads(int n) {
    408   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
    409   GemmContext* context = GetOrCreateGlobalContext();
    410   context->set_max_num_threads(n);
    411 }
    412 
    413 void FreePersistentResources() {
    414   ScopedLock sl(GlobalMutexes::EightBitIntGemm());
    415   DestroyGlobalContext();
    416   DestroyGlobalScratch();
    417 }
    418 
    419 }  // namespace eight_bit_int_gemm
    420 }  // namespace gemmlowp
    421