Home | History | Annotate | Download | only in cl_kernel
      1 /*
      2  * function: kernel_3d_denoise
      3  *     3D Noise Reduction
      4  * gain:        The parameter determines the filtering strength for the reference block
      5  * threshold:   Noise variances of observed image
      6  * restoredPrev: The previous restored image, image2d_t as read only
      7  * output:      restored image, image2d_t as write only
      8  * input:       observed image, image2d_t as read only
      9  * inputPrev1:  reference image, image2d_t as read only
     10  * inputPrev2:  reference image, image2d_t as read only
     11  */
     12 
     13 #ifndef REFERENCE_FRAME_COUNT
     14 #define REFERENCE_FRAME_COUNT 2
     15 #endif
     16 
     17 #ifndef ENABLE_IIR_FILERING
     18 #define ENABLE_IIR_FILERING 1
     19 #endif
     20 
     21 #define ENABLE_GRADIENT     1
     22 
     23 #ifndef WORKGROUP_WIDTH
     24 #define WORKGROUP_WIDTH    2
     25 #endif
     26 
     27 #ifndef WORKGROUP_HEIGHT
     28 #define WORKGROUP_HEIGHT   32
     29 #endif
     30 
     31 #define REF_BLOCK_X_OFFSET  1
     32 #define REF_BLOCK_Y_OFFSET  4
     33 
     34 #define REF_BLOCK_WIDTH     (WORKGROUP_WIDTH + 2 * REF_BLOCK_X_OFFSET)
     35 #define REF_BLOCK_HEIGHT    (WORKGROUP_HEIGHT + 2 * REF_BLOCK_Y_OFFSET)
     36 
     37 inline int2 subgroup_pos(const int sg_id, const int sg_lid)
     38 {
     39     int2 pos;
     40     pos.x = mad24(2, sg_id % 2, sg_lid % 2);
     41     pos.y = mad24(4, sg_id / 2, sg_lid / 2);
     42 
     43     return pos;
     44 }
     45 
     46 inline void average_slice(float8 ref,
     47                           float8 observe,
     48                           float8* restore,
     49                           float2* sum_weight,
     50                           float gain,
     51                           float threshold,
     52                           uint sg_id,
     53                           uint sg_lid)
     54 {
     55     float8 grad = 0.0f;
     56     float8 gradient = 0.0f;
     57     float8 dist = 0.0f;
     58     float8 distance = 0.0f;
     59     float weight = 0.0f;
     60 
     61 #if ENABLE_GRADIENT
     62     // calculate & cumulate gradient
     63     if (sg_lid % 2 == 0) {
     64         grad = intel_sub_group_shuffle(ref, 4);
     65     } else {
     66         grad = intel_sub_group_shuffle(ref, 5);
     67     }
     68     gradient = (float8)(grad.s1, grad.s1, grad.s1, grad.s1, grad.s5, grad.s5, grad.s5, grad.s5);
     69 
     70     // normalize gradient "1/(4*255.0f) = 0.00098039f"
     71     grad = fabs(gradient - ref) * 0.00098039f;
     72     //grad = mad(-2, gradient, (ref + grad)) * 0.0004902f;
     73 
     74     grad.s0 = (grad.s0 + grad.s1 + grad.s2 + grad.s3);
     75     grad.s4 = (grad.s4 + grad.s5 + grad.s6 + grad.s7);
     76 #endif
     77     // calculate & normalize distance "1/255.0f = 0.00392157f"
     78     dist = (observe - ref) * 0.00392157f;
     79     dist = dist * dist;
     80 
     81     float8 dist_shuffle[8];
     82     dist_shuffle[0] = (intel_sub_group_shuffle(dist, 0));
     83     dist_shuffle[1] = (intel_sub_group_shuffle(dist, 1));
     84     dist_shuffle[2] = (intel_sub_group_shuffle(dist, 2));
     85     dist_shuffle[3] = (intel_sub_group_shuffle(dist, 3));
     86     dist_shuffle[4] = (intel_sub_group_shuffle(dist, 4));
     87     dist_shuffle[5] = (intel_sub_group_shuffle(dist, 5));
     88     dist_shuffle[6] = (intel_sub_group_shuffle(dist, 6));
     89     dist_shuffle[7] = (intel_sub_group_shuffle(dist, 7));
     90 
     91     if (sg_lid % 2 == 0) {
     92         distance = dist_shuffle[0];
     93         distance += dist_shuffle[2];
     94         distance += dist_shuffle[4];
     95         distance += dist_shuffle[6];
     96     }
     97     else {
     98         distance = dist_shuffle[1];
     99         distance += dist_shuffle[3];
    100         distance += dist_shuffle[5];
    101         distance += dist_shuffle[7];
    102     }
    103 
    104     // cumulate distance
    105     dist.s0 = (distance.s0 + distance.s1 + distance.s2 + distance.s3);
    106     dist.s4 = (distance.s4 + distance.s5 + distance.s6 + distance.s7);
    107     gain = (grad.s0 < threshold) ? gain : 2.0f * gain;
    108     weight = native_exp(-gain * dist.s0);
    109     (*restore).lo = mad(weight, ref.lo, (*restore).lo);
    110     (*sum_weight).lo = (*sum_weight).lo + weight;
    111 
    112     gain = (grad.s4 < threshold) ? gain : 2.0f * gain;
    113     weight = native_exp(-gain * dist.s4);
    114     (*restore).hi = mad(weight, ref.hi, (*restore).hi);
    115     (*sum_weight).hi = (*sum_weight).hi + weight;
    116 }
    117 
    118 inline void weighted_average (__read_only image2d_t input,
    119                               __local uchar8* ref_cache,
    120                               bool load_observe,
    121                               float8* observe,
    122                               float8* restore,
    123                               float2* sum_weight,
    124                               float gain,
    125                               float threshold,
    126                               uint sg_id,
    127                               uint sg_lid)
    128 {
    129     sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;
    130 
    131     int local_id_x = get_local_id(0);
    132     int local_id_y = get_local_id(1);
    133     const int group_id_x = get_group_id(0);
    134     const int group_id_y = get_group_id(1);
    135 
    136     int start_x = mad24(group_id_x, WORKGROUP_WIDTH, -REF_BLOCK_X_OFFSET);
    137     int start_y = mad24(group_id_y, WORKGROUP_HEIGHT, -REF_BLOCK_Y_OFFSET);
    138 
    139     int i = local_id_x + local_id_y * WORKGROUP_WIDTH;
    140     for ( int j = i; j < (REF_BLOCK_HEIGHT * REF_BLOCK_WIDTH);
    141             j += (WORKGROUP_HEIGHT * WORKGROUP_WIDTH) ) {
    142         int corrd_x = start_x + (j % REF_BLOCK_WIDTH);
    143         int corrd_y = start_y + (j / REF_BLOCK_WIDTH);
    144 
    145         ref_cache[j] = as_uchar8( convert_ushort4(read_imageui(input,
    146                                   sampler,
    147                                   (int2)(corrd_x, corrd_y))));
    148     }
    149     barrier(CLK_LOCAL_MEM_FENCE);
    150 
    151 #if WORKGROUP_WIDTH == 4
    152     int2 pos = subgroup_pos(sg_id, sg_lid);
    153     local_id_x = pos.x;
    154     local_id_y = pos.y;
    155 #endif
    156 
    157     if (load_observe) {
    158         (*observe) = convert_float8(
    159                          ref_cache[mad24(local_id_y + REF_BLOCK_Y_OFFSET,
    160                                          REF_BLOCK_WIDTH,
    161                                          local_id_x + REF_BLOCK_X_OFFSET)]);
    162         (*restore) = (*observe);
    163         (*sum_weight) = 1.0f;
    164     }
    165 
    166     float8 ref[2] = {0.0f, 0.0f};
    167     __local uchar4* p_ref = (__local uchar4*)(ref_cache);
    168 
    169     // top-left
    170     ref[0] = convert_float8(*(__local uchar8*)(p_ref + mad24(local_id_y,
    171                             2 * REF_BLOCK_WIDTH,
    172                             mad24(2, local_id_x, 1))));
    173     average_slice(ref[0], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
    174 
    175     // top-right
    176     ref[1] = convert_float8(*(__local uchar8*)(p_ref + mad24(local_id_y,
    177                             2 * REF_BLOCK_WIDTH,
    178                             mad24(2, local_id_x, 3))));
    179     average_slice(ref[1], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
    180 
    181     // top-mid
    182     average_slice((float8)(ref[0].hi, ref[1].lo), *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
    183 
    184     // mid-left
    185     ref[0] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 4),
    186                             2 * REF_BLOCK_WIDTH,
    187                             mad24(2, local_id_x, 1))));
    188     average_slice(ref[0], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
    189 
    190     // mid-right
    191     ref[1] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 4),
    192                             2 * REF_BLOCK_WIDTH,
    193                             mad24(2, local_id_x, 3))));
    194     average_slice(ref[1], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
    195 
    196     // mid-mid
    197     if (!load_observe) {
    198         average_slice((float8)(ref[0].hi, ref[1].lo), *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
    199     }
    200 
    201     // bottom-left
    202     ref[0] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 8),
    203                             2 * REF_BLOCK_WIDTH,
    204                             mad24(2, local_id_x, 1))));
    205     average_slice(ref[0], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
    206 
    207     // bottom-right
    208     ref[1] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 8),
    209                             2 * REF_BLOCK_WIDTH,
    210                             mad24(2, local_id_x, 3))));
    211     average_slice(ref[1], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
    212 
    213     // bottom-mid
    214     average_slice((float8)(ref[0].hi, ref[1].lo), *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
    215 }
    216 
    217 __kernel void kernel_3d_denoise ( float gain,
    218                                   float threshold,
    219                                   __read_only image2d_t restoredPrev,
    220                                   __write_only image2d_t output,
    221                                   __read_only image2d_t input,
    222                                   __read_only image2d_t inputPrev1,
    223                                   __read_only image2d_t inputPrev2)
    224 {
    225     float8 restore = 0.0f;
    226     float8 observe = 0.0f;
    227     float2 sum_weight = 0.0f;
    228 
    229     const int sg_id = get_sub_group_id();
    230     const int sg_lid = (get_local_id(1) * WORKGROUP_WIDTH + get_local_id(0)) % 8;
    231 
    232     __local uchar8 ref_cache[REF_BLOCK_HEIGHT * REF_BLOCK_WIDTH];
    233 
    234     weighted_average (input, ref_cache, true, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid);
    235 
    236 #if ENABLE_IIR_FILERING
    237     weighted_average (restoredPrev, ref_cache, false, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid);
    238 #else
    239 #if REFERENCE_FRAME_COUNT > 1
    240     weighted_average (inputPrev1, ref_cache, false, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid);
    241 #endif
    242 
    243 #if REFERENCE_FRAME_COUNT > 2
    244     weighted_average (inputPrev2, ref_cache, false, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid);
    245 #endif
    246 #endif
    247 
    248     restore.lo = restore.lo / sum_weight.lo;
    249     restore.hi = restore.hi / sum_weight.hi;
    250 
    251     int local_id_x = get_local_id(0);
    252     int local_id_y = get_local_id(1);
    253     const int group_id_x = get_group_id(0);
    254     const int group_id_y = get_group_id(1);
    255 
    256 #if WORKGROUP_WIDTH == 4
    257     int2 pos = subgroup_pos(sg_id, sg_lid);
    258     local_id_x = pos.x;
    259     local_id_y = pos.y;
    260 #endif
    261 
    262     int coor_x = mad24(group_id_x, WORKGROUP_WIDTH, local_id_x);
    263     int coor_y = mad24(group_id_y, WORKGROUP_HEIGHT, local_id_y);
    264 
    265     write_imageui(output,
    266                   (int2)(coor_x, coor_y),
    267                   convert_uint4(as_ushort4(convert_uchar8(restore))));
    268 }
    269 
    270