Home | History | Annotate | Download | only in eight_bit_int_gemm
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "eight_bit_int_gemm.h"
     16 
     17 #include <memory>
     18 
     19 // gemmlowp symbols should have hidden visibility.
     20 // currently this is ensured in the build system by
     21 // passing -finlines-visibility-hidden. TODO: it would be
     22 // safer to hardcode it here with some #pragma's.
     23 #include "../public/gemmlowp.h"
     24 
     25 // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
     26 // code. This code path consists of a number of meta-programmed, automatically
     27 // generated GEMM kernels that are suitable for some sizes of input matrices.
     28 // Due to the fact that the generated code relies heavily on loop unrolling,
     29 // inling and currying of runtime parameters the size of the generated binary
     30 // is quite significant (approx. 200kb) which might be prohibitive in
     31 // low-memory situations.
     32 
     33 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
     34 #include "../meta/multi_thread_gemm.h"
     35 #endif
     36 
     37 namespace gemmlowp {
     38 namespace eight_bit_int_gemm {
     39 namespace {
     40 
     41 // To be used as template parameter for GlobalLock.
     42 // GlobalLock<EightBitIntGemmLockId> is the global lock
     43 // on EightBitIntGemm entry points, protecting
     44 // EightBitIntGemm's global state.
     45 struct EightBitIntGemmLockId;
     46 
     47 // Global state: consists of one global GemmContext instance.
     48 GemmContext* global_context;
     49 
     50 GemmContext* GetOrCreateGlobalContext() {
     51   if (!global_context) {
     52     global_context = new GemmContext;
     53   }
     54   return global_context;
     55 }
     56 
     57 void DestroyGlobalContext() {
     58   delete global_context;
     59   global_context = nullptr;
     60 }
     61 
     62 template <bool transpose_a, bool transpose_b, bool transpose_c>
     63 void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
     64                          const std::uint8_t* a, std::int32_t a_offset, int lda,
     65                          const std::uint8_t* b, std::int32_t b_offset, int ldb,
     66                          std::uint8_t* c, std::int32_t c_offset,
     67                          std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
     68                          BitDepthSetting bit_depth) {
     69   const int lhs_offset = a_offset;
     70   const int rhs_offset = b_offset;
     71   const int result_offset = c_offset;
     72   const int result_mult_int = c_mult_int;
     73   const int result_shift = c_shift;
     74 
     75   static const MapOrder ResultOrder =
     76       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
     77   static const MapOrder LhsOrder =
     78       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
     79   static const MapOrder RhsOrder =
     80       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
     81 
     82   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
     83   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
     84   MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
     85 
     86   switch (bit_depth) {
     87 #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS)     \
     88   case BitDepthSetting::BIT_DEPTH_SETTING:                                 \
     89     Gemm<std::uint8_t, BIT_DEPTH_PARAMS>(                                  \
     90         context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
     91         result_mult_int, result_shift);                                    \
     92     return;
     93     GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
     94     GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
     95     default:
     96       abort();
     97 #undef GEMMLOWP_HANDLE_BIT_DEPTH
     98   }
     99 }
    100 
    101 template <bool transpose_a, bool transpose_b, bool transpose_c>
    102 void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
    103                               const std::uint8_t* a, std::int32_t a_offset,
    104                               int lda, const std::uint8_t* b,
    105                               std::int32_t b_offset, int ldb, std::int32_t* c,
    106                               int ldc, BitDepthSetting bit_depth) {
    107   const int lhs_offset = a_offset;
    108   const int rhs_offset = b_offset;
    109 
    110   static const MapOrder ResultOrder =
    111       transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
    112   static const MapOrder LhsOrder =
    113       transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
    114   static const MapOrder RhsOrder =
    115       transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
    116 
    117   MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
    118   MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
    119   MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
    120 
    121   auto empty_pipeline = std::make_tuple();
    122 
    123   switch (bit_depth) {
    124 #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
    125   case BitDepthSetting::BIT_DEPTH_SETTING:                                   \
    126     GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>(    \
    127         context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
    128     return;
    129     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
    130     GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
    131     default:
    132       abort();
    133 #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
    134   }
    135 }
    136 
    137 class Scratch {
    138  public:
    139   Scratch() : buffer_(), size_(0) {}
    140 
    141   void AssureSize(std::int32_t required_size) {
    142     if (size_ >= required_size) {
    143       return;
    144     }
    145     buffer_.reset(new std::uint8_t[required_size]);
    146     size_ = required_size;
    147   }
    148 
    149   void Clear() {
    150     buffer_.reset(nullptr);
    151     size_ = 0;
    152   }
    153 
    154   std::uint8_t* buffer() { return buffer_.get(); }
    155 
    156  private:
    157   std::unique_ptr<std::uint8_t[]> buffer_;
    158   std::int32_t size_;
    159 };
    160 
    161 Scratch* global_scratch = nullptr;
    162 
    163 Scratch* GetOrCreateGlobalScratch() {
    164   if (global_scratch == nullptr) {
    165     global_scratch = new Scratch();
    166   }
    167   return global_scratch;
    168 }
    169 
    170 void DestroyGlobalScratch() {
    171   delete global_scratch;
    172   global_scratch = nullptr;
    173 }
    174 
    175 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
    176 
    177 bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
    178   // Is it row major and nicely packed?
    179   if (transpose && stride == cols) {
    180     return true;
    181   }
    182 
    183   // Is it a one row vector? (a vector is both row and column major)
    184   if (rows == 1) {
    185     return true;
    186   }
    187 
    188   return false;
    189 }
    190 
    191 bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
    192   // Is it column major and nicely packed?
    193   if (!transpose && stride == rows) {
    194     return true;
    195   }
    196 
    197   // Is it a one column vector? (a vector is both row and column major)
    198   if (cols == 1) {
    199     return true;
    200   }
    201 
    202   return false;
    203 }
    204 
    205 bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
    206                            int m, int n, int k, int lda, int ldb, int ldc,
    207                            BitDepthSetting depth_setting) {
    208   // Meta fastpath only supports 8bit x 8bit and k up to 2048.
    209   if (depth_setting != BitDepthSetting::A8B8 || k > 2048) {
    210     return false;
    211   }
    212 
    213   // The first operand needs to be a row major matrix or a vector.
    214   if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
    215     return false;
    216   }
    217 
    218   // The second operand needs to be a column major matrix or a vector.
    219   if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
    220     return false;
    221   }
    222 
    223   // The result can either be a row major matrix, a column major matrix or
    224   // a vector.
    225   if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
    226     return true;
    227   }
    228 
    229   if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
    230     return true;
    231   }
    232 
    233   return false;
    234 }
    235 
    236 // Assure enough scratch memory is allocated and run the fast path gemm.
    237 void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
    238                            const std::uint8_t* rhs, int m, int n, int k,
    239                            std::int32_t lhs_offset, std::int32_t rhs_offset,
    240                            std::int32_t sum_offset,
    241                            std::int32_t multiplicative_offset,
    242                            std::int32_t shift, bool result_transpose,
    243                            std::int32_t result_stride, std::uint8_t* result) {
    244   Scratch* scratch = GetOrCreateGlobalScratch();
    245   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
    246     scratch->AssureSize(
    247         meta::gemm_q8_scratch(m, n, k, context->max_num_threads()));
    248     meta::multi_thread_gemm_q8(
    249         context->workers_pool(), context->max_num_threads(), scratch->buffer(),
    250         lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset,
    251         multiplicative_offset, shift, result);
    252   } else {
    253     scratch->AssureSize(
    254         meta::gemm_q8_scratch(n, m, k, context->max_num_threads()));
    255     meta::multi_thread_gemm_q8(
    256         context->workers_pool(), context->max_num_threads(), scratch->buffer(),
    257         rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset,
    258         multiplicative_offset, shift, result);
    259   }
    260 }
    261 
    262 // Assure enough scratch memory is allocated and run the 8bit to float fast
    263 // path gemm.
    264 void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
    265                    const std::uint8_t* rhs, int m, int n, int k,
    266                    std::int32_t lhs_offset, std::int32_t rhs_offset,
    267                    float result_offset, bool result_transpose,
    268                    std::int32_t result_stride, float* result) {
    269   Scratch* scratch = GetOrCreateGlobalScratch();
    270   if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
    271     scratch->AssureSize(
    272         meta::gemm_f_scratch(m, n, k, context->max_num_threads()));
    273     meta::multi_thread_gemm_f(
    274         context->workers_pool(), context->max_num_threads(), scratch->buffer(),
    275         lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result);
    276   } else {
    277     scratch->AssureSize(
    278         meta::gemm_f_scratch(n, m, k, context->max_num_threads()));
    279     meta::multi_thread_gemm_f(
    280         context->workers_pool(), context->max_num_threads(), scratch->buffer(),
    281         rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result);
    282   }
    283 }
    284 
    285 #endif
    286 
    287 }  // end anonymous namespace
    288 
    289 // Public interface entry points
    290 
    291 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
    292                      int m, int n, int k, const std::uint8_t* a,
    293                      std::int32_t a_offset, int lda, const std::uint8_t* b,
    294                      std::int32_t b_offset, int ldb, std::uint8_t* c,
    295                      std::int32_t c_offset, std::int32_t c_mult_int,
    296                      std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
    297   AutoGlobalLock<EightBitIntGemmLockId> lock;
    298   GemmContext* context = GetOrCreateGlobalContext();
    299 
    300 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
    301   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
    302                             ldb, ldc, bit_depth)) {
    303     MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
    304                           c_mult_int, c_shift, transpose_c, ldc, c);
    305     return;
    306   }
    307 #endif
    308 
    309 #define GEMMLOWP_HANDLE_CASE(ta, tb, tc)                                    \
    310   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {        \
    311     EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b,  \
    312                                     b_offset, ldb, c, c_offset, c_mult_int, \
    313                                     c_shift, ldc, bit_depth);               \
    314   }
    315 
    316   GEMMLOWP_HANDLE_CASE(false, false, false)
    317   GEMMLOWP_HANDLE_CASE(false, false, true)
    318   GEMMLOWP_HANDLE_CASE(false, true, false)
    319   GEMMLOWP_HANDLE_CASE(false, true, true)
    320   GEMMLOWP_HANDLE_CASE(true, false, false)
    321   GEMMLOWP_HANDLE_CASE(true, false, true)
    322   GEMMLOWP_HANDLE_CASE(true, true, false)
    323   GEMMLOWP_HANDLE_CASE(true, true, true)
    324 
    325 #undef GEMMLOWP_HANDLE_CASE
    326 }
    327 
    328 void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
    329                      int m, int n, int k, const std::uint8_t* a,
    330                      std::int32_t a_offset, std::int32_t lda,
    331                      const std::uint8_t* b, std::int32_t b_offset,
    332                      std::int32_t ldb, float* c, float c_offset,
    333                      std::int32_t ldc, BitDepthSetting bit_depth) {
    334   AutoGlobalLock<EightBitIntGemmLockId> lock;
    335   GemmContext* context = GetOrCreateGlobalContext();
    336 
    337 #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
    338   if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
    339                             ldb, ldc, bit_depth)) {
    340     MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
    341                   transpose_c, ldc, c);
    342     return;
    343   }
    344 #endif
    345 
    346   // TODO(maciekc): implement a float output stage, get rid of scratch memory.
    347   Scratch* scratch = GetOrCreateGlobalScratch();
    348   if (transpose_c) {
    349     scratch->AssureSize(m * ldc * sizeof(std::int32_t));
    350   } else {
    351     scratch->AssureSize(n * ldc * sizeof(std::int32_t));
    352   }
    353   std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
    354 
    355 #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc)                               \
    356   if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {         \
    357     EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
    358                                          b, b_offset, ldb, temp_c, ldc,      \
    359                                          bit_depth);                         \
    360   }
    361 
    362   GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
    363   GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
    364   GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
    365   GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
    366   GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
    367   GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
    368   GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
    369   GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
    370 
    371 #undef GEMMLOWP_HANDLE_INT32_CASE
    372 
    373   if (transpose_c) {
    374     // Row major.
    375     for (int i = 0; i < m; ++i) {
    376       float* dest_row = c + i * ldc;
    377       std::int32_t* src_row = temp_c + i * ldc;
    378       for (int j = 0; j < n; ++j) {
    379         dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
    380       }
    381     }
    382   } else {
    383     // Column major.
    384     for (int i = 0; i < n; ++i) {
    385       float* dest_column = c + i * ldc;
    386       std::int32_t* src_column = temp_c + i * ldc;
    387       for (int j = 0; j < m; ++j) {
    388         dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
    389       }
    390     }
    391   }
    392 }
    393 
    394 void SetMaxNumThreads(int n) {
    395   AutoGlobalLock<EightBitIntGemmLockId> lock;
    396   GemmContext* context = GetOrCreateGlobalContext();
    397   context->set_max_num_threads(n);
    398 }
    399 
    400 void FreePersistentResources() {
    401   AutoGlobalLock<EightBitIntGemmLockId> lock;
    402   DestroyGlobalContext();
    403   DestroyGlobalScratch();
    404 }
    405 
    406 }  // namespace eight_bit_int_gemm
    407 }  // namespace gemmlowp
    408