Home | History | Annotate | Download | only in meta
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // multi_thread_common.h: Multithreading code shared by different meta gemm
     16 // versions.
     17 
     18 #ifndef GEMMLOWP_META_MULTI_THREAD_COMMON_H_
     19 #define GEMMLOWP_META_MULTI_THREAD_COMMON_H_
     20 
     21 #include "../internal/multi_thread_gemm.h"
     22 
     23 namespace gemmlowp {
     24 namespace meta {
     25 namespace internal {
     26 
     27 const std::int32_t kMinTaskSize = 16000;
     28 const std::int32_t kMinTaskDimension = 4;
     29 
     30 struct TaskRect {
     31   std::int32_t m_offset;
     32   std::int32_t m;
     33   std::int32_t n_offset;
     34   std::int32_t n;
     35 
     36   TaskRect(std::int32_t m_offset, std::int32_t m, std::int32_t n_offset,
     37            std::int32_t n)
     38       : m_offset(m_offset), m(m), n_offset(n_offset), n(n) {}
     39 };
     40 
     41 template <typename IN_TYPE, typename OUT_TYPE, typename F>
     42 struct MetaTask : gemmlowp::Task {
     43   std::uint8_t* scratch;
     44   const IN_TYPE* lhs;
     45   const IN_TYPE* rhs;
     46   TaskRect task_rect;
     47   std::int32_t k;
     48   OUT_TYPE* result;
     49   std::int32_t result_stride;
     50   const F& operation;
     51 
     52   MetaTask(std::uint8_t* scratch, const IN_TYPE* lhs, const IN_TYPE* rhs,
     53            const TaskRect& task_rect, std::int32_t k, OUT_TYPE* result,
     54            std::int32_t result_stride, const F& operation)
     55       : scratch(scratch),
     56         lhs(lhs),
     57         rhs(rhs),
     58         task_rect(task_rect),
     59         k(k),
     60         result(result),
     61         result_stride(result_stride),
     62         operation(operation) {}
     63 
     64   void Run() override {
     65     const IN_TYPE* task_lhs = lhs + task_rect.m_offset * k;
     66     const IN_TYPE* task_rhs = rhs + task_rect.n_offset * k;
     67     OUT_TYPE* task_result =
     68         result + task_rect.m_offset * result_stride + task_rect.n_offset;
     69     operation.ExecuteMatrixMatrix(scratch, task_lhs, task_rhs, task_rect.m,
     70                                   task_rect.n, k, task_result, result_stride);
     71   }
     72 };
     73 
     74 std::int32_t ResolveMaxThreads(std::int32_t max_threads) {
     75   if (max_threads == 0) {
     76     static const int hardware_threads_count =
     77         static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
     78     return hardware_threads_count;
     79   }
     80   return max_threads;
     81 }
     82 
     83 void PrepareTasks(std::int32_t max_tasks, std::int32_t m, std::int32_t n,
     84                   std::int32_t k, std::vector<internal::TaskRect>* tasks) {
     85   const std::int32_t max_tasks_by_size = (m * n * k) / kMinTaskSize;
     86   const std::int32_t max_tasks_m = m / kMinTaskDimension;
     87   const std::int32_t max_tasks_n = n / kMinTaskDimension;
     88   const std::int32_t max_tasks_dimension = std::max(max_tasks_m, max_tasks_n);
     89 
     90   std::int32_t real_tasks = std::max(
     91       1, std::min(max_tasks, std::min(max_tasks_by_size, max_tasks_dimension)));
     92 
     93   if (real_tasks == 1) {
     94     tasks->push_back(TaskRect(0, m, 0, n));
     95     return;
     96   }
     97 
     98   if (max_tasks_m > max_tasks_n) {
     99     const std::int32_t m_chunk = m / real_tasks;
    100     for (int i = 0; i < real_tasks - 1; ++i) {
    101       tasks->push_back(TaskRect(i * m_chunk, m_chunk, 0, n));
    102     }
    103     const std::int32_t last_m_offset = (real_tasks - 1) * m_chunk;
    104     tasks->push_back(TaskRect(last_m_offset, m - last_m_offset, 0, n));
    105   } else {
    106     const std::int32_t n_chunk = n / real_tasks;
    107     for (int i = 0; i < real_tasks - 1; ++i) {
    108       tasks->push_back(TaskRect(0, m, i * n_chunk, n_chunk));
    109     }
    110     const std::int32_t last_n_offset = (real_tasks - 1) * n_chunk;
    111     tasks->push_back(TaskRect(0, m, last_n_offset, n - last_n_offset));
    112   }
    113 }
    114 
    115 template <typename IN_TYPE, typename OUT_TYPE, typename F>
    116 void MultiThreadedMatrixMatrix(gemmlowp::WorkersPool* pool,
    117                                std::int32_t max_threads, std::uint8_t* scratch,
    118                                const IN_TYPE* lhs, const IN_TYPE* rhs,
    119                                std::int32_t m, std::int32_t n, std::int32_t k,
    120                                OUT_TYPE* result, std::int32_t result_stride,
    121                                const F& operation) {
    122   max_threads = internal::ResolveMaxThreads(max_threads);
    123 
    124   std::vector<internal::TaskRect> task_rects;
    125   internal::PrepareTasks(max_threads, m, n, k, &task_rects);
    126 
    127   if (task_rects.size() == 1) {
    128     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, m, n, k, result,
    129                                   result_stride);
    130     return;
    131   }
    132 
    133   std::uint8_t* task_scratch = scratch;
    134   std::int32_t scratch_per_thread = operation.ScratchPerThread(m, n, k);
    135   std::vector<Task*> tasks;
    136   std::for_each(
    137       task_rects.begin(), task_rects.end(),
    138       [&tasks, &task_scratch, lhs, rhs, k, result, result_stride, operation,
    139        scratch_per_thread](internal::TaskRect& rect) {
    140         tasks.push_back(new internal::MetaTask<IN_TYPE, OUT_TYPE, F>(
    141             task_scratch, lhs, rhs, rect, k, result, result_stride, operation));
    142         task_scratch += scratch_per_thread;
    143       });
    144   pool->Execute(tasks);
    145 }
    146 
    147 }  // namespace internal
    148 }  // namespace meta
    149 }  // namespace gemmlowp
    150 
    151 #endif  // GEMMLOWP_META_MULTI_THREAD_COMMON_H_
    152