Home | History | Annotate | Download | only in meta
      1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #ifndef GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
     16 #define GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
     17 
     18 #include <iostream>
     19 #include "base.h"
     20 
     21 namespace gemmlowp {
     22 namespace meta {
     23 
     24 template <typename Executor, typename Params, int kernel_m, int kernel_n,
     25           int kernel_k>
     26 void Gemm(const Params& params);
     27 
     28 class GemmExecutorPackRHS {
     29  public:
     30   template <typename P>
     31   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
     32                                  int kernel_k) {
     33     const int lhs_scratch =
     34         StreamUtil<typename P::InType, typename P::LeftStream>::Scratch(
     35             params.left_stream, kernel_m, kernel_k);
     36     const int rhs_chunks = ((params.n + kernel_n - 1) / kernel_n);
     37     const int rhs_scratch =
     38         rhs_chunks *
     39         StreamUtil<typename P::InType, typename P::RightStream>::Scratch(
     40             params.right_stream, kernel_n, kernel_k);
     41     return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch);
     42   }
     43 
     44   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
     45             int k_leftovers>
     46   static void ExecuteDispatch3D(const P& params) {
     47     // Shorthand typedefs for streams and multiply kernels.
     48     typedef typename P::InType InType;
     49     typedef typename P::OutType OutType;
     50 
     51     typedef Stream<typename P::InType, m, k, k_leftovers,
     52                    typename P::LeftStream>
     53         LeftStreamF;
     54     typedef Stream<typename P::InType, m_leftovers, k, k_leftovers,
     55                    typename P::LeftStream>
     56         LeftStreamL;
     57 
     58     typedef Stream<typename P::InType, n, k, k_leftovers,
     59                    typename P::RightStream>
     60         RightStreamF;
     61     typedef Stream<typename P::InType, n_leftovers, k, k_leftovers,
     62                    typename P::RightStream>
     63         RightStreamL;
     64 
     65     typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream>
     66         OutputStreamFF;
     67     typedef Stream<typename P::OutType, m_leftovers, n, 0,
     68                    typename P::OutputStream>
     69         OutputStreamLF;
     70 
     71     typedef MulKernel<typename P::InType, typename P::OutType,
     72                       typename P::Kernel, typename P::OutputStream, m, n, k>
     73         KernelFF;
     74     typedef MulKernel<typename P::InType, typename P::OutType,
     75                       typename P::Kernel, typename P::OutputStream, m,
     76                       n_leftovers, k>
     77         KernelFL;
     78     typedef MulKernel<typename P::InType, typename P::OutType,
     79                       typename P::Kernel, typename P::OutputStream, m_leftovers,
     80                       n, k>
     81         KernelLF;
     82     typedef MulKernel<typename P::InType, typename P::OutType,
     83                       typename P::Kernel, typename P::OutputStream, m_leftovers,
     84                       n_leftovers, k>
     85         KernelLL;
     86 
     87 #ifdef DEBUG
     88 #ifdef DEBUG_METAGEMM_VERBOSE
     89     std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n
     90               << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x"
     91               << k_leftovers << " -- " << params.m << "x" << params.n << "x"
     92               << params.k << std::endl;
     93     LeftStreamF::Debug(params.left_stream);
     94     LeftStreamL::Debug(params.left_stream);
     95 
     96     RightStreamF::Debug(params.right_stream);
     97     RightStreamL::Debug(params.right_stream);
     98 
     99     OutputStreamFF::Debug(params.fused_kernel.output_stream);
    100     OutputStreamLF::Debug(params.fused_kernel.output_stream);
    101 
    102     KernelFF::Debug(params.fused_kernel);
    103     KernelFL::Debug(params.fused_kernel);
    104     KernelLF::Debug(params.fused_kernel);
    105     KernelLL::Debug(params.fused_kernel);
    106 #endif
    107 #endif
    108 
    109     int lhs_chunks = params.m / m;
    110     int rhs_chunks = params.n / n;
    111 
    112     // Scratch memory for packed LHS & RHS chunks.
    113 
    114     std::uint8_t* packed_lhs = params.scratch;
    115     std::uint8_t* packed_rhs =
    116         params.scratch + LeftStreamF::Scratch(params.left_stream);
    117 
    118     // Pack full RHS first.
    119 
    120     std::uint8_t* packed_rhs_chunk = packed_rhs;
    121     const int packed_rhs_chunk_size =
    122         RightStreamF::PackedStride(params.right_stream);
    123 
    124     {
    125       const std::uint8_t* rhs_chunk =
    126           reinterpret_cast<const std::uint8_t*>(params.rhs);
    127       const int rhs_chunk_size =
    128           RightStreamF::UnpackedStride(params.right_stream);
    129 
    130       for (int i = 0; i < rhs_chunks; ++i) {
    131         RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk),
    132                            params.right_stream,
    133                            reinterpret_cast<InType*>(packed_rhs_chunk));
    134 
    135         rhs_chunk += rhs_chunk_size;
    136         packed_rhs_chunk += packed_rhs_chunk_size;
    137       }
    138 
    139       RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk),
    140                          params.right_stream,
    141                          reinterpret_cast<InType*>(packed_rhs_chunk));
    142     }
    143 
    144     // Multiply RHS by LHS one LHS chunk at a time.
    145 
    146     const std::uint8_t* lhs_chunk =
    147         reinterpret_cast<const std::uint8_t*>(params.lhs);
    148     std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result);
    149     std::uint8_t* result_chunk = result_strip;
    150 
    151     {
    152       const int lhs_chunk_size =
    153           LeftStreamF::UnpackedStride(params.left_stream);
    154       const int result_strip_size =
    155           OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream);
    156       const int result_chunk_size =
    157           OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream);
    158 
    159       for (int i = 0; i < lhs_chunks; ++i) {
    160         LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk),
    161                           params.left_stream,
    162                           reinterpret_cast<InType*>(packed_lhs));
    163 
    164         result_chunk = result_strip;
    165         packed_rhs_chunk = packed_rhs;
    166 
    167         for (int j = 0; j < rhs_chunks; ++j) {
    168           KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs),
    169                              reinterpret_cast<const InType*>(packed_rhs_chunk),
    170                              params.fused_kernel,
    171                              reinterpret_cast<OutType*>(result_chunk));
    172 
    173           result_chunk += result_chunk_size;
    174           packed_rhs_chunk += packed_rhs_chunk_size;
    175         }
    176 
    177         KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs),
    178                            reinterpret_cast<const InType*>(packed_rhs_chunk),
    179                            params.fused_kernel,
    180                            reinterpret_cast<OutType*>(result_chunk));
    181 
    182         lhs_chunk += lhs_chunk_size;
    183         result_strip += result_strip_size;
    184       }
    185     }
    186 
    187     // Leftover LHS chunk.
    188     if (m_leftovers > 0) {  // static if
    189       const int result_chunk_size =
    190           OutputStreamLF::UnpackedAdvance(params.fused_kernel.output_stream);
    191 
    192       LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk),
    193                         params.left_stream,
    194                         reinterpret_cast<InType*>(packed_lhs));
    195 
    196       result_chunk = result_strip;
    197       packed_rhs_chunk = packed_rhs;
    198 
    199       for (int i = 0; i < rhs_chunks; ++i) {
    200         KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs),
    201                            reinterpret_cast<const InType*>(packed_rhs_chunk),
    202                            params.fused_kernel,
    203                            reinterpret_cast<OutType*>(result_chunk));
    204 
    205         result_chunk += result_chunk_size;
    206         packed_rhs_chunk += packed_rhs_chunk_size;
    207       }
    208 
    209       KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs),
    210                          reinterpret_cast<const InType*>(packed_rhs_chunk),
    211                          params.fused_kernel,
    212                          reinterpret_cast<OutType*>(result_chunk));
    213     }
    214   }
    215 };
    216 
    217 class GemmExecutorPackLHS {
    218  public:
    219   template <typename P>
    220   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
    221                                  int kernel_k) {
    222     const int lhs_chunks = ((params.m + kernel_m - 1) / kernel_m);
    223     const int lhs_scratch =
    224         lhs_chunks *
    225         StreamUtil<typename P::InType, typename P::LeftStream>::Scratch(
    226             params.left_stream, kernel_m, kernel_k);
    227     const int rhs_scratch =
    228         StreamUtil<typename P::InType, typename P::RightStream>::Scratch(
    229             params.right_stream, kernel_n, kernel_k);
    230     return AlignTo<64 * 1024>(lhs_scratch + rhs_scratch);
    231   }
    232 
    233   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
    234             int k_leftovers>
    235   static void ExecuteDispatch3D(const P& params) {
    236     // Shorthand typedefs for streams and multiply kernels.
    237     typedef typename P::InType InType;
    238     typedef typename P::OutType OutType;
    239 
    240     typedef Stream<typename P::InType, m, k, k_leftovers,
    241                    typename P::LeftStream>
    242         LeftStreamF;
    243     typedef Stream<typename P::InType, m_leftovers, k, k_leftovers,
    244                    typename P::LeftStream>
    245         LeftStreamL;
    246 
    247     typedef Stream<typename P::InType, n, k, k_leftovers,
    248                    typename P::RightStream>
    249         RightStreamF;
    250     typedef Stream<typename P::InType, n_leftovers, k, k_leftovers,
    251                    typename P::RightStream>
    252         RightStreamL;
    253 
    254     typedef Stream<typename P::OutType, m, n, 0, typename P::OutputStream>
    255         OutputStreamFF;
    256     typedef Stream<typename P::OutType, m, n_leftovers, 0,
    257                    typename P::OutputStream>
    258         OutputStreamFL;
    259 
    260     typedef MulKernel<typename P::InType, typename P::OutType,
    261                       typename P::Kernel, typename P::OutputStream, m, n, k>
    262         KernelFF;
    263     typedef MulKernel<typename P::InType, typename P::OutType,
    264                       typename P::Kernel, typename P::OutputStream, m,
    265                       n_leftovers, k>
    266         KernelFL;
    267     typedef MulKernel<typename P::InType, typename P::OutType,
    268                       typename P::Kernel, typename P::OutputStream, m_leftovers,
    269                       n, k>
    270         KernelLF;
    271     typedef MulKernel<typename P::InType, typename P::OutType,
    272                       typename P::Kernel, typename P::OutputStream, m_leftovers,
    273                       n_leftovers, k>
    274         KernelLL;
    275 #ifdef DEBUG
    276 #ifdef DEBUG_METAGEMM_VERBOSE
    277     std::cout << "GemmExecutor(" << typeid(P).name() << "): " << m << "x" << n
    278               << "x" << k << " -- " << m_leftovers << "x" << n_leftovers << "x"
    279               << k_leftovers << " -- " << params.m << "x" << params.n << "x"
    280               << params.k << std::endl;
    281     LeftStreamF::Debug(params.left_stream);
    282     LeftStreamL::Debug(params.left_stream);
    283 
    284     RightStreamF::Debug(params.right_stream);
    285     RightStreamL::Debug(params.right_stream);
    286 
    287     OutputStreamFF::Debug(params.fused_kernel.output_stream);
    288     OutputStreamFL::Debug(params.fused_kernel.output_stream);
    289 
    290     KernelFF::Debug(params.fused_kernel);
    291     KernelFL::Debug(params.fused_kernel);
    292     KernelLF::Debug(params.fused_kernel);
    293     KernelLL::Debug(params.fused_kernel);
    294 #endif
    295 #endif
    296 
    297     int lhs_chunks = params.m / m;
    298     int rhs_chunks = params.n / n;
    299 
    300     // Scratch memory for packed LHS & RHS chunks.
    301     std::uint8_t* packed_rhs = params.scratch;
    302     std::uint8_t* packed_lhs =
    303         params.scratch + RightStreamF::Scratch(params.right_stream);
    304 
    305     // Pack full LHS first.
    306 
    307     std::uint8_t* packed_lhs_chunk = packed_lhs;
    308     const int packed_lhs_chunk_size =
    309         LeftStreamF::PackedStride(params.left_stream);
    310 
    311     {
    312       const std::uint8_t* lhs_chunk =
    313           reinterpret_cast<const std::uint8_t*>(params.lhs);
    314       const int lhs_chunk_size =
    315           LeftStreamF::UnpackedStride(params.left_stream);
    316 
    317       for (int i = 0; i < lhs_chunks; ++i) {
    318         LeftStreamF::Pack(reinterpret_cast<const InType*>(lhs_chunk),
    319                           params.left_stream,
    320                           reinterpret_cast<InType*>(packed_lhs_chunk));
    321 
    322         lhs_chunk += lhs_chunk_size;
    323         packed_lhs_chunk += packed_lhs_chunk_size;
    324       }
    325 
    326       LeftStreamL::Pack(reinterpret_cast<const InType*>(lhs_chunk),
    327                         params.left_stream,
    328                         reinterpret_cast<InType*>(packed_lhs_chunk));
    329     }
    330 
    331     // Multiply RHS by LHS one RHS chunk at a time.
    332 
    333     const std::uint8_t* rhs_chunk =
    334         reinterpret_cast<const std::uint8_t*>(params.rhs);
    335     std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result);
    336     std::uint8_t* result_chunk = result_strip;
    337 
    338     {
    339       const int rhs_chunk_size =
    340           RightStreamF::UnpackedStride(params.right_stream);
    341       const int result_strip_size =
    342           OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream);
    343       const int result_chunk_size =
    344           OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream);
    345 
    346       for (int i = 0; i < rhs_chunks; ++i) {
    347         RightStreamF::Pack(reinterpret_cast<const InType*>(rhs_chunk),
    348                            params.right_stream,
    349                            reinterpret_cast<InType*>(packed_rhs));
    350 
    351         result_chunk = result_strip;
    352         packed_lhs_chunk = packed_lhs;
    353 
    354         for (int j = 0; j < lhs_chunks; ++j) {
    355           KernelFF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
    356                              reinterpret_cast<const InType*>(packed_rhs),
    357                              params.fused_kernel,
    358                              reinterpret_cast<OutType*>(result_chunk));
    359 
    360           result_chunk += result_chunk_size;
    361           packed_lhs_chunk += packed_lhs_chunk_size;
    362         }
    363 
    364         KernelLF::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
    365                            reinterpret_cast<const InType*>(packed_rhs),
    366                            params.fused_kernel,
    367                            reinterpret_cast<OutType*>(result_chunk));
    368 
    369         rhs_chunk += rhs_chunk_size;
    370         result_strip += result_strip_size;
    371       }
    372     }
    373 
    374     // Leftover RHS chunk.
    375     if (n_leftovers > 0) {  // static if
    376       const int result_chunk_size =
    377           OutputStreamFL::UnpackedStride(params.fused_kernel.output_stream);
    378 
    379       RightStreamL::Pack(reinterpret_cast<const InType*>(rhs_chunk),
    380                          params.right_stream,
    381                          reinterpret_cast<InType*>(packed_rhs));
    382 
    383       result_chunk = result_strip;
    384       packed_lhs_chunk = packed_lhs;
    385 
    386       for (int i = 0; i < lhs_chunks; ++i) {
    387         KernelFL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
    388                            reinterpret_cast<const InType*>(packed_rhs),
    389                            params.fused_kernel,
    390                            reinterpret_cast<OutType*>(result_chunk));
    391 
    392         result_chunk += result_chunk_size;
    393         packed_lhs_chunk += packed_lhs_chunk_size;
    394       }
    395 
    396       KernelLL::Multiply(reinterpret_cast<const InType*>(packed_lhs_chunk),
    397                          reinterpret_cast<const InType*>(packed_rhs),
    398                          params.fused_kernel,
    399                          reinterpret_cast<OutType*>(result_chunk));
    400     }
    401   }
    402 };
    403 
    404 namespace internal {
    405 
    406 inline int CalculateCacheFriendlyTasksCount(int cache_size, int constant_memory,
    407                                             int per_chunk_memory, int total_dim,
    408                                             int chunk_dim) {
    409   assert(constant_memory + per_chunk_memory < cache_size);
    410   const int available_cache = cache_size - constant_memory;
    411   const int available_chunks = available_cache / per_chunk_memory;
    412   const int chunks_count = (total_dim + chunk_dim - 1) / chunk_dim;
    413   return (chunks_count + available_chunks - 1) / available_chunks;
    414 }
    415 
    416 template <typename Params>
    417 inline void UpdateCacheFriendlyTask(int m_offset, int m, int n_offset, int n,
    418                                     const Params& params, Params* task_params) {
    419   task_params->m = m;
    420   task_params->lhs =
    421       StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset(
    422           params.left_stream, params.lhs, m_offset, 0);
    423 
    424   task_params->n = n;
    425   task_params->rhs =
    426       StreamUtil<typename Params::InType, typename Params::RightStream>::Offset(
    427           params.right_stream, params.rhs, n_offset, 0);
    428 
    429   task_params->result =
    430       StreamUtil<typename Params::OutType, typename Params::OutputStream>::
    431           Offset(params.fused_kernel.output_stream, params.result, m_offset,
    432                  n_offset);
    433 }
    434 
    435 }  // namespace internal
    436 
    437 template <int cache_size = 256 * 1024>
    438 class GemmExecutorPackRHSCacheFriendly {
    439  public:
    440   template <typename P>
    441   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
    442                                  int kernel_k) {
    443     return cache_size;
    444   }
    445 
    446   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
    447             int k_leftovers>
    448   static void ExecuteDispatch3D(const P& params) {
    449     typedef Stream<typename P::InType, m, k, k_leftovers,
    450                    typename P::LeftStream>
    451         LeftStream;
    452 
    453     typedef Stream<typename P::InType, n, k, k_leftovers,
    454                    typename P::RightStream>
    455         RightStream;
    456 
    457     const int lhs_scratch = LeftStream::Scratch(params.left_stream);
    458     const int rhs_scratch = RightStream::Scratch(params.right_stream);
    459 
    460     const int cache_friendly_tasks_count =
    461         internal::CalculateCacheFriendlyTasksCount(cache_size, lhs_scratch,
    462                                                    rhs_scratch, params.n, n);
    463 
    464     if (cache_friendly_tasks_count == 1) {
    465       GemmExecutorPackRHS::ExecuteDispatch3D<P, m, n, k, m_leftovers,
    466                                              n_leftovers, k_leftovers>(params);
    467       return;
    468     }
    469 
    470     const int cache_friendly_dim = params.n / cache_friendly_tasks_count;
    471 
    472     P task_params = params;
    473     for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) {
    474       internal::UpdateCacheFriendlyTask(0, params.m, i * cache_friendly_dim,
    475                                         cache_friendly_dim, params,
    476                                         &task_params);
    477       Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params);
    478     }
    479     const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim;
    480     internal::UpdateCacheFriendlyTask(0, params.m, dim_sum, params.n - dim_sum,
    481                                       params, &task_params);
    482     Gemm<GemmExecutorPackRHS, P, m, n, k>(task_params);
    483   }
    484 };
    485 
    486 template <int cache_size = 256 * 1024>
    487 class GemmExecutorPackLHSCacheFriendly {
    488  public:
    489   template <typename P>
    490   static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n,
    491                                  int kernel_k) {
    492     return cache_size;
    493   }
    494 
    495   template <typename P, int m, int n, int k, int m_leftovers, int n_leftovers,
    496             int k_leftovers>
    497   static void ExecuteDispatch3D(const P& params) {
    498     typedef Stream<typename P::InType, m, k, k_leftovers,
    499                    typename P::LeftStream>
    500         LeftStream;
    501 
    502     typedef Stream<typename P::InType, n, k, k_leftovers,
    503                    typename P::RightStream>
    504         RightStream;
    505 
    506     const int lhs_scratch = LeftStream::Scratch(params.left_stream);
    507     const int rhs_scratch = RightStream::Scratch(params.right_stream);
    508 
    509     const int cache_friendly_tasks_count =
    510         internal::CalculateCacheFriendlyTasksCount(cache_size, rhs_scratch,
    511                                                    lhs_scratch, params.m, m);
    512 
    513     if (cache_friendly_tasks_count == 1) {
    514       GemmExecutorPackLHS::ExecuteDispatch3D<P, m, n, k, m_leftovers,
    515                                              n_leftovers, k_leftovers>(params);
    516       return;
    517     }
    518 
    519     const int cache_friendly_dim = params.m / cache_friendly_tasks_count;
    520 
    521     P task_params = params;
    522     for (int i = 0; i < cache_friendly_tasks_count - 1; ++i) {
    523       internal::UpdateCacheFriendlyTask(i * cache_friendly_dim,
    524                                         cache_friendly_dim, 0, params.n, params,
    525                                         &task_params);
    526       Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params);
    527     }
    528     const int dim_sum = (cache_friendly_tasks_count - 1) * cache_friendly_dim;
    529     internal::UpdateCacheFriendlyTask(dim_sum, params.m - dim_sum, 0, params.n,
    530                                       params, &task_params);
    531     Gemm<GemmExecutorPackLHS, P, m, n, k>(task_params);
    532   }
    533 };
    534 
    535 namespace internal {
    536 
    537 // Stage 3.
    538 
    539 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
    540           int fixed_n, int variable_k>
    541 struct Dispatch3DStage3 {
    542   static void Execute(const P& params, int k) {
    543 #ifdef DEBUG
    544 #ifdef DEBUG_METAGEMM_VERBOSE
    545     std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k
    546               << " : " << fixed_m << "x" << fixed_n << "x" << variable_k
    547               << std::endl
    548               << std::flush;
    549 #endif
    550 #endif
    551     if (k == variable_k) {
    552       E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
    553                                     variable_k>(params);
    554     } else {
    555       Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
    556                        variable_k - 1>::Execute(params, k);
    557     }
    558   }
    559 };
    560 
    561 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
    562           int fixed_n>
    563 struct Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, fixed_n, 0> {
    564   static void Execute(const P& params, int k) {
    565 #ifdef DEBUG
    566 #ifdef DEBUG_METAGEMM_VERBOSE
    567     std::cout << "Dispatch(3): " << dim_m << "x" << dim_n << "x" << dim_k
    568               << " : " << fixed_m << "x" << fixed_n << "x" << 0 << std::endl
    569               << std::flush;
    570 #endif
    571 #endif
    572     if (k == 0) {
    573       E::template ExecuteDispatch3D<P, dim_m, dim_n, dim_k, fixed_m, fixed_n,
    574                                     0>(params);
    575     } else {
    576       std::cerr << "FATAL: dispatch3DStage3 failed: ran out of cases."
    577                 << std::endl
    578                 << std::flush;
    579       std::exit(1);
    580     }
    581   }
    582 };
    583 
    584 // Stage 2.
    585 
    586 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m,
    587           int variable_n>
    588 struct Dispatch3DStage2 {
    589   static void Execute(const P& params, int n, int k) {
    590 #ifdef DEBUG
    591 #ifdef DEBUG_METAGEMM_VERBOSE
    592     std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k
    593               << " : " << fixed_m << "x" << variable_n << std::endl
    594               << std::flush;
    595 #endif
    596 #endif
    597     if (n == variable_n) {
    598       Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, variable_n,
    599                        dim_k - 1>::Execute(params, k);
    600     } else {
    601       Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m,
    602                        variable_n - 1>::Execute(params, n, k);
    603     }
    604   }
    605 };
    606 
    607 template <typename E, typename P, int dim_m, int dim_n, int dim_k, int fixed_m>
    608 struct Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, fixed_m, 0> {
    609   static void Execute(const P& params, int n, int k) {
    610 #ifdef DEBUG
    611 #ifdef DEBUG_METAGEMM_VERBOSE
    612     std::cout << "Dispatch(2): " << dim_m << "x" << dim_n << "x" << dim_k
    613               << " : " << fixed_m << "x" << 0 << std::endl
    614               << std::flush;
    615 #endif
    616 #endif
    617     if (n == 0) {
    618       Dispatch3DStage3<E, P, dim_m, dim_n, dim_k, fixed_m, 0,
    619                        dim_k - 1>::Execute(params, k);
    620     } else {
    621       std::cerr << "FATAL: dispatch3DStage2 failed: ran out of cases."
    622                 << std::endl
    623                 << std::flush;
    624       std::exit(1);
    625     }
    626   }
    627 };
    628 
    629 // Stage 1.
    630 
    631 template <typename E, typename P, int dim_m, int dim_n, int dim_k,
    632           int variable_m>
    633 struct Dispatch3DStage1 {
    634   static void Execute(const P& params, int m, int n, int k) {
    635 #ifdef DEBUG
    636 #ifdef DEBUG_METAGEMM_VERBOSE
    637     std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k
    638               << " : " << variable_m << std::endl
    639               << std::flush;
    640 #endif
    641 #endif
    642     if (m == variable_m) {
    643       Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, variable_m,
    644                        dim_n - 1>::Execute(params, n, k);
    645     } else {
    646       Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, variable_m - 1>::Execute(
    647           params, m, n, k);
    648     }
    649   }
    650 };
    651 
    652 template <typename E, typename P, int dim_m, int dim_n, int dim_k>
    653 struct Dispatch3DStage1<E, P, dim_m, dim_n, dim_k, 0> {
    654   static void Execute(const P& params, int m, int n, int k) {
    655 #ifdef DEBUG
    656 #ifdef DEBUG_METAGEMM_VERBOSE
    657     std::cout << "Dispatch(1): " << dim_m << "x" << dim_n << "x" << dim_k
    658               << " : " << 0 << std::endl
    659               << std::flush;
    660 #endif
    661 #endif
    662     if (m == 0) {
    663       Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, 0, dim_n - 1>::Execute(params,
    664                                                                          n, k);
    665     } else {
    666       std::cerr << "FATAL: dispatch3DStage1 failed: ran out of cases."
    667                 << std::endl
    668                 << std::flush;
    669       std::exit(1);
    670     }
    671   }
    672 };
    673 
    674 }  // namespace internal
    675 
    676 template <typename Executor, typename Params, int kernel_m, int kernel_n,
    677           int kernel_k>
    678 inline void Gemm(const Params& params) {
    679   internal::Dispatch3DStage1<Executor, Params, kernel_m, kernel_n, kernel_k,
    680                              kernel_m - 1>::Execute(params, params.m % kernel_m,
    681                                                     params.n % kernel_n,
    682                                                     params.k % kernel_k);
    683 }
    684 
    685 }  // namespace meta
    686 }  // namespace gemmlowp
    687 
    688 #endif  // GEMMLOWP_META_SINGLE_THREAD_GEMM_H_
    689