Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly (at) google.com>
      6 // Copyright (C) 2014 Eric Martin <eric (at) ericmart.in>
      7 //
      8 // This Source Code Form is subject to the terms of the Mozilla
      9 // Public License v. 2.0. If a copy of the MPL was not distributed
     10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     11 
     12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
     13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
     14 
     15 #if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
     16 
     17 namespace Eigen {
     18 
     19 template<typename Scalar, typename Index, typename LhsMapper,
     20          typename RhsMapper, typename OutputMapper, bool needs_edge_check>
     21 __device__ EIGEN_STRONG_INLINE void
     22 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
     23                                const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
     24                        const Index m_size, const Index n_size, const Index k_size) {
     25 
     26   const Index m_block_idx = blockIdx.x;
     27   const Index n_block_idx = blockIdx.y;
     28 
     29   const Index base_m = 64 * m_block_idx;
     30   const Index base_n = 64 * n_block_idx;
     31 
     32   // declare and initialize 64 registers for output 8x8 block
     33 
     34   // prefetch registers
     35   Scalar lhs_pf0;
     36   Scalar lhs_pf1;
     37   Scalar lhs_pf2;
     38   Scalar lhs_pf3;
     39   Scalar lhs_pf4;
     40   Scalar lhs_pf5;
     41   Scalar lhs_pf6;
     42   Scalar lhs_pf7;
     43 
     44   Scalar rhs_pf0;
     45   Scalar rhs_pf1;
     46   Scalar rhs_pf2;
     47   Scalar rhs_pf3;
     48   Scalar rhs_pf4;
     49   Scalar rhs_pf5;
     50   Scalar rhs_pf6;
     51   Scalar rhs_pf7;
     52 
     53   // shared memory is formatted
     54   // (contract idx in block, nocontract idx in block, block idx)
     55   // where block idx is column major. This transposition limits the number of
     56   // bank conflicts when reading the LHS. The core idea is that since the contracting
     57   // index is shared by both sides, then the contracting index should be in threadIdx.x.
     58 
     59   // On the LHS, we pad each row inside of each block with an extra element. This makes
     60   // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
     61   // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
     62 
     63   // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
     64   // conflicts on writes and also none on reads.
     65 
     66   // storage indices
     67   const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
     68   const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
     69 
     70   const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
     71   const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
     72   const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
     73   const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
     74   const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
     75   const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
     76   const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
     77   const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
     78 
     79   const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
     80   const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
     81   const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
     82   const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
     83   const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
     84   const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
     85   const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
     86   const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
     87 
     88   // in the loading code, the following variables are important:
     89   // threadIdx.x: the vertical position in an 8x8 block
     90   // threadIdx.y: the vertical index of the 8x8 block in the grid
     91   // threadIdx.z: the horizontal position in an 8x8 block
     92   // k: the horizontal index of the 8x8 block in the grid
     93   //
     94   // The k parameter is implicit (it was the loop counter for a loop that went
     95   // from 0 to <8, but now that loop is unrolled in the below code.
     96 
     97   const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
     98   const Index lhs_vert = base_m + load_idx_vert;
     99 
    100 #define prefetchIntoRegisters(base_k)                           \
    101   {                                                             \
    102     lhs_pf0 = conv(0);                                          \
    103     lhs_pf1 = conv(0);                                          \
    104     lhs_pf2 = conv(0);                                          \
    105     lhs_pf3 = conv(0);                                          \
    106     lhs_pf4 = conv(0);                                          \
    107     lhs_pf5 = conv(0);                                          \
    108     lhs_pf6 = conv(0);                                          \
    109     lhs_pf7 = conv(0);                                          \
    110                                                                 \
    111     rhs_pf0 = conv(0);                                          \
    112     rhs_pf1 = conv(0);                                          \
    113     rhs_pf2 = conv(0);                                          \
    114     rhs_pf3 = conv(0);                                          \
    115     rhs_pf4 = conv(0);                                          \
    116     rhs_pf5 = conv(0);                                          \
    117     rhs_pf6 = conv(0);                                          \
    118     rhs_pf7 = conv(0);                                          \
    119                                                                 \
    120     if (!needs_edge_check || lhs_vert < m_size) {               \
    121       const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8;   \
    122       const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8;   \
    123       const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8;   \
    124       const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8;   \
    125       const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8;   \
    126       const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8;   \
    127       const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8;   \
    128       const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8;   \
    129                                                                 \
    130       if (!needs_edge_check || lhs_horiz_7 < k_size) {          \
    131         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    132         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    133         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    134         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    135         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
    136         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
    137         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
    138         lhs_pf7 = lhs(lhs_vert, lhs_horiz_7);                   \
    139       } else if (lhs_horiz_6 < k_size) {                        \
    140         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    141         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    142         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    143         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    144         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
    145         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
    146         lhs_pf6 = lhs(lhs_vert, lhs_horiz_6);                   \
    147       } else if (lhs_horiz_5 < k_size) {                        \
    148         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    149         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    150         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    151         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    152         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
    153         lhs_pf5 = lhs(lhs_vert, lhs_horiz_5);                   \
    154       } else if (lhs_horiz_4 < k_size) {                        \
    155         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    156         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    157         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    158         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    159         lhs_pf4 = lhs(lhs_vert, lhs_horiz_4);                   \
    160       } else if (lhs_horiz_3 < k_size) {                        \
    161         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    162         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    163         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    164         lhs_pf3 = lhs(lhs_vert, lhs_horiz_3);                   \
    165       } else if (lhs_horiz_2 < k_size) {                        \
    166         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    167         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    168         lhs_pf2 = lhs(lhs_vert, lhs_horiz_2);                   \
    169       } else if (lhs_horiz_1 < k_size) {                        \
    170         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    171         lhs_pf1 = lhs(lhs_vert, lhs_horiz_1);                   \
    172       } else if (lhs_horiz_0 < k_size) {                        \
    173         lhs_pf0 = lhs(lhs_vert, lhs_horiz_0);                   \
    174       }                                                         \
    175     }                                                           \
    176                                                                 \
    177     const Index rhs_vert = base_k + load_idx_vert;              \
    178     if (!needs_edge_check || rhs_vert < k_size) {               \
    179       const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8;   \
    180       const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8;   \
    181       const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8;   \
    182       const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8;   \
    183       const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8;   \
    184       const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8;   \
    185       const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8;   \
    186       const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8;   \
    187                                                                 \
    188       if (rhs_horiz_7 < n_size) {                               \
    189         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    190         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    191         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    192         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    193         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
    194         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
    195         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
    196         rhs_pf7 = rhs(rhs_vert, rhs_horiz_7);                   \
    197       } else if (rhs_horiz_6 < n_size) {                        \
    198         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    199         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    200         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    201         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    202         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
    203         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
    204         rhs_pf6 = rhs(rhs_vert, rhs_horiz_6);                   \
    205       } else if (rhs_horiz_5 < n_size) {                        \
    206         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    207         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    208         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    209         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    210         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
    211         rhs_pf5 = rhs(rhs_vert, rhs_horiz_5);                   \
    212       } else if (rhs_horiz_4 < n_size) {                        \
    213         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    214         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    215         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    216         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    217         rhs_pf4 = rhs(rhs_vert, rhs_horiz_4);                   \
    218       } else if (rhs_horiz_3 < n_size) {                        \
    219         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    220         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    221         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    222         rhs_pf3 = rhs(rhs_vert, rhs_horiz_3);                   \
    223       } else if (rhs_horiz_2 < n_size) {                        \
    224         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    225         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    226         rhs_pf2 = rhs(rhs_vert, rhs_horiz_2);                   \
    227       } else if (rhs_horiz_1 < n_size) {                        \
    228         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    229         rhs_pf1 = rhs(rhs_vert, rhs_horiz_1);                   \
    230       } else if (rhs_horiz_0 < n_size) {                        \
    231         rhs_pf0 = rhs(rhs_vert, rhs_horiz_0);                   \
    232       }                                                         \
    233     }                                                           \
    234   }                                                             \
    235 
    236 #define writeRegToShmem(_)                      \
    237   lhs_shmem[lhs_store_idx_0] = lhs_pf0;         \
    238   rhs_shmem[rhs_store_idx_0] = rhs_pf0;         \
    239                                                 \
    240   lhs_shmem[lhs_store_idx_1] = lhs_pf1;         \
    241   rhs_shmem[rhs_store_idx_1] = rhs_pf1;         \
    242                                                 \
    243   lhs_shmem[lhs_store_idx_2] = lhs_pf2;         \
    244   rhs_shmem[rhs_store_idx_2] = rhs_pf2;         \
    245                                                 \
    246   lhs_shmem[lhs_store_idx_3] = lhs_pf3;         \
    247   rhs_shmem[rhs_store_idx_3] = rhs_pf3;         \
    248                                                 \
    249   lhs_shmem[lhs_store_idx_4] = lhs_pf4;         \
    250   rhs_shmem[rhs_store_idx_4] = rhs_pf4;         \
    251                                                 \
    252   lhs_shmem[lhs_store_idx_5] = lhs_pf5;         \
    253   rhs_shmem[rhs_store_idx_5] = rhs_pf5;         \
    254                                                 \
    255   lhs_shmem[lhs_store_idx_6] = lhs_pf6;         \
    256   rhs_shmem[rhs_store_idx_6] = rhs_pf6;         \
    257                                                 \
    258   lhs_shmem[lhs_store_idx_7] = lhs_pf7;         \
    259   rhs_shmem[rhs_store_idx_7] = rhs_pf7;         \
    260 
    261   // declare and initialize result array
    262 #define res(i, j) _res_##i##j
    263 #define initResultRow(i)                        \
    264   Scalar res(i, 0) = conv(0);                   \
    265   Scalar res(i, 1) = conv(0);                   \
    266   Scalar res(i, 2) = conv(0);                   \
    267   Scalar res(i, 3) = conv(0);                   \
    268   Scalar res(i, 4) = conv(0);                   \
    269   Scalar res(i, 5) = conv(0);                   \
    270   Scalar res(i, 6) = conv(0);                   \
    271   Scalar res(i, 7) = conv(0);                   \
    272 
    273   internal::scalar_cast_op<int, Scalar> conv;
    274   initResultRow(0);
    275   initResultRow(1);
    276   initResultRow(2);
    277   initResultRow(3);
    278   initResultRow(4);
    279   initResultRow(5);
    280   initResultRow(6);
    281   initResultRow(7);
    282 #undef initResultRow
    283 
    284   for (Index base_k = 0; base_k < k_size; base_k += 64) {
    285     // wait for previous iteration to finish with shmem. Despite common sense,
    286     // the code is a bit faster with this here then at bottom of loop
    287     __syncthreads();
    288 
    289     prefetchIntoRegisters(base_k);
    290     writeRegToShmem();
    291 
    292     #undef prefetchIntoRegisters
    293     #undef writeRegToShmem
    294 
    295     // wait for shared mem packing to be done before starting computation
    296     __syncthreads();
    297 
    298     // compute 8x8 matrix product by outer product. This involves packing one column
    299     // of LHS and one row of RHS into registers (takes 16 registers).
    300 
    301 #define lcol(i) _lcol##i
    302     Scalar lcol(0);
    303     Scalar lcol(1);
    304     Scalar lcol(2);
    305     Scalar lcol(3);
    306     Scalar lcol(4);
    307     Scalar lcol(5);
    308     Scalar lcol(6);
    309     Scalar lcol(7);
    310 
    311 #define rrow(j) _rrow##j
    312     Scalar rrow(0);
    313     Scalar rrow(1);
    314     Scalar rrow(2);
    315     Scalar rrow(3);
    316     Scalar rrow(4);
    317     Scalar rrow(5);
    318     Scalar rrow(6);
    319     Scalar rrow(7);
    320 
    321     // Now x corresponds to k, y to m, and z to n
    322     const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
    323     const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
    324 
    325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
    326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
    327 
    328 #define loadData(i, j)                          \
    329     lcol(0) = lhs_element(0, j);               \
    330     rrow(0) = rhs_element(i, 0);               \
    331     lcol(1) = lhs_element(1, j);               \
    332     rrow(1) = rhs_element(i, 1);               \
    333     lcol(2) = lhs_element(2, j);               \
    334     rrow(2) = rhs_element(i, 2);               \
    335     lcol(3) = lhs_element(3, j);               \
    336     rrow(3) = rhs_element(i, 3);               \
    337     lcol(4) = lhs_element(4, j);               \
    338     rrow(4) = rhs_element(i, 4);               \
    339     lcol(5) = lhs_element(5, j);               \
    340     rrow(5) = rhs_element(i, 5);               \
    341     lcol(6) = lhs_element(6, j);               \
    342     rrow(6) = rhs_element(i, 6);               \
    343     lcol(7) = lhs_element(7, j);               \
    344     rrow(7) = rhs_element(i, 7);               \
    345 
    346 #define computeCol(j)                           \
    347     res(0, j) += lcol(0) * rrow(j);             \
    348     res(1, j) += lcol(1) * rrow(j);             \
    349     res(2, j) += lcol(2) * rrow(j);             \
    350     res(3, j) += lcol(3) * rrow(j);             \
    351     res(4, j) += lcol(4) * rrow(j);             \
    352     res(5, j) += lcol(5) * rrow(j);             \
    353     res(6, j) += lcol(6) * rrow(j);             \
    354     res(7, j) += lcol(7) * rrow(j);             \
    355 
    356 #define computePass(i)                          \
    357     loadData(i, i);                             \
    358                                                 \
    359     computeCol(0);                              \
    360     computeCol(1);                              \
    361     computeCol(2);                              \
    362     computeCol(3);                              \
    363     computeCol(4);                              \
    364     computeCol(5);                              \
    365     computeCol(6);                              \
    366     computeCol(7);                              \
    367 
    368     computePass(0);
    369     computePass(1);
    370     computePass(2);
    371     computePass(3);
    372     computePass(4);
    373     computePass(5);
    374     computePass(6);
    375     computePass(7);
    376 
    377 #undef lcol
    378 #undef rrow
    379 #undef lhs_element
    380 #undef rhs_element
    381 #undef loadData
    382 #undef computeCol
    383 #undef computePass
    384   } // end loop over k
    385 
    386   // we've now iterated over all of the large (ie width 64) k blocks and
    387   // accumulated results in registers. At this point thread (x, y, z) contains
    388   // the sum across all big k blocks of the product of little k block of index (x, y)
    389   // with block of index (y, z). To compute the final output, we need to reduce
    390   // the 8 threads over y by summation.
    391 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
    392 
    393 #define reduceRow(i, mask)                      \
    394   shuffleInc(i, 0, mask);                       \
    395   shuffleInc(i, 1, mask);                       \
    396   shuffleInc(i, 2, mask);                       \
    397   shuffleInc(i, 3, mask);                       \
    398   shuffleInc(i, 4, mask);                       \
    399   shuffleInc(i, 5, mask);                       \
    400   shuffleInc(i, 6, mask);                       \
    401   shuffleInc(i, 7, mask);                       \
    402 
    403 #define reduceMatrix(mask)                      \
    404   reduceRow(0, mask);                           \
    405   reduceRow(1, mask);                           \
    406   reduceRow(2, mask);                           \
    407   reduceRow(3, mask);                           \
    408   reduceRow(4, mask);                           \
    409   reduceRow(5, mask);                           \
    410   reduceRow(6, mask);                           \
    411   reduceRow(7, mask);                           \
    412 
    413   // actually perform the reduction, now each thread of index (_, y, z)
    414   // contains the correct values in its registers that belong in the output
    415   // block
    416   reduceMatrix(1);
    417   reduceMatrix(2);
    418   reduceMatrix(4);
    419 
    420 #undef shuffleInc
    421 #undef reduceRow
    422 #undef reduceMatrix
    423 
    424   // now we need to copy the 64 values into main memory. We can't split work
    425   // among threads because all variables are in registers. There's 2 ways
    426   // to do this:
    427   // (1) have 1 thread do 64 writes from registers into global memory
    428   // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
    429   //     each do 8 writes into global memory. We can just overwrite the shared
    430   //     memory from the problem we just solved.
    431   // (2) is slightly faster than (1) due to less branching and more ILP
    432 
    433   // TODO: won't yield much gain, but could just use currently unused shared mem
    434   //       and then we won't have to sync
    435   // wait for shared mem to be out of use
    436   __syncthreads();
    437 
    438 #define writeResultShmem(i, j)                                          \
    439   lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \
    440 
    441 #define writeRow(i)                             \
    442   writeResultShmem(i, 0);                       \
    443   writeResultShmem(i, 1);                       \
    444   writeResultShmem(i, 2);                       \
    445   writeResultShmem(i, 3);                       \
    446   writeResultShmem(i, 4);                       \
    447   writeResultShmem(i, 5);                       \
    448   writeResultShmem(i, 6);                       \
    449   writeResultShmem(i, 7);                       \
    450 
    451   if (threadIdx.x == 0) {
    452     writeRow(0);
    453     writeRow(1);
    454     writeRow(2);
    455     writeRow(3);
    456     writeRow(4);
    457     writeRow(5);
    458     writeRow(6);
    459     writeRow(7);
    460   }
    461 #undef writeResultShmem
    462 #undef writeRow
    463 
    464   const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
    465   const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
    466 
    467   if (threadIdx.x < max_i_write) {
    468     if (max_j_write == 8) {
    469       // TODO: can i trade bank conflicts for coalesced writes?
    470       Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
    471       Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
    472       Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
    473       Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
    474       Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
    475       Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
    476       Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
    477       Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
    478 
    479       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
    480       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
    481       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
    482       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
    483       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
    484       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
    485       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
    486       output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
    487     } else {
    488 #pragma unroll 7
    489       for (int j = 0; j < max_j_write; j++) {
    490         Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
    491         output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
    492       }
    493     }
    494   }
    495 #undef res
    496 }
    497 
    498 
    499 template<typename Scalar, typename Index, typename LhsMapper,
    500          typename RhsMapper, typename OutputMapper>
    501 __global__ void
    502 __launch_bounds__(512)
    503 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
    504                        const OutputMapper output,
    505                        const Index m_size, const Index n_size, const Index k_size) {
    506   __shared__ Scalar lhs_shmem[72 * 64];
    507   __shared__ Scalar rhs_shmem[72 * 64];
    508 
    509   const Index m_block_idx = blockIdx.x;
    510   const Index n_block_idx = blockIdx.y;
    511 
    512   const Index base_m = 64 * m_block_idx;
    513   const Index base_n = 64 * n_block_idx;
    514 
    515   if (base_m + 63 < m_size && base_n + 63 < n_size) {
    516     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
    517   } else {
    518     EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
    519   }
    520 }
    521 
    522 
    523 template<typename Index, typename LhsMapper,
    524          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
    525          bool CHECK_RHS_BOUNDARY>
    526 __device__ EIGEN_STRONG_INLINE void
    527 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
    528                        const OutputMapper output, float2 lhs_shmem2[][16],
    529                        float2 rhs_shmem2[][8], const Index m_size,
    530                        const Index n_size, const Index k_size,
    531                        const Index base_m, const Index base_n) {
    532   typedef float Scalar;
    533 
    534   // prefetch registers
    535   float4 lhs_pf0, rhs_pf0;
    536 
    537   float4 results[4];
    538   for (int i=0; i < 4; i++) {
    539     results[i].x = results[i].y = results[i].z = results[i].w = 0;
    540   }
    541 
    542 
    543 #define prefetch_lhs(reg, row, col)                   \
    544     if (!CHECK_LHS_BOUNDARY) {                        \
    545       if (col < k_size) {                             \
    546         reg =lhs.loadPacket<Unaligned>(row, col);     \
    547       }                                               \
    548     } else {                                          \
    549       if (col < k_size) {                             \
    550         if (row + 3 < m_size) {                       \
    551           reg =lhs.loadPacket<Unaligned>(row, col);   \
    552         } else if (row + 2 < m_size) {                \
    553           reg.x =lhs(row + 0, col);                   \
    554           reg.y =lhs(row + 1, col);                   \
    555           reg.z =lhs(row + 2, col);                   \
    556         } else if (row + 1 < m_size) {                \
    557           reg.x =lhs(row + 0, col);                   \
    558           reg.y =lhs(row + 1, col);                   \
    559         } else if (row  < m_size) {                   \
    560           reg.x =lhs(row + 0, col);                   \
    561         }                                             \
    562       }                                               \
    563     }                                                 \
    564 
    565 
    566   Index lhs_vert = base_m+threadIdx.x*4;
    567 
    568   for (Index k = 0; k < k_size; k += 16) {
    569     lhs_pf0 = internal::pset1<float4>(0);
    570     rhs_pf0 = internal::pset1<float4>(0);
    571 
    572     Index lhs_horiz = threadIdx.y+k;
    573     prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
    574 
    575     Index rhs_vert = k+(threadIdx.x%4)*4;
    576     Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n;
    577 
    578     if (!CHECK_RHS_BOUNDARY) {
    579       if ((rhs_vert + 3) < k_size) {
    580         // just CHECK_RHS_BOUNDARY
    581         rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
    582       } else if (rhs_vert + 2 < k_size) {
    583         // just CHECK_RHS_BOUNDARY
    584         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    585         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    586         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    587       } else if (rhs_vert + 1 < k_size) {
    588         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    589         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    590       } else if (rhs_vert  < k_size) {
    591         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    592       }
    593     } else {
    594       if (rhs_horiz0 < n_size) {
    595         if ((rhs_vert + 3) < k_size) {
    596           rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
    597         } else if ((rhs_vert + 2) < k_size) {
    598           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    599           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    600           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    601         } else if ((rhs_vert + 1) < k_size) {
    602           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    603           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    604         } else if (rhs_vert  < k_size) {
    605           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    606         }
    607       }
    608     }
    609     float x1, x2 ;
    610     // the following can be a bitwise operation..... some day.
    611     if((threadIdx.x%8) < 4) {
    612       x1 = rhs_pf0.y;
    613       x2 = rhs_pf0.w;
    614     } else {
    615       x1 = rhs_pf0.x;
    616       x2 = rhs_pf0.z;
    617     }
    618     x1 = __shfl_xor(x1, 4);
    619     x2 = __shfl_xor(x2, 4);
    620     if((threadIdx.x%8) < 4) {
    621       rhs_pf0.y = x1;
    622       rhs_pf0.w = x2;
    623     } else {
    624       rhs_pf0.x = x1;
    625       rhs_pf0.z = x2;
    626     }
    627 
    628     // We have 64 features.
    629     // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
    630     // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
    631     // ...
    632     // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
    633     // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
    634     // ...
    635     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y);
    636     rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w);
    637 
    638     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
    639     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
    640     // ...
    641     // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61)
    642     // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63)
    643     // ...
    644 
    645     lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
    646     lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
    647 
    648 
    649 #define add_vals(fl1, fl2, fr1, fr2)\
    650     results[0].x += fl1.x * fr1.x;\
    651     results[0].y += fl1.y * fr1.x;\
    652     results[0].z += fl2.x * fr1.x;\
    653     results[0].w += fl2.y * fr1.x;\
    654 \
    655     results[1].x += fl1.x * fr1.y;\
    656     results[1].y += fl1.y * fr1.y;\
    657     results[1].z += fl2.x * fr1.y;\
    658     results[1].w += fl2.y * fr1.y;\
    659 \
    660     results[2].x += fl1.x * fr2.x;\
    661     results[2].y += fl1.y * fr2.x;\
    662     results[2].z += fl2.x * fr2.x;\
    663     results[2].w += fl2.y * fr2.x;\
    664 \
    665     results[3].x += fl1.x * fr2.y;\
    666     results[3].y += fl1.y * fr2.y;\
    667     results[3].z += fl2.x * fr2.y;\
    668     results[3].w += fl2.y * fr2.y;\
    669 
    670     __syncthreads();
    671 
    672     // Do the multiplies.
    673     #pragma unroll
    674     for (int koff = 0; koff < 16; koff ++) {
    675       // 32 x threads.
    676       float2 fl1 = lhs_shmem2[koff][threadIdx.x];
    677       float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
    678 
    679       int start_feature = threadIdx.y * 4;
    680       float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
    681       float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4];
    682 
    683       add_vals(fl1, fl2, fr1, fr2)
    684     }
    685     __syncthreads();
    686   }
    687 
    688 #undef prefetch_lhs
    689 #undef add_vals
    690 
    691   Index horiz_base = threadIdx.y*4+base_n;
    692   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
    693     for (int i = 0; i < 4; i++) {
    694       output(lhs_vert, horiz_base + i) = results[i].x;
    695       output(lhs_vert + 1, horiz_base + i) = results[i].y;
    696       output(lhs_vert + 2, horiz_base + i) = results[i].z;
    697       output(lhs_vert + 3, horiz_base + i) = results[i].w;
    698     }
    699   } else if (!CHECK_RHS_BOUNDARY) {
    700     // CHECK LHS
    701     if (lhs_vert + 3 < m_size) {
    702       for (int i = 0; i < 4; i++) {
    703         output(lhs_vert, horiz_base + i) = results[i].x;
    704         output(lhs_vert + 1, horiz_base + i) = results[i].y;
    705         output(lhs_vert + 2, horiz_base + i) = results[i].z;
    706         output(lhs_vert + 3, horiz_base + i) = results[i].w;
    707       }
    708     } else if (lhs_vert + 2 < m_size) {
    709       for (int i = 0; i < 4; i++) {
    710         output(lhs_vert, horiz_base + i) = results[i].x;
    711         output(lhs_vert + 1, horiz_base + i) = results[i].y;
    712         output(lhs_vert + 2, horiz_base + i) = results[i].z;
    713       }
    714     } else if (lhs_vert + 1 < m_size) {
    715       for (int i = 0; i < 4; i++) {
    716         output(lhs_vert, horiz_base + i) = results[i].x;
    717         output(lhs_vert + 1, horiz_base + i) = results[i].y;
    718       }
    719     } else if (lhs_vert  < m_size) {
    720       for (int i = 0; i < 4; i++) {
    721         output(lhs_vert, horiz_base + i) = results[i].x;
    722       }
    723     }
    724   } else if (!CHECK_LHS_BOUNDARY) {
    725     // CHECK RHS
    726     /*
    727     int ncols_rem = fminf(n_size- horiz_base, 4);
    728     for (int i = 0; i < ncols_rem; i++) {
    729       output(lhs_vert, horiz_base + i) = results[i].x;
    730       output(lhs_vert + 1, horiz_base + i) = results[i].y;
    731       output(lhs_vert + 2, horiz_base + i) = results[i].z;
    732       output(lhs_vert + 3, horiz_base + i) = results[i].w;
    733     }*/
    734     for (int i = 0; i < 4; i++) {
    735       if (horiz_base+i < n_size) {
    736         output(lhs_vert, horiz_base + i) = results[i].x;
    737         output(lhs_vert + 1, horiz_base + i) = results[i].y;
    738         output(lhs_vert + 2, horiz_base + i) = results[i].z;
    739         output(lhs_vert + 3, horiz_base + i) = results[i].w;
    740        }
    741     }
    742   } else {
    743     // CHECK both boundaries.
    744     for (int i = 0; i < 4; i++) {
    745       if (horiz_base+i < n_size) {
    746         if (lhs_vert < m_size)
    747           output(lhs_vert, horiz_base + i) = results[i].x;
    748         if (lhs_vert + 1 < m_size)
    749           output(lhs_vert + 1, horiz_base + i) = results[i].y;
    750         if (lhs_vert + 2 < m_size)
    751           output(lhs_vert + 2, horiz_base + i) = results[i].z;
    752         if (lhs_vert + 3 < m_size)
    753           output(lhs_vert + 3, horiz_base + i) = results[i].w;
    754       }
    755     }
    756   }
    757 }
    758 
    759 
    760 template<typename Index, typename LhsMapper,
    761          typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
    762          bool CHECK_RHS_BOUNDARY>
    763 __device__ EIGEN_STRONG_INLINE void
    764 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
    765                        const OutputMapper output, float2 lhs_shmem2[][32],
    766                        float2 rhs_shmem2[][8], const Index m_size,
    767                        const Index n_size, const Index k_size,
    768                        const Index base_m, const Index base_n) {
    769   typedef float Scalar;
    770 
    771   // prefetch registers
    772   float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
    773   float4 rhs_pf0, rhs_pf1;
    774 
    775   float4 results[8];
    776   for (int i=0; i < 8; i++) {
    777     results[i].x = results[i].y = results[i].z = results[i].w = 0;
    778   }
    779 
    780 
    781   Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32;
    782   for (Index k = 0; k < k_size; k += 32) {
    783     lhs_pf0 = internal::pset1<float4>(0);
    784     lhs_pf1 = internal::pset1<float4>(0);
    785     lhs_pf2 = internal::pset1<float4>(0);
    786     lhs_pf3 = internal::pset1<float4>(0);
    787 
    788     rhs_pf0 = internal::pset1<float4>(0);
    789     rhs_pf1 = internal::pset1<float4>(0);
    790 
    791      if (!CHECK_LHS_BOUNDARY) {
    792       if ((threadIdx.y/4+k+24) < k_size) {
    793         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
    794         lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    795         lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
    796         lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
    797       } else if ((threadIdx.y/4+k+16) < k_size) {
    798         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
    799         lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    800         lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
    801       } else if ((threadIdx.y/4+k+8) < k_size) {
    802         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
    803         lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    804       } else if ((threadIdx.y/4+k) < k_size) {
    805         lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
    806       }
    807     } else {
    808       // just CHECK_LHS_BOUNDARY
    809       if (lhs_vert + 3 < m_size) {
    810         if ((threadIdx.y/4+k+24) < k_size) {
    811           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
    812           lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    813           lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
    814           lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
    815         } else if ((threadIdx.y/4+k+16) < k_size) {
    816           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
    817           lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    818           lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
    819         } else if ((threadIdx.y/4+k+8) < k_size) {
    820           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
    821           lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
    822         } else if ((threadIdx.y/4+k) < k_size) {
    823           lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
    824         }
    825       } else if (lhs_vert + 2 < m_size) {
    826         if ((threadIdx.y/4+k+24) < k_size) {
    827           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    828           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    829           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
    830           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    831           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    832           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
    833           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    834           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
    835           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
    836           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
    837           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
    838           lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24));
    839         } else if ((threadIdx.y/4+k+16) < k_size) {
    840           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    841           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    842           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
    843           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    844           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    845           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
    846           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    847           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
    848           lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16));
    849         } else if ((threadIdx.y/4+k+8) < k_size) {
    850           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    851           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    852           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
    853           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    854           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    855           lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8));
    856         } else if ((threadIdx.y/4+k) < k_size) {
    857           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    858           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    859           lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k));
    860         }
    861       } else if (lhs_vert + 1 < m_size) {
    862         if ((threadIdx.y/4+k+24) < k_size) {
    863           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    864           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    865           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    866           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    867           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    868           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
    869           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
    870           lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24));
    871         } else if ((threadIdx.y/4+k+16) < k_size) {
    872           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    873           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    874           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    875           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    876           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    877           lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16));
    878         } else if ((threadIdx.y/4+k+8) < k_size) {
    879           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    880           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    881           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    882           lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8));
    883         } else if ((threadIdx.y/4+k) < k_size) {
    884           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    885           lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k));
    886         }
    887       } else if (lhs_vert < m_size) {
    888         if ((threadIdx.y/4+k+24) < k_size) {
    889           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    890           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    891           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    892           lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24));
    893         } else if ((threadIdx.y/4+k+16) < k_size) {
    894           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    895           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    896           lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16));
    897         } else if ((threadIdx.y/4+k+8) < k_size) {
    898           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    899           lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8));
    900         } else if ((threadIdx.y/4+k) < k_size) {
    901           lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k));
    902         }
    903       }
    904     }
    905     __syncthreads();
    906     Index rhs_vert = k+threadIdx.x*4;
    907     Index rhs_horiz0 = threadIdx.y*2+base_n;
    908     Index rhs_horiz1 = threadIdx.y*2+1+base_n;
    909     if (!CHECK_RHS_BOUNDARY) {
    910       if ((rhs_vert + 3) < k_size) {
    911         // just CHECK_RHS_BOUNDARY
    912         rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
    913         rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
    914       } else if (rhs_vert + 2 < k_size) {
    915         // just CHECK_RHS_BOUNDARY
    916         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    917         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    918         rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    919         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    920         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
    921         rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
    922       } else if (rhs_vert + 1 < k_size) {
    923         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    924         rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    925         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    926         rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
    927       } else if (rhs_vert  < k_size) {
    928         rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    929         rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    930       }
    931     } else {
    932       if (rhs_horiz1 < n_size) {
    933         if ((rhs_vert + 3) < k_size) {
    934           // just CHECK_RHS_BOUNDARY
    935           rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
    936           rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
    937         } else if (rhs_vert + 2 < k_size) {
    938           // just CHECK_RHS_BOUNDARY
    939           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    940           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    941           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    942           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    943           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
    944           rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
    945         } else if (k+threadIdx.x*4 + 1 < k_size) {
    946           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    947           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    948           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    949           rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
    950         } else if (k+threadIdx.x*4  < k_size) {
    951           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    952           rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
    953         }
    954       } else if (rhs_horiz0 < n_size) {
    955         if ((rhs_vert + 3) < k_size) {
    956           // just CHECK_RHS_BOUNDARY
    957           rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
    958         } else if ((rhs_vert + 2) < k_size) {
    959           // just CHECK_RHS_BOUNDARY
    960           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    961           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    962           rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
    963         } else if ((rhs_vert + 1) < k_size) {
    964           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    965           rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
    966         } else if (rhs_vert  < k_size) {
    967           rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
    968         }
    969       }
    970     }
    971     __syncthreads();
    972     // Loaded. Do computation
    973     // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
    974     // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
    975     // ..
    976     // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
    977     rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
    978     // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
    979     // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
    980     // ..
    981     rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
    982     // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
    983     // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
    984     rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
    985     // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
    986     // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
    987     rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
    988 
    989     // LHS.
    990     // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
    991     // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), ..  (60, 61) .. (124, 125)
    992     // ...
    993     // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
    994     // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), ..  (62, 63) .. (126, 127)
    995 
    996 
    997 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\
    998       results[0].x += a_feat1.x * f1.x;\
    999       results[1].x += a_feat1.x * f1.y;\
   1000       results[2].x += a_feat1.x * f2.x;\
   1001       results[3].x += a_feat1.x * f2.y;\
   1002       results[4].x += a_feat1.x * f3.x;\
   1003       results[5].x += a_feat1.x * f3.y;\
   1004       results[6].x += a_feat1.x * f4.x;\
   1005       results[7].x += a_feat1.x * f4.y;\
   1006 \
   1007       results[0].y += a_feat1.y * f1.x;\
   1008       results[1].y += a_feat1.y * f1.y;\
   1009       results[2].y += a_feat1.y * f2.x;\
   1010       results[3].y += a_feat1.y * f2.y;\
   1011       results[4].y += a_feat1.y * f3.x;\
   1012       results[5].y += a_feat1.y * f3.y;\
   1013       results[6].y += a_feat1.y * f4.x;\
   1014       results[7].y += a_feat1.y * f4.y;\
   1015 \
   1016       results[0].z += a_feat2.x * f1.x;\
   1017       results[1].z += a_feat2.x * f1.y;\
   1018       results[2].z += a_feat2.x * f2.x;\
   1019       results[3].z += a_feat2.x * f2.y;\
   1020       results[4].z += a_feat2.x * f3.x;\
   1021       results[5].z += a_feat2.x * f3.y;\
   1022       results[6].z += a_feat2.x * f4.x;\
   1023       results[7].z += a_feat2.x * f4.y;\
   1024 \
   1025       results[0].w += a_feat2.y * f1.x;\
   1026       results[1].w += a_feat2.y * f1.y;\
   1027       results[2].w += a_feat2.y * f2.x;\
   1028       results[3].w += a_feat2.y * f2.y;\
   1029       results[4].w += a_feat2.y * f3.x;\
   1030       results[5].w += a_feat2.y * f3.y;\
   1031       results[6].w += a_feat2.y * f4.x;\
   1032       results[7].w += a_feat2.y * f4.y;\
   1033 
   1034     lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y);
   1035     lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y);
   1036     lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y);
   1037     lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y);
   1038 
   1039     lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w);
   1040     lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w);
   1041     lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w);
   1042     lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w);
   1043 
   1044     __syncthreads();
   1045 
   1046     // Do the multiplies.
   1047     #pragma unroll
   1048     for (int koff = 0; koff < 32; koff ++) {
   1049       float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
   1050       float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
   1051 
   1052       // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
   1053       int start_feature = (threadIdx.y / 4) * 8;
   1054 
   1055       float2 br1 = rhs_shmem2[start_feature/2 +     (koff % 4) * 32][koff/4];
   1056       float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4];
   1057       float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4];
   1058       float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4];
   1059 
   1060       add_vals(a3, a4, br1, br2, br3, br4)
   1061     }
   1062     __syncthreads();
   1063   } // end loop over k
   1064 
   1065 
   1066   __syncthreads();
   1067   Index horiz_base = (threadIdx.y/4)*8+base_n;
   1068   if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
   1069     for (int i = 0; i < 8; i++) {
   1070       output(lhs_vert, horiz_base + i) = results[i].x;
   1071       output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1072       output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1073       output(lhs_vert + 3, horiz_base + i) = results[i].w;
   1074     }
   1075   } else if (!CHECK_RHS_BOUNDARY) {
   1076     if (lhs_vert + 3 < m_size) {
   1077       for (int i = 0; i < 8; i++) {
   1078         output(lhs_vert, horiz_base + i) = results[i].x;
   1079         output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1080         output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1081         output(lhs_vert + 3, horiz_base + i) = results[i].w;
   1082       }
   1083     } else if (lhs_vert + 2 < m_size) {
   1084       for (int i = 0; i < 8; i++) {
   1085         output(lhs_vert, horiz_base + i) = results[i].x;
   1086         output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1087         output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1088       }
   1089     } else if (lhs_vert + 1 < m_size) {
   1090       for (int i = 0; i < 8; i++) {
   1091         output(lhs_vert, horiz_base + i) = results[i].x;
   1092         output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1093       }
   1094     } else if (lhs_vert  < m_size) {
   1095       for (int i = 0; i < 8; i++) {
   1096         output(lhs_vert, horiz_base + i) = results[i].x;
   1097       }
   1098     }
   1099   } else if (!CHECK_LHS_BOUNDARY) {
   1100     // CHECK BOUNDARY_B
   1101     for (int i = 0; i < 8; i++) {
   1102       if (horiz_base + i < n_size) {
   1103         output(lhs_vert, horiz_base + i) = results[i].x;
   1104         output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1105         output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1106         output(lhs_vert + 3, horiz_base + i) = results[i].w;
   1107       }
   1108     }
   1109   } else {
   1110     // CHECK both boundaries.
   1111     for (int i = 0; i < 8; i++) {
   1112       if (horiz_base + i < n_size) {
   1113         if (lhs_vert < m_size)
   1114           output(lhs_vert, horiz_base + i) = results[i].x;
   1115         if (lhs_vert + 1 < m_size)
   1116           output(lhs_vert + 1, horiz_base + i) = results[i].y;
   1117         if (lhs_vert + 2 < m_size)
   1118           output(lhs_vert + 2, horiz_base + i) = results[i].z;
   1119         if (lhs_vert + 3 < m_size)
   1120           output(lhs_vert + 3, horiz_base + i) = results[i].w;
   1121       }
   1122     }
   1123   }
   1124 }
   1125 
   1126 
   1127 template<typename Index, typename LhsMapper,
   1128          typename RhsMapper, typename OutputMapper>
   1129 __global__ void
   1130 __launch_bounds__(256)
   1131 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
   1132                        const OutputMapper output,
   1133                        const Index m_size, const Index n_size, const Index k_size) {
   1134   __shared__ float2 lhs_shmem[64*32];
   1135   __shared__ float2 rhs_shmem[128*8];
   1136 
   1137   typedef float2 LHS_MEM[64][32];
   1138   typedef float2 RHS_MEM[128][8];
   1139 
   1140   typedef float2 LHS_MEM16x16[32][16];
   1141   typedef float2 RHS_MEM16x16[64][8];
   1142 
   1143   const Index m_block_idx = blockIdx.x;
   1144   const Index n_block_idx = blockIdx.y;
   1145 
   1146   const Index base_m = 128 * m_block_idx;
   1147   const Index base_n = 64 * n_block_idx;
   1148 
   1149   bool check_rhs = (base_n + 63) >= n_size;
   1150   bool check_lhs128 = (base_m + 127) >= m_size;
   1151 
   1152   if (!check_rhs) {
   1153     if (!check_lhs128) {
   1154       // >= 128 rows left
   1155       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
   1156                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
   1157     } else {
   1158       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
   1159                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
   1160     }
   1161   } else {
   1162     if (!check_lhs128) {
   1163       // >= 128 rows left
   1164       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
   1165                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
   1166     } else {
   1167       EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
   1168                      lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n);
   1169     }
   1170   }
   1171 }
   1172 
   1173 template<typename Index, typename LhsMapper,
   1174          typename RhsMapper, typename OutputMapper>
   1175 __global__ void
   1176 __launch_bounds__(256)
   1177 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,
   1178                        const OutputMapper output,
   1179                        const Index m_size, const Index n_size, const Index k_size) {
   1180   __shared__ float2 lhs_shmem[32][16];
   1181   __shared__ float2 rhs_shmem[64][8];
   1182 
   1183   const Index m_block_idx = blockIdx.x;
   1184   const Index n_block_idx = blockIdx.y;
   1185 
   1186   const Index base_m = 64 * m_block_idx;
   1187   const Index base_n = 64 * n_block_idx;
   1188 
   1189   if (base_m + 63 < m_size) {
   1190     if (base_n + 63 < n_size) {
   1191       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
   1192     } else {
   1193       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
   1194     }
   1195   } else {
   1196     if (base_n + 63 < n_size) {
   1197       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
   1198     } else {
   1199       EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
   1200     }
   1201   }
   1202 }
   1203 
   1204 
   1205 template<typename Indices, typename LeftArgType, typename RightArgType>
   1206 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> :
   1207     public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> > {
   1208 
   1209   typedef GpuDevice Device;
   1210 
   1211   typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
   1212   typedef TensorContractionEvaluatorBase<Self> Base;
   1213 
   1214   typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
   1215   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
   1216   typedef typename XprType::Index Index;
   1217   typedef typename XprType::CoeffReturnType CoeffReturnType;
   1218   typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
   1219 
   1220   enum {
   1221     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
   1222   };
   1223 
   1224   // Most of the code is assuming that both input tensors are ColMajor. If the
   1225   // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
   1226   // If we want to compute A * B = C, where A is LHS and B is RHS, the code
   1227   // will pretend B is LHS and A is RHS.
   1228   typedef typename internal::conditional<
   1229     static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
   1230   typedef typename internal::conditional<
   1231     static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
   1232 
   1233   static const int LDims =
   1234       internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
   1235   static const int RDims =
   1236       internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
   1237   static const int ContractDims = internal::array_size<Indices>::value;
   1238 
   1239   typedef array<Index, LDims> left_dim_mapper_t;
   1240   typedef array<Index, RDims> right_dim_mapper_t;
   1241 
   1242   typedef array<Index, ContractDims> contract_t;
   1243   typedef array<Index, LDims - ContractDims> left_nocontract_t;
   1244   typedef array<Index, RDims - ContractDims> right_nocontract_t;
   1245 
   1246   static const int NumDims = LDims + RDims - 2 * ContractDims;
   1247 
   1248   typedef DSizes<Index, NumDims> Dimensions;
   1249 
   1250   // typedefs needed in evalTo
   1251   typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
   1252   typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
   1253 
   1254   typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
   1255   typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
   1256 
   1257   typedef typename LeftEvaluator::Dimensions LeftDimensions;
   1258   typedef typename RightEvaluator::Dimensions RightDimensions;
   1259 
   1260   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
   1261       Base(op, device) {}
   1262 
   1263   // We need to redefine this method to make nvcc happy
   1264   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
   1265     this->m_leftImpl.evalSubExprsIfNeeded(NULL);
   1266     this->m_rightImpl.evalSubExprsIfNeeded(NULL);
   1267     if (data) {
   1268       evalTo(data);
   1269       return false;
   1270     } else {
   1271       this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
   1272       evalTo(this->m_result);
   1273       return true;
   1274     }
   1275   }
   1276 
   1277   void evalTo(Scalar* buffer) const {
   1278     if (this->m_lhs_inner_dim_contiguous) {
   1279       if (this->m_rhs_inner_dim_contiguous) {
   1280         if (this->m_rhs_inner_dim_reordered) {
   1281           evalTyped<true, true, true, Unaligned>(buffer);
   1282         }
   1283         else {
   1284           evalTyped<true, true, false, Unaligned>(buffer);
   1285         }
   1286       }
   1287       else {
   1288        if (this->m_rhs_inner_dim_reordered) {
   1289           evalTyped<true, false, true, Unaligned>(buffer);
   1290         }
   1291         else {
   1292           evalTyped<true, false, false, Unaligned>(buffer);
   1293         }
   1294       }
   1295     }
   1296     else {
   1297       if (this->m_rhs_inner_dim_contiguous) {
   1298         if (this->m_rhs_inner_dim_reordered) {
   1299           evalTyped<false, true, true, Unaligned>(buffer);
   1300         }
   1301         else {
   1302           evalTyped<false, true, false, Unaligned>(buffer);
   1303         }
   1304       }
   1305       else {
   1306        if (this->m_rhs_inner_dim_reordered) {
   1307           evalTyped<false, false, true, Unaligned>(buffer);
   1308         }
   1309         else {
   1310           evalTyped<false, false, false, Unaligned>(buffer);
   1311         }
   1312       }
   1313     }
   1314   }
   1315 
   1316   template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels {
   1317     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
   1318     const Index m_blocks = (m + 63) / 64;
   1319     const Index n_blocks = (n + 63) / 64;
   1320     const dim3 num_blocks(m_blocks, n_blocks, 1);
   1321     const dim3 block_size(8, 8, 8);
   1322     LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
   1323     }
   1324   };
   1325 
   1326   template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
   1327     static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) {
   1328       if (m < 768 || n < 768) {
   1329         const Index m_blocks = (m + 63) / 64;
   1330         const Index n_blocks = (n + 63) / 64;
   1331         const dim3 num_blocks(m_blocks, n_blocks, 1);
   1332         const dim3 block_size(16, 16, 1);
   1333         LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
   1334       } else {
   1335         const Index m_blocks = (m + 127) / 128;
   1336         const Index n_blocks = (n + 63) / 64;
   1337         const dim3 num_blocks(m_blocks, n_blocks, 1);
   1338         const dim3 block_size(8, 32, 1);
   1339         LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k);
   1340       }
   1341     }
   1342   };
   1343 
   1344   template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
   1345   void evalTyped(Scalar* buffer) const {
   1346     // columns in left side, rows in right side
   1347     const Index k = this->m_k_size;
   1348     EIGEN_UNUSED_VARIABLE(k)
   1349 
   1350     // rows in left side
   1351     const Index m = this->m_i_size;
   1352 
   1353     // columns in right side
   1354     const Index n = this->m_j_size;
   1355 
   1356     // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
   1357     this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
   1358 
   1359     typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
   1360                                                    LeftEvaluator, left_nocontract_t,
   1361                                                    contract_t, 4,
   1362                                                    lhs_inner_dim_contiguous,
   1363                                                    false, Unaligned> LhsMapper;
   1364 
   1365     typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
   1366                                                    RightEvaluator, right_nocontract_t,
   1367                                                    contract_t, 4,
   1368                                                    rhs_inner_dim_contiguous,
   1369                                                    rhs_inner_dim_reordered, Unaligned> RhsMapper;
   1370 
   1371     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
   1372 
   1373 
   1374     // initialize data mappers
   1375     LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
   1376                   this->m_left_contracting_strides, this->m_k_strides);
   1377 
   1378     RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
   1379                   this->m_right_contracting_strides, this->m_k_strides);
   1380 
   1381     OutputMapper output(buffer, m);
   1382 
   1383     setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte);
   1384     LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output,  m, n, k, this->m_device);
   1385   }
   1386 };
   1387 
   1388 } // end namespace Eigen
   1389 
   1390 #endif // EIGEN_USE_GPU and __CUDACC__
   1391 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H
   1392