1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog (at) gmail.com> 5 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly (at) google.com> 6 // Copyright (C) 2014 Eric Martin <eric (at) ericmart.in> 7 // 8 // This Source Code Form is subject to the terms of the Mozilla 9 // Public License v. 2.0. If a copy of the MPL was not distributed 10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 11 12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H 13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H 14 15 #if defined(EIGEN_USE_GPU) && defined(__CUDACC__) 16 17 namespace Eigen { 18 19 template<typename Scalar, typename Index, typename LhsMapper, 20 typename RhsMapper, typename OutputMapper, bool needs_edge_check> 21 __device__ EIGEN_STRONG_INLINE void 22 EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, 23 const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem, 24 const Index m_size, const Index n_size, const Index k_size) { 25 26 const Index m_block_idx = blockIdx.x; 27 const Index n_block_idx = blockIdx.y; 28 29 const Index base_m = 64 * m_block_idx; 30 const Index base_n = 64 * n_block_idx; 31 32 // declare and initialize 64 registers for output 8x8 block 33 34 // prefetch registers 35 Scalar lhs_pf0; 36 Scalar lhs_pf1; 37 Scalar lhs_pf2; 38 Scalar lhs_pf3; 39 Scalar lhs_pf4; 40 Scalar lhs_pf5; 41 Scalar lhs_pf6; 42 Scalar lhs_pf7; 43 44 Scalar rhs_pf0; 45 Scalar rhs_pf1; 46 Scalar rhs_pf2; 47 Scalar rhs_pf3; 48 Scalar rhs_pf4; 49 Scalar rhs_pf5; 50 Scalar rhs_pf6; 51 Scalar rhs_pf7; 52 53 // shared memory is formatted 54 // (contract idx in block, nocontract idx in block, block idx) 55 // where block idx is column major. This transposition limits the number of 56 // bank conflicts when reading the LHS. The core idea is that since the contracting 57 // index is shared by both sides, then the contracting index should be in threadIdx.x. 58 59 // On the LHS, we pad each row inside of each block with an extra element. This makes 60 // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts 61 // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks. 62 63 // On the RHS we just add 8 padding elements to the end of each block. This gives no bank 64 // conflicts on writes and also none on reads. 65 66 // storage indices 67 const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z; 68 const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x; 69 70 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0; 71 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1; 72 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2; 73 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3; 74 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4; 75 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5; 76 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6; 77 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7; 78 79 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0; 80 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1; 81 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2; 82 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3; 83 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4; 84 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5; 85 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6; 86 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7; 87 88 // in the loading code, the following variables are important: 89 // threadIdx.x: the vertical position in an 8x8 block 90 // threadIdx.y: the vertical index of the 8x8 block in the grid 91 // threadIdx.z: the horizontal position in an 8x8 block 92 // k: the horizontal index of the 8x8 block in the grid 93 // 94 // The k parameter is implicit (it was the loop counter for a loop that went 95 // from 0 to <8, but now that loop is unrolled in the below code. 96 97 const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y; 98 const Index lhs_vert = base_m + load_idx_vert; 99 100 #define prefetchIntoRegisters(base_k) \ 101 { \ 102 lhs_pf0 = conv(0); \ 103 lhs_pf1 = conv(0); \ 104 lhs_pf2 = conv(0); \ 105 lhs_pf3 = conv(0); \ 106 lhs_pf4 = conv(0); \ 107 lhs_pf5 = conv(0); \ 108 lhs_pf6 = conv(0); \ 109 lhs_pf7 = conv(0); \ 110 \ 111 rhs_pf0 = conv(0); \ 112 rhs_pf1 = conv(0); \ 113 rhs_pf2 = conv(0); \ 114 rhs_pf3 = conv(0); \ 115 rhs_pf4 = conv(0); \ 116 rhs_pf5 = conv(0); \ 117 rhs_pf6 = conv(0); \ 118 rhs_pf7 = conv(0); \ 119 \ 120 if (!needs_edge_check || lhs_vert < m_size) { \ 121 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \ 122 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \ 123 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \ 124 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \ 125 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \ 126 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \ 127 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \ 128 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \ 129 \ 130 if (!needs_edge_check || lhs_horiz_7 < k_size) { \ 131 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 132 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 133 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 134 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 135 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 136 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 137 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 138 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \ 139 } else if (lhs_horiz_6 < k_size) { \ 140 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 141 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 142 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 143 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 144 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 145 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 146 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \ 147 } else if (lhs_horiz_5 < k_size) { \ 148 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 149 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 150 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 151 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 152 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 153 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \ 154 } else if (lhs_horiz_4 < k_size) { \ 155 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 156 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 157 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 158 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 159 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \ 160 } else if (lhs_horiz_3 < k_size) { \ 161 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 162 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 163 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 164 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \ 165 } else if (lhs_horiz_2 < k_size) { \ 166 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 167 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 168 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \ 169 } else if (lhs_horiz_1 < k_size) { \ 170 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 171 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \ 172 } else if (lhs_horiz_0 < k_size) { \ 173 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \ 174 } \ 175 } \ 176 \ 177 const Index rhs_vert = base_k + load_idx_vert; \ 178 if (!needs_edge_check || rhs_vert < k_size) { \ 179 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \ 180 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \ 181 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \ 182 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \ 183 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \ 184 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \ 185 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \ 186 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \ 187 \ 188 if (rhs_horiz_7 < n_size) { \ 189 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 190 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 191 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 192 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 193 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 194 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 195 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 196 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \ 197 } else if (rhs_horiz_6 < n_size) { \ 198 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 199 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 200 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 201 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 202 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 203 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 204 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \ 205 } else if (rhs_horiz_5 < n_size) { \ 206 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 207 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 208 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 209 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 210 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 211 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \ 212 } else if (rhs_horiz_4 < n_size) { \ 213 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 214 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 215 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 216 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 217 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \ 218 } else if (rhs_horiz_3 < n_size) { \ 219 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 220 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 221 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 222 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \ 223 } else if (rhs_horiz_2 < n_size) { \ 224 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 225 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 226 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \ 227 } else if (rhs_horiz_1 < n_size) { \ 228 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 229 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \ 230 } else if (rhs_horiz_0 < n_size) { \ 231 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \ 232 } \ 233 } \ 234 } \ 235 236 #define writeRegToShmem(_) \ 237 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \ 238 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \ 239 \ 240 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \ 241 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \ 242 \ 243 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \ 244 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \ 245 \ 246 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \ 247 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \ 248 \ 249 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \ 250 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \ 251 \ 252 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \ 253 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \ 254 \ 255 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \ 256 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \ 257 \ 258 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \ 259 rhs_shmem[rhs_store_idx_7] = rhs_pf7; \ 260 261 // declare and initialize result array 262 #define res(i, j) _res_##i##j 263 #define initResultRow(i) \ 264 Scalar res(i, 0) = conv(0); \ 265 Scalar res(i, 1) = conv(0); \ 266 Scalar res(i, 2) = conv(0); \ 267 Scalar res(i, 3) = conv(0); \ 268 Scalar res(i, 4) = conv(0); \ 269 Scalar res(i, 5) = conv(0); \ 270 Scalar res(i, 6) = conv(0); \ 271 Scalar res(i, 7) = conv(0); \ 272 273 internal::scalar_cast_op<int, Scalar> conv; 274 initResultRow(0); 275 initResultRow(1); 276 initResultRow(2); 277 initResultRow(3); 278 initResultRow(4); 279 initResultRow(5); 280 initResultRow(6); 281 initResultRow(7); 282 #undef initResultRow 283 284 for (Index base_k = 0; base_k < k_size; base_k += 64) { 285 // wait for previous iteration to finish with shmem. Despite common sense, 286 // the code is a bit faster with this here then at bottom of loop 287 __syncthreads(); 288 289 prefetchIntoRegisters(base_k); 290 writeRegToShmem(); 291 292 #undef prefetchIntoRegisters 293 #undef writeRegToShmem 294 295 // wait for shared mem packing to be done before starting computation 296 __syncthreads(); 297 298 // compute 8x8 matrix product by outer product. This involves packing one column 299 // of LHS and one row of RHS into registers (takes 16 registers). 300 301 #define lcol(i) _lcol##i 302 Scalar lcol(0); 303 Scalar lcol(1); 304 Scalar lcol(2); 305 Scalar lcol(3); 306 Scalar lcol(4); 307 Scalar lcol(5); 308 Scalar lcol(6); 309 Scalar lcol(7); 310 311 #define rrow(j) _rrow##j 312 Scalar rrow(0); 313 Scalar rrow(1); 314 Scalar rrow(2); 315 Scalar rrow(3); 316 Scalar rrow(4); 317 Scalar rrow(5); 318 Scalar rrow(6); 319 Scalar rrow(7); 320 321 // Now x corresponds to k, y to m, and z to n 322 const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y]; 323 const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z]; 324 325 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))] 326 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))] 327 328 #define loadData(i, j) \ 329 lcol(0) = lhs_element(0, j); \ 330 rrow(0) = rhs_element(i, 0); \ 331 lcol(1) = lhs_element(1, j); \ 332 rrow(1) = rhs_element(i, 1); \ 333 lcol(2) = lhs_element(2, j); \ 334 rrow(2) = rhs_element(i, 2); \ 335 lcol(3) = lhs_element(3, j); \ 336 rrow(3) = rhs_element(i, 3); \ 337 lcol(4) = lhs_element(4, j); \ 338 rrow(4) = rhs_element(i, 4); \ 339 lcol(5) = lhs_element(5, j); \ 340 rrow(5) = rhs_element(i, 5); \ 341 lcol(6) = lhs_element(6, j); \ 342 rrow(6) = rhs_element(i, 6); \ 343 lcol(7) = lhs_element(7, j); \ 344 rrow(7) = rhs_element(i, 7); \ 345 346 #define computeCol(j) \ 347 res(0, j) += lcol(0) * rrow(j); \ 348 res(1, j) += lcol(1) * rrow(j); \ 349 res(2, j) += lcol(2) * rrow(j); \ 350 res(3, j) += lcol(3) * rrow(j); \ 351 res(4, j) += lcol(4) * rrow(j); \ 352 res(5, j) += lcol(5) * rrow(j); \ 353 res(6, j) += lcol(6) * rrow(j); \ 354 res(7, j) += lcol(7) * rrow(j); \ 355 356 #define computePass(i) \ 357 loadData(i, i); \ 358 \ 359 computeCol(0); \ 360 computeCol(1); \ 361 computeCol(2); \ 362 computeCol(3); \ 363 computeCol(4); \ 364 computeCol(5); \ 365 computeCol(6); \ 366 computeCol(7); \ 367 368 computePass(0); 369 computePass(1); 370 computePass(2); 371 computePass(3); 372 computePass(4); 373 computePass(5); 374 computePass(6); 375 computePass(7); 376 377 #undef lcol 378 #undef rrow 379 #undef lhs_element 380 #undef rhs_element 381 #undef loadData 382 #undef computeCol 383 #undef computePass 384 } // end loop over k 385 386 // we've now iterated over all of the large (ie width 64) k blocks and 387 // accumulated results in registers. At this point thread (x, y, z) contains 388 // the sum across all big k blocks of the product of little k block of index (x, y) 389 // with block of index (y, z). To compute the final output, we need to reduce 390 // the 8 threads over y by summation. 391 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask) 392 393 #define reduceRow(i, mask) \ 394 shuffleInc(i, 0, mask); \ 395 shuffleInc(i, 1, mask); \ 396 shuffleInc(i, 2, mask); \ 397 shuffleInc(i, 3, mask); \ 398 shuffleInc(i, 4, mask); \ 399 shuffleInc(i, 5, mask); \ 400 shuffleInc(i, 6, mask); \ 401 shuffleInc(i, 7, mask); \ 402 403 #define reduceMatrix(mask) \ 404 reduceRow(0, mask); \ 405 reduceRow(1, mask); \ 406 reduceRow(2, mask); \ 407 reduceRow(3, mask); \ 408 reduceRow(4, mask); \ 409 reduceRow(5, mask); \ 410 reduceRow(6, mask); \ 411 reduceRow(7, mask); \ 412 413 // actually perform the reduction, now each thread of index (_, y, z) 414 // contains the correct values in its registers that belong in the output 415 // block 416 reduceMatrix(1); 417 reduceMatrix(2); 418 reduceMatrix(4); 419 420 #undef shuffleInc 421 #undef reduceRow 422 #undef reduceMatrix 423 424 // now we need to copy the 64 values into main memory. We can't split work 425 // among threads because all variables are in registers. There's 2 ways 426 // to do this: 427 // (1) have 1 thread do 64 writes from registers into global memory 428 // (2) have 1 thread do 64 writes into shared memory, and then 8 threads 429 // each do 8 writes into global memory. We can just overwrite the shared 430 // memory from the problem we just solved. 431 // (2) is slightly faster than (1) due to less branching and more ILP 432 433 // TODO: won't yield much gain, but could just use currently unused shared mem 434 // and then we won't have to sync 435 // wait for shared mem to be out of use 436 __syncthreads(); 437 438 #define writeResultShmem(i, j) \ 439 lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j); \ 440 441 #define writeRow(i) \ 442 writeResultShmem(i, 0); \ 443 writeResultShmem(i, 1); \ 444 writeResultShmem(i, 2); \ 445 writeResultShmem(i, 3); \ 446 writeResultShmem(i, 4); \ 447 writeResultShmem(i, 5); \ 448 writeResultShmem(i, 6); \ 449 writeResultShmem(i, 7); \ 450 451 if (threadIdx.x == 0) { 452 writeRow(0); 453 writeRow(1); 454 writeRow(2); 455 writeRow(3); 456 writeRow(4); 457 writeRow(5); 458 writeRow(6); 459 writeRow(7); 460 } 461 #undef writeResultShmem 462 #undef writeRow 463 464 const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8); 465 const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8); 466 467 if (threadIdx.x < max_i_write) { 468 if (max_j_write == 8) { 469 // TODO: can i trade bank conflicts for coalesced writes? 470 Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0]; 471 Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1]; 472 Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2]; 473 Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3]; 474 Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4]; 475 Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5]; 476 Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6]; 477 Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7]; 478 479 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0; 480 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1; 481 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2; 482 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3; 483 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4; 484 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5; 485 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6; 486 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7; 487 } else { 488 #pragma unroll 7 489 for (int j = 0; j < max_j_write; j++) { 490 Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j]; 491 output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val; 492 } 493 } 494 } 495 #undef res 496 } 497 498 499 template<typename Scalar, typename Index, typename LhsMapper, 500 typename RhsMapper, typename OutputMapper> 501 __global__ void 502 __launch_bounds__(512) 503 EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, 504 const OutputMapper output, 505 const Index m_size, const Index n_size, const Index k_size) { 506 __shared__ Scalar lhs_shmem[72 * 64]; 507 __shared__ Scalar rhs_shmem[72 * 64]; 508 509 const Index m_block_idx = blockIdx.x; 510 const Index n_block_idx = blockIdx.y; 511 512 const Index base_m = 64 * m_block_idx; 513 const Index base_n = 64 * n_block_idx; 514 515 if (base_m + 63 < m_size && base_n + 63 < n_size) { 516 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); 517 } else { 518 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size); 519 } 520 } 521 522 523 template<typename Index, typename LhsMapper, 524 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY, 525 bool CHECK_RHS_BOUNDARY> 526 __device__ EIGEN_STRONG_INLINE void 527 EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs, 528 const OutputMapper output, float2 lhs_shmem2[][16], 529 float2 rhs_shmem2[][8], const Index m_size, 530 const Index n_size, const Index k_size, 531 const Index base_m, const Index base_n) { 532 typedef float Scalar; 533 534 // prefetch registers 535 float4 lhs_pf0, rhs_pf0; 536 537 float4 results[4]; 538 for (int i=0; i < 4; i++) { 539 results[i].x = results[i].y = results[i].z = results[i].w = 0; 540 } 541 542 543 #define prefetch_lhs(reg, row, col) \ 544 if (!CHECK_LHS_BOUNDARY) { \ 545 if (col < k_size) { \ 546 reg =lhs.loadPacket<Unaligned>(row, col); \ 547 } \ 548 } else { \ 549 if (col < k_size) { \ 550 if (row + 3 < m_size) { \ 551 reg =lhs.loadPacket<Unaligned>(row, col); \ 552 } else if (row + 2 < m_size) { \ 553 reg.x =lhs(row + 0, col); \ 554 reg.y =lhs(row + 1, col); \ 555 reg.z =lhs(row + 2, col); \ 556 } else if (row + 1 < m_size) { \ 557 reg.x =lhs(row + 0, col); \ 558 reg.y =lhs(row + 1, col); \ 559 } else if (row < m_size) { \ 560 reg.x =lhs(row + 0, col); \ 561 } \ 562 } \ 563 } \ 564 565 566 Index lhs_vert = base_m+threadIdx.x*4; 567 568 for (Index k = 0; k < k_size; k += 16) { 569 lhs_pf0 = internal::pset1<float4>(0); 570 rhs_pf0 = internal::pset1<float4>(0); 571 572 Index lhs_horiz = threadIdx.y+k; 573 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz) 574 575 Index rhs_vert = k+(threadIdx.x%4)*4; 576 Index rhs_horiz0 = (threadIdx.x>>2)+threadIdx.y*4+base_n; 577 578 if (!CHECK_RHS_BOUNDARY) { 579 if ((rhs_vert + 3) < k_size) { 580 // just CHECK_RHS_BOUNDARY 581 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 582 } else if (rhs_vert + 2 < k_size) { 583 // just CHECK_RHS_BOUNDARY 584 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 585 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 586 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 587 } else if (rhs_vert + 1 < k_size) { 588 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 589 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 590 } else if (rhs_vert < k_size) { 591 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 592 } 593 } else { 594 if (rhs_horiz0 < n_size) { 595 if ((rhs_vert + 3) < k_size) { 596 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 597 } else if ((rhs_vert + 2) < k_size) { 598 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 599 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 600 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 601 } else if ((rhs_vert + 1) < k_size) { 602 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 603 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 604 } else if (rhs_vert < k_size) { 605 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 606 } 607 } 608 } 609 float x1, x2 ; 610 // the following can be a bitwise operation..... some day. 611 if((threadIdx.x%8) < 4) { 612 x1 = rhs_pf0.y; 613 x2 = rhs_pf0.w; 614 } else { 615 x1 = rhs_pf0.x; 616 x2 = rhs_pf0.z; 617 } 618 x1 = __shfl_xor(x1, 4); 619 x2 = __shfl_xor(x2, 4); 620 if((threadIdx.x%8) < 4) { 621 rhs_pf0.y = x1; 622 rhs_pf0.w = x2; 623 } else { 624 rhs_pf0.x = x1; 625 rhs_pf0.z = x2; 626 } 627 628 // We have 64 features. 629 // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1. 630 // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3. 631 // ... 632 // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63 633 // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1 634 // ... 635 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2][threadIdx.x%8] = make_float2(rhs_pf0.x, rhs_pf0.y); 636 rhs_shmem2[(threadIdx.x>>3)+ threadIdx.y*2+32][threadIdx.x%8] = make_float2(rhs_pf0.z, rhs_pf0.w); 637 638 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) 639 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) 640 // ... 641 // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) 642 // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) 643 // ... 644 645 lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y); 646 lhs_shmem2[threadIdx.y+16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w); 647 648 649 #define add_vals(fl1, fl2, fr1, fr2)\ 650 results[0].x += fl1.x * fr1.x;\ 651 results[0].y += fl1.y * fr1.x;\ 652 results[0].z += fl2.x * fr1.x;\ 653 results[0].w += fl2.y * fr1.x;\ 654 \ 655 results[1].x += fl1.x * fr1.y;\ 656 results[1].y += fl1.y * fr1.y;\ 657 results[1].z += fl2.x * fr1.y;\ 658 results[1].w += fl2.y * fr1.y;\ 659 \ 660 results[2].x += fl1.x * fr2.x;\ 661 results[2].y += fl1.y * fr2.x;\ 662 results[2].z += fl2.x * fr2.x;\ 663 results[2].w += fl2.y * fr2.x;\ 664 \ 665 results[3].x += fl1.x * fr2.y;\ 666 results[3].y += fl1.y * fr2.y;\ 667 results[3].z += fl2.x * fr2.y;\ 668 results[3].w += fl2.y * fr2.y;\ 669 670 __syncthreads(); 671 672 // Do the multiplies. 673 #pragma unroll 674 for (int koff = 0; koff < 16; koff ++) { 675 // 32 x threads. 676 float2 fl1 = lhs_shmem2[koff][threadIdx.x]; 677 float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x]; 678 679 int start_feature = threadIdx.y * 4; 680 float2 fr1 = rhs_shmem2[(start_feature>>1) + 32*((koff%4)/2)][koff/4 + (koff%2)*4]; 681 float2 fr2 = rhs_shmem2[(start_feature>>1) + 1 + 32*((koff%4)/2)][koff/4 + (koff%2)*4]; 682 683 add_vals(fl1, fl2, fr1, fr2) 684 } 685 __syncthreads(); 686 } 687 688 #undef prefetch_lhs 689 #undef add_vals 690 691 Index horiz_base = threadIdx.y*4+base_n; 692 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { 693 for (int i = 0; i < 4; i++) { 694 output(lhs_vert, horiz_base + i) = results[i].x; 695 output(lhs_vert + 1, horiz_base + i) = results[i].y; 696 output(lhs_vert + 2, horiz_base + i) = results[i].z; 697 output(lhs_vert + 3, horiz_base + i) = results[i].w; 698 } 699 } else if (!CHECK_RHS_BOUNDARY) { 700 // CHECK LHS 701 if (lhs_vert + 3 < m_size) { 702 for (int i = 0; i < 4; i++) { 703 output(lhs_vert, horiz_base + i) = results[i].x; 704 output(lhs_vert + 1, horiz_base + i) = results[i].y; 705 output(lhs_vert + 2, horiz_base + i) = results[i].z; 706 output(lhs_vert + 3, horiz_base + i) = results[i].w; 707 } 708 } else if (lhs_vert + 2 < m_size) { 709 for (int i = 0; i < 4; i++) { 710 output(lhs_vert, horiz_base + i) = results[i].x; 711 output(lhs_vert + 1, horiz_base + i) = results[i].y; 712 output(lhs_vert + 2, horiz_base + i) = results[i].z; 713 } 714 } else if (lhs_vert + 1 < m_size) { 715 for (int i = 0; i < 4; i++) { 716 output(lhs_vert, horiz_base + i) = results[i].x; 717 output(lhs_vert + 1, horiz_base + i) = results[i].y; 718 } 719 } else if (lhs_vert < m_size) { 720 for (int i = 0; i < 4; i++) { 721 output(lhs_vert, horiz_base + i) = results[i].x; 722 } 723 } 724 } else if (!CHECK_LHS_BOUNDARY) { 725 // CHECK RHS 726 /* 727 int ncols_rem = fminf(n_size- horiz_base, 4); 728 for (int i = 0; i < ncols_rem; i++) { 729 output(lhs_vert, horiz_base + i) = results[i].x; 730 output(lhs_vert + 1, horiz_base + i) = results[i].y; 731 output(lhs_vert + 2, horiz_base + i) = results[i].z; 732 output(lhs_vert + 3, horiz_base + i) = results[i].w; 733 }*/ 734 for (int i = 0; i < 4; i++) { 735 if (horiz_base+i < n_size) { 736 output(lhs_vert, horiz_base + i) = results[i].x; 737 output(lhs_vert + 1, horiz_base + i) = results[i].y; 738 output(lhs_vert + 2, horiz_base + i) = results[i].z; 739 output(lhs_vert + 3, horiz_base + i) = results[i].w; 740 } 741 } 742 } else { 743 // CHECK both boundaries. 744 for (int i = 0; i < 4; i++) { 745 if (horiz_base+i < n_size) { 746 if (lhs_vert < m_size) 747 output(lhs_vert, horiz_base + i) = results[i].x; 748 if (lhs_vert + 1 < m_size) 749 output(lhs_vert + 1, horiz_base + i) = results[i].y; 750 if (lhs_vert + 2 < m_size) 751 output(lhs_vert + 2, horiz_base + i) = results[i].z; 752 if (lhs_vert + 3 < m_size) 753 output(lhs_vert + 3, horiz_base + i) = results[i].w; 754 } 755 } 756 } 757 } 758 759 760 template<typename Index, typename LhsMapper, 761 typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY, 762 bool CHECK_RHS_BOUNDARY> 763 __device__ EIGEN_STRONG_INLINE void 764 EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs, 765 const OutputMapper output, float2 lhs_shmem2[][32], 766 float2 rhs_shmem2[][8], const Index m_size, 767 const Index n_size, const Index k_size, 768 const Index base_m, const Index base_n) { 769 typedef float Scalar; 770 771 // prefetch registers 772 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3; 773 float4 rhs_pf0, rhs_pf1; 774 775 float4 results[8]; 776 for (int i=0; i < 8; i++) { 777 results[i].x = results[i].y = results[i].z = results[i].w = 0; 778 } 779 780 781 Index lhs_vert = base_m+threadIdx.x*4+(threadIdx.y%4)*32; 782 for (Index k = 0; k < k_size; k += 32) { 783 lhs_pf0 = internal::pset1<float4>(0); 784 lhs_pf1 = internal::pset1<float4>(0); 785 lhs_pf2 = internal::pset1<float4>(0); 786 lhs_pf3 = internal::pset1<float4>(0); 787 788 rhs_pf0 = internal::pset1<float4>(0); 789 rhs_pf1 = internal::pset1<float4>(0); 790 791 if (!CHECK_LHS_BOUNDARY) { 792 if ((threadIdx.y/4+k+24) < k_size) { 793 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 794 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 795 lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 796 lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24)); 797 } else if ((threadIdx.y/4+k+16) < k_size) { 798 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 799 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 800 lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 801 } else if ((threadIdx.y/4+k+8) < k_size) { 802 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 803 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 804 } else if ((threadIdx.y/4+k) < k_size) { 805 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 806 } 807 } else { 808 // just CHECK_LHS_BOUNDARY 809 if (lhs_vert + 3 < m_size) { 810 if ((threadIdx.y/4+k+24) < k_size) { 811 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 812 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 813 lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 814 lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24)); 815 } else if ((threadIdx.y/4+k+16) < k_size) { 816 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 817 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 818 lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16)); 819 } else if ((threadIdx.y/4+k+8) < k_size) { 820 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 821 lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8)); 822 } else if ((threadIdx.y/4+k) < k_size) { 823 lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k)); 824 } 825 } else if (lhs_vert + 2 < m_size) { 826 if ((threadIdx.y/4+k+24) < k_size) { 827 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 828 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 829 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 830 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 831 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 832 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); 833 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 834 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 835 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16)); 836 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); 837 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24)); 838 lhs_pf3.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+24)); 839 } else if ((threadIdx.y/4+k+16) < k_size) { 840 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 841 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 842 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 843 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 844 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 845 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); 846 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 847 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 848 lhs_pf2.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+16)); 849 } else if ((threadIdx.y/4+k+8) < k_size) { 850 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 851 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 852 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 853 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 854 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 855 lhs_pf1.z =lhs(lhs_vert + 2, (threadIdx.y/4+k+8)); 856 } else if ((threadIdx.y/4+k) < k_size) { 857 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 858 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 859 lhs_pf0.z =lhs(lhs_vert + 2, (threadIdx.y/4+k)); 860 } 861 } else if (lhs_vert + 1 < m_size) { 862 if ((threadIdx.y/4+k+24) < k_size) { 863 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 864 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 865 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 866 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 867 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 868 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 869 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); 870 lhs_pf3.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+24)); 871 } else if ((threadIdx.y/4+k+16) < k_size) { 872 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 873 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 874 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 875 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 876 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 877 lhs_pf2.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+16)); 878 } else if ((threadIdx.y/4+k+8) < k_size) { 879 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 880 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 881 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 882 lhs_pf1.y =lhs(lhs_vert + 1, (threadIdx.y/4+k+8)); 883 } else if ((threadIdx.y/4+k) < k_size) { 884 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 885 lhs_pf0.y =lhs(lhs_vert + 1, (threadIdx.y/4+k)); 886 } 887 } else if (lhs_vert < m_size) { 888 if ((threadIdx.y/4+k+24) < k_size) { 889 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 890 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 891 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 892 lhs_pf3.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+24)); 893 } else if ((threadIdx.y/4+k+16) < k_size) { 894 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 895 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 896 lhs_pf2.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+16)); 897 } else if ((threadIdx.y/4+k+8) < k_size) { 898 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 899 lhs_pf1.x =lhs(lhs_vert + 0, (threadIdx.y/4+k+8)); 900 } else if ((threadIdx.y/4+k) < k_size) { 901 lhs_pf0.x =lhs(lhs_vert + 0, (threadIdx.y/4+k)); 902 } 903 } 904 } 905 __syncthreads(); 906 Index rhs_vert = k+threadIdx.x*4; 907 Index rhs_horiz0 = threadIdx.y*2+base_n; 908 Index rhs_horiz1 = threadIdx.y*2+1+base_n; 909 if (!CHECK_RHS_BOUNDARY) { 910 if ((rhs_vert + 3) < k_size) { 911 // just CHECK_RHS_BOUNDARY 912 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 913 rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1); 914 } else if (rhs_vert + 2 < k_size) { 915 // just CHECK_RHS_BOUNDARY 916 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 917 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 918 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 919 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 920 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 921 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1); 922 } else if (rhs_vert + 1 < k_size) { 923 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 924 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 925 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 926 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 927 } else if (rhs_vert < k_size) { 928 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 929 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 930 } 931 } else { 932 if (rhs_horiz1 < n_size) { 933 if ((rhs_vert + 3) < k_size) { 934 // just CHECK_RHS_BOUNDARY 935 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 936 rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1); 937 } else if (rhs_vert + 2 < k_size) { 938 // just CHECK_RHS_BOUNDARY 939 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 940 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 941 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 942 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 943 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 944 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1); 945 } else if (k+threadIdx.x*4 + 1 < k_size) { 946 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 947 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 948 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 949 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1); 950 } else if (k+threadIdx.x*4 < k_size) { 951 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 952 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1); 953 } 954 } else if (rhs_horiz0 < n_size) { 955 if ((rhs_vert + 3) < k_size) { 956 // just CHECK_RHS_BOUNDARY 957 rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0); 958 } else if ((rhs_vert + 2) < k_size) { 959 // just CHECK_RHS_BOUNDARY 960 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 961 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 962 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0); 963 } else if ((rhs_vert + 1) < k_size) { 964 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 965 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0); 966 } else if (rhs_vert < k_size) { 967 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0); 968 } 969 } 970 } 971 __syncthreads(); 972 // Loaded. Do computation 973 // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1. 974 // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3. 975 // .. 976 // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63 977 rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x); 978 // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1. 979 // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3. 980 // .. 981 rhs_shmem2[threadIdx.y+32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y); 982 // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1. 983 // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3. 984 rhs_shmem2[threadIdx.y+64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z); 985 // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1. 986 // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3. 987 rhs_shmem2[threadIdx.y+96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w); 988 989 // LHS. 990 // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125) 991 // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125) 992 // ... 993 // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127) 994 // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127) 995 996 997 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4)\ 998 results[0].x += a_feat1.x * f1.x;\ 999 results[1].x += a_feat1.x * f1.y;\ 1000 results[2].x += a_feat1.x * f2.x;\ 1001 results[3].x += a_feat1.x * f2.y;\ 1002 results[4].x += a_feat1.x * f3.x;\ 1003 results[5].x += a_feat1.x * f3.y;\ 1004 results[6].x += a_feat1.x * f4.x;\ 1005 results[7].x += a_feat1.x * f4.y;\ 1006 \ 1007 results[0].y += a_feat1.y * f1.x;\ 1008 results[1].y += a_feat1.y * f1.y;\ 1009 results[2].y += a_feat1.y * f2.x;\ 1010 results[3].y += a_feat1.y * f2.y;\ 1011 results[4].y += a_feat1.y * f3.x;\ 1012 results[5].y += a_feat1.y * f3.y;\ 1013 results[6].y += a_feat1.y * f4.x;\ 1014 results[7].y += a_feat1.y * f4.y;\ 1015 \ 1016 results[0].z += a_feat2.x * f1.x;\ 1017 results[1].z += a_feat2.x * f1.y;\ 1018 results[2].z += a_feat2.x * f2.x;\ 1019 results[3].z += a_feat2.x * f2.y;\ 1020 results[4].z += a_feat2.x * f3.x;\ 1021 results[5].z += a_feat2.x * f3.y;\ 1022 results[6].z += a_feat2.x * f4.x;\ 1023 results[7].z += a_feat2.x * f4.y;\ 1024 \ 1025 results[0].w += a_feat2.y * f1.x;\ 1026 results[1].w += a_feat2.y * f1.y;\ 1027 results[2].w += a_feat2.y * f2.x;\ 1028 results[3].w += a_feat2.y * f2.y;\ 1029 results[4].w += a_feat2.y * f3.x;\ 1030 results[5].w += a_feat2.y * f3.y;\ 1031 results[6].w += a_feat2.y * f4.x;\ 1032 results[7].w += a_feat2.y * f4.y;\ 1033 1034 lhs_shmem2[threadIdx.y/4][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.x, lhs_pf0.y); 1035 lhs_shmem2[threadIdx.y/4+8][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.x, lhs_pf1.y); 1036 lhs_shmem2[threadIdx.y/4+16][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.x, lhs_pf2.y); 1037 lhs_shmem2[threadIdx.y/4+24][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.x, lhs_pf3.y); 1038 1039 lhs_shmem2[threadIdx.y/4 + 32][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf0.z, lhs_pf0.w); 1040 lhs_shmem2[threadIdx.y/4 + 40][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf1.z, lhs_pf1.w); 1041 lhs_shmem2[threadIdx.y/4 + 48][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf2.z, lhs_pf2.w); 1042 lhs_shmem2[threadIdx.y/4 + 56][threadIdx.x+(threadIdx.y%4)*8] = make_float2(lhs_pf3.z, lhs_pf3.w); 1043 1044 __syncthreads(); 1045 1046 // Do the multiplies. 1047 #pragma unroll 1048 for (int koff = 0; koff < 32; koff ++) { 1049 float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8]; 1050 float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8]; 1051 1052 // first feature is at (threadIdx.y/4) * 8 last is at start + 8. 1053 int start_feature = (threadIdx.y / 4) * 8; 1054 1055 float2 br1 = rhs_shmem2[start_feature/2 + (koff % 4) * 32][koff/4]; 1056 float2 br2 = rhs_shmem2[start_feature/2 + 1 + (koff % 4) * 32][koff/4]; 1057 float2 br3 = rhs_shmem2[start_feature/2 + 2 + (koff % 4) * 32][koff/4]; 1058 float2 br4 = rhs_shmem2[start_feature/2 + 3 + (koff % 4) * 32][koff/4]; 1059 1060 add_vals(a3, a4, br1, br2, br3, br4) 1061 } 1062 __syncthreads(); 1063 } // end loop over k 1064 1065 1066 __syncthreads(); 1067 Index horiz_base = (threadIdx.y/4)*8+base_n; 1068 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { 1069 for (int i = 0; i < 8; i++) { 1070 output(lhs_vert, horiz_base + i) = results[i].x; 1071 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1072 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1073 output(lhs_vert + 3, horiz_base + i) = results[i].w; 1074 } 1075 } else if (!CHECK_RHS_BOUNDARY) { 1076 if (lhs_vert + 3 < m_size) { 1077 for (int i = 0; i < 8; i++) { 1078 output(lhs_vert, horiz_base + i) = results[i].x; 1079 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1080 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1081 output(lhs_vert + 3, horiz_base + i) = results[i].w; 1082 } 1083 } else if (lhs_vert + 2 < m_size) { 1084 for (int i = 0; i < 8; i++) { 1085 output(lhs_vert, horiz_base + i) = results[i].x; 1086 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1087 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1088 } 1089 } else if (lhs_vert + 1 < m_size) { 1090 for (int i = 0; i < 8; i++) { 1091 output(lhs_vert, horiz_base + i) = results[i].x; 1092 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1093 } 1094 } else if (lhs_vert < m_size) { 1095 for (int i = 0; i < 8; i++) { 1096 output(lhs_vert, horiz_base + i) = results[i].x; 1097 } 1098 } 1099 } else if (!CHECK_LHS_BOUNDARY) { 1100 // CHECK BOUNDARY_B 1101 for (int i = 0; i < 8; i++) { 1102 if (horiz_base + i < n_size) { 1103 output(lhs_vert, horiz_base + i) = results[i].x; 1104 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1105 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1106 output(lhs_vert + 3, horiz_base + i) = results[i].w; 1107 } 1108 } 1109 } else { 1110 // CHECK both boundaries. 1111 for (int i = 0; i < 8; i++) { 1112 if (horiz_base + i < n_size) { 1113 if (lhs_vert < m_size) 1114 output(lhs_vert, horiz_base + i) = results[i].x; 1115 if (lhs_vert + 1 < m_size) 1116 output(lhs_vert + 1, horiz_base + i) = results[i].y; 1117 if (lhs_vert + 2 < m_size) 1118 output(lhs_vert + 2, horiz_base + i) = results[i].z; 1119 if (lhs_vert + 3 < m_size) 1120 output(lhs_vert + 3, horiz_base + i) = results[i].w; 1121 } 1122 } 1123 } 1124 } 1125 1126 1127 template<typename Index, typename LhsMapper, 1128 typename RhsMapper, typename OutputMapper> 1129 __global__ void 1130 __launch_bounds__(256) 1131 EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, 1132 const OutputMapper output, 1133 const Index m_size, const Index n_size, const Index k_size) { 1134 __shared__ float2 lhs_shmem[64*32]; 1135 __shared__ float2 rhs_shmem[128*8]; 1136 1137 typedef float2 LHS_MEM[64][32]; 1138 typedef float2 RHS_MEM[128][8]; 1139 1140 typedef float2 LHS_MEM16x16[32][16]; 1141 typedef float2 RHS_MEM16x16[64][8]; 1142 1143 const Index m_block_idx = blockIdx.x; 1144 const Index n_block_idx = blockIdx.y; 1145 1146 const Index base_m = 128 * m_block_idx; 1147 const Index base_n = 64 * n_block_idx; 1148 1149 bool check_rhs = (base_n + 63) >= n_size; 1150 bool check_lhs128 = (base_m + 127) >= m_size; 1151 1152 if (!check_rhs) { 1153 if (!check_lhs128) { 1154 // >= 128 rows left 1155 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>( 1156 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 1157 } else { 1158 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>( 1159 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 1160 } 1161 } else { 1162 if (!check_lhs128) { 1163 // >= 128 rows left 1164 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>( 1165 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 1166 } else { 1167 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>( 1168 lhs, rhs, output, *((LHS_MEM *) lhs_shmem), *((RHS_MEM *) rhs_shmem), m_size, n_size, k_size, base_m, base_n); 1169 } 1170 } 1171 } 1172 1173 template<typename Index, typename LhsMapper, 1174 typename RhsMapper, typename OutputMapper> 1175 __global__ void 1176 __launch_bounds__(256) 1177 EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, 1178 const OutputMapper output, 1179 const Index m_size, const Index n_size, const Index k_size) { 1180 __shared__ float2 lhs_shmem[32][16]; 1181 __shared__ float2 rhs_shmem[64][8]; 1182 1183 const Index m_block_idx = blockIdx.x; 1184 const Index n_block_idx = blockIdx.y; 1185 1186 const Index base_m = 64 * m_block_idx; 1187 const Index base_n = 64 * n_block_idx; 1188 1189 if (base_m + 63 < m_size) { 1190 if (base_n + 63 < n_size) { 1191 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 1192 } else { 1193 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 1194 } 1195 } else { 1196 if (base_n + 63 < n_size) { 1197 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 1198 } else { 1199 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n); 1200 } 1201 } 1202 } 1203 1204 1205 template<typename Indices, typename LeftArgType, typename RightArgType> 1206 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> : 1207 public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, GpuDevice> > { 1208 1209 typedef GpuDevice Device; 1210 1211 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; 1212 typedef TensorContractionEvaluatorBase<Self> Base; 1213 1214 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; 1215 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 1216 typedef typename XprType::Index Index; 1217 typedef typename XprType::CoeffReturnType CoeffReturnType; 1218 typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType; 1219 1220 enum { 1221 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 1222 }; 1223 1224 // Most of the code is assuming that both input tensors are ColMajor. If the 1225 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: 1226 // If we want to compute A * B = C, where A is LHS and B is RHS, the code 1227 // will pretend B is LHS and A is RHS. 1228 typedef typename internal::conditional< 1229 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; 1230 typedef typename internal::conditional< 1231 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; 1232 1233 static const int LDims = 1234 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; 1235 static const int RDims = 1236 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; 1237 static const int ContractDims = internal::array_size<Indices>::value; 1238 1239 typedef array<Index, LDims> left_dim_mapper_t; 1240 typedef array<Index, RDims> right_dim_mapper_t; 1241 1242 typedef array<Index, ContractDims> contract_t; 1243 typedef array<Index, LDims - ContractDims> left_nocontract_t; 1244 typedef array<Index, RDims - ContractDims> right_nocontract_t; 1245 1246 static const int NumDims = LDims + RDims - 2 * ContractDims; 1247 1248 typedef DSizes<Index, NumDims> Dimensions; 1249 1250 // typedefs needed in evalTo 1251 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; 1252 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; 1253 1254 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; 1255 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; 1256 1257 typedef typename LeftEvaluator::Dimensions LeftDimensions; 1258 typedef typename RightEvaluator::Dimensions RightDimensions; 1259 1260 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : 1261 Base(op, device) {} 1262 1263 // We need to redefine this method to make nvcc happy 1264 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { 1265 this->m_leftImpl.evalSubExprsIfNeeded(NULL); 1266 this->m_rightImpl.evalSubExprsIfNeeded(NULL); 1267 if (data) { 1268 evalTo(data); 1269 return false; 1270 } else { 1271 this->m_result = static_cast<Scalar *>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar))); 1272 evalTo(this->m_result); 1273 return true; 1274 } 1275 } 1276 1277 void evalTo(Scalar* buffer) const { 1278 if (this->m_lhs_inner_dim_contiguous) { 1279 if (this->m_rhs_inner_dim_contiguous) { 1280 if (this->m_rhs_inner_dim_reordered) { 1281 evalTyped<true, true, true, Unaligned>(buffer); 1282 } 1283 else { 1284 evalTyped<true, true, false, Unaligned>(buffer); 1285 } 1286 } 1287 else { 1288 if (this->m_rhs_inner_dim_reordered) { 1289 evalTyped<true, false, true, Unaligned>(buffer); 1290 } 1291 else { 1292 evalTyped<true, false, false, Unaligned>(buffer); 1293 } 1294 } 1295 } 1296 else { 1297 if (this->m_rhs_inner_dim_contiguous) { 1298 if (this->m_rhs_inner_dim_reordered) { 1299 evalTyped<false, true, true, Unaligned>(buffer); 1300 } 1301 else { 1302 evalTyped<false, true, false, Unaligned>(buffer); 1303 } 1304 } 1305 else { 1306 if (this->m_rhs_inner_dim_reordered) { 1307 evalTyped<false, false, true, Unaligned>(buffer); 1308 } 1309 else { 1310 evalTyped<false, false, false, Unaligned>(buffer); 1311 } 1312 } 1313 } 1314 } 1315 1316 template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels { 1317 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) { 1318 const Index m_blocks = (m + 63) / 64; 1319 const Index n_blocks = (n + 63) / 64; 1320 const dim3 num_blocks(m_blocks, n_blocks, 1); 1321 const dim3 block_size(8, 8, 8); 1322 LAUNCH_CUDA_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); 1323 } 1324 }; 1325 1326 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper> struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> { 1327 static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k, const GpuDevice& device) { 1328 if (m < 768 || n < 768) { 1329 const Index m_blocks = (m + 63) / 64; 1330 const Index n_blocks = (n + 63) / 64; 1331 const dim3 num_blocks(m_blocks, n_blocks, 1); 1332 const dim3 block_size(16, 16, 1); 1333 LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); 1334 } else { 1335 const Index m_blocks = (m + 127) / 128; 1336 const Index n_blocks = (n + 63) / 64; 1337 const dim3 num_blocks(m_blocks, n_blocks, 1); 1338 const dim3 block_size(8, 32, 1); 1339 LAUNCH_CUDA_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks, block_size, 0, device, lhs, rhs, output, m, n, k); 1340 } 1341 } 1342 }; 1343 1344 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 1345 void evalTyped(Scalar* buffer) const { 1346 // columns in left side, rows in right side 1347 const Index k = this->m_k_size; 1348 EIGEN_UNUSED_VARIABLE(k) 1349 1350 // rows in left side 1351 const Index m = this->m_i_size; 1352 1353 // columns in right side 1354 const Index n = this->m_j_size; 1355 1356 // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) 1357 this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); 1358 1359 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, 1360 LeftEvaluator, left_nocontract_t, 1361 contract_t, 4, 1362 lhs_inner_dim_contiguous, 1363 false, Unaligned> LhsMapper; 1364 1365 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, 1366 RightEvaluator, right_nocontract_t, 1367 contract_t, 4, 1368 rhs_inner_dim_contiguous, 1369 rhs_inner_dim_reordered, Unaligned> RhsMapper; 1370 1371 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; 1372 1373 1374 // initialize data mappers 1375 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, 1376 this->m_left_contracting_strides, this->m_k_strides); 1377 1378 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, 1379 this->m_right_contracting_strides, this->m_k_strides); 1380 1381 OutputMapper output(buffer, m); 1382 1383 setCudaSharedMemConfig(cudaSharedMemBankSizeEightByte); 1384 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k, this->m_device); 1385 } 1386 }; 1387 1388 } // end namespace Eigen 1389 1390 #endif // EIGEN_USE_GPU and __CUDACC__ 1391 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_CUDA_H 1392