Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
     12 
     13 // evaluator for thread pool device
     14 #ifdef EIGEN_USE_THREADS
     15 
     16 namespace Eigen {
     17 
     18 #ifdef EIGEN_USE_SIMPLE_THREAD_POOL
     19 namespace internal {
     20 
     21 template<typename LhsScalar, typename LhsMapper, typename Index>
     22 struct packLhsArg {
     23   LhsScalar* blockA;
     24   const LhsMapper& lhs;
     25   const Index m_start;
     26   const Index k_start;
     27   const Index mc;
     28   const Index kc;
     29 };
     30 
     31 template<typename LhsScalar, typename RhsScalar, typename RhsMapper, typename OutputMapper, typename Index>
     32 struct packRhsAndKernelArg {
     33   const MaxSizeVector<LhsScalar*>* blockAs;
     34   RhsScalar* blockB;
     35   const RhsMapper& rhs;
     36   OutputMapper& output;
     37   const Index m;
     38   const Index k;
     39   const Index n;
     40   const Index mc;
     41   const Index kc;
     42   const Index nc;
     43   const Index num_threads;
     44   const Index num_blockAs;
     45   const Index max_m;
     46   const Index k_block_idx;
     47   const Index m_block_idx;
     48   const Index n_block_idx;
     49   const Index m_blocks;
     50   const Index n_blocks;
     51   MaxSizeVector<Notification*>* kernel_notifications;
     52   const MaxSizeVector<Notification*>* lhs_notifications;
     53   const bool need_to_pack;
     54 };
     55 
     56 }  // end namespace internal
     57 #endif  // EIGEN_USE_SIMPLE_THREAD_POOL
     58 
     59 template<typename Indices, typename LeftArgType, typename RightArgType>
     60 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> :
     61     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > {
     62 
     63   typedef ThreadPoolDevice Device;
     64 
     65   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
     66   typedef TensorContractionEvaluatorBase<Self> Base;
     67 
     68   typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
     69   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
     70   typedef typename XprType::Index Index;
     71   typedef typename XprType::CoeffReturnType CoeffReturnType;
     72   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
     73 
     74   enum {
     75     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
     76   };
     77 
     78   // Most of the code is assuming that both input tensors are ColMajor. If the
     79   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
     80   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
     81   // will pretend B is LHS and A is RHS.
     82   typedef typename internal::conditional<
     83     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
     84   typedef typename internal::conditional<
     85     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
     86 
     87   static const int LDims =
     88       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
     89   static const int RDims =
     90       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
     91   static const int ContractDims = internal::array_size<Indices>::value;
     92 
     93   typedef array<Index, LDims> left_dim_mapper_t;
     94   typedef array<Index, RDims> right_dim_mapper_t;
     95 
     96   typedef array<Index, ContractDims> contract_t;
     97   typedef array<Index, LDims - ContractDims> left_nocontract_t;
     98   typedef array<Index, RDims - ContractDims> right_nocontract_t;
     99 
    100   static const int NumDims = LDims + RDims - 2 * ContractDims;
    101 
    102   typedef DSizes<Index, NumDims> Dimensions;
    103 
    104   // typedefs needed in evalTo
    105   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
    106   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
    107   typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
    108 
    109   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
    110   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
    111 
    112   TensorEvaluator(const XprType& op, const Device& device) :
    113       Base(op, device) {}
    114 
    115 #ifndef EIGEN_USE_SIMPLE_THREAD_POOL
    116   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
    117             bool rhs_inner_dim_reordered, int Alignment>
    118   void evalProduct(Scalar* buffer) const {
    119     typedef
    120         typename internal::remove_const<typename EvalLeftArgType::Scalar>::type
    121             LhsScalar;
    122     typedef
    123         typename internal::remove_const<typename EvalRightArgType::Scalar>::type
    124             RhsScalar;
    125     typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
    126     typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
    127     typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
    128     typedef internal::TensorContractionInputMapper<
    129         LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
    130         contract_t, internal::packet_traits<LhsScalar>::size,
    131         lhs_inner_dim_contiguous, false, Unaligned>
    132         LhsMapper;
    133     typedef internal::TensorContractionInputMapper<
    134         RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
    135         contract_t, internal::packet_traits<RhsScalar>::size,
    136         rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
    137         RhsMapper;
    138     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
    139     typedef internal::gemm_pack_lhs<LhsScalar, Index,
    140                                     typename LhsMapper::SubMapper, Traits::mr,
    141                                     Traits::LhsProgress, ColMajor>
    142         LhsPacker;
    143     typedef internal::gemm_pack_rhs<
    144         RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
    145         RhsPacker;
    146     typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
    147                                   Traits::mr, Traits::nr, false, false>
    148         GebpKernel;
    149 
    150     const Index m = this->m_i_size;
    151     const Index n = this->m_j_size;
    152     const Index k = this->m_k_size;
    153     if (m == 0 || n == 0 || k == 0) return;
    154 
    155     // Compute a set of algorithm parameters:
    156     // - kernel block sizes (bm, bn, bk)
    157     // - task grain sizes (number of kernels executed per task: gm, gn)
    158     // - number of threads
    159     // - sharding by row/column
    160     // - parallel packing or first lhs then rhs
    161     // and some derived parameters:
    162     // - number of tasks (nm, nn, nk)
    163     // - number of kernels (nm0, nn0)
    164     // Unfortunately, all these parameters are tightly interdependent.
    165     // So in some cases we first compute approximate values, then compute other
    166     // values based on these approximations and then refine the approximations.
    167 
    168     // There are lots of heuristics here. There is some reasoning behind them,
    169     // but ultimately they are just tuned on contraction benchmarks for
    170     // different input configurations, thread counts and instruction sets.
    171     // So feel free to question any of them.
    172 
    173     // Compute whether we want to shard by row or by column.
    174     // This is a first approximation, it will be refined later. Since we don't
    175     // know number of threads yet we use 2, because what's we are most
    176     // interested in at this point is whether it makes sense to use
    177     // parallelization at all or not.
    178     bool shard_by_col = shardByCol(m, n, 2);
    179 
    180     // First approximation of kernel blocking sizes.
    181     // Again, we don't know number of threads yet, so we use 2.
    182     Index bm, bn, bk;
    183     if (shard_by_col) {
    184       internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
    185                                           internal::ShardByCol>
    186           blocking(k, m, n, 2);
    187       bm = blocking.mc();
    188       bn = blocking.nc();
    189       bk = blocking.kc();
    190     } else {
    191       internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
    192                                           internal::ShardByRow>
    193           blocking(k, m, n, 2);
    194       bm = blocking.mc();
    195       bn = blocking.nc();
    196       bk = blocking.kc();
    197     }
    198 
    199     // Compute optimal number of threads.
    200     // Note: we use bk instead of k here because we are interested in amount of
    201     // _parallelizable_ computations, and computations are not parallelizable
    202     // across k dimension.
    203     const TensorOpCost cost =
    204         contractionCost(m, n, bm, bn, bk, shard_by_col, false);
    205     int num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
    206         static_cast<double>(n) * m, cost, this->m_device.numThreads());
    207 
    208     // TODO(dvyukov): this is a stop-gap to prevent regressions while the cost
    209     // model is not tuned. Remove this when the cost model is tuned.
    210     if (n == 1) num_threads = 1;
    211 
    212     if (num_threads == 1) {
    213       // The single-threaded algorithm should be faster in this case.
    214       if (n == 1)
    215         this->template evalGemv<lhs_inner_dim_contiguous,
    216                                 rhs_inner_dim_contiguous,
    217                                 rhs_inner_dim_reordered, Alignment>(buffer);
    218       else
    219         this->template evalGemm<lhs_inner_dim_contiguous,
    220                                 rhs_inner_dim_contiguous,
    221                                 rhs_inner_dim_reordered, Alignment>(buffer);
    222       return;
    223     }
    224 
    225     // Now that we know number of threads, recalculate sharding and blocking.
    226     shard_by_col = shardByCol(m, n, num_threads);
    227     if (shard_by_col) {
    228       internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
    229                                           internal::ShardByCol>
    230           blocking(k, m, n, num_threads);
    231       bm = blocking.mc();
    232       bn = blocking.nc();
    233       bk = blocking.kc();
    234     } else {
    235       internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
    236                                           internal::ShardByRow>
    237           blocking(k, m, n, num_threads);
    238       bm = blocking.mc();
    239       bn = blocking.nc();
    240       bk = blocking.kc();
    241     }
    242 
    243     // Number of kernels for each dimension.
    244     Index nm0 = divup(m, bm);
    245     Index nn0 = divup(n, bn);
    246     Index nk = divup(k, bk);
    247 
    248     // Calculate task grain size (number of kernels executed per task).
    249     // This task size coarsening serves two purposes:
    250     // 1. It reduces per-task overheads including synchronization overheads.
    251     // 2. It allows to use caches better (reuse the same packed rhs in several
    252     // consecutive kernels).
    253     Index gm = 1;
    254     Index gn = 1;
    255     // If we are sharding by column, then we prefer to reduce rows first.
    256     if (shard_by_col) {
    257       gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
    258       gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
    259     } else {
    260       gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
    261       gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
    262     }
    263     // Number of tasks in each dimension.
    264     Index nm = divup(nm0, gm);
    265     Index nn = divup(nn0, gn);
    266 
    267     // Last by not least, decide whether we want to issue both lhs and rhs
    268     // packing in parallel; or issue lhs packing first, and then issue rhs
    269     // packing when lhs packing completes (for !shard_by_col lhs and rhs are
    270     // swapped). Parallel packing allows more parallelism (for both packing and
    271     // kernels), while sequential packing provides better locality (once
    272     // a thread finishes rhs packing it proceed to kernels with that rhs).
    273     // First, we are interested in parallel packing if there are few tasks.
    274     bool parallel_pack = num_threads >= nm * nn;
    275     // Also do parallel packing if all data fits into L2$.
    276     if (m * bk * Index(sizeof(LhsScalar)) + n * bk * Index(sizeof(RhsScalar)) <=
    277         l2CacheSize() * num_threads)
    278       parallel_pack = true;
    279     // But don't do it if we will use each rhs only once. Locality seems to be
    280     // more important in this case.
    281     if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
    282 
    283     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides,
    284                   this->m_i_strides, this->m_left_contracting_strides,
    285                   this->m_k_strides);
    286 
    287     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides,
    288                   this->m_j_strides, this->m_right_contracting_strides,
    289                   this->m_k_strides);
    290 
    291     Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
    292             OutputMapper>(this->m_device, num_threads, lhs, rhs, buffer, m, n,
    293                           k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
    294                           shard_by_col, parallel_pack)
    295         .run();
    296   }
    297 
    298   // Context coordinates a single parallel gemm operation.
    299   template <typename LhsPacker, typename RhsPacker, typename GebpKernel,
    300             typename LhsMapper, typename RhsMapper, typename OutputMapper>
    301   class Context {
    302    public:
    303     Context(const Device& device, int num_threads, LhsMapper& lhs,
    304             RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
    305             Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
    306             Index gn, Index nm0, Index nn0, bool shard_by_col,
    307             bool parallel_pack)
    308         : device_(device),
    309           lhs_(lhs),
    310           rhs_(rhs),
    311           buffer_(buffer),
    312           output_(buffer, tm),
    313           num_threads_(num_threads),
    314           shard_by_col_(shard_by_col),
    315           parallel_pack_(parallel_pack),
    316           m_(tm),
    317           n_(tn),
    318           k_(tk),
    319           bm_(bm),
    320           bn_(bn),
    321           bk_(bk),
    322           nm_(nm),
    323           nn_(nn),
    324           nk_(nk),
    325           gm_(gm),
    326           gn_(gn),
    327           nm0_(nm0),
    328           nn0_(nn0)
    329   {
    330       for (Index x = 0; x < P; x++) {
    331         // Normal number of notifications for k slice switch is
    332         // nm_ + nn_ + nm_ * nn_. However, first P - 1 slices will receive only
    333         // nm_ + nn_ notifications, because they will not receive notifications
    334         // from preceeding kernels.
    335         state_switch_[x] =
    336             x == 0
    337                 ? 1
    338                 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) +
    339                       (x == P - 1 ? nm_ * nn_ : 0);
    340         state_packing_ready_[x] =
    341             parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
    342         state_kernel_[x] = new std::atomic<uint8_t>*[nm_];
    343         for (Index m = 0; m < nm_; m++) {
    344           state_kernel_[x][m] = new std::atomic<uint8_t>[nn_];
    345           // Kernels generally receive 3 notifications (previous kernel + 2
    346           // packing), but the first slice won't get notifications from previous
    347           // kernels.
    348           for (Index n = 0; n < nn_; n++)
    349             state_kernel_[x][m][n].store(
    350                 (x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1),
    351                 std::memory_order_relaxed);
    352         }
    353       }
    354 
    355       // Allocate memory for packed rhs/lhs matrices.
    356       size_t align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1);
    357       size_t lhs_size =
    358           divup<size_t>(bm_ * bk_ * sizeof(LhsScalar), align) * align;
    359       size_t rhs_size =
    360           divup<size_t>(bn_ * bk_ * sizeof(RhsScalar), align) * align;
    361       packed_mem_ = static_cast<char*>(internal::aligned_malloc(
    362           (nm0_ * lhs_size + nn0_ * rhs_size) * std::min<size_t>(nk_, P - 1)));
    363       char* mem = static_cast<char*>(packed_mem_);
    364       for (Index x = 0; x < numext::mini<Index>(nk_, P - 1); x++) {
    365         packed_lhs_[x].resize(nm0_);
    366         for (Index m = 0; m < nm0_; m++) {
    367           packed_lhs_[x][m] = reinterpret_cast<LhsScalar*>(mem);
    368           mem += lhs_size;
    369         }
    370         packed_rhs_[x].resize(nn0_);
    371         for (Index n = 0; n < nn0_; n++) {
    372           packed_rhs_[x][n] = reinterpret_cast<RhsScalar*>(mem);
    373           mem += rhs_size;
    374         }
    375       }
    376     }
    377 
    378     ~Context() {
    379       for (Index x = 0; x < P; x++) {
    380         for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
    381         delete[] state_kernel_[x];
    382       }
    383       internal::aligned_free(packed_mem_);
    384     }
    385 
    386     void run() {
    387       // Kick off packing of the first slice.
    388       signal_switch(0, 1);
    389       // Wait for overall completion.
    390       // TODO(dvyukov): this wait can lead to deadlock.
    391       // If nthreads contractions are concurrently submitted from worker
    392       // threads, this wait will block all worker threads and the system will
    393       // deadlock.
    394       done_.Wait();
    395     }
    396 
    397    private:
    398     Notification done_;
    399     const Device& device_;
    400     LhsMapper& lhs_;
    401     RhsMapper& rhs_;
    402     Scalar* const buffer_;
    403     OutputMapper output_;
    404     const int num_threads_;
    405     const bool shard_by_col_;
    406     const bool parallel_pack_;
    407     // Matrix sizes.
    408     const Index m_;
    409     const Index n_;
    410     const Index k_;
    411     // Block sizes.
    412     const Index bm_;
    413     const Index bn_;
    414     const Index bk_;
    415     // Number of tasks.
    416     const Index nm_;
    417     const Index nn_;
    418     const Index nk_;
    419     // Task grain sizes (number of kernels executed per task).
    420     const Index gm_;
    421     const Index gn_;
    422     // Number of blocks (this is different from ni_/nn_ because of task size
    423     // coarsening).
    424     const Index nm0_;
    425     const Index nn0_;
    426 
    427     // Parallelization strategy.
    428     //
    429     // Blocks related to the same k block can run in parallel because they write
    430     // to different output blocks. So we parallelize within k slices, this
    431     // gives us parallelism level of m x n. Before we can start any kernels
    432     // related to k-th slice, we need to issue m lhs packing tasks and n rhs
    433     // packing tasks.
    434     //
    435     // However, there is a bottleneck when we are finishing kernels for k-th
    436     // slice (at the very end there is only 1 runnable kernel). To mitigate this
    437     // bottleneck we allow kernels from k-th and k+1-th slices to run in
    438     // parallel. Note that (m, n, k) and (m, n, k+1) kernels write to the same
    439     // output block, so they must not run in parallel.
    440     //
    441     // This gives us the following dependency graph.
    442     // On each k slice we have m x n kernel tasks, m lhs paking tasks and n rhs
    443     // packing tasks.
    444     // Kernel (m, n, k) can start when:
    445     //  - kernel (m, n, k-1) has finished
    446     //  - lhs packing (m, k) has finished
    447     //  - rhs packing (n, k) has finished
    448     // Lhs/rhs packing can start when:
    449     //  - all k-1 packing has finished (artificially imposed to limit amount of
    450     //  parallel packing)
    451     //
    452     // On top of that we limit runnable tasks to two consecutive k slices.
    453     // This is done to limit amount of memory we need for packed lhs/rhs
    454     // (for each k slice we need m*bk + n*bk memory in packed_lhs_/packed_rhs_).
    455     //
    456     // state_switch_ tracks when we are ready to switch to the next k slice.
    457     // state_kernel_[m][n] tracks when we are ready to kick off kernel (m, n).
    458     // These variable are rolling over 3 consecutive k slices: first two we are
    459     // actively executing + one to track completion of kernels in the second
    460     // slice.
    461     static const Index P = 3;
    462     void* packed_mem_;
    463     std::vector<LhsScalar*> packed_lhs_[P - 1];
    464     std::vector<RhsScalar*> packed_rhs_[P - 1];
    465     std::atomic<uint8_t>** state_kernel_[P];
    466     // state_switch_ is frequently modified by worker threads, while other
    467     // fields are read-only after constructor. Let's move it to a separate cache
    468     // line to reduce cache-coherency traffic.
    469     char pad_[128];
    470     std::atomic<Index> state_packing_ready_[P];
    471     std::atomic<Index> state_switch_[P];
    472 
    473     void pack_lhs(Index m, Index k) {
    474       const Index mend = m * gm_ + gm(m);
    475       for (Index m1 = m * gm_; m1 < mend; m1++)
    476         LhsPacker()(packed_lhs_[k % (P - 1)][m1],
    477                     lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
    478 
    479       if (!parallel_pack_ && shard_by_col_) {
    480         signal_packing(k);
    481       } else {
    482         signal_switch(k + 1);
    483         for (Index n = nn_ - 1; n >= 0; n--) signal_kernel(m, n, k, n == 0);
    484       }
    485     }
    486 
    487     void pack_rhs(Index n, Index k) {
    488       const Index nend = n * gn_ + gn(n);
    489       for (Index n1 = n * gn_; n1 < nend; n1++) {
    490         if (k == 0) {
    491           // Zero the output memory in parallel.
    492           // On 10000x2x10000 mm zeroing can easily take half of time.
    493           // Zero (bn x m) row. Safe to do here because all kernels that will
    494           // write to this memory depend on completion of this task.
    495           // Note: don't call device_.memset() here. device_.memset() blocks on
    496           // thread pool worker thread, which can lead to underutilization and
    497           // deadlocks.
    498           memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
    499         }
    500         RhsPacker()(packed_rhs_[k % (P - 1)][n1],
    501                     rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
    502       }
    503 
    504       if (parallel_pack_ || shard_by_col_) {
    505         signal_switch(k + 1);
    506         for (Index m = nm_ - 1; m >= 0; m--) signal_kernel(m, n, k, m == 0);
    507       } else {
    508         signal_packing(k);
    509       }
    510     }
    511 
    512     void kernel(Index m, Index n, Index k) {
    513       // Note: order of iteration matters here. Iteration over m is innermost
    514       // because we want to reuse the same packed rhs in consequetive tasks
    515       // (rhs fits into L2$ while lhs only into L3$).
    516       const Index nend = n * gn_ + gn(n);
    517       const Index mend = m * gm_ + gm(m);
    518       if (shard_by_col_) {
    519         for (Index n1 = n * gn_; n1 < nend; n1++) {
    520           for (Index m1 = m * gm_; m1 < mend; m1++)
    521             GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
    522                          packed_lhs_[k % (P - 1)][m1],
    523                          packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
    524                          Scalar(1), -1, -1, 0, 0);
    525         }
    526       } else {
    527         for (Index m1 = m * gm_; m1 < mend; m1++)
    528           for (Index n1 = n * gn_; n1 < nend; n1++) {
    529             GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
    530                          packed_lhs_[k % (P - 1)][m1],
    531                          packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
    532                          Scalar(1), -1, -1, 0, 0);
    533           }
    534       }
    535       signal_kernel(m, n, k + 1, false);
    536       signal_switch(k + 2);
    537     }
    538 
    539     void signal_packing(Index k) {
    540       eigen_assert(!parallel_pack_);
    541       Index s = state_packing_ready_[k % P].fetch_sub(1);
    542       eigen_assert(s > 0);
    543       if (s != 1) return;
    544       state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
    545       enqueue_packing(k, shard_by_col_);
    546     }
    547 
    548     void signal_kernel(Index m, Index n, Index k, bool sync) {
    549       std::atomic<uint8_t>* state = &state_kernel_[k % P][m][n];
    550       Index s = state->load();
    551       eigen_assert(s > 0);
    552       if (s != 1 && state->fetch_sub(1) != 1) return;
    553       state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
    554       if (sync)
    555         kernel(m, n, k);
    556       else
    557         device_.enqueueNoNotification([=]() { kernel(m, n, k); });
    558     }
    559 
    560     void signal_switch(Index k, Index v = 1) {
    561       Index s = state_switch_[k % P].fetch_sub(v);
    562       eigen_assert(s >= v);
    563       if (s != v) return;
    564 
    565       // Ready to switch to the next k slice.
    566       // Reset counter for the next iteration.
    567       state_switch_[k % P] =
    568           (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) +
    569           nm_ * nn_;
    570       if (k < nk_) {
    571         // Issue lhs/rhs packing. Their completion will in turn kick off
    572         // kernels.
    573         if (parallel_pack_) {
    574           enqueue_packing(k, !shard_by_col_);
    575           enqueue_packing(k, shard_by_col_);
    576         } else if (shard_by_col_) {
    577           enqueue_packing(k, false);
    578         } else {
    579           enqueue_packing(k, true);
    580         }
    581 
    582         // Termination handling.
    583         // Because kernel completion signals k + 2 switch, we need to finish nk
    584         // + 2 slices without issuing any tasks on nk + 1 slice. So here we
    585         // pretend that all nk + 1 packing tasks just finish instantly; so that
    586         // nk + 2 switch only waits for completion of nk kernels.
    587       } else if (k == nk_) {
    588         signal_switch(k + 1,
    589                       parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
    590       } else {
    591         done_.Notify();
    592       }
    593     }
    594 
    595     // Enqueue all rhs/lhs packing for k-th slice.
    596     void enqueue_packing(Index k, bool rhs) {
    597       enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs);
    598     }
    599 
    600     void enqueue_packing_helper(Index start, Index end, Index k, bool rhs) {
    601       if (end - start == 1) {
    602         if (rhs)
    603           pack_rhs(start, k);
    604         else
    605           pack_lhs(start, k);
    606       } else {
    607         Index mid = (start + end) / 2;
    608         device_.enqueueNoNotification(
    609             [=]() { enqueue_packing_helper(mid, end, k, rhs); });
    610         device_.enqueueNoNotification(
    611             [=]() { enqueue_packing_helper(start, mid, k, rhs); });
    612       }
    613     }
    614 
    615     // Block sizes with accounting for potentially incomplete last block.
    616     Index bm(Index m) const { return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
    617     Index bn(Index n) const { return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
    618     Index bk(Index k) const { return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
    619     // Task grain sizes accounting for potentially incomplete last task.
    620     Index gm(Index m) const { return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
    621     Index gn(Index n) const { return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
    622 
    623     Context(const Context&) = delete;
    624     void operator=(const Context&) = delete;
    625   };
    626 
    627   // Decide whether we want to shard m x n contraction by columns or by rows.
    628   static bool shardByCol(Index m, Index n, Index num_threads) {
    629     // Note: we are comparing both n and m against Traits::nr, it is not
    630     // a mistake. We are trying to figure out how both n and m will fit into
    631     // the main sharding dimension.
    632 
    633     // Sharding by column is the default
    634     // ... unless there is enough data for vectorization over rows
    635     if (m / num_threads >= Traits::nr &&
    636         // and not enough data for vectorization over columns
    637         (n / num_threads < Traits::nr ||
    638          // ... or barely enough data for vectorization over columns,
    639          // but it is not evenly dividable across threads
    640          (n / num_threads < 4 * Traits::nr &&
    641           (n % (num_threads * Traits::nr)) != 0 &&
    642           // ... and it is evenly dividable across threads for rows
    643           ((m % (num_threads * Traits::nr)) == 0 ||
    644            // .. or it is not evenly dividable for both dimensions but
    645            // there is much more data over rows so that corner effects are
    646            // mitigated.
    647            (m / n >= 6)))))
    648       return false;
    649     // Wait, or if matrices are just substantially prolonged over the other
    650     // dimension.
    651     if (n / num_threads < 16 * Traits::nr && m > n * 32) return false;
    652     return true;
    653   }
    654 
    655   Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn,
    656                  int num_threads, bool shard_by_col) const {
    657     Index gm = 1;
    658     Index gm1 = 1;
    659     Index nm0 = divup(m, bm);
    660     Index nm1 = nm0;
    661     for (;;) {
    662       // Find the next candidate for m grain size. It needs to result in
    663       // different number of blocks. E.g. if we have 10 kernels, we want to try
    664       // 5 and 10, but not 6, 7, 8 and 9.
    665       while (gm1 <= nm0 && nm1 == divup(nm0, gm1)) gm1++;
    666       if (gm1 > nm0) break;
    667       // Check the candidate.
    668       int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads,
    669                            shard_by_col);
    670       if (res < 0) break;
    671       nm1 = divup(nm0, gm1);
    672       if (res == 0) continue;
    673       // Commit new grain size.
    674       gm = gm1;
    675     }
    676     return gm;
    677   }
    678 
    679   Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
    680                  int num_threads, bool shard_by_col) const {
    681     Index gn = 1;
    682     Index gn1 = 1;
    683     Index nn0 = divup(n, bn);
    684     Index nn1 = nn0;
    685     for (;;) {
    686       while (gn1 <= nn0 && nn1 == divup(nn0, gn1)) gn1++;
    687       if (gn1 > nn0) break;
    688       int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads,
    689                            shard_by_col);
    690       if (res < 0) break;
    691       nn1 = divup(nn0, gn1);
    692       if (res == 0) continue;
    693       gn = gn1;
    694     }
    695     return gn;
    696   }
    697 
    698   // checkGrain checks whether grain (gm, gn) is suitable and is better than
    699   // (oldgm, oldgn).
    700   int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm,
    701                  Index gn, Index oldgm, Index oldgn, int num_threads,
    702                  bool shard_by_col) const {
    703     const TensorOpCost cost =
    704         contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col, true);
    705     double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize(
    706         static_cast<double>(bm) * gm * bn * gn, cost);
    707     // If the task is too small, then we agree on it regardless of anything
    708     // else. Otherwise synchronization overheads will dominate.
    709     if (taskSize < 1) return 1;
    710     // If it is too large, then we reject it and all larger tasks.
    711     if (taskSize > 2) return -1;
    712     // Now we are in presumably good task size range.
    713     // The main deciding factor here is parallelism. Consider that we have 12
    714     // kernels and 4 threads. Grains of 2, 3 and 4 all yield good task sizes.
    715     // But 2/4 yield 6/3 tasks, which gives us parallelism of 0.75 (at most 3/4
    716     // of cores will be busy). While grain size 3 gives us 4 tasks, which gives
    717     // us parallelism of 1 (we can load all cores).
    718     Index nm0 = divup(m, bm);
    719     Index nn0 = divup(n, bn);
    720     Index new_tasks = divup(nm0, gm) * divup(nn0, gn);
    721     double new_parallelism = static_cast<double>(new_tasks) /
    722                              (divup<int>(new_tasks, num_threads) * num_threads);
    723     Index old_tasks = divup(nm0, oldgm) * divup(nn0, oldgn);
    724     double old_parallelism = static_cast<double>(old_tasks) /
    725                              (divup<int>(old_tasks, num_threads) * num_threads);
    726     if (new_parallelism > old_parallelism || new_parallelism == 1) return 1;
    727     return 0;
    728   }
    729 
    730 #else  // EIGEN_USE_SIMPLE_THREAD_POOL
    731 
    732   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
    733   void evalProduct(Scalar* buffer) const {
    734     if (this->m_j_size == 1) {
    735       this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
    736       return;
    737     }
    738 
    739     evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
    740   }
    741 
    742   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
    743   void evalGemm(Scalar* buffer) const {
    744     // columns in left side, rows in right side
    745     const Index k = this->m_k_size;
    746 
    747     // rows in left side
    748     const Index m = this->m_i_size;
    749 
    750     // columns in right side
    751     const Index n = this->m_j_size;
    752 
    753     // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
    754     this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
    755 
    756 
    757     const int lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
    758     const int rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
    759 
    760     typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
    761                                                    LeftEvaluator, left_nocontract_t,
    762                                                    contract_t, lhs_packet_size,
    763                                                    lhs_inner_dim_contiguous,
    764                                                    false, Unaligned> LhsMapper;
    765 
    766     typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
    767                                                    RightEvaluator, right_nocontract_t,
    768                                                    contract_t, rhs_packet_size,
    769                                                    rhs_inner_dim_contiguous,
    770                                                    rhs_inner_dim_reordered, Unaligned> RhsMapper;
    771 
    772     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
    773 
    774     // TODO: packing could be faster sometimes if we supported row major tensor mappers
    775     typedef internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, Traits::mr,
    776                                     Traits::LhsProgress, ColMajor> LhsPacker;
    777     typedef internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> RhsPacker;
    778 
    779     // TODO: replace false, false with conjugate values?
    780     typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
    781                                   Traits::mr, Traits::nr, false, false> GebpKernel;
    782 
    783     typedef internal::packLhsArg<LhsScalar, LhsMapper, Index> packLArg;
    784     typedef internal::packRhsAndKernelArg<LhsScalar, RhsScalar, RhsMapper, OutputMapper, Index> packRKArg;
    785 
    786     // initialize data mappers
    787     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
    788                   this->m_left_contracting_strides, this->m_k_strides);
    789 
    790     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
    791                   this->m_right_contracting_strides, this->m_k_strides);
    792 
    793     OutputMapper output(buffer, m);
    794 
    795     // compute block sizes (which depend on number of threads)
    796     const Index num_threads = this->m_device.numThreads();
    797     internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, num_threads);
    798     Index mc = blocking.mc();
    799     Index nc = blocking.nc();
    800     Index kc = blocking.kc();
    801     eigen_assert(mc <= m);
    802     eigen_assert(nc <= n);
    803     eigen_assert(kc <= k);
    804 
    805 #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
    806     const Index k_blocks = CEIL_DIV(k, kc);
    807     const Index n_blocks = CEIL_DIV(n, nc);
    808     const Index m_blocks = CEIL_DIV(m, mc);
    809     const Index sizeA = mc * kc;
    810     const Index sizeB = kc * nc;
    811 
    812     /*    cout << "m: " << m << " n: " << n << " k: " << k << endl;
    813     cout << "mc: " << mc << " nc: " << nc << " kc: " << kc << endl;
    814     cout << "m_blocks: " << m_blocks << " n_blocks: " << n_blocks << " k_blocks: " << k_blocks << endl;
    815     cout << "num threads: " << num_threads << endl;
    816     */
    817 
    818     // note: m_device.allocate should return 16 byte aligned pointers, but if blockA and blockB
    819     //       aren't 16 byte aligned segfaults will happen due to SIMD instructions
    820     // note: You can get away with allocating just a single blockA and offsets and meet the
    821     //       the alignment requirements with the assumption that
    822     //       (Traits::mr * sizeof(ResScalar)) % 16 == 0
    823     const Index numBlockAs = numext::mini(num_threads, m_blocks);
    824     MaxSizeVector<LhsScalar *> blockAs(num_threads);
    825     for (int i = 0; i < num_threads; i++) {
    826       blockAs.push_back(static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar))));
    827     }
    828 
    829     // To circumvent alignment issues, I'm just going to separately allocate the memory for each thread
    830     // TODO: is this too much memory to allocate? This simplifies coding a lot, but is wasteful.
    831     //       Other options: (1) reuse memory when a thread finishes. con: tricky
    832     //                      (2) allocate block B memory in each thread. con: overhead
    833     MaxSizeVector<RhsScalar *> blockBs(n_blocks);
    834     for (int i = 0; i < n_blocks; i++) {
    835       blockBs.push_back(static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar))));
    836     }
    837 
    838     // lhs_notifications starts with all null Notifications
    839     MaxSizeVector<Notification*> lhs_notifications(num_threads, nullptr);
    840 
    841     // this should really be numBlockAs * n_blocks;
    842     const Index num_kernel_notifications = num_threads * n_blocks;
    843     MaxSizeVector<Notification*> kernel_notifications(num_kernel_notifications,
    844                                                     nullptr);
    845 
    846     for (Index k_block_idx = 0; k_block_idx < k_blocks; k_block_idx++) {
    847       const Index k_start = k_block_idx * kc;
    848       // make sure we don't overshoot right edge of left matrix
    849       const Index actual_kc = numext::mini(k_start + kc, k) - k_start;
    850 
    851       for (Index m_block_idx = 0; m_block_idx < m_blocks; m_block_idx += numBlockAs) {
    852         const Index num_blocks = numext::mini(m_blocks-m_block_idx, numBlockAs);
    853 
    854         for (Index mt_block_idx = m_block_idx; mt_block_idx < m_block_idx+num_blocks; mt_block_idx++) {
    855           const Index m_start = mt_block_idx * mc;
    856           const Index actual_mc = numext::mini(m_start + mc, m) - m_start;
    857           eigen_assert(actual_mc > 0);
    858 
    859           Index blockAId = (k_block_idx * m_blocks + mt_block_idx) % num_threads;
    860 
    861           for (int i = 0; i < n_blocks; ++i) {
    862             Index notification_id = (blockAId * n_blocks + i);
    863             // Wait for any current kernels using this slot to complete
    864             // before using it.
    865             if (kernel_notifications[notification_id]) {
    866               wait_until_ready(kernel_notifications[notification_id]);
    867               delete kernel_notifications[notification_id];
    868             }
    869             kernel_notifications[notification_id] = new Notification();
    870           }
    871           const packLArg arg = {
    872             blockAs[blockAId], // blockA
    873             lhs,        // lhs
    874             m_start,    // m
    875             k_start,    // k
    876             actual_mc,  // mc
    877             actual_kc,  // kc
    878           };
    879 
    880           // Delete any existing notification since we may be
    881           // replacing it.  The algorithm should ensure that there are
    882           // no existing waiters on this notification.
    883           delete lhs_notifications[blockAId];
    884           lhs_notifications[blockAId] =
    885           this->m_device.enqueue(&Self::packLhs<packLArg, LhsPacker>, arg);
    886         }
    887 
    888         // now start kernels.
    889         const Index m_base_start = m_block_idx * mc;
    890         const bool need_to_pack = m_block_idx == 0;
    891 
    892         for (Index n_block_idx = 0; n_block_idx < n_blocks; n_block_idx++) {
    893           const Index n_start = n_block_idx * nc;
    894           const Index actual_nc = numext::mini(n_start + nc, n) - n_start;
    895 
    896           // first make sure the previous kernels are all done before overwriting rhs. Also wait if
    897           // we're going to start new k. In both cases need_to_pack is true.
    898           if (need_to_pack) {
    899             for (Index i = num_blocks; i < num_threads; ++i) {
    900               Index blockAId = (k_block_idx * m_blocks + i + m_block_idx) % num_threads;
    901               Index future_id = (blockAId * n_blocks + n_block_idx);
    902               wait_until_ready(kernel_notifications[future_id]);
    903             }
    904           }
    905 
    906           packRKArg arg = {
    907             &blockAs, // blockA
    908             blockBs[n_block_idx], // blockB
    909             rhs,          // rhs
    910             output,       // output
    911             m_base_start, // m
    912             k_start,      // k
    913             n_start,      // n
    914             mc,           // mc
    915             actual_kc,    // kc
    916             actual_nc,    // nc
    917             num_threads,
    918             numBlockAs,
    919             m,
    920             k_block_idx,
    921             m_block_idx,
    922             n_block_idx, // n_block_idx
    923             m_blocks, // m_blocks
    924             n_blocks, // n_blocks
    925             &kernel_notifications, // kernel notifications
    926             &lhs_notifications,    // lhs notifications
    927             need_to_pack, // need_to_pack
    928           };
    929 
    930           // We asynchronously kick off this function, which ends up
    931           // notifying the appropriate kernel_notifications objects,
    932           // which this thread waits on before exiting.
    933           this->m_device.enqueueNoNotification(&Self::packRhsAndKernel<packRKArg, RhsPacker, GebpKernel>, arg);
    934         }
    935       }
    936     }
    937 
    938     // Make sure all the kernels are done.
    939     for (size_t i = 0; i < kernel_notifications.size(); ++i) {
    940       wait_until_ready(kernel_notifications[i]);
    941       delete kernel_notifications[i];
    942     }
    943 
    944     // No need to wait for lhs notifications since they should have
    945     // already been waited on.  Just clean them up.
    946     for (size_t i = 0; i < lhs_notifications.size(); ++i) {
    947       delete lhs_notifications[i];
    948     }
    949 
    950     // deallocate all of the memory for both A and B's
    951     for (size_t i = 0; i < blockAs.size(); i++) {
    952       this->m_device.deallocate(blockAs[i]);
    953     }
    954     for (size_t i = 0; i < blockBs.size(); i++) {
    955       this->m_device.deallocate(blockBs[i]);
    956     }
    957 
    958 #undef CEIL_DIV
    959   }
    960 
    961   /*
    962    * Packs a LHS block of size (mt, kc) starting at lhs(m, k). Before packing
    963    * the LHS block, check that all of the kernels that worked on the same
    964    * mt_block_idx in the previous m_block are done.
    965    */
    966   template <typename packLArg, typename LhsPacker>
    967   static void packLhs(const packLArg arg) {
    968     // perform actual packing
    969     LhsPacker pack_lhs;
    970     pack_lhs(arg.blockA, arg.lhs.getSubMapper(arg.m_start, arg.k_start), arg.kc, arg.mc);
    971   }
    972 
    973   /*
    974    * Packs a RHS block of size (kc, nc) starting at (k, n) after checking that
    975    * all kernels in the previous block are done.
    976    * Then for each LHS future, we wait on the future and then call GEBP
    977    * on the area packed by the future (which starts at
    978    * blockA + future_idx * mt * kc) on the LHS and with the full packed
    979    * RHS block.
    980    * The output of this GEBP is written to output(m + i * mt, n).
    981    */
    982   template <typename packRKArg, typename RhsPacker, typename GebpKernel>
    983   static void packRhsAndKernel(packRKArg arg) {
    984     if (arg.need_to_pack) {
    985       RhsPacker pack_rhs;
    986       pack_rhs(arg.blockB, arg.rhs.getSubMapper(arg.k, arg.n), arg.kc, arg.nc);
    987     }
    988 
    989     GebpKernel gebp;
    990     for (Index mt_block_idx = 0; mt_block_idx < arg.num_blockAs; mt_block_idx++) {
    991       const Index m_base_start = arg.m + arg.mc*mt_block_idx;
    992       if (m_base_start < arg.max_m) {
    993         Index blockAId = (arg.k_block_idx * arg.m_blocks + mt_block_idx + arg.m_block_idx) % arg.num_threads;
    994         wait_until_ready((*arg.lhs_notifications)[blockAId]);
    995         const Index actual_mc = numext::mini(m_base_start + arg.mc, arg.max_m) - m_base_start;
    996         gebp(arg.output.getSubMapper(m_base_start, arg.n),
    997              (*arg.blockAs)[blockAId], arg.blockB,
    998              actual_mc, arg.kc, arg.nc, Scalar(1), -1, -1, 0, 0);
    999 
   1000         // Notify that the kernel is done.
   1001         const Index set_idx = blockAId * arg.n_blocks + arg.n_block_idx;
   1002         (*arg.kernel_notifications)[set_idx]->Notify();
   1003       }
   1004     }
   1005   }
   1006 #endif  // EIGEN_USE_SIMPLE_THREAD_POOL
   1007 
   1008   TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk,
   1009                                bool shard_by_col, bool prepacked) const {
   1010     const int packed_size = std::min<int>(PacketType<LhsScalar, Device>::size,
   1011                                           PacketType<RhsScalar, Device>::size);
   1012     const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
   1013     const double kd = static_cast<double>(bk);
   1014     // Peak VFMA bandwidth is 0.5. However if we have not enough data for
   1015     // vectorization bandwidth drops. The 4.0 and 2.0 bandwidth is determined
   1016     // experimentally.
   1017     double computeBandwidth = bk == 1 ? 4.0 :
   1018           (shard_by_col ? bn : bm) < Traits::nr ||
   1019           (shard_by_col ? bm : bn) < Traits::mr ? 2.0 : 0.5;
   1020 #ifndef EIGEN_VECTORIZE_FMA
   1021     // Bandwidth of all of VFMA/MULPS/ADDPS is 0.5 on latest Intel processors.
   1022     // However for MULPS/ADDPS we have dependent sequence of 2 such instructions,
   1023     // so overall bandwidth is 1.0.
   1024     if (computeBandwidth == 0.5) computeBandwidth = 1.0;
   1025 #endif
   1026     // Computations.
   1027     TensorOpCost cost = TensorOpCost(0, 0, kd * computeBandwidth, true, packed_size);
   1028     // Output stores.
   1029     cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
   1030     if (prepacked) {
   1031       // Packing and kernels are executed in different tasks. When we calculate
   1032       // task grain size we look only at kernel cost assuming that kernel
   1033       // is more expensive than packing.
   1034       return cost;
   1035     }
   1036     // Lhs/rhs loads + computations.
   1037     TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * (kd / n);
   1038     TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * (kd / m);
   1039     // Lhs packing memory cost does not contribute considerably to overall
   1040     // execution time because lhs is prefetched early and accessed sequentially.
   1041     if (shard_by_col)
   1042       lhsCost.dropMemoryCost();
   1043     else
   1044       rhsCost.dropMemoryCost();
   1045     return cost + lhsCost + rhsCost;
   1046   }
   1047 };
   1048 
   1049 } // end namespace Eigen
   1050 
   1051 #endif  // EIGEN_USE_THREADS
   1052 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
   1053