Home | History | Annotate | Download | only in meta
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // multi_thread_gemm.h: Entry point to the multithreaded version of the
     16 // generated (meta) gemm library.
     17 
     18 #ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_
     19 #define GEMMLOWP_META_MULTI_THREAD_GEMM_H_
     20 
     21 #ifdef GEMMLOWP_NEON_32
     22 
     23 #include "multi_thread_common.h"
     24 #include "single_thread_gemm.h"
     25 
     26 namespace gemmlowp {
     27 namespace meta {
     28 namespace internal {
     29 
     30 const std::int32_t kMaxCacheFriendlySize = 24 * 1024;
     31 
     32 template <typename IN_TYPE, typename OUT_TYPE, typename F>
     33 void CacheFriendlyMatrixMatrix(std::uint8_t* scratch, const IN_TYPE* lhs,
     34                                const IN_TYPE* rhs, std::int32_t m,
     35                                std::int32_t n, std::int32_t k, OUT_TYPE* result,
     36                                std::int32_t result_stride, const F& operation) {
     37   const std::int32_t rhs_size = n * k * sizeof(IN_TYPE);
     38   if (rhs_size > kMaxCacheFriendlySize) {
     39     const std::int32_t optimal_n =
     40         std::max(1, 3 * (kMaxCacheFriendlySize / (k * 3)));
     41     const std::int32_t chunks_count_less_one = n / optimal_n - 1;
     42     const std::int32_t chunk_size = optimal_n * k;
     43     for (int i = 0; i < chunks_count_less_one; ++i) {
     44       operation.ExecuteCacheFriendlyMatrixMatrix(
     45           scratch, lhs, rhs + i * chunk_size, m, optimal_n, k,
     46           result + i * optimal_n, result_stride);
     47     }
     48     const std::int32_t n_left = n - chunks_count_less_one * optimal_n;
     49     operation.ExecuteCacheFriendlyMatrixMatrix(
     50         scratch, lhs, rhs + chunks_count_less_one * chunk_size, m, n_left, k,
     51         result + chunks_count_less_one * optimal_n, result_stride);
     52   } else {
     53     operation.ExecuteCacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k,
     54                                                result, result_stride);
     55   }
     56 }
     57 
     58 class GemmQuantized8BitOperation {
     59  public:
     60   GemmQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
     61                              std::int32_t sum_offset, std::int32_t multiplier,
     62                              std::int32_t shift)
     63       : lhs_offset(lhs_offset),
     64         rhs_offset(rhs_offset),
     65         sum_offset(sum_offset),
     66         multiplier(multiplier),
     67         shift(shift) {}
     68 
     69   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
     70                            const std::uint8_t* rhs, std::int32_t m,
     71                            std::int32_t n, std::int32_t k, std::uint8_t* result,
     72                            std::int32_t result_stride) const {
     73     CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
     74                               *this);
     75   }
     76 
     77   void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
     78                                         const std::uint8_t* lhs,
     79                                         const std::uint8_t* rhs, std::int32_t m,
     80                                         std::int32_t n, std::int32_t k,
     81                                         std::uint8_t* result,
     82                                         std::int32_t result_stride) const {
     83     gemm_q8_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
     84                     sum_offset, multiplier, shift, result, result_stride);
     85   }
     86 
     87   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
     88                                        std::int32_t k) {
     89     return 128 * 1024;
     90   }
     91 
     92  private:
     93   std::int32_t lhs_offset;
     94   std::int32_t rhs_offset;
     95   std::int32_t sum_offset;
     96   std::int32_t multiplier;
     97   std::int32_t shift;
     98 };
     99 
    100 class GemmFloatOperation {
    101  public:
    102   GemmFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
    103                      float result_offset)
    104       : lhs_offset(lhs_offset),
    105         rhs_offset(rhs_offset),
    106         result_offset(result_offset) {}
    107 
    108   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
    109                            const std::uint8_t* rhs, std::int32_t m,
    110                            std::int32_t n, std::int32_t k, float* result,
    111                            std::int32_t result_stride) const {
    112     CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
    113                               *this);
    114   }
    115 
    116   void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
    117                                         const std::uint8_t* lhs,
    118                                         const std::uint8_t* rhs, std::int32_t m,
    119                                         std::int32_t n, std::int32_t k,
    120                                         float* result,
    121                                         std::int32_t result_stride) const {
    122     gemm_f_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
    123                    result_offset, result, result_stride);
    124   }
    125 
    126   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
    127                                        std::int32_t k) {
    128     return 128 * 1024;
    129   }
    130 
    131  private:
    132   std::int32_t lhs_offset;
    133   std::int32_t rhs_offset;
    134   float result_offset;
    135 };
    136 
    137 class GemmInt32Operation {
    138  public:
    139   GemmInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
    140       : lhs_offset(lhs_offset), rhs_offset(rhs_offset) {}
    141 
    142   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
    143                            const std::uint8_t* rhs, std::int32_t m,
    144                            std::int32_t n, std::int32_t k, std::int32_t* result,
    145                            std::int32_t result_stride) const {
    146     CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
    147                               *this);
    148   }
    149 
    150   void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
    151                                         const std::uint8_t* lhs,
    152                                         const std::uint8_t* rhs, std::int32_t m,
    153                                         std::int32_t n, std::int32_t k,
    154                                         std::int32_t* result,
    155                                         std::int32_t result_stride) const {
    156     gemm_i32_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, result,
    157                      result_stride);
    158   }
    159 
    160   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
    161                                        std::int32_t k) {
    162     return 128 * 1024;
    163   }
    164 
    165  private:
    166   std::int32_t lhs_offset;
    167   std::int32_t rhs_offset;
    168 };
    169 
    170 }  // namespace internal
    171 
    172 std::int32_t gemm_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
    173                              std::int32_t max_threads) {
    174   return internal::ResolveMaxThreads(max_threads) *
    175          internal::GemmQuantized8BitOperation::ScratchPerThread(m, n, k);
    176 }
    177 
    178 void multi_thread_gemm_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
    179                           std::uint8_t* scratch, const std::uint8_t* lhs,
    180                           const std::uint8_t* rhs, std::int32_t m,
    181                           std::int32_t n, std::int32_t k,
    182                           std::int32_t lhs_offset, std::int32_t rhs_offset,
    183                           std::int32_t sum_offset, std::int32_t multiplier,
    184                           std::int32_t shift, std::uint8_t* result) {
    185   internal::GemmQuantized8BitOperation operation(lhs_offset, rhs_offset,
    186                                                  sum_offset, multiplier, shift);
    187   internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
    188                                       n, k, result, n, operation);
    189 }
    190 
    191 std::int32_t gemm_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
    192                             std::int32_t max_threads) {
    193   return internal::ResolveMaxThreads(max_threads) *
    194          internal::GemmFloatOperation::ScratchPerThread(m, n, k);
    195 }
    196 
    197 void multi_thread_gemm_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
    198                          std::uint8_t* scratch, const std::uint8_t* lhs,
    199                          const std::uint8_t* rhs, std::int32_t m,
    200                          std::int32_t n, std::int32_t k,
    201                          std::int32_t lhs_offset, std::int32_t rhs_offset,
    202                          float result_offset, float* result) {
    203   internal::GemmFloatOperation operation(lhs_offset, rhs_offset, result_offset);
    204   internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
    205                                       n, k, result, n, operation);
    206 }
    207 
    208 std::int32_t gemm_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
    209                               std::int32_t max_threads) {
    210   return internal::ResolveMaxThreads(max_threads) *
    211          internal::GemmInt32Operation::ScratchPerThread(m, n, k);
    212 }
    213 
    214 void multi_thread_gemm_i32(gemmlowp::WorkersPool* pool,
    215                            std::int32_t max_threads, std::uint8_t* scratch,
    216                            const std::uint8_t* lhs, const std::uint8_t* rhs,
    217                            std::int32_t m, std::int32_t n, std::int32_t k,
    218                            std::int32_t lhs_offset, std::int32_t rhs_offset,
    219                            std::int32_t* result) {
    220   internal::GemmInt32Operation operation(lhs_offset, rhs_offset);
    221   internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
    222                                       n, k, result, n, operation);
    223 }
    224 
    225 }  // namespace meta
    226 }  // namespace gemmlowp
    227 
    228 #else
    229 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!"
    230 #endif
    231 
    232 #endif  // GEMMLOWP_META_MULTI_THREAD_GEMM_H_
    233