Home | History | Annotate | Download | only in sm_35
      1 //
      2 // Copyright 2016 Google Inc.
      3 //
      4 // Use of this source code is governed by a BSD-style
      5 // license that can be found in the LICENSE file.
      6 //
      7 
      8 #ifndef HS_CUDA_MACROS_ONCE
      9 #define HS_CUDA_MACROS_ONCE
     10 
     11 //
     12 //
     13 //
     14 
     15 #ifdef __cplusplus
     16 extern "C" {
     17 #endif
     18 
     19 #include <stdint.h>
     20 
     21 #ifdef __cplusplus
     22 }
     23 #endif
     24 
     25 //
     26 // Define the type based on key and val sizes
     27 //
     28 
     29 #if   HS_KEY_WORDS == 1
     30 #if   HS_VAL_WORDS == 0
     31 #define HS_KEY_TYPE  uint32_t
     32 #endif
     33 #elif HS_KEY_WORDS == 2
     34 #define HS_KEY_TYPE  uint64_t
     35 #endif
     36 
     37 //
     38 // FYI, restrict shouldn't have any impact on these kernels and
     39 // benchmarks appear to prove that true
     40 //
     41 
     42 #define HS_RESTRICT  __restrict__
     43 
     44 //
     45 //
     46 //
     47 
     48 #define HS_SCOPE()                              \
     49   static
     50 
     51 #define HS_KERNEL_QUALIFIER()                   \
     52   __global__ void
     53 
     54 //
     55 // The sm_35 arch has a maximum of 16 blocks per multiprocessor.  Just
     56 // clamp it to 16 when targeting this arch.
     57 //
     58 // This only arises when compiling the 32-bit sorting kernels.
     59 //
     60 // You can also generate a narrower 16-warp wide 32-bit sorting kernel
     61 // which is sometimes faster and sometimes slower than the 32-block
     62 // configuration.
     63 //
     64 
     65 #if ( __CUDA_ARCH__ == 350 )
     66 #define HS_CUDA_MAX_BPM  16
     67 #else
     68 #define HS_CUDA_MAX_BPM  UINT32_MAX // 32
     69 #endif
     70 
     71 #define HS_CLAMPED_BPM(min_bpm)                                 \
     72   ((min_bpm) < HS_CUDA_MAX_BPM ? (min_bpm) : HS_CUDA_MAX_BPM)
     73 
     74 //
     75 //
     76 //
     77 
     78 #define HS_LAUNCH_BOUNDS(max_tpb,min_bpm)       \
     79   __launch_bounds__(max_tpb,HS_CLAMPED_BPM(min_bpm))
     80 
     81 //
     82 // KERNEL PROTOS
     83 //
     84 
     85 #define HS_BS_KERNEL_NAME(slab_count_ru_log2)   \
     86   hs_kernel_bs_##slab_count_ru_log2
     87 
     88 #define HS_BS_KERNEL_PROTO(slab_count,slab_count_ru_log2)             \
     89   HS_SCOPE()                                                          \
     90   HS_KERNEL_QUALIFIER()                                               \
     91   HS_LAUNCH_BOUNDS(HS_SLAB_THREADS*slab_count,1)                      \
     92   HS_BS_KERNEL_NAME(slab_count_ru_log2)(HS_KEY_TYPE       * const HS_RESTRICT vout, \
     93                                         HS_KEY_TYPE const * const HS_RESTRICT vin)
     94 
     95 //
     96 
     97 #define HS_OFFSET_BS_KERNEL_NAME(slab_count_ru_log2)    \
     98   hs_kernel_bs_##slab_count_ru_log2
     99 
    100 #define HS_OFFSET_BS_KERNEL_PROTO(slab_count,slab_count_ru_log2)              \
    101   HS_SCOPE()                                                                  \
    102   HS_KERNEL_QUALIFIER()                                                       \
    103   HS_LAUNCH_BOUNDS(HS_SLAB_THREADS*slab_count,HS_BS_SLABS/(1<<slab_count_ru_log2)) \
    104   HS_OFFSET_BS_KERNEL_NAME(slab_count_ru_log2)(HS_KEY_TYPE       * const HS_RESTRICT vout, \
    105                                                HS_KEY_TYPE const * const HS_RESTRICT vin,  \
    106                                                uint32_t            const             slab_offset)
    107 
    108 //
    109 
    110 #define HS_BC_KERNEL_NAME(slab_count_log2)      \
    111   hs_kernel_bc_##slab_count_log2
    112 
    113 #define HS_BC_KERNEL_PROTO(slab_count,slab_count_log2)                \
    114   HS_SCOPE()                                                          \
    115   HS_KERNEL_QUALIFIER()                                               \
    116   HS_LAUNCH_BOUNDS(HS_SLAB_THREADS*slab_count,HS_BS_SLABS/(1<<slab_count_log2)) \
    117   HS_BC_KERNEL_NAME(slab_count_log2)(HS_KEY_TYPE * const HS_RESTRICT vout)
    118 
    119 //
    120 
    121 #define HS_HM_KERNEL_NAME(s)                    \
    122   hs_kernel_hm_##s
    123 
    124 #define HS_HM_KERNEL_PROTO(s)                                 \
    125   HS_SCOPE()                                                  \
    126   HS_KERNEL_QUALIFIER()                                       \
    127   HS_HM_KERNEL_NAME(s)(HS_KEY_TYPE * const HS_RESTRICT vout)
    128 
    129 //
    130 
    131 #define HS_FM_KERNEL_NAME(s,r)                  \
    132   hs_kernel_fm_##s##_##r
    133 
    134 #define HS_FM_KERNEL_PROTO(s,r)                                      \
    135   HS_SCOPE()                                                         \
    136   HS_KERNEL_QUALIFIER()                                              \
    137   HS_FM_KERNEL_NAME(s,r)(HS_KEY_TYPE * const HS_RESTRICT vout)
    138 
    139 //
    140 
    141 #define HS_OFFSET_FM_KERNEL_NAME(s,r)           \
    142   hs_kernel_fm_##s##_##r
    143 
    144 #define HS_OFFSET_FM_KERNEL_PROTO(s,r)                                \
    145   HS_SCOPE()                                                          \
    146   HS_KERNEL_QUALIFIER()                                               \
    147   HS_OFFSET_FM_KERNEL_NAME(s,r)(HS_KEY_TYPE * const HS_RESTRICT vout, \
    148                                 uint32_t      const             span_offset)
    149 
    150 //
    151 
    152 #define HS_TRANSPOSE_KERNEL_NAME()              \
    153   hs_kernel_transpose
    154 
    155 #define HS_TRANSPOSE_KERNEL_PROTO()                             \
    156   HS_SCOPE()                                                    \
    157   HS_KERNEL_QUALIFIER()                                         \
    158   HS_LAUNCH_BOUNDS(HS_SLAB_THREADS,1)                           \
    159   HS_TRANSPOSE_KERNEL_NAME()(HS_KEY_TYPE * const HS_RESTRICT vout)
    160 
    161 //
    162 // BLOCK LOCAL MEMORY DECLARATION
    163 //
    164 
    165 #define HS_BLOCK_LOCAL_MEM_DECL(width,height)   \
    166   __shared__ struct {                           \
    167     HS_KEY_TYPE m[width * height];              \
    168   } shared
    169 
    170 //
    171 // BLOCK BARRIER
    172 //
    173 
    174 #define HS_BLOCK_BARRIER()                      \
    175   __syncthreads()
    176 
    177 //
    178 // GRID VARIABLES
    179 //
    180 
    181 #define HS_GLOBAL_SIZE_X() (gridDim.x * blockDim.x)
    182 #define HS_GLOBAL_ID_X()   (blockDim.x * blockIdx.x + threadIdx.x)
    183 #define HS_LOCAL_ID_X()    threadIdx.x
    184 #define HS_WARP_ID_X()     (threadIdx.x / 32)
    185 #define HS_LANE_ID()       (threadIdx.x & 31)
    186 
    187 //
    188 // SLAB GLOBAL
    189 //
    190 
    191 #define HS_SLAB_GLOBAL_PREAMBLE()               \
    192   uint32_t const gmem_idx =                     \
    193     (HS_GLOBAL_ID_X() & ~(HS_SLAB_THREADS-1)) * \
    194     HS_SLAB_HEIGHT + HS_LANE_ID()
    195 
    196 #define HS_OFFSET_SLAB_GLOBAL_PREAMBLE()                        \
    197   uint32_t const gmem_idx =                                     \
    198     ((slab_offset + HS_GLOBAL_ID_X()) & ~(HS_SLAB_THREADS-1)) * \
    199     HS_SLAB_HEIGHT + HS_LANE_ID()
    200 
    201 #define HS_SLAB_GLOBAL_LOAD(extent,row_idx)  \
    202   extent[gmem_idx + HS_SLAB_THREADS * row_idx]
    203 
    204 #define HS_SLAB_GLOBAL_STORE(row_idx,reg)    \
    205   vout[gmem_idx + HS_SLAB_THREADS * row_idx] = reg
    206 
    207 //
    208 // SLAB LOCAL
    209 //
    210 
    211 #define HS_SLAB_LOCAL_L(offset)                 \
    212   shared.m[smem_l_idx + (offset)]
    213 
    214 #define HS_SLAB_LOCAL_R(offset)                 \
    215   shared.m[smem_r_idx + (offset)]
    216 
    217 //
    218 // SLAB LOCAL VERTICAL LOADS
    219 //
    220 
    221 #define HS_BX_LOCAL_V(offset)                   \
    222   shared.m[HS_LOCAL_ID_X() + (offset)]
    223 
    224 //
    225 // BLOCK SORT MERGE HORIZONTAL
    226 //
    227 
    228 #define HS_BS_MERGE_H_PREAMBLE(slab_count)                      \
    229   uint32_t const smem_l_idx =                                   \
    230     HS_WARP_ID_X() * (HS_SLAB_THREADS * slab_count) +           \
    231     HS_LANE_ID();                                               \
    232   uint32_t const smem_r_idx =                                   \
    233     (HS_WARP_ID_X() ^ 1) * (HS_SLAB_THREADS * slab_count) +     \
    234     (HS_LANE_ID() ^ (HS_SLAB_THREADS - 1))
    235 
    236 //
    237 // BLOCK CLEAN MERGE HORIZONTAL
    238 //
    239 
    240 #define HS_BC_MERGE_H_PREAMBLE(slab_count)                      \
    241   uint32_t const gmem_l_idx =                                   \
    242     (HS_GLOBAL_ID_X() & ~(HS_SLAB_THREADS*slab_count-1)) *      \
    243     HS_SLAB_HEIGHT + HS_LOCAL_ID_X();                           \
    244   uint32_t const smem_l_idx =                                   \
    245     HS_WARP_ID_X() * (HS_SLAB_THREADS * slab_count) +           \
    246     HS_LANE_ID()
    247 
    248 #define HS_BC_GLOBAL_LOAD_L(slab_idx)                   \
    249   vout[gmem_l_idx + (HS_SLAB_THREADS * slab_idx)]
    250 
    251 //
    252 // SLAB FLIP AND HALF PREAMBLES
    253 //
    254 
    255 #define HS_SLAB_FLIP_PREAMBLE(mask)                             \
    256   uint32_t const flip_lane_idx  = HS_LANE_ID() ^ mask;          \
    257   int32_t  const t_lt           = HS_LANE_ID() < flip_lane_idx;
    258 
    259 // if we want to shlf_xor: uint32_t const flip_lane_mask = mask;
    260 
    261 #define HS_SLAB_HALF_PREAMBLE(mask)                             \
    262   uint32_t const half_lane_idx  = HS_LANE_ID() ^ mask;          \
    263   int32_t  const t_lt           = HS_LANE_ID() < half_lane_idx;
    264 
    265 // if we want to shfl_xor: uint32_t const half_lane_mask = mask;
    266 
    267 //
    268 // Inter-lane compare exchange
    269 //
    270 
    271 // good
    272 #define HS_CMP_XCHG_V0(a,b)                     \
    273   {                                             \
    274     HS_KEY_TYPE const t = min(a,b);             \
    275     b = max(a,b);                               \
    276     a = t;                                      \
    277   }
    278 
    279 // surprisingly fast -- #1 on 64-bit keys
    280 #define HS_CMP_XCHG_V1(a,b)                     \
    281   {                                             \
    282     HS_KEY_TYPE const tmp = a;                  \
    283     a  = (a < b) ? a : b;                       \
    284     b ^= a ^ tmp;                               \
    285   }
    286 
    287 // good
    288 #define HS_CMP_XCHG_V2(a,b)                     \
    289   if (a >= b) {                                 \
    290     HS_KEY_TYPE const t = a;                    \
    291     a = b;                                      \
    292     b = t;                                      \
    293   }
    294 
    295 // good
    296 #define HS_CMP_XCHG_V3(a,b)                     \
    297   {                                             \
    298     int32_t     const ge = a >= b;              \
    299     HS_KEY_TYPE const t  = a;                   \
    300     a = ge ? b : a;                             \
    301     b = ge ? t : b;                             \
    302   }
    303 
    304 //
    305 //
    306 //
    307 
    308 #if   (HS_KEY_WORDS == 1)
    309 #define HS_CMP_XCHG(a,b)  HS_CMP_XCHG_V0(a,b)
    310 #elif (HS_KEY_WORDS == 2)
    311 #define HS_CMP_XCHG(a,b)  HS_CMP_XCHG_V0(a,b)
    312 #endif
    313 
    314 //
    315 // The flip/half comparisons rely on a "conditional min/max":
    316 //
    317 //  - if the flag is false, return min(a,b)
    318 //  - otherwise, return max(a,b)
    319 //
    320 // What's a little surprising is that sequence (1) is faster than (2)
    321 // for 32-bit keys.
    322 //
    323 // I suspect either a code generation problem or that the sequence
    324 // maps well to the GEN instruction set.
    325 //
    326 // We mostly care about 64-bit keys and unsurprisingly sequence (2) is
    327 // fastest for this wider type.
    328 //
    329 
    330 // this is what you would normally use
    331 #define HS_COND_MIN_MAX_V0(lt,a,b) ((a <= b) ^ lt) ? b : a
    332 
    333 // this seems to be faster for 32-bit keys
    334 #define HS_COND_MIN_MAX_V1(lt,a,b) (lt ? b : a) ^ ((a ^ b) & HS_LTE_TO_MASK(a,b))
    335 
    336 //
    337 //
    338 //
    339 
    340 #if   (HS_KEY_WORDS == 1)
    341 #define HS_COND_MIN_MAX(lt,a,b) HS_COND_MIN_MAX_V0(lt,a,b)
    342 #elif (HS_KEY_WORDS == 2)
    343 #define HS_COND_MIN_MAX(lt,a,b) HS_COND_MIN_MAX_V0(lt,a,b)
    344 #endif
    345 
    346 //
    347 // HotSort shuffles are always warp-wide
    348 //
    349 
    350 #define HS_SHFL_ALL 0xFFFFFFFF
    351 
    352 //
    353 // Conditional inter-subgroup flip/half compare exchange
    354 //
    355 
    356 #define HS_CMP_FLIP(i,a,b)                                              \
    357   {                                                                     \
    358     HS_KEY_TYPE const ta = __shfl_sync(HS_SHFL_ALL,a,flip_lane_idx);    \
    359     HS_KEY_TYPE const tb = __shfl_sync(HS_SHFL_ALL,b,flip_lane_idx);    \
    360     a = HS_COND_MIN_MAX(t_lt,a,tb);                                     \
    361     b = HS_COND_MIN_MAX(t_lt,b,ta);                                     \
    362   }
    363 
    364 #define HS_CMP_HALF(i,a)                                                \
    365   {                                                                     \
    366     HS_KEY_TYPE const ta = __shfl_sync(HS_SHFL_ALL,a,half_lane_idx);    \
    367     a = HS_COND_MIN_MAX(t_lt,a,ta);                                     \
    368   }
    369 
    370 //
    371 // The device's comparison operator might return what we actually
    372 // want.  For example, it appears GEN 'cmp' returns {true:-1,false:0}.
    373 //
    374 
    375 #define HS_CMP_IS_ZERO_ONE
    376 
    377 #ifdef HS_CMP_IS_ZERO_ONE
    378 // OpenCL requires a {true: +1, false: 0} scalar result
    379 // (a < b) -> { +1, 0 } -> NEGATE -> { 0, 0xFFFFFFFF }
    380 #define HS_LTE_TO_MASK(a,b) (HS_KEY_TYPE)(-(a <= b))
    381 #define HS_CMP_TO_MASK(a)   (HS_KEY_TYPE)(-a)
    382 #else
    383 // However, OpenCL requires { -1, 0 } for vectors
    384 // (a < b) -> { 0xFFFFFFFF, 0 }
    385 #define HS_LTE_TO_MASK(a,b) (a <= b) // FIXME for uint64
    386 #define HS_CMP_TO_MASK(a)   (a)
    387 #endif
    388 
    389 //
    390 // The "flip-merge" and "half-merge" preambles are very similar
    391 //
    392 // For now, we're only using the .y dimension for the span idx
    393 //
    394 
    395 #define HS_OFFSET_HM_PREAMBLE(half_span,span_offset)                    \
    396   uint32_t const span_idx    = span_offset + blockIdx.y;                \
    397   uint32_t const span_stride = HS_GLOBAL_SIZE_X();                      \
    398   uint32_t const span_size   = span_stride * half_span * 2;             \
    399   uint32_t const span_base   = span_idx * span_size;                    \
    400   uint32_t const span_off    = HS_GLOBAL_ID_X();                        \
    401   uint32_t const span_l      = span_base + span_off
    402 
    403 #define HS_HM_PREAMBLE(half_span)               \
    404   HS_OFFSET_HM_PREAMBLE(half_span,0)            \
    405 
    406 #define HS_FM_PREAMBLE(half_span)                                       \
    407   HS_HM_PREAMBLE(half_span);                                            \
    408   uint32_t const span_r = span_base + span_stride * (half_span + 1) - span_off - 1
    409 
    410 #define HS_OFFSET_FM_PREAMBLE(half_span)                                \
    411   HS_OFFSET_HM_PREAMBLE(half_span,span_offset);                         \
    412   uint32_t const span_r = span_base + span_stride * (half_span + 1) - span_off - 1
    413 
    414 //
    415 //
    416 //
    417 
    418 #define HS_XM_GLOBAL_L(stride_idx)              \
    419   vout[span_l + span_stride * stride_idx]
    420 
    421 #define HS_XM_GLOBAL_LOAD_L(stride_idx)         \
    422   HS_XM_GLOBAL_L(stride_idx)
    423 
    424 #define HS_XM_GLOBAL_STORE_L(stride_idx,reg)    \
    425   HS_XM_GLOBAL_L(stride_idx) = reg
    426 
    427 #define HS_FM_GLOBAL_R(stride_idx)              \
    428   vout[span_r + span_stride * stride_idx]
    429 
    430 #define HS_FM_GLOBAL_LOAD_R(stride_idx)         \
    431   HS_FM_GLOBAL_R(stride_idx)
    432 
    433 #define HS_FM_GLOBAL_STORE_R(stride_idx,reg)    \
    434   HS_FM_GLOBAL_R(stride_idx) = reg
    435 
    436 //
    437 // This snarl of macros is for transposing a "slab" of sorted elements
    438 // into linear order.
    439 //
    440 // This can occur as the last step in hs_sort() or via a custom kernel
    441 // that inspects the slab and then transposes and stores it to memory.
    442 //
    443 // The slab format can be inspected more efficiently than a linear
    444 // arrangement.
    445 //
    446 // The prime example is detecting when adjacent keys (in sort order)
    447 // have differing high order bits ("key changes").  The index of each
    448 // change is recorded to an auxilary array.
    449 //
    450 // A post-processing step like this needs to be able to navigate the
    451 // slab and eventually transpose and store the slab in linear order.
    452 //
    453 
    454 #define HS_SUBGROUP_SHUFFLE_XOR(v,m)   __shfl_xor_sync(HS_SHFL_ALL,v,m)
    455 
    456 #define HS_TRANSPOSE_REG(prefix,row)   prefix##row
    457 #define HS_TRANSPOSE_DECL(prefix,row)  HS_KEY_TYPE const HS_TRANSPOSE_REG(prefix,row)
    458 #define HS_TRANSPOSE_PRED(level)       is_lo_##level
    459 
    460 #define HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur)       \
    461   prefix_curr##row_ll##_##row_ur
    462 
    463 #define HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur)      \
    464   HS_KEY_TYPE const HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur)
    465 
    466 #define HS_TRANSPOSE_STAGE(level)                       \
    467   bool const HS_TRANSPOSE_PRED(level) =                 \
    468     (HS_LANE_ID() & (1 << (level-1))) == 0;
    469 
    470 #define HS_TRANSPOSE_BLEND(prefix_prev,prefix_curr,level,row_ll,row_ur) \
    471   HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur) =                    \
    472     HS_SUBGROUP_SHUFFLE_XOR(HS_TRANSPOSE_PRED(level) ?                  \
    473                             HS_TRANSPOSE_REG(prefix_prev,row_ll) :      \
    474                             HS_TRANSPOSE_REG(prefix_prev,row_ur),       \
    475                             1<<(level-1));                              \
    476                                                                         \
    477   HS_TRANSPOSE_DECL(prefix_curr,row_ll) =                               \
    478     HS_TRANSPOSE_PRED(level)                  ?                         \
    479     HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur) :                   \
    480     HS_TRANSPOSE_REG(prefix_prev,row_ll);                               \
    481                                                                         \
    482   HS_TRANSPOSE_DECL(prefix_curr,row_ur) =                               \
    483     HS_TRANSPOSE_PRED(level)                  ?                         \
    484     HS_TRANSPOSE_REG(prefix_prev,row_ur)      :                         \
    485     HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur);
    486 
    487 #define HS_TRANSPOSE_REMAP(prefix,row_from,row_to)      \
    488   vout[gmem_idx + ((row_to-1) << HS_SLAB_WIDTH_LOG2)] = \
    489     HS_TRANSPOSE_REG(prefix,row_from);
    490 
    491 //
    492 //
    493 //
    494 
    495 #endif
    496 
    497 //
    498 //
    499 //
    500