1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_ 16 #define GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_ 17 18 #include "multi_thread_common.h" 19 #include "single_thread_transform.h" 20 21 namespace gemmlowp { 22 namespace meta { 23 namespace internal { 24 25 const int kTransformTaskOverhead = 128000; 26 const int kMinTransformTaskSize = 32000; 27 28 template <typename MultiThreadingContext, typename Params> 29 inline bool PrepareTransform1DTasks(MultiThreadingContext* context, 30 const Params& params, int kernel_size, 31 std::vector<Params>* task_params) { 32 typedef Transform1DUtil<typename Params::InType, typename Params::OutType, 33 typename Params::Kernel> 34 Util; 35 36 const int max_threads = ResolveMaxThreads(context->max_num_threads()); 37 const int task_size = Util::EstimateComputeCost(params.kernel); 38 const int max_tasks_by_size = 39 (task_size - kTransformTaskOverhead) / kMinTransformTaskSize; 40 41 const int real_tasks = std::max(1, std::min(max_threads, max_tasks_by_size)); 42 43 if (real_tasks == 1) { 44 return false; 45 } 46 47 const int chunk = params.kernel.count / real_tasks; 48 for (int i = 0; i < real_tasks - 1; ++i) { 49 task_params->push_back(params); 50 Params& task = task_params->back(); 51 task.kernel.count = chunk; 52 task.input = Util::OffsetInput(params.kernel, params.input, i * chunk); 53 task.output = Util::OffsetOutput(params.kernel, params.output, i * chunk); 54 } 55 task_params->push_back(params); 56 Params& task = task_params->back(); 57 const int sum_chunk = (real_tasks - 1) * chunk; 58 task.kernel.count = params.kernel.count - sum_chunk; 59 task.input = Util::OffsetInput(params.kernel, params.input, sum_chunk); 60 task.output = Util::OffsetOutput(params.kernel, params.output, sum_chunk); 61 return true; 62 } 63 64 template <typename Params, int kernel_size> 65 struct Transform1DTaskRunner : gemmlowp::Task { 66 Transform1DTaskRunner(const Params& params) : params(params) {} 67 68 void Run() override { Transform1D<Params, kernel_size>(params); } 69 70 Params params; 71 }; 72 73 } // namespace internal 74 75 template <typename MultiThreadingContext, typename Params, int kernel_size> 76 inline void MultiThreadTransform1D(MultiThreadingContext* context, 77 const Params& params) { 78 typedef internal::Transform1DTaskRunner<Params, kernel_size> TaskRunnerType; 79 80 std::vector<Params> task_params; 81 if (!internal::PrepareTransform1DTasks<MultiThreadingContext, Params>( 82 context, params, kernel_size, &task_params)) { 83 Transform1D<Params, kernel_size>(params); 84 return; 85 } 86 87 auto workers_pool = context->workers_pool(); 88 std::vector<Task*> tasks; 89 std::for_each(task_params.begin(), task_params.end(), [tasks](Params* param) { 90 tasks.push_back(new TaskRunnerType(param)); 91 }); 92 workers_pool->Execute(tasks); 93 } 94 95 } // namespace meta 96 } // namespace gemmlowp 97 98 #endif // GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_ 99