1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ 16 #define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ 17 18 #include <iostream> 19 #include "base.h" 20 21 namespace gemmlowp { 22 namespace meta { 23 24 template <typename Executor, typename Params, int kernel_m, int kernel_n, 25 int kernel_k> 26 void Gemm(const Params& params); 27 28 class GemmExecutorPackRHS { 29 public: 30 template <typename P> 31 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, 32 int kernel_k) { 33 const int lhs_scratch = 34 StreamUtil<typename P::InType, typename P::LeftStream>::Scratch( 35 params.left_stream, kernel_m, kernel_k); 36 const int rhs_chunks = ((params.n + kernel_n - 1) / kernel_n); 37 const int rhs_scratch = 38 rhs_chunks * 39 StreamUtil<typename P::InType, typename P::RightStream>::Scratch( 40 params.right_stream, kernel_n, kernel_k); 41 return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch); 42 } 43 44 template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, 45 int k_leftovers> 46 static void ExecuteDispatch3D(const P& params) { 47 // Shorthand typedefs for streams and multiply kernels. 48 typedef typename P::InType InType; 49 typedef typename P::OutType OutType; 50 51 typedef Stream<typename P::InType, m, k, k_leftovers, 52 typename P::LeftStream> 53 LeftStreamF; 54 typedef Stream<typename P::InType, m_leftovers, k, k_leftovers, 55 typename P::LeftStream> 56 LeftStreamL; 57 58 typedef Stream<typename P::InType, n, k, k_leftovers, 59 typename P::RightStream> 60 RightStreamF; 61 typedef Stream<typename P::InType, n_leftovers, k, k_leftovers, 62 typename P::RightStream> 63 RightStreamL; 64 65 typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream> 66 OutputStreamFF; 67 typedef Stream<typename P::OutType, m_leftovers, n, 0, 68 typename P::OutputStream> 69 OutputStreamLF; 70 71 typedef MulKernel<typename P::InType, typename P::OutType, 72 typename P::Kernel, typename P::OutputStream, m, n, k> 73 KernelFF; 74 typedef MulKernel<typename P::InType, typename P::OutType, 75 typename P::Kernel, typename P::OutputStream, m, 76 n_leftovers, k> 77 KernelFL; 78 typedef MulKernel<typename P::InType, typename P::OutType, 79 typename P::Kernel, typename P::OutputStream, m_leftovers, 80 n, k> 81 KernelLF; 82 typedef MulKernel<typename P::InType, typename P::OutType, 83 typename P::Kernel, typename P::OutputStream, m_leftovers, 84 n_leftovers, k> 85 KernelLL; 86 87 #ifdef DEBUG 88 #ifdef DEBUG_METAGEMM_VERBOSE 89 std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n 90 << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x" 91 << k_leftovers << " -- " << params.m << "x" << params.n << "x" 92 << params.k << std::endl; 93 LeftStreamF::Debug(params.left_stream); 94 LeftStreamL::Debug(params.left_stream); 95 96 RightStreamF::Debug(params.right_stream); 97 RightStreamL::Debug(params.right_stream); 98 99 OutputStreamFF::Debug(params.fused_kernel.output_stream); 100 OutputStreamLF::Debug(params.fused_kernel.output_stream); 101 102 KernelFF::Debug(params.fused_kernel); 103 KernelFL::Debug(params.fused_kernel); 104 KernelLF::Debug(params.fused_kernel); 105 KernelLL::Debug(params.fused_kernel); 106 #endif 107 #endif 108 109 int lhs_chunks = params.m / m; 110 int rhs_chunks = params.n / n; 111 112 // Scratch memory for packed LHS & RHS chunks. 113 114 std::uint8_t* packed_lhs = params.scratch; 115 std::uint8_t* packed_rhs = 116 params.scratch + LeftStreamF::Scratch(params.left_stream); 117 118 // Pack full RHS first. 119 120 std::uint8_t* packed_rhs_chunk = packed_rhs; 121 const int packed_rhs_chunk_size = 122 RightStreamF::PackedStride(params.right_stream); 123 124 { 125 const std::uint8_t* rhs_chunk = 126 reinterpret_cast<const std::uint8_t*>(params.rhs); 127 const int rhs_chunk_size = 128 RightStreamF::UnpackedStride(params.right_stream); 129 130 for (int i = 0; i < rhs_chunks; ++i) { 131 RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk), 132 params.right_stream, 133 reinterpret_cast<InType*>(packed_rhs_chunk)); 134 135 rhs_chunk += rhs_chunk_size; 136 packed_rhs_chunk += packed_rhs_chunk_size; 137 } 138 139 RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk), 140 params.right_stream, 141 reinterpret_cast<InType*>(packed_rhs_chunk)); 142 } 143 144 // Multiply RHS by LHS one LHS chunk at a time. 145 146 const std::uint8_t* lhs_chunk = 147 reinterpret_cast<const std::uint8_t*>(params.lhs); 148 std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result); 149 std::uint8_t* result_chunk = result_strip; 150 151 { 152 const int lhs_chunk_size = 153 LeftStreamF::UnpackedStride(params.left_stream); 154 const int result_strip_size = 155 OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); 156 const int result_chunk_size = 157 OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); 158 159 for (int i = 0; i < lhs_chunks; ++i) { 160 LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk), 161 params.left_stream, 162 reinterpret_cast<InType*>(packed_lhs)); 163 164 result_chunk = result_strip; 165 packed_rhs_chunk = packed_rhs; 166 167 for (int j = 0; j < rhs_chunks; ++j) { 168 KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs), 169 reinterpret_cast<const InType*>(packed_rhs_chunk), 170 params.fused_kernel, 171 reinterpret_cast<OutType*>(result_chunk)); 172 173 result_chunk += result_chunk_size; 174 packed_rhs_chunk += packed_rhs_chunk_size; 175 } 176 177 KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs), 178 reinterpret_cast<const InType*>(packed_rhs_chunk), 179 params.fused_kernel, 180 reinterpret_cast<OutType*>(result_chunk)); 181 182 lhs_chunk += lhs_chunk_size; 183 result_strip += result_strip_size; 184 } 185 } 186 187 // Leftover LHS chunk. 188 if (m_leftovers > 0) { // static if 189 const int result_chunk_size = 190 OutputStreamLF::UnpackedAdvance(params.fused_kernel.output_stream); 191 192 LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk), 193 params.left_stream, 194 reinterpret_cast<InType*>(packed_lhs)); 195 196 result_chunk = result_strip; 197 packed_rhs_chunk = packed_rhs; 198 199 for (int i = 0; i < rhs_chunks; ++i) { 200 KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs), 201 reinterpret_cast<const InType*>(packed_rhs_chunk), 202 params.fused_kernel, 203 reinterpret_cast<OutType*>(result_chunk)); 204 205 result_chunk += result_chunk_size; 206 packed_rhs_chunk += packed_rhs_chunk_size; 207 } 208 209 KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs), 210 reinterpret_cast<const InType*>(packed_rhs_chunk), 211 params.fused_kernel, 212 reinterpret_cast<OutType*>(result_chunk)); 213 } 214 } 215 }; 216 217 class GemmExecutorPackLHS { 218 public: 219 template <typename P> 220 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, 221 int kernel_k) { 222 const int lhs_chunks = ((params.m + kernel_m - 1) / kernel_m); 223 const int lhs_scratch = 224 lhs_chunks * 225 StreamUtil<typename P::InType, typename P::LeftStream>::Scratch( 226 params.left_stream, kernel_m, kernel_k); 227 const int rhs_scratch = 228 StreamUtil<typename P::InType, typename P::RightStream>::Scratch( 229 params.right_stream, kernel_n, kernel_k); 230 return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch); 231 } 232 233 template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, 234 int k_leftovers> 235 static void ExecuteDispatch3D(const P& params) { 236 // Shorthand typedefs for streams and multiply kernels. 237 typedef typename P::InType InType; 238 typedef typename P::OutType OutType; 239 240 typedef Stream<typename P::InType, m, k, k_leftovers, 241 typename P::LeftStream> 242 LeftStreamF; 243 typedef Stream<typename P::InType, m_leftovers, k, k_leftovers, 244 typename P::LeftStream> 245 LeftStreamL; 246 247 typedef Stream<typename P::InType, n, k, k_leftovers, 248 typename P::RightStream> 249 RightStreamF; 250 typedef Stream<typename P::InType, n_leftovers, k, k_leftovers, 251 typename P::RightStream> 252 RightStreamL; 253 254 typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream> 255 OutputStreamFF; 256 typedef Stream<typename P::OutType, m, n_leftovers, 0, 257 typename P::OutputStream> 258 OutputStreamFL; 259 260 typedef MulKernel<typename P::InType, typename P::OutType, 261 typename P::Kernel, typename P::OutputStream, m, n, k> 262 KernelFF; 263 typedef MulKernel<typename P::InType, typename P::OutType, 264 typename P::Kernel, typename P::OutputStream, m, 265 n_leftovers, k> 266 KernelFL; 267 typedef MulKernel<typename P::InType, typename P::OutType, 268 typename P::Kernel, typename P::OutputStream, m_leftovers, 269 n, k> 270 KernelLF; 271 typedef MulKernel<typename P::InType, typename P::OutType, 272 typename P::Kernel, typename P::OutputStream, m_leftovers, 273 n_leftovers, k> 274 KernelLL; 275 #ifdef DEBUG 276 #ifdef DEBUG_METAGEMM_VERBOSE 277 std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n 278 << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x" 279 << k_leftovers << " -- " << params.m << "x" << params.n << "x" 280 << params.k << std::endl; 281 LeftStreamF::Debug(params.left_stream); 282 LeftStreamL::Debug(params.left_stream); 283 284 RightStreamF::Debug(params.right_stream); 285 RightStreamL::Debug(params.right_stream); 286 287 OutputStreamFF::Debug(params.fused_kernel.output_stream); 288 OutputStreamFL::Debug(params.fused_kernel.output_stream); 289 290 KernelFF::Debug(params.fused_kernel); 291 KernelFL::Debug(params.fused_kernel); 292 KernelLF::Debug(params.fused_kernel); 293 KernelLL::Debug(params.fused_kernel); 294 #endif 295 #endif 296 297 int lhs_chunks = params.m / m; 298 int rhs_chunks = params.n / n; 299 300 // Scratch memory for packed LHS & RHS chunks. 301 std::uint8_t* packed_rhs = params.scratch; 302 std::uint8_t* packed_lhs = 303 params.scratch + RightStreamF::Scratch(params.right_stream); 304 305 // Pack full LHS first. 306 307 std::uint8_t* packed_lhs_chunk = packed_lhs; 308 const int packed_lhs_chunk_size = 309 LeftStreamF::PackedStride(params.left_stream); 310 311 { 312 const std::uint8_t* lhs_chunk = 313 reinterpret_cast<const std::uint8_t*>(params.lhs); 314 const int lhs_chunk_size = 315 LeftStreamF::UnpackedStride(params.left_stream); 316 317 for (int i = 0; i < lhs_chunks; ++i) { 318 LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk), 319 params.left_stream, 320 reinterpret_cast<InType*>(packed_lhs_chunk)); 321 322 lhs_chunk += lhs_chunk_size; 323 packed_lhs_chunk += packed_lhs_chunk_size; 324 } 325 326 LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk), 327 params.left_stream, 328 reinterpret_cast<InType*>(packed_lhs_chunk)); 329 } 330 331 // Multiply RHS by LHS one RHS chunk at a time. 332 333 const std::uint8_t* rhs_chunk = 334 reinterpret_cast<const std::uint8_t*>(params.rhs); 335 std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result); 336 std::uint8_t* result_chunk = result_strip; 337 338 { 339 const int rhs_chunk_size = 340 RightStreamF::UnpackedStride(params.right_stream); 341 const int result_strip_size = 342 OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); 343 const int result_chunk_size = 344 OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); 345 346 for (int i = 0; i < rhs_chunks; ++i) { 347 RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk), 348 params.right_stream, 349 reinterpret_cast<InType*>(packed_rhs)); 350 351 result_chunk = result_strip; 352 packed_lhs_chunk = packed_lhs; 353 354 for (int j = 0; j < lhs_chunks; ++j) { 355 KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), 356 reinterpret_cast<const InType*>(packed_rhs), 357 params.fused_kernel, 358 reinterpret_cast<OutType*>(result_chunk)); 359 360 result_chunk += result_chunk_size; 361 packed_lhs_chunk += packed_lhs_chunk_size; 362 } 363 364 KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), 365 reinterpret_cast<const InType*>(packed_rhs), 366 params.fused_kernel, 367 reinterpret_cast<OutType*>(result_chunk)); 368 369 rhs_chunk += rhs_chunk_size; 370 result_strip += result_strip_size; 371 } 372 } 373 374 // Leftover RHS chunk. 375 if (n_leftovers > 0) { // static if 376 const int result_chunk_size = 377 OutputStreamFL::UnpackedStride(params.fused_kernel.output_stream); 378 379 RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk), 380 params.right_stream, 381 reinterpret_cast<InType*>(packed_rhs)); 382 383 result_chunk = result_strip; 384 packed_lhs_chunk = packed_lhs; 385 386 for (int i = 0; i < lhs_chunks; ++i) { 387 KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), 388 reinterpret_cast<const InType*>(packed_rhs), 389 params.fused_kernel, 390 reinterpret_cast<OutType*>(result_chunk)); 391 392 result_chunk += result_chunk_size; 393 packed_lhs_chunk += packed_lhs_chunk_size; 394 } 395 396 KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk), 397 reinterpret_cast<const InType*>(packed_rhs), 398 params.fused_kernel, 399 reinterpret_cast<OutType*>(result_chunk)); 400 } 401 } 402 }; 403 404 namespace internal { 405 406 inline int CalculateCacheFriendlyTasksCount(int cache_size, int constant_memory, 407 int per_chunk_memory, int total_dim, 408 int chunk_dim) { 409 assert(constant_memory + per_chunk_memory < cache_size); 410 const int available_cache = cache_size - constant_memory; 411 const int available_chunks = available_cache / per_chunk_memory; 412 const int chunks_count = (total_dim + chunk_dim - 1) / chunk_dim; 413 return (chunks_count + available_chunks - 1) / available_chunks; 414 } 415 416 template <typename Params> 417 inline void UpdateCacheFriendlyTask(int m_offset, int m, int n_offset, int n, 418 const Params& params, Params* task_params) { 419 task_params->m = m; 420 task_params->lhs = 421 StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset( 422 params.left_stream, params.lhs, m_offset, 0); 423 424 task_params->n = n; 425 task_params->rhs = 426 StreamUtil<typename Params::InType, typename Params::RightStream>::Offset( 427 params.right_stream, params.rhs, n_offset, 0); 428 429 task_params->result = 430 StreamUtil<typename Params::OutType, typename Params::OutputStream>:: 431 Offset(params.fused_kernel.output_stream, params.result, m_offset, 432 n_offset); 433 } 434 435 } // namespace internal 436 437 template <int cache_size = 256 * 1024> 438 class GemmExecutorPackRHSCacheFriendly { 439 public: 440 template <typename P> 441 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, 442 int kernel_k) { 443 return cache_size; 444 } 445 446 template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, 447 int k_leftovers> 448 static void ExecuteDispatch3D(const P& params) { 449 typedef Stream<typename P::InType, m, k, k_leftovers, 450 typename P::LeftStream> 451 LeftStream; 452 453 typedef Stream<typename P::InType, n, k, k_leftovers, 454 typename P::RightStream> 455 RightStream; 456 457 const int lhs_scratch = LeftStream::Scratch(params.left_stream); 458 const int rhs_scratch = RightStream::Scratch(params.right_stream); 459 460 const int cache_friendly_tasks_count = 461 internal::CalculateCacheFriendlyTasksCount(cache_size, lhs_scratch, 462 rhs_scratch, params.n, n); 463 464 if (cache_friendly_tasks_count == 1) { 465 GemmExecutorPackRHS::ExecuteDispatch3D<P, m, n, k, m_leftovers, 466 n_leftovers, k_leftovers>(params); 467 return; 468 } 469 470 const int cache_friendly_dim = params.n / cache_friendly_tasks_count; 471 472 P task_params = params; 473 for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) { 474 internal::UpdateCacheFriendlyTask(0, params.m, i * cache_friendly_dim, 475 cache_friendly_dim, params, 476 &task_params); 477 Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params); 478 } 479 const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim; 480 internal::UpdateCacheFriendlyTask(0, params.m, dim_sum, params.n - dim_sum, 481 params, &task_params); 482 Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params); 483 } 484 }; 485 486 template <int cache_size = 256 * 1024> 487 class GemmExecutorPackLHSCacheFriendly { 488 public: 489 template <typename P> 490 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, 491 int kernel_k) { 492 return cache_size; 493 } 494 495 template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers, 496 int k_leftovers> 497 static void ExecuteDispatch3D(const P& params) { 498 typedef Stream<typename P::InType, m, k, k_leftovers, 499 typename P::LeftStream> 500 LeftStream; 501 502 typedef Stream<typename P::InType, n, k, k_leftovers, 503 typename P::RightStream> 504 RightStream; 505 506 const int lhs_scratch = LeftStream::Scratch(params.left_stream); 507 const int rhs_scratch = RightStream::Scratch(params.right_stream); 508 509 const int cache_friendly_tasks_count = 510 internal::CalculateCacheFriendlyTasksCount(cache_size, rhs_scratch, 511 lhs_scratch, params.m, m); 512 513 if (cache_friendly_tasks_count == 1) { 514 GemmExecutorPackLHS::ExecuteDispatch3D<P, m, n, k, m_leftovers, 515 n_leftovers, k_leftovers>(params); 516 return; 517 } 518 519 const int cache_friendly_dim = params.m / cache_friendly_tasks_count; 520 521 P task_params = params; 522 for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) { 523 internal::UpdateCacheFriendlyTask(i * cache_friendly_dim, 524 cache_friendly_dim, 0, params.n, params, 525 &task_params); 526 Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params); 527 } 528 const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim; 529 internal::UpdateCacheFriendlyTask(dim_sum, params.m - dim_sum, 0, params.n, 530 params, &task_params); 531 Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params); 532 } 533 }; 534 535 namespace internal { 536 537 // Stage 3. 538 539 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m, 540 int fixed_n, int variable_k> 541 struct Dispatch3DStage3 { 542 static void Execute(const P& params, int k) { 543 #ifdef DEBUG 544 #ifdef DEBUG_METAGEMM_VERBOSE 545 std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k 546 << " : " << fixed_m << "x" << fixed_n << "x" << variable_k 547 << std::endl 548 << std::flush; 549 #endif 550 #endif 551 if (k == variable_k) { 552 E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 553 variable_k>(params); 554 } else { 555 Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 556 variable_k - 1>::Execute(params, k); 557 } 558 } 559 }; 560 561 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m, 562 int fixed_n> 563 struct Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 0> { 564 static void Execute(const P& params, int k) { 565 #ifdef DEBUG 566 #ifdef DEBUG_METAGEMM_VERBOSE 567 std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k 568 << " : " << fixed_m << "x" << fixed_n << "x" << 0 << std::endl 569 << std::flush; 570 #endif 571 #endif 572 if (k == 0) { 573 E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 574 0>(params); 575 } else { 576 std::cerr << "FATAL: dispatch3DStage3 failed: ran out of cases." 577 << std::endl 578 << std::flush; 579 std::exit(1); 580 } 581 } 582 }; 583 584 // Stage 2. 585 586 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m, 587 int variable_n> 588 struct Dispatch3DStage2 { 589 static void Execute(const P& params, int n, int k) { 590 #ifdef DEBUG 591 #ifdef DEBUG_METAGEMM_VERBOSE 592 std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k 593 << " : " << fixed_m << "x" << variable_n << std::endl 594 << std::flush; 595 #endif 596 #endif 597 if (n == variable_n) { 598 Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, variable_n, 599 dim_k - 1>::Execute(params, k); 600 } else { 601 Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, 602 variable_n - 1>::Execute(params, n, k); 603 } 604 } 605 }; 606 607 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m> 608 struct Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, 0> { 609 static void Execute(const P& params, int n, int k) { 610 #ifdef DEBUG 611 #ifdef DEBUG_METAGEMM_VERBOSE 612 std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k 613 << " : " << fixed_m << "x" << 0 << std::endl 614 << std::flush; 615 #endif 616 #endif 617 if (n == 0) { 618 Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, 0, 619 dim_k - 1>::Execute(params, k); 620 } else { 621 std::cerr << "FATAL: dispatch3DStage2 failed: ran out of cases." 622 << std::endl 623 << std::flush; 624 std::exit(1); 625 } 626 } 627 }; 628 629 // Stage 1. 630 631 template <typename E, typename P, int dim_m, int dim_n, int dim_k, 632 int variable_m> 633 struct Dispatch3DStage1 { 634 static void Execute(const P& params, int m, int n, int k) { 635 #ifdef DEBUG 636 #ifdef DEBUG_METAGEMM_VERBOSE 637 std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k 638 << " : " << variable_m << std::endl 639 << std::flush; 640 #endif 641 #endif 642 if (m == variable_m) { 643 Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, variable_m, 644 dim_n - 1>::Execute(params, n, k); 645 } else { 646 Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, variable_m - 1>::Execute( 647 params, m, n, k); 648 } 649 } 650 }; 651 652 template <typename E, typename P, int dim_m, int dim_n, int dim_k> 653 struct Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, 0> { 654 static void Execute(const P& params, int m, int n, int k) { 655 #ifdef DEBUG 656 #ifdef DEBUG_METAGEMM_VERBOSE 657 std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k 658 << " : " << 0 << std::endl 659 << std::flush; 660 #endif 661 #endif 662 if (m == 0) { 663 Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, 0, dim_n - 1>::Execute(params, 664 n, k); 665 } else { 666 std::cerr << "FATAL: dispatch3DStage1 failed: ran out of cases." 667 << std::endl 668 << std::flush; 669 std::exit(1); 670 } 671 } 672 }; 673 674 } // namespace internal 675 676 template <typename Executor, typename Params, int kernel_m, int kernel_n, 677 int kernel_k> 678 inline void Gemm(const Params& params) { 679 internal::Dispatch3DStage1<Executor, Params, kernel_m, kernel_n, kernel_k, 680 kernel_m - 1>::Execute(params, params.m % kernel_m, 681 params.n % kernel_n, 682 params.k % kernel_k); 683 } 684 685 } // namespace meta 686 } // namespace gemmlowp 687 688 #endif // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_ 689