Home | History | Annotate | Download | only in meta
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // multi_thread_gemv.h: Entry point to the multithreaded version of the
     16 // generated (meta) gemv library.
     17 
     18 #ifndef GEMMLOWP_META_MULTI_THREAD_GEMV_H_
     19 #define GEMMLOWP_META_MULTI_THREAD_GEMV_H_
     20 
     21 #ifdef GEMMLOWP_NEON_32
     22 
     23 #include "multi_thread_common.h"
     24 #include "operations_common.h"
     25 #include "single_thread_gemm.h"
     26 
     27 namespace gemmlowp {
     28 namespace meta {
     29 namespace internal {
     30 
     31 class GemvQuantized8BitOperation : public Quantized8BitOperation {
     32  public:
     33   GemvQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
     34                              std::int32_t sum_offset, std::int32_t multiplier,
     35                              std::int32_t shift)
     36       : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier,
     37                                shift) {}
     38 
     39   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
     40                            const std::uint8_t* rhs, std::int32_t m,
     41                            std::int32_t n, std::int32_t k, std::uint8_t* result,
     42                            std::int32_t result_stride) const {
     43     gemv_q8(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, sum_offset,
     44             multiplier, shift, result);
     45   }
     46 
     47   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
     48                                        std::int32_t k) {
     49     return 128 * 1024;
     50   }
     51 };
     52 
     53 class GemvFloatOperation : public FloatOperation {
     54  public:
     55   GemvFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
     56                      float result_offset)
     57       : FloatOperation(lhs_offset, rhs_offset, result_offset) {}
     58 
     59   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
     60                            const std::uint8_t* rhs, std::int32_t m,
     61                            std::int32_t n, std::int32_t k, float* result,
     62                            std::int32_t result_stride) const {
     63     gemv_f(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result_offset,
     64            result);
     65   }
     66 
     67   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
     68                                        std::int32_t k) {
     69     return 128 * 1024;
     70   }
     71 };
     72 
     73 class GemvInt32Operation : public Int32Operation {
     74  public:
     75   GemvInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
     76       : Int32Operation(lhs_offset, rhs_offset) {}
     77 
     78   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
     79                            const std::uint8_t* rhs, std::int32_t m,
     80                            std::int32_t n, std::int32_t k, std::int32_t* result,
     81                            std::int32_t result_stride) const {
     82     gemv_i32(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result);
     83   }
     84 
     85   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
     86                                        std::int32_t k) {
     87     return 128 * 1024;
     88   }
     89 };
     90 
     91 }  // namespace internal
     92 
     93 std::int32_t gemv_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
     94                              std::int32_t max_threads) {
     95   return internal::ResolveMaxThreads(max_threads) *
     96          internal::GemvQuantized8BitOperation::ScratchPerThread(m, n, k);
     97 }
     98 
     99 void multi_thread_gemv_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
    100                           std::uint8_t* scratch, const std::uint8_t* lhs,
    101                           const std::uint8_t* rhs, std::int32_t n,
    102                           std::int32_t k, std::int32_t lhs_offset,
    103                           std::int32_t rhs_offset, std::int32_t sum_offset,
    104                           std::int32_t multiplier, std::int32_t shift,
    105                           std::uint8_t* result) {
    106   max_threads = internal::ResolveMaxThreads(max_threads);
    107   internal::GemvQuantized8BitOperation operation(lhs_offset, rhs_offset,
    108                                                  sum_offset, multiplier, shift);
    109   if (max_threads == 1) {
    110     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
    111   } else {
    112     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
    113                                         n, k, result, n, operation);
    114   }
    115 }
    116 
    117 std::int32_t gemv_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
    118                             std::int32_t max_threads) {
    119   return internal::ResolveMaxThreads(max_threads) *
    120          internal::GemvFloatOperation::ScratchPerThread(m, n, k);
    121 }
    122 
    123 void multi_thread_gemv_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
    124                          std::uint8_t* scratch, const std::uint8_t* lhs,
    125                          const std::uint8_t* rhs, std::int32_t n,
    126                          std::int32_t k, std::int32_t lhs_offset,
    127                          std::int32_t rhs_offset, float result_offset,
    128                          float* result) {
    129   max_threads = internal::ResolveMaxThreads(max_threads);
    130   internal::GemvFloatOperation operation(lhs_offset, rhs_offset, result_offset);
    131   if (max_threads == 1) {
    132     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
    133   } else {
    134     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
    135                                         n, k, result, n, operation);
    136   }
    137 }
    138 
    139 std::int32_t gemv_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
    140                               std::int32_t max_threads) {
    141   return internal::ResolveMaxThreads(max_threads) *
    142          internal::GemvInt32Operation::ScratchPerThread(m, n, k);
    143 }
    144 
    145 void multi_thread_gemv_i32(gemmlowp::WorkersPool* pool,
    146                            std::int32_t max_threads, std::uint8_t* scratch,
    147                            const std::uint8_t* lhs, const std::uint8_t* rhs,
    148                            std::int32_t n, std::int32_t k,
    149                            std::int32_t lhs_offset, std::int32_t rhs_offset,
    150                            std::int32_t* result) {
    151   max_threads = internal::ResolveMaxThreads(max_threads);
    152   internal::GemvInt32Operation operation(lhs_offset, rhs_offset);
    153   if (max_threads == 1) {
    154     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
    155   } else {
    156     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
    157                                         n, k, result, n, operation);
    158   }
    159 }
    160 
    161 }  // namespace meta
    162 }  // namespace gemmlowp
    163 
    164 #else
    165 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!"
    166 #endif
    167 
    168 #endif  // GEMMLOWP_META_MULTI_THREAD_GEMV_H_
    169