Home | History | Annotate | Download | only in cuda
      1 /*
      2  * Copyright 2016 Google Inc.
      3  *
      4  * Use of this source code is governed by a BSD-style license that can
      5  * be found in the LICENSE file.
      6  *
      7  */
      8 
      9 //
     10 //
     11 //
     12 
     13 #ifdef __cplusplus
     14 extern "C" {
     15 #endif
     16 
     17 #include "common/cuda/assert_cuda.h"
     18 #include "common/macros.h"
     19 #include "common/util.h"
     20 
     21 #ifdef __cplusplus
     22 }
     23 #endif
     24 
     25 //
     26 // We want concurrent kernel execution to occur in a few places.
     27 //
     28 // The summary is:
     29 //
     30 //   1) If necessary, some max valued keys are written to the end of
     31 //      the vin/vout buffers.
     32 //
     33 //   2) Blocks of slabs of keys are sorted.
     34 //
     35 //   3) If necesary, the blocks of slabs are merged until complete.
     36 //
     37 //   4) If requested, the slabs will be converted from slab ordering
     38 //      to linear ordering.
     39 //
     40 // Below is the general "happens-before" relationship between HotSort
     41 // compute kernels.
     42 //
     43 // Note the diagram assumes vin and vout are different buffers.  If
     44 // they're not, then the first merge doesn't include the pad_vout
     45 // event in the wait list.
     46 //
     47 //                    +----------+            +---------+
     48 //                    | pad_vout |            | pad_vin |
     49 //                    +----+-----+            +----+----+
     50 //                         |                       |
     51 //                         |                WAITFOR(pad_vin)
     52 //                         |                       |
     53 //                         |                 +-----v-----+
     54 //                         |                 |           |
     55 //                         |            +----v----+ +----v----+
     56 //                         |            | bs_full | | bs_frac |
     57 //                         |            +----+----+ +----+----+
     58 //                         |                 |           |
     59 //                         |                 +-----v-----+
     60 //                         |                       |
     61 //                         |  +------NO------JUST ONE BLOCK?
     62 //                         | /                     |
     63 //                         |/                     YES
     64 //                         +                       |
     65 //                         |                       v
     66 //                         |         END_WITH_EVENTS(bs_full,bs_frac)
     67 //                         |
     68 //                         |
     69 //        WAITFOR(pad_vout,bs_full,bs_frac) >>> first iteration of loop <<<
     70 //                         |
     71 //                         |
     72 //                         +-----------<------------+
     73 //                         |                        |
     74 //                   +-----v-----+                  |
     75 //                   |           |                  |
     76 //              +----v----+ +----v----+             |
     77 //              | fm_full | | fm_frac |             |
     78 //              +----+----+ +----+----+             |
     79 //                   |           |                  ^
     80 //                   +-----v-----+                  |
     81 //                         |                        |
     82 //              WAITFOR(fm_full,fm_frac)            |
     83 //                         |                        |
     84 //                         v                        |
     85 //                      +--v--+                WAITFOR(bc)
     86 //                      | hm  |                     |
     87 //                      +-----+                     |
     88 //                         |                        |
     89 //                    WAITFOR(hm)                   |
     90 //                         |                        ^
     91 //                      +--v--+                     |
     92 //                      | bc  |                     |
     93 //                      +-----+                     |
     94 //                         |                        |
     95 //                         v                        |
     96 //                  MERGING COMPLETE?-------NO------+
     97 //                         |
     98 //                        YES
     99 //                         |
    100 //                         v
    101 //                END_WITH_EVENTS(bc)
    102 //
    103 //
    104 // NOTE: CUDA streams are in-order so a dependency isn't required for
    105 // kernels launched on the same stream.
    106 //
    107 // This is actually a more subtle problem than it appears.
    108 //
    109 // We'll take a different approach and declare the "happens before"
    110 // kernel relationships:
    111 //
    112 //      concurrent (pad_vin,pad_vout) -> (pad_vin)  happens_before (bs_full,bs_frac)
    113 //                                       (pad_vout) happens_before (fm_full,fm_frac)
    114 //
    115 //      concurrent (bs_full,bs_frac)  -> (bs_full)  happens_before (fm_full,fm_frac)
    116 //                                       (bs_frac)  happens_before (fm_full,fm_frac)
    117 //
    118 //      concurrent (fm_full,fm_frac)  -> (fm_full)  happens_before (hm)
    119 //                                       (fm_frac)  happens_before (hm)
    120 //
    121 //      concurrent (fm_full,fm_frac)  -> (fm_full)  happens_before (hm)
    122 //                                       (fm_frac)  happens_before (hm)
    123 //
    124 //      launch     (hm)               -> (hm)       happens_before (hm)
    125 //                                       (hm)       happens_before (bc)
    126 //
    127 //      launch     (bc)               -> (bc)       happens_before (fm_full,fm_frac)
    128 //
    129 //
    130 // We can go ahead and permanently map kernel launches to our 3
    131 // streams.  As an optimization, we'll dynamically assign each kernel
    132 // to the lowest available stream.  This transforms the problem into
    133 // one that considers streams happening before streams -- which
    134 // kernels are involved doesn't matter.
    135 //
    136 //      STREAM0   STREAM1   STREAM2
    137 //      -------   -------   -------
    138 //
    139 //      pad_vin             pad_vout     (pad_vin)  happens_before (bs_full,bs_frac)
    140 //                                       (pad_vout) happens_before (fm_full,fm_frac)
    141 //
    142 //      bs_full   bs_frac                (bs_full)  happens_before (fm_full,fm_frac)
    143 //                                       (bs_frac)  happens_before (fm_full,fm_frac)
    144 //
    145 //      fm_full   fm_frac                (fm_full)  happens_before (hm or bc)
    146 //                                       (fm_frac)  happens_before (hm or bc)
    147 //
    148 //      hm                               (hm)       happens_before (hm or bc)
    149 //
    150 //      bc                               (bc)       happens_before (fm_full,fm_frac)
    151 //
    152 // A single final kernel will always complete on stream 0.
    153 //
    154 // This simplifies reasoning about concurrency that's downstream of
    155 // hs_cuda_sort().
    156 //
    157 
    158 typedef void (*hs_kernel_offset_bs_pfn)(HS_KEY_TYPE       * const HS_RESTRICT vout,
    159                                         HS_KEY_TYPE const * const HS_RESTRICT vin,
    160                                         uint32_t            const slab_offset);
    161 
    162 static hs_kernel_offset_bs_pfn const hs_kernels_offset_bs[]
    163 {
    164 #if HS_BS_SLABS_LOG2_RU >= 1
    165   hs_kernel_bs_0,
    166 #endif
    167 #if HS_BS_SLABS_LOG2_RU >= 2
    168   hs_kernel_bs_1,
    169 #endif
    170 #if HS_BS_SLABS_LOG2_RU >= 3
    171   hs_kernel_bs_2,
    172 #endif
    173 #if HS_BS_SLABS_LOG2_RU >= 4
    174   hs_kernel_bs_3,
    175 #endif
    176 #if HS_BS_SLABS_LOG2_RU >= 5
    177   hs_kernel_bs_4,
    178 #endif
    179 #if HS_BS_SLABS_LOG2_RU >= 6
    180   hs_kernel_bs_5,
    181 #endif
    182 #if HS_BS_SLABS_LOG2_RU >= 7
    183   hs_kernel_bs_6,
    184 #endif
    185 #if HS_BS_SLABS_LOG2_RU >= 8
    186   hs_kernel_bs_7,
    187 #endif
    188 };
    189 
    190 //
    191 //
    192 //
    193 
    194 typedef void (*hs_kernel_bc_pfn)(HS_KEY_TYPE * const HS_RESTRICT vout);
    195 
    196 static hs_kernel_bc_pfn const hs_kernels_bc[]
    197 {
    198   hs_kernel_bc_0,
    199 #if HS_BC_SLABS_LOG2_MAX >= 1
    200   hs_kernel_bc_1,
    201 #endif
    202 #if HS_BC_SLABS_LOG2_MAX >= 2
    203   hs_kernel_bc_2,
    204 #endif
    205 #if HS_BC_SLABS_LOG2_MAX >= 3
    206   hs_kernel_bc_3,
    207 #endif
    208 #if HS_BC_SLABS_LOG2_MAX >= 4
    209   hs_kernel_bc_4,
    210 #endif
    211 #if HS_BC_SLABS_LOG2_MAX >= 5
    212   hs_kernel_bc_5,
    213 #endif
    214 #if HS_BC_SLABS_LOG2_MAX >= 6
    215   hs_kernel_bc_6,
    216 #endif
    217 #if HS_BC_SLABS_LOG2_MAX >= 7
    218   hs_kernel_bc_7,
    219 #endif
    220 #if HS_BC_SLABS_LOG2_MAX >= 8
    221   hs_kernel_bc_8,
    222 #endif
    223 };
    224 
    225 //
    226 //
    227 //
    228 
    229 typedef void (*hs_kernel_hm_pfn)(HS_KEY_TYPE * const HS_RESTRICT vout);
    230 
    231 static hs_kernel_hm_pfn const hs_kernels_hm[]
    232 {
    233 #if (HS_HM_SCALE_MIN == 0)
    234   hs_kernel_hm_0,
    235 #endif
    236 #if (HS_HM_SCALE_MIN <= 1) && (1 <= HS_HM_SCALE_MAX)
    237   hs_kernel_hm_1,
    238 #endif
    239 #if (HS_HM_SCALE_MIN <= 2) && (2 <= HS_HM_SCALE_MAX)
    240   hs_kernel_hm_2,
    241 #endif
    242 };
    243 
    244 //
    245 //
    246 //
    247 
    248 typedef void (*hs_kernel_fm_pfn)(HS_KEY_TYPE * const HS_RESTRICT vout);
    249 
    250 static hs_kernel_fm_pfn const hs_kernels_fm[]
    251 {
    252 #if (HS_FM_SCALE_MIN == 0)
    253 #if (HS_BS_SLABS_LOG2_RU == 1)
    254   hs_kernel_fm_0_0,
    255 #endif
    256 #if (HS_BS_SLABS_LOG2_RU == 2)
    257   hs_kernel_fm_0_1,
    258 #endif
    259 #if (HS_BS_SLABS_LOG2_RU == 3)
    260   hs_kernel_fm_0_2,
    261 #endif
    262 #if (HS_BS_SLABS_LOG2_RU == 4)
    263   hs_kernel_fm_0_3,
    264 #endif
    265 #if (HS_BS_SLABS_LOG2_RU == 5)
    266   hs_kernel_fm_0_4,
    267 #endif
    268 #if (HS_BS_SLABS_LOG2_RU == 6)
    269   hs_kernel_fm_0_5,
    270 #endif
    271 #if (HS_BS_SLABS_LOG2_RU == 7)
    272   hs_kernel_fm_0_6,
    273 #endif
    274 #endif
    275 
    276 #if (HS_FM_SCALE_MIN <= 1) && (1 <= HS_FM_SCALE_MAX)
    277   CONCAT_MACRO(hs_kernel_fm_1_,HS_BS_SLABS_LOG2_RU)
    278 #endif
    279 
    280 #if (HS_FM_SCALE_MIN <= 2) && (2 <= HS_FM_SCALE_MAX)
    281 #if (HS_BS_SLABS_LOG2_RU == 1)
    282   hs_kernel_fm_2_2,
    283 #endif
    284 #if (HS_BS_SLABS_LOG2_RU == 2)
    285   hs_kernel_fm_2_3,
    286 #endif
    287 #if (HS_BS_SLABS_LOG2_RU == 3)
    288   hs_kernel_fm_2_4,
    289 #endif
    290 #if (HS_BS_SLABS_LOG2_RU == 4)
    291   hs_kernel_fm_2_5,
    292 #endif
    293 #if (HS_BS_SLABS_LOG2_RU == 5)
    294   hs_kernel_fm_2_6,
    295 #endif
    296 #if (HS_BS_SLABS_LOG2_RU == 6)
    297   hs_kernel_fm_2_7,
    298 #endif
    299 #if (HS_BS_SLABS_LOG2_RU == 7)
    300   hs_kernel_fm_2_8,
    301 #endif
    302 
    303 #endif
    304 };
    305 
    306 //
    307 //
    308 //
    309 
    310 typedef void (*hs_kernel_offset_fm_pfn)(HS_KEY_TYPE * const HS_RESTRICT vout,
    311                                         uint32_t const span_offset);
    312 
    313 #if (HS_FM_SCALE_MIN == 0)
    314 static hs_kernel_offset_fm_pfn const hs_kernels_offset_fm_0[]
    315 {
    316 #if (HS_BS_SLABS_LOG2_RU >= 2)
    317   hs_kernel_fm_0_0,
    318 #endif
    319 #if (HS_BS_SLABS_LOG2_RU >= 3)
    320   hs_kernel_fm_0_1,
    321 #endif
    322 #if (HS_BS_SLABS_LOG2_RU >= 4)
    323   hs_kernel_fm_0_2,
    324 #endif
    325 #if (HS_BS_SLABS_LOG2_RU >= 5)
    326   hs_kernel_fm_0_3,
    327 #endif
    328 #if (HS_BS_SLABS_LOG2_RU >= 6)
    329   hs_kernel_fm_0_4,
    330 #endif
    331 #if (HS_BS_SLABS_LOG2_RU >= 7)
    332   hs_kernel_fm_0_5,
    333 #endif
    334 };
    335 #endif
    336 
    337 #if (HS_FM_SCALE_MIN <= 1) && (1 <= HS_FM_SCALE_MAX)
    338 static hs_kernel_offset_fm_pfn const hs_kernels_offset_fm_1[]
    339 {
    340 #if (HS_BS_SLABS_LOG2_RU >= 1)
    341   hs_kernel_fm_1_0,
    342 #endif
    343 #if (HS_BS_SLABS_LOG2_RU >= 2)
    344   hs_kernel_fm_1_1,
    345 #endif
    346 #if (HS_BS_SLABS_LOG2_RU >= 3)
    347   hs_kernel_fm_1_2,
    348 #endif
    349 #if (HS_BS_SLABS_LOG2_RU >= 4)
    350   hs_kernel_fm_1_3,
    351 #endif
    352 #if (HS_BS_SLABS_LOG2_RU >= 5)
    353   hs_kernel_fm_1_4,
    354 #endif
    355 #if (HS_BS_SLABS_LOG2_RU >= 6)
    356   hs_kernel_fm_1_5,
    357 #endif
    358 #if (HS_BS_SLABS_LOG2_RU >= 7)
    359   hs_kernel_fm_1_6,
    360 #endif
    361 };
    362 #endif
    363 
    364 #if (HS_FM_SCALE_MIN <= 2) && (2 <= HS_FM_SCALE_MAX)
    365 static hs_kernel_offset_fm_pfn const hs_kernels_offset_fm_2[]
    366 {
    367   hs_kernel_fm_2_0,
    368 #if (HS_BS_SLABS_LOG2_RU >= 1)
    369   hs_kernel_fm_2_1,
    370 #endif
    371 #if (HS_BS_SLABS_LOG2_RU >= 2)
    372   hs_kernel_fm_2_2,
    373 #endif
    374 #if (HS_BS_SLABS_LOG2_RU >= 3)
    375   hs_kernel_fm_2_3,
    376 #endif
    377 #if (HS_BS_SLABS_LOG2_RU >= 4)
    378   hs_kernel_fm_2_4,
    379 #endif
    380 #if (HS_BS_SLABS_LOG2_RU >= 5)
    381   hs_kernel_fm_2_5,
    382 #endif
    383 #if (HS_BS_SLABS_LOG2_RU >= 6)
    384   hs_kernel_fm_2_6,
    385 #endif
    386 #if (HS_BS_SLABS_LOG2_RU >= 7)
    387   hs_kernel_fm_2_7,
    388 #endif
    389 };
    390 #endif
    391 
    392 static hs_kernel_offset_fm_pfn const * const hs_kernels_offset_fm[]
    393 {
    394 #if (HS_FM_SCALE_MIN == 0)
    395   hs_kernels_offset_fm_0,
    396 #endif
    397 #if (HS_FM_SCALE_MIN <= 1) && (1 <= HS_FM_SCALE_MAX)
    398   hs_kernels_offset_fm_1,
    399 #endif
    400 #if (HS_FM_SCALE_MIN <= 2) && (2 <= HS_FM_SCALE_MAX)
    401   hs_kernels_offset_fm_2,
    402 #endif
    403 };
    404 
    405 //
    406 //
    407 //
    408 
    409 typedef uint32_t hs_indices_t;
    410 
    411 //
    412 //
    413 //
    414 
    415 struct hs_state
    416 {
    417   // key buffers
    418   HS_KEY_TYPE *  vin;
    419   HS_KEY_TYPE *  vout; // can be vin
    420 
    421   cudaStream_t   streams[3];
    422 
    423   // pool of stream indices
    424   hs_indices_t   pool;
    425 
    426   // bx_ru is number of rounded up warps in vin
    427   uint32_t       bx_ru;
    428 };
    429 
    430 //
    431 //
    432 //
    433 
    434 static
    435 uint32_t
    436 hs_indices_acquire(hs_indices_t * const indices)
    437 {
    438   //
    439   // FIXME -- an FFS intrinsic might be faster but there are so few
    440   // bits in this implementation that it might not matter.
    441   //
    442   if      (*indices & 1)
    443     {
    444       *indices = *indices & ~1;
    445       return 0;
    446     }
    447   else if (*indices & 2)
    448     {
    449       *indices = *indices & ~2;
    450       return 1;
    451     }
    452   else // if (*indices & 4)
    453     {
    454       *indices = *indices & ~4;
    455       return 2;
    456     }
    457 }
    458 
    459 
    460 static
    461 uint32_t
    462 hs_state_acquire(struct hs_state * const state,
    463                  hs_indices_t    * const indices)
    464 {
    465   //
    466   // FIXME -- an FFS intrinsic might be faster but there are so few
    467   // bits in this implementation that it might not matter.
    468   //
    469   if      (state->pool & 1)
    470     {
    471       state->pool &= ~1;
    472       *indices    |=  1;
    473       return 0;
    474     }
    475   else if (state->pool & 2)
    476     {
    477       state->pool &= ~2;
    478       *indices    |=  2;
    479       return 1;
    480     }
    481   else // (state->pool & 4)
    482     {
    483       state->pool &= ~4;
    484       *indices    |=  4;
    485       return 2;
    486     }
    487 }
    488 
    489 static
    490 void
    491 hs_indices_merge(hs_indices_t * const to, hs_indices_t const from)
    492 {
    493   *to |= from;
    494 }
    495 
    496 static
    497 void
    498 hs_barrier_enqueue(cudaStream_t to, cudaStream_t from)
    499 {
    500   cudaEvent_t event_before;
    501 
    502   cuda(EventCreate(&event_before));
    503 
    504   cuda(EventRecord(event_before,from));
    505 
    506   cuda(StreamWaitEvent(to,event_before,0));
    507 
    508   cuda(EventDestroy(event_before));
    509 }
    510 
    511 static
    512 hs_indices_t
    513 hs_barrier(struct hs_state * const state,
    514            hs_indices_t      const before,
    515            hs_indices_t    * const after,
    516            uint32_t          const count) // count is 1 or 2
    517 {
    518   // return streams this stage depends on back into the pool
    519   hs_indices_merge(&state->pool,before);
    520 
    521   hs_indices_t indices = 0;
    522 
    523   // acquire 'count' stream indices for this stage
    524   for (uint32_t ii=0; ii<count; ii++)
    525     {
    526       hs_indices_t new_indices = 0;
    527 
    528       // new index
    529       uint32_t const idx = hs_state_acquire(state,&new_indices);
    530 
    531       // add the new index to the indices
    532       indices |= new_indices;
    533 
    534       // only enqueue barriers when streams are different
    535       uint32_t const wait = before & ~new_indices;
    536 
    537       if (wait != 0)
    538         {
    539           cudaStream_t to = state->streams[idx];
    540 
    541           //
    542           // FIXME -- an FFS loop might be slower for so few bits. So
    543           // leave it as is for now.
    544           //
    545           if (wait & 1)
    546             hs_barrier_enqueue(to,state->streams[0]);
    547           if (wait & 2)
    548             hs_barrier_enqueue(to,state->streams[1]);
    549           if (wait & 4)
    550             hs_barrier_enqueue(to,state->streams[2]);
    551         }
    552     }
    553 
    554   hs_indices_merge(after,indices);
    555 
    556   return indices;
    557 }
    558 
    559 //
    560 //
    561 //
    562 
    563 #ifndef NDEBUG
    564 
    565 #include <stdio.h>
    566 #define HS_STREAM_SYNCHRONIZE(s)                \
    567   cuda(StreamSynchronize(s));                   \
    568   fprintf(stderr,"%s\n",__func__);
    569 #else
    570 
    571 #define HS_STREAM_SYNCHRONIZE(s)
    572 
    573 #endif
    574 
    575 //
    576 //
    577 //
    578 
    579 static
    580 void
    581 hs_transpose(struct hs_state * const state)
    582 {
    583   HS_TRANSPOSE_KERNEL_NAME()
    584     <<<state->bx_ru,HS_SLAB_THREADS,0,state->streams[0]>>>
    585     (state->vout);
    586 
    587   HS_STREAM_SYNCHRONIZE(state->streams[0]);
    588 }
    589 
    590 //
    591 //
    592 //
    593 
    594 static
    595 void
    596 hs_bc(struct hs_state * const state,
    597       hs_indices_t      const hs_bc,
    598       hs_indices_t    * const fm,
    599       uint32_t          const down_slabs,
    600       uint32_t          const clean_slabs_log2)
    601 {
    602   // enqueue any necessary barriers
    603   hs_indices_t indices = hs_barrier(state,hs_bc,fm,1);
    604 
    605   // block clean the minimal number of down_slabs_log2 spans
    606   uint32_t const frac_ru = (1u << clean_slabs_log2) - 1;
    607   uint32_t const full    = (down_slabs + frac_ru) >> clean_slabs_log2;
    608   uint32_t const threads = HS_SLAB_THREADS << clean_slabs_log2;
    609 
    610   // stream will *always* be stream[0]
    611   cudaStream_t stream  = state->streams[hs_indices_acquire(&indices)];
    612 
    613   hs_kernels_bc[clean_slabs_log2]
    614     <<<full,threads,0,stream>>>
    615     (state->vout);
    616 
    617   HS_STREAM_SYNCHRONIZE(stream);
    618 }
    619 
    620 //
    621 //
    622 //
    623 
    624 static
    625 uint32_t
    626 hs_hm(struct hs_state  * const state,
    627       hs_indices_t       const hs_bc,
    628       hs_indices_t     * const hs_bc_tmp,
    629       uint32_t           const down_slabs,
    630       uint32_t           const clean_slabs_log2)
    631 {
    632   // enqueue any necessary barriers
    633   hs_indices_t   indices    = hs_barrier(state,hs_bc,hs_bc_tmp,1);
    634 
    635   // how many scaled half-merge spans are there?
    636   uint32_t const frac_ru    = (1 << clean_slabs_log2) - 1;
    637   uint32_t const spans      = (down_slabs + frac_ru) >> clean_slabs_log2;
    638 
    639   // for now, just clamp to the max
    640   uint32_t const log2_rem   = clean_slabs_log2 - HS_BC_SLABS_LOG2_MAX;
    641   uint32_t const scale_log2 = MIN_MACRO(HS_HM_SCALE_MAX,log2_rem);
    642   uint32_t const log2_out   = log2_rem - scale_log2;
    643 
    644   //
    645   // Size the grid
    646   //
    647   // The simplifying choices below limit the maximum keys that can be
    648   // sorted with this grid scheme to around ~2B.
    649   //
    650   //   .x : slab height << clean_log2  -- this is the slab span
    651   //   .y : [1...65535]                -- this is the slab index
    652   //   .z : ( this could also be used to further expand .y )
    653   //
    654   // Note that OpenCL declares a grid in terms of global threads and
    655   // not grids and blocks
    656   //
    657   dim3 grid;
    658 
    659   grid.x = (HS_SLAB_HEIGHT / HS_HM_BLOCK_HEIGHT) << log2_out;
    660   grid.y = spans;
    661   grid.z = 1;
    662 
    663   cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
    664 
    665   hs_kernels_hm[scale_log2-HS_HM_SCALE_MIN]
    666     <<<grid,HS_SLAB_THREADS * HS_HM_BLOCK_HEIGHT,0,stream>>>
    667     (state->vout);
    668 
    669   HS_STREAM_SYNCHRONIZE(stream);
    670 
    671   return log2_out;
    672 }
    673 
    674 //
    675 // FIXME -- some of this logic can be skipped if BS is a power-of-two
    676 //
    677 
    678 static
    679 uint32_t
    680 hs_fm(struct hs_state * const state,
    681       hs_indices_t      const fm,
    682       hs_indices_t    * const hs_bc,
    683       uint32_t        * const down_slabs,
    684       uint32_t          const up_scale_log2)
    685 {
    686   //
    687   // FIXME OPTIMIZATION: in previous HotSort launchers it's sometimes
    688   // a performance win to bias toward launching the smaller flip merge
    689   // kernel in order to get more warps in flight (increased
    690   // occupancy).  This is useful when merging small numbers of slabs.
    691   //
    692   // Note that HS_FM_SCALE_MIN will always be 0 or 1.
    693   //
    694   // So, for now, just clamp to the max until there is a reason to
    695   // restore the fancier and probably low-impact approach.
    696   //
    697   uint32_t const scale_log2 = MIN_MACRO(HS_FM_SCALE_MAX,up_scale_log2);
    698   uint32_t const clean_log2 = up_scale_log2 - scale_log2;
    699 
    700   // number of slabs in a full-sized scaled flip-merge span
    701   uint32_t const full_span_slabs = HS_BS_SLABS << up_scale_log2;
    702 
    703   // how many full-sized scaled flip-merge spans are there?
    704   uint32_t full_fm = state->bx_ru / full_span_slabs;
    705   uint32_t frac_fm = 0;
    706 
    707   // initialize down_slabs
    708   *down_slabs = full_fm * full_span_slabs;
    709 
    710   // how many half-size scaled + fractional scaled spans are there?
    711   uint32_t const span_rem        = state->bx_ru - *down_slabs;
    712   uint32_t const half_span_slabs = full_span_slabs >> 1;
    713 
    714   // if we have over a half-span then fractionally merge it
    715   if (span_rem > half_span_slabs)
    716     {
    717       // the remaining slabs will be cleaned
    718       *down_slabs += span_rem;
    719 
    720       uint32_t const frac_rem      = span_rem - half_span_slabs;
    721       uint32_t const frac_rem_pow2 = pow2_ru_u32(frac_rem);
    722 
    723       if (frac_rem_pow2 >= half_span_slabs)
    724         {
    725           // bump it up to a full span
    726           full_fm += 1;
    727         }
    728       else
    729         {
    730           // otherwise, add fractional
    731           frac_fm  = MAX_MACRO(1,frac_rem_pow2 >> clean_log2);
    732         }
    733     }
    734 
    735   // enqueue any necessary barriers
    736   bool const   both    = (full_fm != 0) && (frac_fm != 0);
    737   hs_indices_t indices = hs_barrier(state,fm,hs_bc,both ? 2 : 1);
    738 
    739   //
    740   // Size the grid
    741   //
    742   // The simplifying choices below limit the maximum keys that can be
    743   // sorted with this grid scheme to around ~2B.
    744   //
    745   //   .x : slab height << clean_log2  -- this is the slab span
    746   //   .y : [1...65535]                -- this is the slab index
    747   //   .z : ( this could also be used to further expand .y )
    748   //
    749   // Note that OpenCL declares a grid in terms of global threads and
    750   // not grids and blocks
    751   //
    752   dim3 grid;
    753 
    754   grid.x = (HS_SLAB_HEIGHT / HS_FM_BLOCK_HEIGHT) << clean_log2;
    755   grid.z = 1;
    756 
    757   if (full_fm > 0)
    758     {
    759       cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
    760 
    761       grid.y = full_fm;
    762 
    763       hs_kernels_fm[scale_log2-HS_FM_SCALE_MIN]
    764         <<<grid,HS_SLAB_THREADS * HS_FM_BLOCK_HEIGHT,0,stream>>>
    765           (state->vout);
    766 
    767       HS_STREAM_SYNCHRONIZE(stream);
    768     }
    769 
    770   if (frac_fm > 0)
    771     {
    772       cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
    773 
    774       grid.y = 1;
    775 
    776       hs_kernels_offset_fm[scale_log2-HS_FM_SCALE_MIN][msb_idx_u32(frac_fm)]
    777         <<<grid,HS_SLAB_THREADS * HS_FM_BLOCK_HEIGHT,0,stream>>>
    778         (state->vout,full_fm);
    779 
    780       HS_STREAM_SYNCHRONIZE(stream);
    781     }
    782 
    783   return clean_log2;
    784 }
    785 
    786 //
    787 //
    788 //
    789 
    790 static
    791 void
    792 hs_bs(struct hs_state * const state,
    793       hs_indices_t      const bs,
    794       hs_indices_t    * const fm,
    795       uint32_t          const count_padded_in)
    796 {
    797   uint32_t const slabs_in = count_padded_in / HS_SLAB_KEYS;
    798   uint32_t const full_bs  = slabs_in / HS_BS_SLABS;
    799   uint32_t const frac_bs  = slabs_in - full_bs * HS_BS_SLABS;
    800   bool     const both     = (full_bs != 0) && (frac_bs != 0);
    801 
    802   // enqueue any necessary barriers
    803   hs_indices_t   indices  = hs_barrier(state,bs,fm,both ? 2 : 1);
    804 
    805   if (full_bs != 0)
    806     {
    807       cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
    808 
    809       CONCAT_MACRO(hs_kernel_bs_,HS_BS_SLABS_LOG2_RU)
    810         <<<full_bs,HS_BS_SLABS*HS_SLAB_THREADS,0,stream>>>
    811         (state->vout,state->vin);
    812 
    813       HS_STREAM_SYNCHRONIZE(stream);
    814     }
    815 
    816   if (frac_bs != 0)
    817     {
    818       cudaStream_t stream = state->streams[hs_indices_acquire(&indices)];
    819 
    820       hs_kernels_offset_bs[msb_idx_u32(frac_bs)]
    821         <<<1,frac_bs*HS_SLAB_THREADS,0,stream>>>
    822         (state->vout,state->vin,full_bs*HS_BS_SLABS*HS_SLAB_THREADS);
    823 
    824       HS_STREAM_SYNCHRONIZE(stream);
    825     }
    826 }
    827 
    828 //
    829 //
    830 //
    831 
    832 static
    833 void
    834 hs_keyset_pre_merge(struct hs_state * const state,
    835                     hs_indices_t    * const fm,
    836                     uint32_t          const count_lo,
    837                     uint32_t          const count_hi)
    838 {
    839   uint32_t const vout_span = count_hi - count_lo;
    840   cudaStream_t   stream    = state->streams[hs_state_acquire(state,fm)];
    841 
    842   cuda(MemsetAsync(state->vout + count_lo,
    843                    0xFF,
    844                    vout_span * sizeof(HS_KEY_TYPE),
    845                    stream));
    846 }
    847 
    848 //
    849 //
    850 //
    851 
    852 static
    853 void
    854 hs_keyset_pre_sort(struct hs_state * const state,
    855                    hs_indices_t    * const bs,
    856                    uint32_t          const count,
    857                    uint32_t          const count_hi)
    858 {
    859   uint32_t const vin_span = count_hi - count;
    860   cudaStream_t   stream   = state->streams[hs_state_acquire(state,bs)];
    861 
    862   cuda(MemsetAsync(state->vin + count,
    863                    0xFF,
    864                    vin_span * sizeof(HS_KEY_TYPE),
    865                    stream));
    866 }
    867 
    868 //
    869 //
    870 //
    871 
    872 void
    873 CONCAT_MACRO(hs_cuda_sort_,HS_KEY_TYPE_PRETTY)
    874   (HS_KEY_TYPE * const vin,
    875    HS_KEY_TYPE * const vout,
    876    uint32_t      const count,
    877    uint32_t      const count_padded_in,
    878    uint32_t      const count_padded_out,
    879    bool          const linearize,
    880    cudaStream_t        stream0,  // primary stream
    881    cudaStream_t        stream1,  // auxilary
    882    cudaStream_t        stream2)  // auxilary
    883 {
    884   // is this sort in place?
    885   bool const is_in_place = (vout == NULL);
    886 
    887   // cq, buffers, wait list and slab count
    888   struct hs_state state;
    889 
    890   state.vin        = vin;
    891   state.vout       = is_in_place ? vin : vout;
    892   state.streams[0] = stream0;
    893   state.streams[1] = stream1;
    894   state.streams[2] = stream2;
    895   state.pool       = 0x7; // 3 bits
    896   state.bx_ru      = (count + HS_SLAB_KEYS - 1) / HS_SLAB_KEYS;
    897 
    898   // initialize vin
    899   uint32_t const count_hi                 = is_in_place ? count_padded_out : count_padded_in;
    900   bool     const is_pre_sort_keyset_reqd  = count_hi > count;
    901   bool     const is_pre_merge_keyset_reqd = !is_in_place && (count_padded_out > count_padded_in);
    902 
    903   hs_indices_t bs = 0;
    904 
    905   // initialize any trailing keys in vin before sorting
    906   if (is_pre_sort_keyset_reqd)
    907     hs_keyset_pre_sort(&state,&bs,count,count_hi);
    908 
    909   hs_indices_t fm = 0;
    910 
    911   // concurrently initialize any trailing keys in vout before merging
    912   if (is_pre_merge_keyset_reqd)
    913     hs_keyset_pre_merge(&state,&fm,count_padded_in,count_padded_out);
    914 
    915   // immediately sort blocks of slabs
    916   hs_bs(&state,bs,&fm,count_padded_in);
    917 
    918   //
    919   // we're done if this was a single bs block...
    920   //
    921   // otherwise, merge sorted spans of slabs until done
    922   //
    923   if (state.bx_ru > HS_BS_SLABS)
    924     {
    925       int32_t up_scale_log2 = 1;
    926 
    927       while (true)
    928         {
    929           hs_indices_t hs_or_bc = 0;
    930 
    931           uint32_t down_slabs;
    932 
    933           // flip merge slabs -- return span of slabs that must be cleaned
    934           uint32_t clean_slabs_log2 = hs_fm(&state,
    935                                             fm,
    936                                             &hs_or_bc,
    937                                             &down_slabs,
    938                                             up_scale_log2);
    939 
    940           // if span is gt largest slab block cleaner then half merge
    941           while (clean_slabs_log2 > HS_BC_SLABS_LOG2_MAX)
    942             {
    943               hs_indices_t hs_or_bc_tmp;
    944 
    945               clean_slabs_log2 = hs_hm(&state,
    946                                        hs_or_bc,
    947                                        &hs_or_bc_tmp,
    948                                        down_slabs,
    949                                        clean_slabs_log2);
    950               hs_or_bc = hs_or_bc_tmp;
    951             }
    952 
    953           // reset fm
    954           fm = 0;
    955 
    956           // launch clean slab grid -- is it the final launch?
    957           hs_bc(&state,
    958                 hs_or_bc,
    959                 &fm,
    960                 down_slabs,
    961                 clean_slabs_log2);
    962 
    963           // was this the final block clean?
    964           if (((uint32_t)HS_BS_SLABS << up_scale_log2) >= state.bx_ru)
    965             break;
    966 
    967           // otherwise, merge twice as many slabs
    968           up_scale_log2 += 1;
    969         }
    970     }
    971 
    972   // slabs or linear?
    973   if (linearize) {
    974     // guaranteed to be on stream0
    975     hs_transpose(&state);
    976   }
    977 }
    978 
    979 //
    980 // all grids will be computed as a function of the minimum number of slabs
    981 //
    982 
    983 void
    984 CONCAT_MACRO(hs_cuda_pad_,HS_KEY_TYPE_PRETTY)
    985   (uint32_t   const count,
    986    uint32_t * const count_padded_in,
    987    uint32_t * const count_padded_out)
    988 {
    989   //
    990   // round up the count to slabs
    991   //
    992   uint32_t const slabs_ru        = (count + HS_SLAB_KEYS - 1) / HS_SLAB_KEYS;
    993   uint32_t const blocks          = slabs_ru / HS_BS_SLABS;
    994   uint32_t const block_slabs     = blocks * HS_BS_SLABS;
    995   uint32_t const slabs_ru_rem    = slabs_ru - block_slabs;
    996   uint32_t const slabs_ru_rem_ru = MIN_MACRO(pow2_ru_u32(slabs_ru_rem),HS_BS_SLABS);
    997 
    998   *count_padded_in  = (block_slabs + slabs_ru_rem_ru) * HS_SLAB_KEYS;
    999   *count_padded_out = *count_padded_in;
   1000 
   1001   //
   1002   // will merging be required?
   1003   //
   1004   if (slabs_ru > HS_BS_SLABS)
   1005     {
   1006       // more than one block
   1007       uint32_t const blocks_lo       = pow2_rd_u32(blocks);
   1008       uint32_t const block_slabs_lo  = blocks_lo * HS_BS_SLABS;
   1009       uint32_t const block_slabs_rem = slabs_ru - block_slabs_lo;
   1010 
   1011       if (block_slabs_rem > 0)
   1012         {
   1013           uint32_t const block_slabs_rem_ru     = pow2_ru_u32(block_slabs_rem);
   1014 
   1015           uint32_t const block_slabs_hi         = MAX_MACRO(block_slabs_rem_ru,
   1016                                                             blocks_lo << (1 - HS_FM_SCALE_MIN));
   1017 
   1018           uint32_t const block_slabs_padded_out = MIN_MACRO(block_slabs_lo+block_slabs_hi,
   1019                                                             block_slabs_lo*2); // clamp non-pow2 blocks
   1020 
   1021           *count_padded_out = block_slabs_padded_out * HS_SLAB_KEYS;
   1022         }
   1023     }
   1024 }
   1025 
   1026 //
   1027 //
   1028 //
   1029 
   1030 void
   1031 CONCAT_MACRO(hs_cuda_info_,HS_KEY_TYPE_PRETTY)
   1032   (uint32_t * const key_words,
   1033    uint32_t * const val_words,
   1034    uint32_t * const slab_height,
   1035    uint32_t * const slab_width_log2)
   1036 {
   1037   *key_words       = HS_KEY_WORDS;
   1038   *val_words       = HS_VAL_WORDS;
   1039   *slab_height     = HS_SLAB_HEIGHT;
   1040   *slab_width_log2 = HS_SLAB_WIDTH_LOG2;
   1041 }
   1042 
   1043 //
   1044 //
   1045 //
   1046