1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // multi_thread_gemv.h: Entry point to the multithreaded version of the 16 // generated (meta) gemv library. 17 18 #ifndef GEMMLOWP_META_MULTI_THREAD_GEMV_H_ 19 #define GEMMLOWP_META_MULTI_THREAD_GEMV_H_ 20 21 #ifdef GEMMLOWP_NEON_32 22 23 #include "multi_thread_common.h" 24 #include "operations_common.h" 25 #include "single_thread_gemm.h" 26 27 namespace gemmlowp { 28 namespace meta { 29 namespace internal { 30 31 class GemvQuantized8BitOperation : public Quantized8BitOperation { 32 public: 33 GemvQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, 34 std::int32_t sum_offset, std::int32_t multiplier, 35 std::int32_t shift) 36 : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier, 37 shift) {} 38 39 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, 40 const std::uint8_t* rhs, std::int32_t m, 41 std::int32_t n, std::int32_t k, std::uint8_t* result, 42 std::int32_t result_stride) const { 43 gemv_q8(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, sum_offset, 44 multiplier, shift, result); 45 } 46 47 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, 48 std::int32_t k) { 49 return 128 * 1024; 50 } 51 }; 52 53 class GemvFloatOperation : public FloatOperation { 54 public: 55 GemvFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, 56 float result_offset) 57 : FloatOperation(lhs_offset, rhs_offset, result_offset) {} 58 59 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, 60 const std::uint8_t* rhs, std::int32_t m, 61 std::int32_t n, std::int32_t k, float* result, 62 std::int32_t result_stride) const { 63 gemv_f(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result_offset, 64 result); 65 } 66 67 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, 68 std::int32_t k) { 69 return 128 * 1024; 70 } 71 }; 72 73 class GemvInt32Operation : public Int32Operation { 74 public: 75 GemvInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset) 76 : Int32Operation(lhs_offset, rhs_offset) {} 77 78 void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, 79 const std::uint8_t* rhs, std::int32_t m, 80 std::int32_t n, std::int32_t k, std::int32_t* result, 81 std::int32_t result_stride) const { 82 gemv_i32(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result); 83 } 84 85 static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, 86 std::int32_t k) { 87 return 128 * 1024; 88 } 89 }; 90 91 } // namespace internal 92 93 std::int32_t gemv_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k, 94 std::int32_t max_threads) { 95 return internal::ResolveMaxThreads(max_threads) * 96 internal::GemvQuantized8BitOperation::ScratchPerThread(m, n, k); 97 } 98 99 void multi_thread_gemv_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads, 100 std::uint8_t* scratch, const std::uint8_t* lhs, 101 const std::uint8_t* rhs, std::int32_t n, 102 std::int32_t k, std::int32_t lhs_offset, 103 std::int32_t rhs_offset, std::int32_t sum_offset, 104 std::int32_t multiplier, std::int32_t shift, 105 std::uint8_t* result) { 106 max_threads = internal::ResolveMaxThreads(max_threads); 107 internal::GemvQuantized8BitOperation operation(lhs_offset, rhs_offset, 108 sum_offset, multiplier, shift); 109 if (max_threads == 1) { 110 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n); 111 } else { 112 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1, 113 n, k, result, n, operation); 114 } 115 } 116 117 std::int32_t gemv_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k, 118 std::int32_t max_threads) { 119 return internal::ResolveMaxThreads(max_threads) * 120 internal::GemvFloatOperation::ScratchPerThread(m, n, k); 121 } 122 123 void multi_thread_gemv_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads, 124 std::uint8_t* scratch, const std::uint8_t* lhs, 125 const std::uint8_t* rhs, std::int32_t n, 126 std::int32_t k, std::int32_t lhs_offset, 127 std::int32_t rhs_offset, float result_offset, 128 float* result) { 129 max_threads = internal::ResolveMaxThreads(max_threads); 130 internal::GemvFloatOperation operation(lhs_offset, rhs_offset, result_offset); 131 if (max_threads == 1) { 132 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n); 133 } else { 134 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1, 135 n, k, result, n, operation); 136 } 137 } 138 139 std::int32_t gemv_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k, 140 std::int32_t max_threads) { 141 return internal::ResolveMaxThreads(max_threads) * 142 internal::GemvInt32Operation::ScratchPerThread(m, n, k); 143 } 144 145 void multi_thread_gemv_i32(gemmlowp::WorkersPool* pool, 146 std::int32_t max_threads, std::uint8_t* scratch, 147 const std::uint8_t* lhs, const std::uint8_t* rhs, 148 std::int32_t n, std::int32_t k, 149 std::int32_t lhs_offset, std::int32_t rhs_offset, 150 std::int32_t* result) { 151 max_threads = internal::ResolveMaxThreads(max_threads); 152 internal::GemvInt32Operation operation(lhs_offset, rhs_offset); 153 if (max_threads == 1) { 154 operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n); 155 } else { 156 internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1, 157 n, k, result, n, operation); 158 } 159 } 160 161 } // namespace meta 162 } // namespace gemmlowp 163 164 #else 165 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!" 166 #endif 167 168 #endif // GEMMLOWP_META_MULTI_THREAD_GEMV_H_ 169