Home | History | Annotate | Download | only in neon
      1 /*
      2  *  Copyright (c) 2012 The WebM project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include <arm_neon.h>
     12 
     13 #include "vp8/encoder/denoising.h"
     14 #include "vpx_mem/vpx_mem.h"
     15 #include "./vp8_rtcd.h"
     16 
     17 /*
     18  * The filter function was modified to reduce the computational complexity.
     19  *
     20  * Step 1:
     21  *  Instead of applying tap coefficients for each pixel, we calculated the
     22  *  pixel adjustments vs. pixel diff value ahead of time.
     23  *     adjustment = filtered_value - current_raw
     24  *                = (filter_coefficient * diff + 128) >> 8
     25  *  where
     26  *     filter_coefficient = (255 << 8) / (256 + ((abs_diff * 330) >> 3));
     27  *     filter_coefficient += filter_coefficient /
     28  *                           (3 + motion_magnitude_adjustment);
     29  *     filter_coefficient is clamped to 0 ~ 255.
     30  *
     31  * Step 2:
     32  *  The adjustment vs. diff curve becomes flat very quick when diff increases.
     33  *  This allowed us to use only several levels to approximate the curve without
     34  *  changing the filtering algorithm too much.
     35  *  The adjustments were further corrected by checking the motion magnitude.
     36  *  The levels used are:
     37  *      diff          level       adjustment w/o       adjustment w/
     38  *                               motion correction    motion correction
     39  *      [-255, -16]     3              -6                   -7
     40  *      [-15, -8]       2              -4                   -5
     41  *      [-7, -4]        1              -3                   -4
     42  *      [-3, 3]         0              diff                 diff
     43  *      [4, 7]          1               3                    4
     44  *      [8, 15]         2               4                    5
     45  *      [16, 255]       3               6                    7
     46  */
     47 
     48 int vp8_denoiser_filter_neon(unsigned char *mc_running_avg_y,
     49                              int mc_running_avg_y_stride,
     50                              unsigned char *running_avg_y,
     51                              int running_avg_y_stride, unsigned char *sig,
     52                              int sig_stride, unsigned int motion_magnitude,
     53                              int increase_denoising) {
     54   /* If motion_magnitude is small, making the denoiser more aggressive by
     55    * increasing the adjustment for each level, level1 adjustment is
     56    * increased, the deltas stay the same.
     57    */
     58   int shift_inc =
     59       (increase_denoising && motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD)
     60           ? 1
     61           : 0;
     62   const uint8x16_t v_level1_adjustment = vmovq_n_u8(
     63       (motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD) ? 4 + shift_inc : 3);
     64   const uint8x16_t v_delta_level_1_and_2 = vdupq_n_u8(1);
     65   const uint8x16_t v_delta_level_2_and_3 = vdupq_n_u8(2);
     66   const uint8x16_t v_level1_threshold = vmovq_n_u8(4 + shift_inc);
     67   const uint8x16_t v_level2_threshold = vdupq_n_u8(8);
     68   const uint8x16_t v_level3_threshold = vdupq_n_u8(16);
     69   int64x2_t v_sum_diff_total = vdupq_n_s64(0);
     70 
     71   /* Go over lines. */
     72   int r;
     73   for (r = 0; r < 16; ++r) {
     74     /* Load inputs. */
     75     const uint8x16_t v_sig = vld1q_u8(sig);
     76     const uint8x16_t v_mc_running_avg_y = vld1q_u8(mc_running_avg_y);
     77 
     78     /* Calculate absolute difference and sign masks. */
     79     const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg_y);
     80     const uint8x16_t v_diff_pos_mask = vcltq_u8(v_sig, v_mc_running_avg_y);
     81     const uint8x16_t v_diff_neg_mask = vcgtq_u8(v_sig, v_mc_running_avg_y);
     82 
     83     /* Figure out which level that put us in. */
     84     const uint8x16_t v_level1_mask = vcleq_u8(v_level1_threshold, v_abs_diff);
     85     const uint8x16_t v_level2_mask = vcleq_u8(v_level2_threshold, v_abs_diff);
     86     const uint8x16_t v_level3_mask = vcleq_u8(v_level3_threshold, v_abs_diff);
     87 
     88     /* Calculate absolute adjustments for level 1, 2 and 3. */
     89     const uint8x16_t v_level2_adjustment =
     90         vandq_u8(v_level2_mask, v_delta_level_1_and_2);
     91     const uint8x16_t v_level3_adjustment =
     92         vandq_u8(v_level3_mask, v_delta_level_2_and_3);
     93     const uint8x16_t v_level1and2_adjustment =
     94         vaddq_u8(v_level1_adjustment, v_level2_adjustment);
     95     const uint8x16_t v_level1and2and3_adjustment =
     96         vaddq_u8(v_level1and2_adjustment, v_level3_adjustment);
     97 
     98     /* Figure adjustment absolute value by selecting between the absolute
     99      * difference if in level0 or the value for level 1, 2 and 3.
    100      */
    101     const uint8x16_t v_abs_adjustment =
    102         vbslq_u8(v_level1_mask, v_level1and2and3_adjustment, v_abs_diff);
    103 
    104     /* Calculate positive and negative adjustments. Apply them to the signal
    105      * and accumulate them. Adjustments are less than eight and the maximum
    106      * sum of them (7 * 16) can fit in a signed char.
    107      */
    108     const uint8x16_t v_pos_adjustment =
    109         vandq_u8(v_diff_pos_mask, v_abs_adjustment);
    110     const uint8x16_t v_neg_adjustment =
    111         vandq_u8(v_diff_neg_mask, v_abs_adjustment);
    112 
    113     uint8x16_t v_running_avg_y = vqaddq_u8(v_sig, v_pos_adjustment);
    114     v_running_avg_y = vqsubq_u8(v_running_avg_y, v_neg_adjustment);
    115 
    116     /* Store results. */
    117     vst1q_u8(running_avg_y, v_running_avg_y);
    118 
    119     /* Sum all the accumulators to have the sum of all pixel differences
    120      * for this macroblock.
    121      */
    122     {
    123       const int8x16_t v_sum_diff =
    124           vqsubq_s8(vreinterpretq_s8_u8(v_pos_adjustment),
    125                     vreinterpretq_s8_u8(v_neg_adjustment));
    126 
    127       const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
    128 
    129       const int32x4_t fedc_ba98_7654_3210 =
    130           vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
    131 
    132       const int64x2_t fedcba98_76543210 = vpaddlq_s32(fedc_ba98_7654_3210);
    133 
    134       v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
    135     }
    136 
    137     /* Update pointers for next iteration. */
    138     sig += sig_stride;
    139     mc_running_avg_y += mc_running_avg_y_stride;
    140     running_avg_y += running_avg_y_stride;
    141   }
    142 
    143   /* Too much adjustments => copy block. */
    144   {
    145     int64x1_t x = vqadd_s64(vget_high_s64(v_sum_diff_total),
    146                             vget_low_s64(v_sum_diff_total));
    147     int sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
    148     int sum_diff_thresh = SUM_DIFF_THRESHOLD;
    149 
    150     if (increase_denoising) sum_diff_thresh = SUM_DIFF_THRESHOLD_HIGH;
    151     if (sum_diff > sum_diff_thresh) {
    152       // Before returning to copy the block (i.e., apply no denoising),
    153       // checK if we can still apply some (weaker) temporal filtering to
    154       // this block, that would otherwise not be denoised at all. Simplest
    155       // is to apply an additional adjustment to running_avg_y to bring it
    156       // closer to sig. The adjustment is capped by a maximum delta, and
    157       // chosen such that in most cases the resulting sum_diff will be
    158       // within the accceptable range given by sum_diff_thresh.
    159 
    160       // The delta is set by the excess of absolute pixel diff over the
    161       // threshold.
    162       int delta = ((sum_diff - sum_diff_thresh) >> 8) + 1;
    163       // Only apply the adjustment for max delta up to 3.
    164       if (delta < 4) {
    165         const uint8x16_t k_delta = vmovq_n_u8(delta);
    166         sig -= sig_stride * 16;
    167         mc_running_avg_y -= mc_running_avg_y_stride * 16;
    168         running_avg_y -= running_avg_y_stride * 16;
    169         for (r = 0; r < 16; ++r) {
    170           uint8x16_t v_running_avg_y = vld1q_u8(running_avg_y);
    171           const uint8x16_t v_sig = vld1q_u8(sig);
    172           const uint8x16_t v_mc_running_avg_y = vld1q_u8(mc_running_avg_y);
    173 
    174           /* Calculate absolute difference and sign masks. */
    175           const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg_y);
    176           const uint8x16_t v_diff_pos_mask =
    177               vcltq_u8(v_sig, v_mc_running_avg_y);
    178           const uint8x16_t v_diff_neg_mask =
    179               vcgtq_u8(v_sig, v_mc_running_avg_y);
    180           // Clamp absolute difference to delta to get the adjustment.
    181           const uint8x16_t v_abs_adjustment = vminq_u8(v_abs_diff, (k_delta));
    182 
    183           const uint8x16_t v_pos_adjustment =
    184               vandq_u8(v_diff_pos_mask, v_abs_adjustment);
    185           const uint8x16_t v_neg_adjustment =
    186               vandq_u8(v_diff_neg_mask, v_abs_adjustment);
    187 
    188           v_running_avg_y = vqsubq_u8(v_running_avg_y, v_pos_adjustment);
    189           v_running_avg_y = vqaddq_u8(v_running_avg_y, v_neg_adjustment);
    190 
    191           /* Store results. */
    192           vst1q_u8(running_avg_y, v_running_avg_y);
    193 
    194           {
    195             const int8x16_t v_sum_diff =
    196                 vqsubq_s8(vreinterpretq_s8_u8(v_neg_adjustment),
    197                           vreinterpretq_s8_u8(v_pos_adjustment));
    198 
    199             const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
    200             const int32x4_t fedc_ba98_7654_3210 =
    201                 vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
    202             const int64x2_t fedcba98_76543210 =
    203                 vpaddlq_s32(fedc_ba98_7654_3210);
    204 
    205             v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
    206           }
    207           /* Update pointers for next iteration. */
    208           sig += sig_stride;
    209           mc_running_avg_y += mc_running_avg_y_stride;
    210           running_avg_y += running_avg_y_stride;
    211         }
    212         {
    213           // Update the sum of all pixel differences of this MB.
    214           x = vqadd_s64(vget_high_s64(v_sum_diff_total),
    215                         vget_low_s64(v_sum_diff_total));
    216           sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
    217 
    218           if (sum_diff > sum_diff_thresh) {
    219             return COPY_BLOCK;
    220           }
    221         }
    222       } else {
    223         return COPY_BLOCK;
    224       }
    225     }
    226   }
    227 
    228   /* Tell above level that block was filtered. */
    229   running_avg_y -= running_avg_y_stride * 16;
    230   sig -= sig_stride * 16;
    231 
    232   vp8_copy_mem16x16(running_avg_y, running_avg_y_stride, sig, sig_stride);
    233 
    234   return FILTER_BLOCK;
    235 }
    236 
    237 int vp8_denoiser_filter_uv_neon(unsigned char *mc_running_avg,
    238                                 int mc_running_avg_stride,
    239                                 unsigned char *running_avg,
    240                                 int running_avg_stride, unsigned char *sig,
    241                                 int sig_stride, unsigned int motion_magnitude,
    242                                 int increase_denoising) {
    243   /* If motion_magnitude is small, making the denoiser more aggressive by
    244    * increasing the adjustment for each level, level1 adjustment is
    245    * increased, the deltas stay the same.
    246    */
    247   int shift_inc =
    248       (increase_denoising && motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD_UV)
    249           ? 1
    250           : 0;
    251   const uint8x16_t v_level1_adjustment = vmovq_n_u8(
    252       (motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD_UV) ? 4 + shift_inc : 3);
    253 
    254   const uint8x16_t v_delta_level_1_and_2 = vdupq_n_u8(1);
    255   const uint8x16_t v_delta_level_2_and_3 = vdupq_n_u8(2);
    256   const uint8x16_t v_level1_threshold = vmovq_n_u8(4 + shift_inc);
    257   const uint8x16_t v_level2_threshold = vdupq_n_u8(8);
    258   const uint8x16_t v_level3_threshold = vdupq_n_u8(16);
    259   int64x2_t v_sum_diff_total = vdupq_n_s64(0);
    260   int r;
    261 
    262   {
    263     uint16x4_t v_sum_block = vdup_n_u16(0);
    264 
    265     // Avoid denoising color signal if its close to average level.
    266     for (r = 0; r < 8; ++r) {
    267       const uint8x8_t v_sig = vld1_u8(sig);
    268       const uint16x4_t _76_54_32_10 = vpaddl_u8(v_sig);
    269       v_sum_block = vqadd_u16(v_sum_block, _76_54_32_10);
    270       sig += sig_stride;
    271     }
    272     sig -= sig_stride * 8;
    273     {
    274       const uint32x2_t _7654_3210 = vpaddl_u16(v_sum_block);
    275       const uint64x1_t _76543210 = vpaddl_u32(_7654_3210);
    276       const int sum_block = vget_lane_s32(vreinterpret_s32_u64(_76543210), 0);
    277       if (abs(sum_block - (128 * 8 * 8)) < SUM_DIFF_FROM_AVG_THRESH_UV) {
    278         return COPY_BLOCK;
    279       }
    280     }
    281   }
    282 
    283   /* Go over lines. */
    284   for (r = 0; r < 4; ++r) {
    285     /* Load inputs. */
    286     const uint8x8_t v_sig_lo = vld1_u8(sig);
    287     const uint8x8_t v_sig_hi = vld1_u8(&sig[sig_stride]);
    288     const uint8x16_t v_sig = vcombine_u8(v_sig_lo, v_sig_hi);
    289     const uint8x8_t v_mc_running_avg_lo = vld1_u8(mc_running_avg);
    290     const uint8x8_t v_mc_running_avg_hi =
    291         vld1_u8(&mc_running_avg[mc_running_avg_stride]);
    292     const uint8x16_t v_mc_running_avg =
    293         vcombine_u8(v_mc_running_avg_lo, v_mc_running_avg_hi);
    294     /* Calculate absolute difference and sign masks. */
    295     const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg);
    296     const uint8x16_t v_diff_pos_mask = vcltq_u8(v_sig, v_mc_running_avg);
    297     const uint8x16_t v_diff_neg_mask = vcgtq_u8(v_sig, v_mc_running_avg);
    298 
    299     /* Figure out which level that put us in. */
    300     const uint8x16_t v_level1_mask = vcleq_u8(v_level1_threshold, v_abs_diff);
    301     const uint8x16_t v_level2_mask = vcleq_u8(v_level2_threshold, v_abs_diff);
    302     const uint8x16_t v_level3_mask = vcleq_u8(v_level3_threshold, v_abs_diff);
    303 
    304     /* Calculate absolute adjustments for level 1, 2 and 3. */
    305     const uint8x16_t v_level2_adjustment =
    306         vandq_u8(v_level2_mask, v_delta_level_1_and_2);
    307     const uint8x16_t v_level3_adjustment =
    308         vandq_u8(v_level3_mask, v_delta_level_2_and_3);
    309     const uint8x16_t v_level1and2_adjustment =
    310         vaddq_u8(v_level1_adjustment, v_level2_adjustment);
    311     const uint8x16_t v_level1and2and3_adjustment =
    312         vaddq_u8(v_level1and2_adjustment, v_level3_adjustment);
    313 
    314     /* Figure adjustment absolute value by selecting between the absolute
    315      * difference if in level0 or the value for level 1, 2 and 3.
    316      */
    317     const uint8x16_t v_abs_adjustment =
    318         vbslq_u8(v_level1_mask, v_level1and2and3_adjustment, v_abs_diff);
    319 
    320     /* Calculate positive and negative adjustments. Apply them to the signal
    321      * and accumulate them. Adjustments are less than eight and the maximum
    322      * sum of them (7 * 16) can fit in a signed char.
    323      */
    324     const uint8x16_t v_pos_adjustment =
    325         vandq_u8(v_diff_pos_mask, v_abs_adjustment);
    326     const uint8x16_t v_neg_adjustment =
    327         vandq_u8(v_diff_neg_mask, v_abs_adjustment);
    328 
    329     uint8x16_t v_running_avg = vqaddq_u8(v_sig, v_pos_adjustment);
    330     v_running_avg = vqsubq_u8(v_running_avg, v_neg_adjustment);
    331 
    332     /* Store results. */
    333     vst1_u8(running_avg, vget_low_u8(v_running_avg));
    334     vst1_u8(&running_avg[running_avg_stride], vget_high_u8(v_running_avg));
    335 
    336     /* Sum all the accumulators to have the sum of all pixel differences
    337      * for this macroblock.
    338      */
    339     {
    340       const int8x16_t v_sum_diff =
    341           vqsubq_s8(vreinterpretq_s8_u8(v_pos_adjustment),
    342                     vreinterpretq_s8_u8(v_neg_adjustment));
    343 
    344       const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
    345 
    346       const int32x4_t fedc_ba98_7654_3210 =
    347           vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
    348 
    349       const int64x2_t fedcba98_76543210 = vpaddlq_s32(fedc_ba98_7654_3210);
    350 
    351       v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
    352     }
    353 
    354     /* Update pointers for next iteration. */
    355     sig += sig_stride * 2;
    356     mc_running_avg += mc_running_avg_stride * 2;
    357     running_avg += running_avg_stride * 2;
    358   }
    359 
    360   /* Too much adjustments => copy block. */
    361   {
    362     int64x1_t x = vqadd_s64(vget_high_s64(v_sum_diff_total),
    363                             vget_low_s64(v_sum_diff_total));
    364     int sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
    365     int sum_diff_thresh = SUM_DIFF_THRESHOLD_UV;
    366     if (increase_denoising) sum_diff_thresh = SUM_DIFF_THRESHOLD_HIGH_UV;
    367     if (sum_diff > sum_diff_thresh) {
    368       // Before returning to copy the block (i.e., apply no denoising),
    369       // checK if we can still apply some (weaker) temporal filtering to
    370       // this block, that would otherwise not be denoised at all. Simplest
    371       // is to apply an additional adjustment to running_avg_y to bring it
    372       // closer to sig. The adjustment is capped by a maximum delta, and
    373       // chosen such that in most cases the resulting sum_diff will be
    374       // within the accceptable range given by sum_diff_thresh.
    375 
    376       // The delta is set by the excess of absolute pixel diff over the
    377       // threshold.
    378       int delta = ((sum_diff - sum_diff_thresh) >> 8) + 1;
    379       // Only apply the adjustment for max delta up to 3.
    380       if (delta < 4) {
    381         const uint8x16_t k_delta = vmovq_n_u8(delta);
    382         sig -= sig_stride * 8;
    383         mc_running_avg -= mc_running_avg_stride * 8;
    384         running_avg -= running_avg_stride * 8;
    385         for (r = 0; r < 4; ++r) {
    386           const uint8x8_t v_sig_lo = vld1_u8(sig);
    387           const uint8x8_t v_sig_hi = vld1_u8(&sig[sig_stride]);
    388           const uint8x16_t v_sig = vcombine_u8(v_sig_lo, v_sig_hi);
    389           const uint8x8_t v_mc_running_avg_lo = vld1_u8(mc_running_avg);
    390           const uint8x8_t v_mc_running_avg_hi =
    391               vld1_u8(&mc_running_avg[mc_running_avg_stride]);
    392           const uint8x16_t v_mc_running_avg =
    393               vcombine_u8(v_mc_running_avg_lo, v_mc_running_avg_hi);
    394           /* Calculate absolute difference and sign masks. */
    395           const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg);
    396           const uint8x16_t v_diff_pos_mask = vcltq_u8(v_sig, v_mc_running_avg);
    397           const uint8x16_t v_diff_neg_mask = vcgtq_u8(v_sig, v_mc_running_avg);
    398           // Clamp absolute difference to delta to get the adjustment.
    399           const uint8x16_t v_abs_adjustment = vminq_u8(v_abs_diff, (k_delta));
    400 
    401           const uint8x16_t v_pos_adjustment =
    402               vandq_u8(v_diff_pos_mask, v_abs_adjustment);
    403           const uint8x16_t v_neg_adjustment =
    404               vandq_u8(v_diff_neg_mask, v_abs_adjustment);
    405           const uint8x8_t v_running_avg_lo = vld1_u8(running_avg);
    406           const uint8x8_t v_running_avg_hi =
    407               vld1_u8(&running_avg[running_avg_stride]);
    408           uint8x16_t v_running_avg =
    409               vcombine_u8(v_running_avg_lo, v_running_avg_hi);
    410 
    411           v_running_avg = vqsubq_u8(v_running_avg, v_pos_adjustment);
    412           v_running_avg = vqaddq_u8(v_running_avg, v_neg_adjustment);
    413 
    414           /* Store results. */
    415           vst1_u8(running_avg, vget_low_u8(v_running_avg));
    416           vst1_u8(&running_avg[running_avg_stride],
    417                   vget_high_u8(v_running_avg));
    418 
    419           {
    420             const int8x16_t v_sum_diff =
    421                 vqsubq_s8(vreinterpretq_s8_u8(v_neg_adjustment),
    422                           vreinterpretq_s8_u8(v_pos_adjustment));
    423 
    424             const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
    425             const int32x4_t fedc_ba98_7654_3210 =
    426                 vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
    427             const int64x2_t fedcba98_76543210 =
    428                 vpaddlq_s32(fedc_ba98_7654_3210);
    429 
    430             v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
    431           }
    432           /* Update pointers for next iteration. */
    433           sig += sig_stride * 2;
    434           mc_running_avg += mc_running_avg_stride * 2;
    435           running_avg += running_avg_stride * 2;
    436         }
    437         {
    438           // Update the sum of all pixel differences of this MB.
    439           x = vqadd_s64(vget_high_s64(v_sum_diff_total),
    440                         vget_low_s64(v_sum_diff_total));
    441           sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
    442 
    443           if (sum_diff > sum_diff_thresh) {
    444             return COPY_BLOCK;
    445           }
    446         }
    447       } else {
    448         return COPY_BLOCK;
    449       }
    450     }
    451   }
    452 
    453   /* Tell above level that block was filtered. */
    454   running_avg -= running_avg_stride * 8;
    455   sig -= sig_stride * 8;
    456 
    457   vp8_copy_mem8x8(running_avg, running_avg_stride, sig, sig_stride);
    458 
    459   return FILTER_BLOCK;
    460 }
    461