Home | History | Annotate | Download | only in meta
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // multi_thread_common.h: Multithreading code shared by different meta gemm
     16 // versions.
     17 
     18 #ifndef GEMMLOWP_META_MULTI_THREAD_COMMON_H_
     19 #define GEMMLOWP_META_MULTI_THREAD_COMMON_H_
     20 
     21 #include "../internal/multi_thread_gemm.h"
     22 
     23 namespace gemmlowp {
     24 namespace meta {
     25 namespace internal {
     26 
     27 const std::int32_t kMinTaskSize = 10000;
     28 const std::int32_t kMinTaskDimension = 6;
     29 
     30 struct TaskRect {
     31   std::int32_t m_offset;
     32   std::int32_t m;
     33   std::int32_t n_offset;
     34   std::int32_t n;
     35 
     36   TaskRect(std::int32_t m_offset, std::int32_t m, std::int32_t n_offset,
     37            std::int32_t n)
     38       : m_offset(m_offset), m(m), n_offset(n_offset), n(n) {}
     39 };
     40 
     41 template <typename IN_TYPE, typename OUT_TYPE, typename F>
     42 struct MetaTask : gemmlowp::Task {
     43   std::uint8_t* scratch;
     44   const IN_TYPE* lhs;
     45   const IN_TYPE* rhs;
     46   TaskRect task_rect;
     47   std::int32_t k;
     48   OUT_TYPE* result;
     49   std::int32_t result_stride;
     50   const F& operation;
     51 
     52   MetaTask(std::uint8_t* scratch, const IN_TYPE* lhs, const IN_TYPE* rhs,
     53            const TaskRect& task_rect, std::int32_t k, OUT_TYPE* result,
     54            std::int32_t result_stride, const F& operation)
     55       : scratch(scratch),
     56         lhs(lhs),
     57         rhs(rhs),
     58         task_rect(task_rect),
     59         k(k),
     60         result(result),
     61         result_stride(result_stride),
     62         operation(operation) {}
     63 
     64   void Run() const override {
     65     const IN_TYPE* task_lhs = lhs + task_rect.m_offset * k;
     66     const IN_TYPE* task_rhs = rhs + task_rect.n_offset * k;
     67     OUT_TYPE* task_result =
     68         result + task_rect.m_offset * result_stride + task_rect.n_offset;
     69     operation.ExecuteMatrixMatrix(scratch, task_lhs, task_rhs, task_rect.m,
     70                                   task_rect.n, k, task_result, result_stride);
     71   }
     72 };
     73 
     74 std::int32_t ResolveMaxThreads(std::int32_t max_threads) {
     75   if (max_threads == 0) {
     76     static const int hardware_threads_count =
     77         static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
     78     return hardware_threads_count;
     79   }
     80   return max_threads;
     81 }
     82 
     83 void PrepareTasks(std::int32_t max_tasks, std::int32_t m, std::int32_t n,
     84                   std::int32_t k, std::vector<internal::TaskRect>* tasks) {
     85   const std::int32_t max_tasks_by_size = (m * n * k) / kMinTaskSize;
     86   const std::int32_t max_tasks_m = m / kMinTaskDimension;
     87   const std::int32_t max_tasks_n = n / kMinTaskDimension;
     88   const std::int32_t max_tasks_dimension = std::max(max_tasks_m, max_tasks_n);
     89 
     90   std::int32_t real_tasks = std::max(
     91       1, std::min(max_tasks, std::min(max_tasks_by_size, max_tasks_dimension)));
     92 
     93   if (real_tasks == 1) {
     94     tasks->push_back(TaskRect(0, m, 0, n));
     95     return;
     96   }
     97 
     98   if (max_tasks_m > max_tasks_n) {
     99     const std::int32_t m_chunk = m / real_tasks;
    100     for (int i = 0; i < real_tasks - 1; ++i) {
    101       tasks->push_back(TaskRect(i * m_chunk, m_chunk, 0, n));
    102     }
    103     const std::int32_t last_m_offset = (real_tasks - 1) * m_chunk;
    104     tasks->push_back(TaskRect(last_m_offset, m - last_m_offset, 0, n));
    105   } else {
    106     const std::int32_t n_chunk = n / real_tasks;
    107     for (int i = 0; i < real_tasks - 1; ++i) {
    108       tasks->push_back(TaskRect(0, m, i * n_chunk, n_chunk));
    109     }
    110     const std::int32_t last_n_offset = (real_tasks - 1) * n_chunk;
    111     tasks->push_back(TaskRect(0, m, last_n_offset, n - last_n_offset));
    112   }
    113 }
    114 
    115 template <typename IN_TYPE, typename OUT_TYPE, typename F>
    116 void MultiThreadedMatrixMatrix(gemmlowp::WorkersPool* pool,
    117                                std::int32_t max_threads, std::uint8_t* scratch,
    118                                const IN_TYPE* lhs, const IN_TYPE* rhs,
    119                                std::int32_t m, std::int32_t n, std::int32_t k,
    120                                OUT_TYPE* result, std::int32_t result_stride,
    121                                const F& operation) {
    122   max_threads = internal::ResolveMaxThreads(max_threads);
    123   if (max_threads > 1) {
    124     pool->CreateWorkers(max_threads - 1);
    125   }
    126 
    127   std::vector<internal::TaskRect> task_rects;
    128   internal::PrepareTasks(max_threads, m, n, k, &task_rects);
    129 
    130   if (task_rects.size() == 1) {
    131     operation.ExecuteMatrixMatrix(scratch, lhs, rhs, m, n, k, result,
    132                                   result_stride);
    133     return;
    134   }
    135 
    136   std::uint8_t* task_scratch = scratch;
    137   std::int32_t scratch_per_thread = operation.ScratchPerThread(m, n, k);
    138   std::int32_t worker_tasks = task_rects.size() - 1;
    139   pool->counter_to_decrement_when_ready().Reset(worker_tasks);
    140 
    141   for (std::int32_t i = 0; i < worker_tasks; ++i) {
    142     auto task = new internal::MetaTask<IN_TYPE, OUT_TYPE, F>(
    143         task_scratch, lhs, rhs, task_rects[i], k, result, result_stride,
    144         operation);
    145     pool->StartWorker(i, task);
    146     task_scratch += scratch_per_thread;
    147   }
    148 
    149   {
    150     internal::MetaTask<IN_TYPE, OUT_TYPE, F> master_task(
    151         task_scratch, lhs, rhs, task_rects.back(), k, result, result_stride,
    152         operation);
    153     master_task.Run();
    154   }
    155 
    156   pool->counter_to_decrement_when_ready().Wait();
    157 }
    158 
    159 }  // namespace internal
    160 }  // namespace meta
    161 }  // namespace gemmlowp
    162 
    163 #endif  // GEMMLOWP_META_MULTI_THREAD_COMMON_H_
    164