Home | History | Annotate | Download | only in vk
      1 //
      2 // Copyright 2016 Google Inc.
      3 //
      4 // Use of this source code is governed by a BSD-style
      5 // license that can be found in the LICENSE file.
      6 //
      7 
      8 #ifndef HS_GLSL_MACROS_ONCE
      9 #define HS_GLSL_MACROS_ONCE
     10 
     11 //
     12 // Define the type based on key and val sizes
     13 //
     14 
     15 #if   HS_KEY_WORDS == 1
     16 #if   HS_VAL_WORDS == 0
     17 #define HS_KEY_TYPE  uint
     18 #endif
     19 #elif HS_KEY_WORDS == 2       // FIXME -- might want to use uint2
     20 #define HS_KEY_TYPE  uint64_t // GL_ARB_gpu_shader_int64
     21 #endif
     22 
     23 //
     24 // FYI, restrict shouldn't have any impact on these kernels and
     25 // benchmarks appear to prove that true
     26 //
     27 
     28 #define HS_RESTRICT restrict
     29 
     30 //
     31 //
     32 //
     33 
     34 #define HS_GLSL_BINDING(n)                      \
     35   layout( binding = n)
     36 
     37 #define HS_GLSL_WORKGROUP_SIZE(x,y,z)           \
     38   layout( local_size_x = x,                     \
     39           local_size_y = y,                     \
     40           local_size_z = z) in
     41 
     42 //
     43 // KERNEL PROTOS
     44 //
     45 
     46 #define HS_BS_KERNEL_PROTO(slab_count,slab_count_ru_log2)               \
     47   HS_GLSL_SUBGROUP_SIZE()                                               \
     48   HS_GLSL_WORKGROUP_SIZE(HS_SLAB_THREADS*slab_count,1,1);               \
     49   HS_GLSL_BINDING(0) writeonly buffer Vout { HS_KEY_TYPE vout[]; };     \
     50   HS_GLSL_BINDING(1) readonly  buffer Vin  { HS_KEY_TYPE vin[];  };     \
     51   void main()
     52 
     53 #define HS_BC_KERNEL_PROTO(slab_count,slab_count_log2)          \
     54   HS_GLSL_SUBGROUP_SIZE()                                       \
     55   HS_GLSL_WORKGROUP_SIZE(HS_SLAB_THREADS*slab_count,1,1);       \
     56   HS_GLSL_BINDING(0) buffer Vout { HS_KEY_TYPE vout[]; };       \
     57   void main()
     58 
     59 #define HS_FM_KERNEL_PROTO(s,r)                                 \
     60   HS_GLSL_SUBGROUP_SIZE()                                       \
     61   HS_GLSL_WORKGROUP_SIZE(HS_SLAB_THREADS,1,1);                  \
     62   HS_GLSL_BINDING(0) buffer Vout { HS_KEY_TYPE vout[]; };       \
     63   void main()
     64 
     65 #define HS_HM_KERNEL_PROTO(s)                                   \
     66   HS_GLSL_SUBGROUP_SIZE()                                       \
     67   HS_GLSL_WORKGROUP_SIZE(HS_SLAB_THREADS,1,1);                  \
     68   HS_GLSL_BINDING(0) buffer Vout { HS_KEY_TYPE vout[]; };       \
     69   void main()
     70 
     71 #define HS_TRANSPOSE_KERNEL_PROTO()                             \
     72   HS_GLSL_SUBGROUP_SIZE()                                       \
     73   HS_GLSL_WORKGROUP_SIZE(HS_SLAB_THREADS,1,1);                  \
     74   HS_GLSL_BINDING(0) buffer Vout { HS_KEY_TYPE vout[]; };       \
     75   void main()
     76 
     77 //
     78 // BLOCK LOCAL MEMORY DECLARATION
     79 //
     80 
     81 #define HS_BLOCK_LOCAL_MEM_DECL(width,height)   \
     82   shared struct {                               \
     83     HS_KEY_TYPE m[width * height];              \
     84   } smem
     85 
     86 //
     87 // BLOCK BARRIER
     88 //
     89 
     90 #define HS_BLOCK_BARRIER()                      \
     91   barrier()
     92 
     93 //
     94 // SHUFFLES
     95 //
     96 
     97 #if   (HS_KEY_WORDS == 1)
     98 #define HS_SHUFFLE_CAST_TO(v)         v
     99 #define HS_SHUFFLE_CAST_FROM(v)       v
    100 #elif (HS_KEY_WORDS == 2)
    101 #define HS_SHUFFLE_CAST_TO(v)         uint64BitsToDouble(v)
    102 #define HS_SHUFFLE_CAST_FROM(v)       doubleBitsToUint64(v)
    103 #endif
    104 
    105 #define HS_SUBGROUP_SHUFFLE(v,i)      HS_SHUFFLE_CAST_FROM(subgroupShuffle(HS_SHUFFLE_CAST_TO(v),i))
    106 #define HS_SUBGROUP_SHUFFLE_XOR(v,m)  HS_SHUFFLE_CAST_FROM(subgroupShuffleXor(HS_SHUFFLE_CAST_TO(v),m))
    107 #define HS_SUBGROUP_SHUFFLE_UP(v,d)   HS_SHUFFLE_CAST_FROM(subgroupShuffleUp(HS_SHUFFLE_CAST_TO(v),d))
    108 #define HS_SUBGROUP_SHUFFLE_DOWN(v,d) HS_SHUFFLE_CAST_FROM(subgroupShuffleDown(HS_SHUFFLE_CAST_TO(v),d))
    109 
    110 //
    111 // SLAB GLOBAL
    112 //
    113 
    114 #define HS_SLAB_GLOBAL_PREAMBLE()                       \
    115   const uint gmem_idx =                                 \
    116     (gl_GlobalInvocationID.x & ~(HS_SLAB_THREADS-1)) *  \
    117     HS_SLAB_HEIGHT +                                    \
    118     (gl_LocalInvocationID.x  &  (HS_SLAB_THREADS-1))
    119 
    120 #define HS_SLAB_GLOBAL_LOAD(extent,row_idx)  \
    121   extent[gmem_idx + HS_SLAB_THREADS * row_idx]
    122 
    123 #define HS_SLAB_GLOBAL_STORE(row_idx,reg)    \
    124   vout[gmem_idx + HS_SLAB_THREADS * row_idx] = reg
    125 
    126 //
    127 // SLAB LOCAL
    128 //
    129 
    130 #define HS_SLAB_LOCAL_L(offset)                 \
    131     smem.m[smem_l_idx + (offset)]
    132 
    133 #define HS_SLAB_LOCAL_R(offset)                 \
    134     smem.m[smem_r_idx + (offset)]
    135 
    136 //
    137 // SLAB LOCAL VERTICAL LOADS
    138 //
    139 
    140 #define HS_BX_LOCAL_V(offset)                   \
    141   smem.m[gl_LocalInvocationID.x + (offset)]
    142 
    143 //
    144 // BLOCK SORT MERGE HORIZONTAL
    145 //
    146 
    147 #define HS_BS_MERGE_H_PREAMBLE(slab_count)                      \
    148   const uint smem_l_idx =                                       \
    149     HS_SUBGROUP_ID() * (HS_SLAB_THREADS * slab_count) +         \
    150     HS_SUBGROUP_LANE_ID();                                      \
    151   const uint smem_r_idx =                                       \
    152     (HS_SUBGROUP_ID() ^ 1) * (HS_SLAB_THREADS * slab_count) +   \
    153     (HS_SUBGROUP_LANE_ID() ^ (HS_SLAB_THREADS - 1))
    154 
    155 //
    156 // BLOCK CLEAN MERGE HORIZONTAL
    157 //
    158 
    159 #define HS_BC_MERGE_H_PREAMBLE(slab_count)                              \
    160   const uint gmem_l_idx =                                               \
    161     (gl_GlobalInvocationID.x & ~(HS_SLAB_THREADS * slab_count -1))      \
    162     * HS_SLAB_HEIGHT + gl_LocalInvocationID.x;                          \
    163   const uint smem_l_idx =                                               \
    164     HS_SUBGROUP_ID() * (HS_SLAB_THREADS * slab_count) +                 \
    165     HS_SUBGROUP_LANE_ID()
    166 
    167 #define HS_BC_GLOBAL_LOAD_L(slab_idx)        \
    168   vout[gmem_l_idx + (HS_SLAB_THREADS * slab_idx)]
    169 
    170 //
    171 // SLAB FLIP AND HALF PREAMBLES
    172 //
    173 
    174 #if 0
    175 
    176 #define HS_SLAB_FLIP_PREAMBLE(mask)                                     \
    177   const uint flip_lane_idx = HS_SUBGROUP_LANE_ID() ^ mask;              \
    178   const bool t_lt          = HS_SUBGROUP_LANE_ID() < flip_lane_idx
    179 
    180 #define HS_SLAB_HALF_PREAMBLE(mask)                                     \
    181   const uint half_lane_idx = HS_SUBGROUP_LANE_ID() ^ mask;              \
    182   const bool t_lt          = HS_SUBGROUP_LANE_ID() < half_lane_idx
    183 
    184 #else
    185 
    186 #define HS_SLAB_FLIP_PREAMBLE(mask)                                     \
    187   const uint flip_lane_mask = mask;                                     \
    188   const bool t_lt           = gl_LocalInvocationID.x < (gl_LocalInvocationID.x ^ mask)
    189 
    190 #define HS_SLAB_HALF_PREAMBLE(mask)                                     \
    191   const uint half_lane_mask = mask;                                     \
    192   const bool t_lt           = gl_LocalInvocationID.x < (gl_LocalInvocationID.x ^ mask)
    193 
    194 #endif
    195 
    196 //
    197 // Inter-lane compare exchange
    198 //
    199 
    200 // best on 32-bit keys
    201 #define HS_CMP_XCHG_V0(a,b)                     \
    202   {                                             \
    203     const HS_KEY_TYPE t = min(a,b);             \
    204     b = max(a,b);                               \
    205     a = t;                                      \
    206   }
    207 
    208 // good on Intel GEN 32-bit keys
    209 #define HS_CMP_XCHG_V1(a,b)                     \
    210   {                                             \
    211     const HS_KEY_TYPE tmp = a;                  \
    212     a  = (a < b) ? a : b;                       \
    213     b ^= a ^ tmp;                               \
    214   }
    215 
    216 // best on 64-bit keys
    217 #define HS_CMP_XCHG_V2(a,b)                     \
    218   if (a >= b) {                                 \
    219     const HS_KEY_TYPE t = a;                    \
    220     a = b;                                      \
    221     b = t;                                      \
    222   }
    223 
    224 // ok
    225 #define HS_CMP_XCHG_V3(a,b)                     \
    226   {                                             \
    227     const bool        ge = a >= b;              \
    228     const HS_KEY_TYPE t  = a;                   \
    229     a = ge ? b : a;                             \
    230     b = ge ? t : b;                             \
    231   }
    232 
    233 //
    234 // The flip/half comparisons rely on a "conditional min/max":
    235 //
    236 //  - if the flag is false, return min(a,b)
    237 //  - otherwise, return max(a,b)
    238 //
    239 // What's a little surprising is that sequence (1) is faster than (2)
    240 // for 32-bit keys.
    241 //
    242 // I suspect either a code generation problem or that the sequence
    243 // maps well to the GEN instruction set.
    244 //
    245 // We mostly care about 64-bit keys and unsurprisingly sequence (2) is
    246 // fastest for this wider type.
    247 //
    248 
    249 #define HS_LOGICAL_XOR()  !=
    250 
    251 // this is what you would normally use
    252 #define HS_COND_MIN_MAX_V0(lt,a,b) ((a <= b) HS_LOGICAL_XOR() lt) ? b : a
    253 
    254 // this seems to be faster for 32-bit keys on Intel GEN
    255 #define HS_COND_MIN_MAX_V1(lt,a,b) (lt ? b : a) ^ ((a ^ b) & HS_LTE_TO_MASK(a,b))
    256 
    257 //
    258 // Conditional inter-subgroup flip/half compare exchange
    259 //
    260 
    261 #if 0
    262 
    263 #define HS_CMP_FLIP(i,a,b)                                              \
    264   {                                                                     \
    265     const HS_KEY_TYPE ta = HS_SUBGROUP_SHUFFLE(a,flip_lane_idx);        \
    266     const HS_KEY_TYPE tb = HS_SUBGROUP_SHUFFLE(b,flip_lane_idx);        \
    267     a = HS_COND_MIN_MAX(t_lt,a,tb);                                     \
    268     b = HS_COND_MIN_MAX(t_lt,b,ta);                                     \
    269   }
    270 
    271 #define HS_CMP_HALF(i,a)                                                \
    272   {                                                                     \
    273     const HS_KEY_TYPE ta = HS_SUBGROUP_SHUFFLE(a,half_lane_idx);        \
    274     a = HS_COND_MIN_MAX(t_lt,a,ta);                                     \
    275   }
    276 
    277 #else
    278 
    279 #define HS_CMP_FLIP(i,a,b)                                              \
    280   {                                                                     \
    281     const HS_KEY_TYPE ta = HS_SUBGROUP_SHUFFLE_XOR(a,flip_lane_mask);   \
    282     const HS_KEY_TYPE tb = HS_SUBGROUP_SHUFFLE_XOR(b,flip_lane_mask);   \
    283     a = HS_COND_MIN_MAX(t_lt,a,tb);                                     \
    284     b = HS_COND_MIN_MAX(t_lt,b,ta);                                     \
    285   }
    286 
    287 #define HS_CMP_HALF(i,a)                                                \
    288   {                                                                     \
    289     const HS_KEY_TYPE ta = HS_SUBGROUP_SHUFFLE_XOR(a,half_lane_mask);   \
    290     a = HS_COND_MIN_MAX(t_lt,a,ta);                                     \
    291   }
    292 
    293 #endif
    294 
    295 //
    296 // The device's comparison operator might return what we actually
    297 // want.  For example, it appears GEN 'cmp' returns {true:-1,false:0}.
    298 //
    299 
    300 #define HS_CMP_IS_ZERO_ONE
    301 
    302 #ifdef HS_CMP_IS_ZERO_ONE
    303 // OpenCL requires a {true: +1, false: 0} scalar result
    304 // (a < b) -> { +1, 0 } -> NEGATE -> { 0, 0xFFFFFFFF }
    305 #define HS_LTE_TO_MASK(a,b) (HS_KEY_TYPE)(-(a <= b))
    306 #define HS_CMP_TO_MASK(a)   (HS_KEY_TYPE)(-a)
    307 #else
    308 // However, OpenCL requires { -1, 0 } for vectors
    309 // (a < b) -> { 0xFFFFFFFF, 0 }
    310 #define HS_LTE_TO_MASK(a,b) (a <= b) // FIXME for uint64
    311 #define HS_CMP_TO_MASK(a)   (a)
    312 #endif
    313 
    314 //
    315 // The "flip-merge" and "half-merge" preambles are very similar
    316 //
    317 // For now, we're only using the .y dimension for the span idx
    318 //
    319 
    320 #define HS_HM_PREAMBLE(half_span)                                       \
    321   const uint span_idx    = gl_WorkGroupID.y;                            \
    322   const uint span_stride = gl_NumWorkGroups.x * gl_WorkGroupSize.x;     \
    323   const uint span_size   = span_stride * half_span * 2;                 \
    324   const uint span_base   = span_idx * span_size;                        \
    325   const uint span_off    = gl_GlobalInvocationID.x;                     \
    326   const uint span_l      = span_base + span_off
    327 
    328 #define HS_FM_PREAMBLE(half_span)                                       \
    329   HS_HM_PREAMBLE(half_span);                                            \
    330   const uint span_r      = span_base + span_stride * (half_span + 1) - span_off - 1
    331 
    332 //
    333 //
    334 //
    335 
    336 #define HS_XM_GLOBAL_L(stride_idx)              \
    337   vout[span_l + span_stride * stride_idx]
    338 
    339 #define HS_XM_GLOBAL_LOAD_L(stride_idx)         \
    340   HS_XM_GLOBAL_L(stride_idx)
    341 
    342 #define HS_XM_GLOBAL_STORE_L(stride_idx,reg)    \
    343   HS_XM_GLOBAL_L(stride_idx) = reg
    344 
    345 #define HS_FM_GLOBAL_R(stride_idx)              \
    346   vout[span_r + span_stride * stride_idx]
    347 
    348 #define HS_FM_GLOBAL_LOAD_R(stride_idx)         \
    349   HS_FM_GLOBAL_R(stride_idx)
    350 
    351 #define HS_FM_GLOBAL_STORE_R(stride_idx,reg)    \
    352   HS_FM_GLOBAL_R(stride_idx) = reg
    353 
    354 //
    355 // This snarl of macros is for transposing a "slab" of sorted elements
    356 // into linear order.
    357 //
    358 // This can occur as the last step in hs_sort() or via a custom kernel
    359 // that inspects the slab and then transposes and stores it to memory.
    360 //
    361 // The slab format can be inspected more efficiently than a linear
    362 // arrangement.
    363 //
    364 // The prime example is detecting when adjacent keys (in sort order)
    365 // have differing high order bits ("key changes").  The index of each
    366 // change is recorded to an auxilary array.
    367 //
    368 // A post-processing step like this needs to be able to navigate the
    369 // slab and eventually transpose and store the slab in linear order.
    370 //
    371 
    372 #define HS_TRANSPOSE_REG(prefix,row)   prefix##row
    373 #define HS_TRANSPOSE_DECL(prefix,row)  const HS_KEY_TYPE HS_TRANSPOSE_REG(prefix,row)
    374 #define HS_TRANSPOSE_PRED(level)       is_lo_##level
    375 
    376 #define HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur)       \
    377   prefix_curr##row_ll##_##row_ur
    378 
    379 #define HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur)      \
    380   const HS_KEY_TYPE HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur)
    381 
    382 #define HS_TRANSPOSE_STAGE(level)                       \
    383   const bool HS_TRANSPOSE_PRED(level) =                 \
    384     (HS_SUBGROUP_LANE_ID() & (1 << (level-1))) == 0;
    385 
    386 #define HS_TRANSPOSE_BLEND(prefix_prev,prefix_curr,level,row_ll,row_ur) \
    387   HS_TRANSPOSE_TMP_DECL(prefix_curr,row_ll,row_ur) =                    \
    388     HS_SUBGROUP_SHUFFLE_XOR(HS_TRANSPOSE_PRED(level) ?                  \
    389                             HS_TRANSPOSE_REG(prefix_prev,row_ll) :      \
    390                             HS_TRANSPOSE_REG(prefix_prev,row_ur),       \
    391                             1<<(level-1));                              \
    392                                                                         \
    393   HS_TRANSPOSE_DECL(prefix_curr,row_ll) =                               \
    394     HS_TRANSPOSE_PRED(level)                  ?                         \
    395     HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur) :                   \
    396     HS_TRANSPOSE_REG(prefix_prev,row_ll);                               \
    397                                                                         \
    398   HS_TRANSPOSE_DECL(prefix_curr,row_ur) =                               \
    399     HS_TRANSPOSE_PRED(level)                  ?                         \
    400     HS_TRANSPOSE_REG(prefix_prev,row_ur)      :                         \
    401     HS_TRANSPOSE_TMP_REG(prefix_curr,row_ll,row_ur);
    402 
    403 #define HS_TRANSPOSE_REMAP(prefix,row_from,row_to)      \
    404   vout[gmem_idx + ((row_to-1) << HS_SLAB_WIDTH_LOG2)] = \
    405     HS_TRANSPOSE_REG(prefix,row_from);
    406 
    407 //
    408 //
    409 //
    410 
    411 #endif
    412 
    413 //
    414 //
    415 //
    416