1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // multi_thread_gemm.h: Multi-threaded GEMM entry point. 16 // Readers note: To understand this file, it is useful to first 17 // read and understand the much simpler single_thread_gemm.h. 18 19 #ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_ 20 #define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_ 21 22 #include <vector> 23 24 #include "single_thread_gemm.h" 25 26 namespace gemmlowp { 27 28 // On X86 and ARM platforms we enable a busy-wait spinlock before waiting on a 29 // pthread conditional variable. In order to implement that correctly we need 30 // to put some explicit memory load/store barriers. 31 32 #if defined(GEMMLOWP_ALLOW_INLINE_ASM) && !defined(GEMMLOWP_NO_BUSYWAIT) && \ 33 (defined(GEMMLOWP_ARM) || defined(GEMMLOWP_X86)) 34 35 #define GEMMLOWP_USE_BUSYWAIT 36 37 const int kMaxBusyWaitNOPs = 32 * 1000 * 1000; 38 39 #define GEMMLOWP_NOP "nop\n" 40 41 #define GEMMLOWP_STRING_CONCAT_4(X) X X X X 42 #define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP) 43 #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4) 44 #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16) 45 46 inline int Do256NOPs() { 47 asm volatile(GEMMLOWP_NOP64); 48 return 64; 49 } 50 51 #undef GEMMLOWP_STRING_CONCAT_4 52 #undef GEMMLOWP_NOP256 53 #undef GEMMLOWP_NOP64 54 #undef GEMMLOWP_NOP16 55 #undef GEMMLOWP_NOP4 56 #undef GEMMLOWP_NOP 57 58 inline void WriteBarrier() { 59 #if defined(_MSC_VER) 60 MemoryBarrier(); 61 #elif defined(GEMMLOWP_ARM_32) 62 asm volatile("" ::: "memory"); 63 #elif defined(GEMMLOWP_ARM_64) 64 asm volatile("dmb ishst" ::: "memory"); 65 #elif defined(GEMMLOWP_X86) 66 asm volatile("sfence" ::: "memory"); 67 #else 68 #error "Unsupported architecture for WriteBarrier." 69 #endif 70 } 71 72 inline void ReadBarrier() { 73 #if defined(_MSC_VER) 74 MemoryBarrier(); 75 #elif defined(GEMMLOWP_ARM_32) 76 asm volatile("" ::: "memory"); 77 #elif defined(GEMMLOWP_ARM_64) 78 asm volatile("dmb ishld" ::: "memory"); 79 #elif defined(GEMMLOWP_X86) 80 asm volatile("lfence" ::: "memory"); 81 #else 82 #error "Unsupported architecture for ReadBarrier." 83 #endif 84 } 85 86 #endif 87 88 // Waits until *var != initial_value. 89 // 90 // Returns the new value of *var. The guarantee here is that 91 // the return value is different from initial_value, and that that 92 // new value has been taken by *var at some point during the 93 // execution of this function. There is no guarantee that this is 94 // still the value of *var when this function returns, since *var is 95 // not assumed to be guarded by any lock. 96 // 97 // First does some busy-waiting for a fixed number of no-op cycles, 98 // then falls back to passive waiting for the given condvar, guarded 99 // by the given mutex. 100 // 101 // The idea of doing some initial busy-waiting is to help get 102 // better and more consistent multithreading benefits for small GEMM sizes. 103 // Busy-waiting help ensuring that if we need to wake up soon after having 104 // started waiting, then we can wake up quickly (as opposed to, say, 105 // having to wait to be scheduled again by the OS). On the other hand, 106 // we must still eventually revert to passive waiting for longer waits 107 // (e.g. worker threads having finished a GEMM and waiting until the next GEMM) 108 // so as to avoid permanently spinning. 109 // 110 template <typename T> 111 T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond, 112 pthread_mutex_t* mutex) { 113 #ifdef GEMMLOWP_USE_BUSYWAIT 114 // If we are on a platform that supports it, spin for some time. 115 { 116 int nops = 0; 117 // First, trivial case where the variable already changed value. 118 T new_value = *var; 119 if (new_value != initial_value) { 120 ReadBarrier(); 121 return new_value; 122 } 123 // Then try busy-waiting. 124 while (nops < kMaxBusyWaitNOPs) { 125 nops += Do256NOPs(); 126 new_value = *var; 127 if (new_value != initial_value) { 128 ReadBarrier(); 129 return new_value; 130 } 131 } 132 } 133 #endif 134 135 // Finally, do real passive waiting. 136 pthread_mutex_lock(mutex); 137 T new_value = *var; 138 if (new_value == initial_value) { 139 pthread_cond_wait(cond, mutex); 140 new_value = *var; 141 assert(new_value != initial_value); 142 } 143 pthread_mutex_unlock(mutex); 144 return new_value; 145 } 146 147 // A BlockingCounter lets one thread to wait for N events to occur. 148 // This is how the master thread waits for all the worker threads 149 // to have finished working. 150 class BlockingCounter { 151 public: 152 BlockingCounter() : count_(0), initial_count_(0) { 153 pthread_cond_init(&cond_, nullptr); 154 pthread_mutex_init(&mutex_, nullptr); 155 } 156 157 ~BlockingCounter() { 158 pthread_cond_destroy(&cond_); 159 pthread_mutex_destroy(&mutex_); 160 } 161 162 // Sets/resets the counter; initial_count is the number of 163 // decrementing events that the Wait() call will be waiting for. 164 void Reset(std::size_t initial_count) { 165 pthread_mutex_lock(&mutex_); 166 assert(count_ == 0); 167 initial_count_ = initial_count; 168 count_ = initial_count_; 169 pthread_mutex_unlock(&mutex_); 170 } 171 172 // Decrements the counter; if the counter hits zero, signals 173 // the thread that was waiting for that, and returns true. 174 // Otherwise (if the decremented count is still nonzero), 175 // returns false. 176 bool DecrementCount() { 177 pthread_mutex_lock(&mutex_); 178 assert(count_ > 0); 179 count_--; 180 #ifdef GEMMLOWP_USE_BUSYWAIT 181 WriteBarrier(); 182 #endif 183 if (count_ == 0) { 184 pthread_cond_signal(&cond_); 185 } 186 bool retval = count_ == 0; 187 pthread_mutex_unlock(&mutex_); 188 return retval; 189 } 190 191 // Waits for the N other threads (N having been set by Reset()) 192 // to hit the BlockingCounter. 193 void Wait() { 194 ScopedProfilingLabel label("BlockingCounter::Wait"); 195 while (count_) { 196 #ifdef GEMMLOWP_USE_BUSYWAIT 197 ReadBarrier(); 198 #else 199 // This is likely unnecessary, but is kept to ensure regressions are not 200 // introduced. 201 #ifndef _WIN32 202 asm volatile("" ::: "memory"); 203 #endif 204 #endif 205 const std::size_t count_value = count_; 206 if (count_value) { 207 WaitForVariableChange(&count_, count_value, &cond_, &mutex_); 208 } 209 } 210 } 211 212 private: 213 pthread_cond_t cond_; 214 pthread_mutex_t mutex_; 215 std::size_t count_; 216 std::size_t initial_count_; 217 }; 218 219 // A workload for a worker. 220 struct Task { 221 Task() : local_allocator(nullptr) {} 222 virtual ~Task() {} 223 virtual void Run() = 0; 224 Allocator* local_allocator; 225 }; 226 227 // A worker thread. 228 class Worker { 229 public: 230 enum class State { 231 ThreadStartup, // The initial state before the thread main loop runs. 232 Ready, // Is not working, has not yet received new work to do. 233 HasWork, // Has work to do. 234 ExitAsSoonAsPossible // Should exit at earliest convenience. 235 }; 236 237 explicit Worker(BlockingCounter* counter_to_decrement_when_ready) 238 : task_(nullptr), 239 state_(State::ThreadStartup), 240 counter_to_decrement_when_ready_(counter_to_decrement_when_ready) { 241 pthread_cond_init(&state_cond_, nullptr); 242 pthread_mutex_init(&state_mutex_, nullptr); 243 pthread_create(&thread_, nullptr, ThreadFunc, this); 244 } 245 246 ~Worker() { 247 ChangeState(State::ExitAsSoonAsPossible); 248 pthread_join(thread_, nullptr); 249 pthread_cond_destroy(&state_cond_); 250 pthread_mutex_destroy(&state_mutex_); 251 } 252 253 // Changes State; may be called from either the worker thread 254 // or the master thread; however, not all state transitions are legal, 255 // which is guarded by assertions. 256 void ChangeState(State new_state) { 257 ScopedProfilingLabel label("Worker::ChangeState"); 258 pthread_mutex_lock(&state_mutex_); 259 assert(new_state != state_); 260 switch (state_) { 261 case State::ThreadStartup: 262 assert(new_state == State::Ready); 263 break; 264 case State::Ready: 265 assert(new_state == State::HasWork || 266 new_state == State::ExitAsSoonAsPossible); 267 break; 268 case State::HasWork: 269 assert(new_state == State::Ready || 270 new_state == State::ExitAsSoonAsPossible); 271 break; 272 default: 273 abort(); 274 } 275 state_ = new_state; 276 pthread_cond_signal(&state_cond_); 277 if (state_ == State::Ready) { 278 counter_to_decrement_when_ready_->DecrementCount(); 279 } 280 pthread_mutex_unlock(&state_mutex_); 281 } 282 283 // Thread entry point. 284 void ThreadFunc() { 285 ScopedProfilingLabel label("Worker::ThreadFunc"); 286 RegisterCurrentThreadForProfiling(); 287 288 ChangeState(State::Ready); 289 290 // Thread main loop 291 while (true) { 292 // Get a state to act on 293 // In the 'Ready' state, we have nothing to do but to wait until 294 // we switch to another state. 295 State state_to_act_upon = WaitForVariableChange( 296 &state_, State::Ready, &state_cond_, &state_mutex_); 297 298 // We now have a state to act on, so act. 299 switch (state_to_act_upon) { 300 case State::HasWork: 301 // Got work to do! So do it, and then revert to 'Ready' state. 302 assert(task_); 303 task_->Run(); 304 task_ = nullptr; 305 ChangeState(State::Ready); 306 break; 307 case State::ExitAsSoonAsPossible: 308 return; 309 default: 310 abort(); 311 } 312 } 313 } 314 315 static void* ThreadFunc(void* arg) { 316 static_cast<Worker*>(arg)->ThreadFunc(); 317 return nullptr; 318 } 319 320 // Called by the master thead to give this worker work to do. 321 // It is only legal to call this if the worker 322 void StartWork(Task* task) { 323 assert(!task_); 324 task->local_allocator = &local_allocator_; 325 task_ = task; 326 #ifdef GEMMLOWP_USE_BUSYWAIT 327 WriteBarrier(); 328 #endif 329 assert(state_ == State::Ready); 330 ChangeState(State::HasWork); 331 } 332 333 private: 334 // The underlying thread. 335 pthread_t thread_; 336 337 // The task to be worked on. 338 Task* task_; 339 340 // The condition variable and mutex guarding state changes. 341 pthread_cond_t state_cond_; 342 pthread_mutex_t state_mutex_; 343 344 // The state enum tells if we're currently working, waiting for work, etc. 345 State state_; 346 347 // Each thread had a local allocator so they can allocate temporary 348 // buffers without blocking each other. 349 Allocator local_allocator_; 350 351 // pointer to the master's thread BlockingCounter object, to notify the 352 // master thread of when this worker switches to the 'Ready' state. 353 BlockingCounter* const counter_to_decrement_when_ready_; 354 }; 355 356 // A very simple pool of workers, that only allows the very 357 // specific parallelization pattern that we use here: 358 // a fixed number of workers can be given work, and one then 359 // waits for all of them to finish. 360 // 361 // See MultiThreadGemmContextBase for how other WorkersPool implementations can 362 // be used. Note that in those implementations, StartWorker can be free to 363 // ignore the <index> value; that is, the caller of WorkersPool does not rely on 364 // <index> to order tasks with equal <index>. 365 class WorkersPool { 366 public: 367 WorkersPool() {} 368 369 ~WorkersPool() { 370 for (auto w : workers_) { 371 delete w; 372 } 373 } 374 375 void Execute(const std::vector<Task*>& tasks) { 376 assert(tasks.size() >= 1); 377 // One of the tasks will be run on the current thread. 378 std::size_t workers_count = tasks.size() - 1; 379 CreateWorkers(workers_count); 380 assert(workers_count <= workers_.size()); 381 counter_to_decrement_when_ready_.Reset(workers_count); 382 int n = 0; 383 std::for_each(tasks.begin(), --tasks.end(), 384 [this, &n](Task* task) { workers_[n++]->StartWork(task); }); 385 // Execute the remaining workload immediately on the current thread. 386 Task* task = tasks.back(); 387 task->local_allocator = &main_thread_task_allocator_; 388 task->Run(); 389 // Wait for the workers submitted above to finish. 390 counter_to_decrement_when_ready_.Wait(); 391 // Cleanup tasks (best to do this from the same thread that allocated 392 // the memory). 393 std::for_each(tasks.begin(), tasks.end(), [](Task* task) { delete task; }); 394 } 395 396 private: 397 // Ensures that the pool has at least the given count of workers. 398 // If any new worker has to be created, this function waits for it to 399 // be ready. 400 void CreateWorkers(std::size_t workers_count) { 401 if (workers_.size() >= workers_count) { 402 return; 403 } 404 counter_to_decrement_when_ready_.Reset(workers_count - workers_.size()); 405 while (workers_.size() < workers_count) { 406 workers_.push_back(new Worker(&counter_to_decrement_when_ready_)); 407 } 408 counter_to_decrement_when_ready_.Wait(); 409 } 410 411 // copy construction disallowed 412 WorkersPool(const WorkersPool&) = delete; 413 414 // The workers in this pool. They are owned by the pool: 415 // the pool creates workers and destroys them in its destructor. 416 std::vector<Worker*> workers_; 417 418 // The BlockingCounter used to wait for the workers. 419 BlockingCounter counter_to_decrement_when_ready_; 420 421 // For N-threaded operations, we will use only N-1 worker threads 422 // while the last task will be run directly on the main thread. 423 // It will then use this main_thread_task_allocator_; having a 424 // dedicated allocator for that (separate from the base allocator_) 425 // allows to use the same code for all tasks regardless of which 426 // thread they run on. 427 Allocator main_thread_task_allocator_; 428 }; 429 430 // The task we use to implement a multi-threaded Gemm: a block of the 431 // RHS has been packed by the master thread; each worker thread 432 // then has to pack a block of the LHS and accumulate the Gemm of these 433 // packed LHS and RHS blocks. 434 template <typename KernelFormat, typename InputScalar, typename OutputScalar, 435 typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder, 436 MapOrder ResultOrder, typename LhsOffset, typename RhsOffset, 437 typename OutputPipelineType, typename GemmContextType> 438 struct GemmWithPackedRhsTask : Task { 439 typedef PackedSideBlock<typename KernelFormat::Lhs> PackedLhs; 440 typedef PackedSideBlock<typename KernelFormat::Rhs> PackedRhs; 441 GemmWithPackedRhsTask(GemmContextType* _context, const KernelBase& _kernel, 442 const MatrixMap<const InputScalar, LhsOrder>& _lhs, 443 const PackedRhs& _packed_rhs, 444 MatrixMap<OutputScalar, ResultOrder>* _result, 445 const MatrixBlockBounds& _result_block, 446 const LhsOffset& _lhs_offset, 447 const RhsOffset& _rhs_offset, 448 const BlockParams& _block_params, 449 const OutputPipelineType& _output_pipeline) 450 : context(_context), 451 kernel(_kernel), 452 lhs(_lhs), 453 packed_rhs(_packed_rhs), 454 result(*_result), 455 result_block(_result_block), 456 lhs_offset(_lhs_offset), 457 rhs_offset(_rhs_offset), 458 block_params(_block_params), 459 output_pipeline(_output_pipeline) {} 460 461 void Run() override { 462 ScopedProfilingLabel label("GemmWithPackedRhsTask"); 463 464 const int rows = result_block.rows; 465 const int cols = result_block.cols; 466 const int depth = lhs.cols(); 467 468 PackedLhs packed_lhs(Side::Lhs, local_allocator, block_params); 469 470 PackedResult packed_result(local_allocator, block_params); 471 472 local_allocator->Commit(); 473 474 for (int c = 0; c < cols; c += block_params.l2_cols) { 475 int cs = std::min(block_params.l2_cols, cols - c); 476 477 for (int r = 0; r < rows; r += block_params.l2_rows) { 478 int rs = std::min(block_params.l2_rows, rows - r); 479 480 PackLhs(&packed_lhs, lhs.block(r, 0, rs, depth)); 481 482 Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs, 483 depth); 484 485 auto curr_result_block = MatrixBlockBounds( 486 result_block.start_row + r, result_block.start_col + c, rs, cs); 487 UnpackResult<KernelFormat>( 488 &result, curr_result_block, packed_result, depth, 489 packed_lhs.sums_of_each_slice(), packed_rhs.sums_of_each_slice(), 490 lhs_offset.block(curr_result_block.start_row, rs), 491 rhs_offset.block(curr_result_block.start_col, cs), output_pipeline); 492 } 493 } 494 495 local_allocator->Decommit(); 496 } 497 498 const GemmContextType* context; 499 const KernelBase& kernel; 500 const MatrixMap<const InputScalar, LhsOrder> lhs; 501 const PackedRhs packed_rhs; 502 MatrixMap<OutputScalar, ResultOrder> result; 503 const MatrixBlockBounds result_block; 504 const LhsOffset& lhs_offset; 505 const RhsOffset& rhs_offset; 506 const BlockParams& block_params; 507 const OutputPipelineType& output_pipeline; 508 }; 509 510 // This base class for multi-threading allows subclasses to implement their own 511 // workers_pool() method. See MultiThreadGemmContext below for an example; 512 // any other implementation of workers_pool() must return an object with the 513 // same public methods as WorkersPool. 514 class MultiThreadGemmContextBase : public SingleThreadGemmContext { 515 public: 516 void set_max_num_threads(int n) { max_num_threads_ = n; } 517 518 int max_num_threads() const { return max_num_threads_; } 519 520 protected: 521 // The maximum number of worker threads to use (including 522 // the master thread). 523 // The default value 1 means single-threading. That is the default 524 // because gemmlowp's primary target is mobile hardware, where thermal 525 // constraints usually mean that it may not be realistic to use more 526 // than 1 CPU core even if multiple cores are present. 527 // The special value 0 means try to detect the number of hardware threads. 528 // Note: this assumes that all CPU cores are equivalent. That assumption 529 // is defeated on big.LITTLE ARM devices, where we have no API to query 530 // the number of big cores (which is typically what we would want to use, 531 // leaving aside above-mentioned thermal issues). That is the other reason 532 // why the best compromise here is to let max_num_threads_ default to 1, 533 // so users who want multi-threading have to make the decision of how many 534 // threads to use by themselves. 535 int max_num_threads_ = 1; 536 }; 537 538 class MultiThreadGemmContext : public MultiThreadGemmContextBase { 539 public: 540 WorkersPool* workers_pool() { return &workers_pool_; } 541 542 private: 543 // The workers pool used by MultiThreadGemm. Making 544 // this part of the context allows it to be persistent, 545 // avoiding recreating threads on every Gemm. 546 WorkersPool workers_pool_; 547 }; 548 549 // Determines how many threads should be used for a given Gemm 550 // operation. 551 template <int KernelRows> 552 inline int HowManyThreads(int max_num_threads, int rows, int cols, int depth) { 553 // Early-exit in the default case where multi-threading is disabled. 554 if (max_num_threads == 1) { 555 return 1; 556 } 557 558 // Determine the maximum number of threads. 559 int max_count = GetHardwareConcurrency(max_num_threads); 560 561 // Basic calculation: take into account max pool size, and 562 // how many rows we have to feed our kernel. 563 // The motivation for an absolute minimum number of rows per thread, 564 // potentially higher than KernelRows, is that very thin thread workload 565 // currently defeat assumptions of the AddMod generator, resulting 566 // in substantial bias in TestWithRealData on 24 threads. 567 // Ideally, the AddMod generator should be aware of global (r,c) coordinates 568 // so as to be independent of the number of threads. 569 static const int AbsoluteMinRowsPerThread = 16; 570 static const int MinRowsPerThread = KernelRows > AbsoluteMinRowsPerThread 571 ? KernelRows 572 : AbsoluteMinRowsPerThread; 573 int thread_count = std::min(max_count, CeilQuotient(rows, MinRowsPerThread)); 574 575 // At this point for small products we already have thread_count==1 so 576 // we can avoid doing more work; otherwise, we still want to check 577 // that the cubic size (rows*cols*depth) is big enough to keep 578 // workers_ busy. 579 if (thread_count > 1) { 580 // Empirically determined value. 581 static const std::uint64_t min_cubic_size_per_thread = 64 * 1024; 582 583 // We can only multiply two out of three sizes without risking overflow 584 const std::uint64_t cubic_size = 585 std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth); 586 587 thread_count = 588 std::min(thread_count, int(cubic_size / min_cubic_size_per_thread)); 589 590 if (thread_count < 1) { 591 thread_count = 1; 592 } 593 } 594 595 assert(thread_count > 0 && thread_count <= max_count); 596 return thread_count; 597 } 598 599 // The main multi-threaded Gemm function. 600 // To understand it, first read the code of SingleThreadGemm(). 601 // The parallelization scheme used here is to have this master function 602 // pack a block of RHS and then start worker threads to pack a block of LHS 603 // each, and accumulate the corresponding products. 604 template <typename KernelFormat, typename InputScalar, typename OutputScalar, 605 typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder, 606 MapOrder ResultOrder, typename LhsOffset, typename RhsOffset, 607 typename OutputPipelineType, typename GemmContextType> 608 void MultiThreadGemm(GemmContextType* context, const KernelBase& kernel, 609 const MatrixMap<const InputScalar, LhsOrder>& lhs, 610 const MatrixMap<const InputScalar, RhsOrder>& rhs, 611 MatrixMap<OutputScalar, ResultOrder>* result, 612 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, 613 const OutputPipelineType& output_pipeline) { 614 ScopedProfilingLabel label("gemmlowp::MultiThreadGemm"); 615 616 assert(lhs.cols() == rhs.rows()); 617 618 int rows = result->rows(); 619 int cols = result->cols(); 620 int depth = lhs.cols(); 621 622 // zero sizes should have been caught earlier and early-returned. 623 assert(rows > 0); 624 assert(cols > 0); 625 assert(depth > 0); 626 627 // The case of rows<cols should have been caught earlier and transposed. 628 assert(rows >= cols); 629 630 const int thread_count = HowManyThreads<KernelFormat::kRows>( 631 context->max_num_threads(), rows, cols, depth); 632 if (thread_count == 1) { 633 return SingleThreadGemm<KernelFormat, InputScalar, OutputScalar, 634 BitDepthParams>(context, kernel, lhs, rhs, result, 635 lhs_offset, rhs_offset, 636 output_pipeline); 637 } 638 assert(thread_count > 1); 639 640 // Simple 1:1 mapping of tasks to physical cores, which is very important 641 // to getting good multithreaded performance, specially for not-very-large 642 // GEMMs, and especially on Android. 643 const int task_count = thread_count; 644 645 Allocator* allocator = context->allocator(); 646 auto* workers_pool = context->workers_pool(); 647 648 BlockParams block_params; 649 block_params.Init<KernelFormat>( 650 rows, cols, depth, task_count, context->l1_bytes_to_use(), 651 context->l2_bytes_to_use(), context->l2_rhs_factor()); 652 653 PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(Side::Rhs, allocator, 654 block_params); 655 allocator->Commit(); 656 657 // We loop over large blocks of the RHS. 658 for (int c = 0; c < cols; c += block_params.l2_cols) { 659 int cs = std::min(block_params.l2_cols, cols - c); 660 661 // Pack a large block of the RHS. 662 PackRhs(&packed_rhs, rhs.block(0, c, depth, cs)); 663 664 // Give work to each worker. 665 std::vector<Task*> tasks; 666 int next_start_row = 0; 667 for (int n = 0; n < task_count; ++n) { 668 int start_row = next_start_row; 669 next_start_row = std::min( 670 rows, RoundUp<KernelFormat::kRows>(rows * (n + 1) / task_count)); 671 672 int block_rows = next_start_row - start_row; 673 auto lhs_block = lhs.block(start_row, 0, block_rows, depth); 674 typedef GemmWithPackedRhsTask<KernelFormat, InputScalar, OutputScalar, 675 BitDepthParams, LhsOrder, RhsOrder, 676 ResultOrder, LhsOffset, RhsOffset, 677 OutputPipelineType, GemmContextType> 678 TaskType; 679 tasks.push_back( 680 new TaskType(context, kernel, lhs_block, packed_rhs, result, 681 MatrixBlockBounds(start_row, c, block_rows, cs), 682 lhs_offset, rhs_offset, block_params, output_pipeline)); 683 } 684 // Execute the work on the workers (and partially on this thread). 685 workers_pool->Execute(tasks); 686 } 687 688 allocator->Decommit(); 689 } 690 691 } // namespace gemmlowp 692 693 #endif // GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_ 694