1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // multi_thread_gemm.h: Entry point to the multithreaded version of the 16 // generated (meta) gemm library. 17 18 #ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_ 19 #define GEMMLOWP_META_MULTI_THREAD_GEMM_H_ 20 21 #ifdef GEMMLOWP_NEON_32 22 23 #include "multi_thread_common.h" 24 #include "single_thread_gemm.h" 25 26 namespace gemmlowp { 27 namespace meta { 28 namespace internal { 29 30 const std::int32_t kMaxCacheFriendlySize = 24 * 1024; 31 32 template <typename IN_TYPE, typename OUT_TYPE, typename F> 33 void CacheFriendlyMatrixMatrix(std::uint8_t* scratch, const IN_TYPE* lhs, 34 const IN_TYPE* rhs, std::int32_t m, 35 std::int32_t n, std::int32_t k, OUT_TYPE* result, 36 std::int32_t result_stride, const F& operation) { 37 const std::int32_t rhs_size = n * k * sizeof(IN_TYPE); 38 if (rhs_size > kMaxCacheFriendlySize) { 39 const std::int32_t optimal_n = 40 std::max(1, 3 * (kMaxCacheFriendlySize / (k * 3))); 41 const std::int32_t chunks_count_less_one = n / optimal_n - 1; 42 const std::int32_t chunk_size = optimal_n * k; 43 for (int i = 0; i < chunks_count_less_one; ++i) { 44 operation.ExecuteCacheFriendlyMatrixMatrix( 45 scratch, lhs, rhs + i * chunk_size, m, optimal_n, k, 46 result + i * optimal_n, result_stride); 47 } 48 const std::int32_t n_left = n - chunks_count_less_one * optimal_n; 49 operation.ExecuteCacheFriendlyMatrixMatrix( 50 scratch, lhs, rhs + chunks_count_less_one * chunk_size, m, n_left, k, 51 result + chunks_count_less_one * optimal_n, result_stride); 52 } else { 53 operation.ExecuteCacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, 54 result, result_stride); 55 } 56 } 57 58 class GemmQuantized8BitOperation { 59 public: 60 GemmQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, 61 std::int32_t sum_offset, std::int32_t multiplier, 62 std::int32_t shift) 63 : lhs_offset(lhs_offset), 64 rhs_offset(rhs_offset), 65 sum_offset(sum_offset), 66 multiplier(multiplier), 67 shift(shift) {} 68 69 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, 70 const std::uint8_t* rhs, std::int32_t m, 71 std::int32_t n, std::int32_t k, std::uint8_t* result, 72 std::int32_t result_stride) const { 73 CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride, 74 *this); 75 } 76 77 void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch, 78 const std::uint8_t* lhs, 79 const std::uint8_t* rhs, std::int32_t m, 80 std::int32_t n, std::int32_t k, 81 std::uint8_t* result, 82 std::int32_t result_stride) const { 83 gemm_q8_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, 84 sum_offset, multiplier, shift, result, result_stride); 85 } 86 87 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, 88 std::int32_t k) { 89 return 128 * 1024; 90 } 91 92 private: 93 std::int32_t lhs_offset; 94 std::int32_t rhs_offset; 95 std::int32_t sum_offset; 96 std::int32_t multiplier; 97 std::int32_t shift; 98 }; 99 100 class GemmFloatOperation { 101 public: 102 GemmFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, 103 float result_offset) 104 : lhs_offset(lhs_offset), 105 rhs_offset(rhs_offset), 106 result_offset(result_offset) {} 107 108 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, 109 const std::uint8_t* rhs, std::int32_t m, 110 std::int32_t n, std::int32_t k, float* result, 111 std::int32_t result_stride) const { 112 CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride, 113 *this); 114 } 115 116 void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch, 117 const std::uint8_t* lhs, 118 const std::uint8_t* rhs, std::int32_t m, 119 std::int32_t n, std::int32_t k, 120 float* result, 121 std::int32_t result_stride) const { 122 gemm_f_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, 123 result_offset, result, result_stride); 124 } 125 126 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, 127 std::int32_t k) { 128 return 128 * 1024; 129 } 130 131 private: 132 std::int32_t lhs_offset; 133 std::int32_t rhs_offset; 134 float result_offset; 135 }; 136 137 class GemmInt32Operation { 138 public: 139 GemmInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset) 140 : lhs_offset(lhs_offset), rhs_offset(rhs_offset) {} 141 142 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, 143 const std::uint8_t* rhs, std::int32_t m, 144 std::int32_t n, std::int32_t k, std::int32_t* result, 145 std::int32_t result_stride) const { 146 CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride, 147 *this); 148 } 149 150 void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch, 151 const std::uint8_t* lhs, 152 const std::uint8_t* rhs, std::int32_t m, 153 std::int32_t n, std::int32_t k, 154 std::int32_t* result, 155 std::int32_t result_stride) const { 156 gemm_i32_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, result, 157 result_stride); 158 } 159 160 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, 161 std::int32_t k) { 162 return 128 * 1024; 163 } 164 165 private: 166 std::int32_t lhs_offset; 167 std::int32_t rhs_offset; 168 }; 169 170 } // namespace internal 171 172 std::int32_t gemm_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k, 173 std::int32_t max_threads) { 174 return internal::ResolveMaxThreads(max_threads) * 175 internal::GemmQuantized8BitOperation::ScratchPerThread(m, n, k); 176 } 177 178 void multi_thread_gemm_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads, 179 std::uint8_t* scratch, const std::uint8_t* lhs, 180 const std::uint8_t* rhs, std::int32_t m, 181 std::int32_t n, std::int32_t k, 182 std::int32_t lhs_offset, std::int32_t rhs_offset, 183 std::int32_t sum_offset, std::int32_t multiplier, 184 std::int32_t shift, std::uint8_t* result) { 185 internal::GemmQuantized8BitOperation operation(lhs_offset, rhs_offset, 186 sum_offset, multiplier, shift); 187 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m, 188 n, k, result, n, operation); 189 } 190 191 std::int32_t gemm_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k, 192 std::int32_t max_threads) { 193 return internal::ResolveMaxThreads(max_threads) * 194 internal::GemmFloatOperation::ScratchPerThread(m, n, k); 195 } 196 197 void multi_thread_gemm_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads, 198 std::uint8_t* scratch, const std::uint8_t* lhs, 199 const std::uint8_t* rhs, std::int32_t m, 200 std::int32_t n, std::int32_t k, 201 std::int32_t lhs_offset, std::int32_t rhs_offset, 202 float result_offset, float* result) { 203 internal::GemmFloatOperation operation(lhs_offset, rhs_offset, result_offset); 204 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m, 205 n, k, result, n, operation); 206 } 207 208 std::int32_t gemm_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k, 209 std::int32_t max_threads) { 210 return internal::ResolveMaxThreads(max_threads) * 211 internal::GemmInt32Operation::ScratchPerThread(m, n, k); 212 } 213 214 void multi_thread_gemm_i32(gemmlowp::WorkersPool* pool, 215 std::int32_t max_threads, std::uint8_t* scratch, 216 const std::uint8_t* lhs, const std::uint8_t* rhs, 217 std::int32_t m, std::int32_t n, std::int32_t k, 218 std::int32_t lhs_offset, std::int32_t rhs_offset, 219 std::int32_t* result) { 220 internal::GemmInt32Operation operation(lhs_offset, rhs_offset); 221 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m, 222 n, k, result, n, operation); 223 } 224 225 } // namespace meta 226 } // namespace gemmlowp 227 228 #else 229 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!" 230 #endif 231 232 #endif // GEMMLOWP_META_MULTI_THREAD_GEMM_H_ 233