Home | History | Annotate | Download | only in util
      1 /*
      2  *  Copyright (c) 2015 The WebRTC project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include <arm_neon.h>
     12 
     13 #include "webrtc/modules/video_processing/util/denoiser_filter_neon.h"
     14 
     15 namespace webrtc {
     16 
     17 static int HorizontalAddS16x8(const int16x8_t v_16x8) {
     18   const int32x4_t a = vpaddlq_s16(v_16x8);
     19   const int64x2_t b = vpaddlq_s32(a);
     20   const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)),
     21                                vreinterpret_s32_s64(vget_high_s64(b)));
     22   return vget_lane_s32(c, 0);
     23 }
     24 
     25 static int HorizontalAddS32x4(const int32x4_t v_32x4) {
     26   const int64x2_t b = vpaddlq_s32(v_32x4);
     27   const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)),
     28                                vreinterpret_s32_s64(vget_high_s64(b)));
     29   return vget_lane_s32(c, 0);
     30 }
     31 
     32 static void VarianceNeonW8(const uint8_t* a,
     33                            int a_stride,
     34                            const uint8_t* b,
     35                            int b_stride,
     36                            int w,
     37                            int h,
     38                            uint32_t* sse,
     39                            int64_t* sum) {
     40   int16x8_t v_sum = vdupq_n_s16(0);
     41   int32x4_t v_sse_lo = vdupq_n_s32(0);
     42   int32x4_t v_sse_hi = vdupq_n_s32(0);
     43 
     44   for (int i = 0; i < h; ++i) {
     45     for (int j = 0; j < w; j += 8) {
     46       const uint8x8_t v_a = vld1_u8(&a[j]);
     47       const uint8x8_t v_b = vld1_u8(&b[j]);
     48       const uint16x8_t v_diff = vsubl_u8(v_a, v_b);
     49       const int16x8_t sv_diff = vreinterpretq_s16_u16(v_diff);
     50       v_sum = vaddq_s16(v_sum, sv_diff);
     51       v_sse_lo =
     52           vmlal_s16(v_sse_lo, vget_low_s16(sv_diff), vget_low_s16(sv_diff));
     53       v_sse_hi =
     54           vmlal_s16(v_sse_hi, vget_high_s16(sv_diff), vget_high_s16(sv_diff));
     55     }
     56     a += a_stride;
     57     b += b_stride;
     58   }
     59 
     60   *sum = HorizontalAddS16x8(v_sum);
     61   *sse =
     62       static_cast<uint32_t>(HorizontalAddS32x4(vaddq_s32(v_sse_lo, v_sse_hi)));
     63 }
     64 
     65 void DenoiserFilterNEON::CopyMem16x16(const uint8_t* src,
     66                                       int src_stride,
     67                                       uint8_t* dst,
     68                                       int dst_stride) {
     69   uint8x16_t qtmp;
     70   for (int r = 0; r < 16; r++) {
     71     qtmp = vld1q_u8(src);
     72     vst1q_u8(dst, qtmp);
     73     src += src_stride;
     74     dst += dst_stride;
     75   }
     76 }
     77 
     78 void DenoiserFilterNEON::CopyMem8x8(const uint8_t* src,
     79                                     int src_stride,
     80                                     uint8_t* dst,
     81                                     int dst_stride) {
     82   uint8x8_t vtmp;
     83 
     84   for (int r = 0; r < 8; r++) {
     85     vtmp = vld1_u8(src);
     86     vst1_u8(dst, vtmp);
     87     src += src_stride;
     88     dst += dst_stride;
     89   }
     90 }
     91 
     92 uint32_t DenoiserFilterNEON::Variance16x8(const uint8_t* a,
     93                                           int a_stride,
     94                                           const uint8_t* b,
     95                                           int b_stride,
     96                                           uint32_t* sse) {
     97   int64_t sum = 0;
     98   VarianceNeonW8(a, a_stride << 1, b, b_stride << 1, 16, 8, sse, &sum);
     99   return *sse - ((sum * sum) >> 7);
    100 }
    101 
    102 DenoiserDecision DenoiserFilterNEON::MbDenoise(uint8_t* mc_running_avg_y,
    103                                                int mc_running_avg_y_stride,
    104                                                uint8_t* running_avg_y,
    105                                                int running_avg_y_stride,
    106                                                const uint8_t* sig,
    107                                                int sig_stride,
    108                                                uint8_t motion_magnitude,
    109                                                int increase_denoising) {
    110   // If motion_magnitude is small, making the denoiser more aggressive by
    111   // increasing the adjustment for each level, level1 adjustment is
    112   // increased, the deltas stay the same.
    113   int shift_inc =
    114       (increase_denoising && motion_magnitude <= kMotionMagnitudeThreshold) ? 1
    115                                                                             : 0;
    116   const uint8x16_t v_level1_adjustment = vmovq_n_u8(
    117       (motion_magnitude <= kMotionMagnitudeThreshold) ? 4 + shift_inc : 3);
    118   const uint8x16_t v_delta_level_1_and_2 = vdupq_n_u8(1);
    119   const uint8x16_t v_delta_level_2_and_3 = vdupq_n_u8(2);
    120   const uint8x16_t v_level1_threshold = vmovq_n_u8(4 + shift_inc);
    121   const uint8x16_t v_level2_threshold = vdupq_n_u8(8);
    122   const uint8x16_t v_level3_threshold = vdupq_n_u8(16);
    123   int64x2_t v_sum_diff_total = vdupq_n_s64(0);
    124 
    125   // Go over lines.
    126   for (int r = 0; r < 16; ++r) {
    127     // Load inputs.
    128     const uint8x16_t v_sig = vld1q_u8(sig);
    129     const uint8x16_t v_mc_running_avg_y = vld1q_u8(mc_running_avg_y);
    130 
    131     // Calculate absolute difference and sign masks.
    132     const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg_y);
    133     const uint8x16_t v_diff_pos_mask = vcltq_u8(v_sig, v_mc_running_avg_y);
    134     const uint8x16_t v_diff_neg_mask = vcgtq_u8(v_sig, v_mc_running_avg_y);
    135 
    136     // Figure out which level that put us in.
    137     const uint8x16_t v_level1_mask = vcleq_u8(v_level1_threshold, v_abs_diff);
    138     const uint8x16_t v_level2_mask = vcleq_u8(v_level2_threshold, v_abs_diff);
    139     const uint8x16_t v_level3_mask = vcleq_u8(v_level3_threshold, v_abs_diff);
    140 
    141     // Calculate absolute adjustments for level 1, 2 and 3.
    142     const uint8x16_t v_level2_adjustment =
    143         vandq_u8(v_level2_mask, v_delta_level_1_and_2);
    144     const uint8x16_t v_level3_adjustment =
    145         vandq_u8(v_level3_mask, v_delta_level_2_and_3);
    146     const uint8x16_t v_level1and2_adjustment =
    147         vaddq_u8(v_level1_adjustment, v_level2_adjustment);
    148     const uint8x16_t v_level1and2and3_adjustment =
    149         vaddq_u8(v_level1and2_adjustment, v_level3_adjustment);
    150 
    151     // Figure adjustment absolute value by selecting between the absolute
    152     // difference if in level0 or the value for level 1, 2 and 3.
    153     const uint8x16_t v_abs_adjustment =
    154         vbslq_u8(v_level1_mask, v_level1and2and3_adjustment, v_abs_diff);
    155 
    156     // Calculate positive and negative adjustments. Apply them to the signal
    157     // and accumulate them. Adjustments are less than eight and the maximum
    158     // sum of them (7 * 16) can fit in a signed char.
    159     const uint8x16_t v_pos_adjustment =
    160         vandq_u8(v_diff_pos_mask, v_abs_adjustment);
    161     const uint8x16_t v_neg_adjustment =
    162         vandq_u8(v_diff_neg_mask, v_abs_adjustment);
    163 
    164     uint8x16_t v_running_avg_y = vqaddq_u8(v_sig, v_pos_adjustment);
    165     v_running_avg_y = vqsubq_u8(v_running_avg_y, v_neg_adjustment);
    166 
    167     // Store results.
    168     vst1q_u8(running_avg_y, v_running_avg_y);
    169 
    170     // Sum all the accumulators to have the sum of all pixel differences
    171     // for this macroblock.
    172     {
    173       const int8x16_t v_sum_diff =
    174           vqsubq_s8(vreinterpretq_s8_u8(v_pos_adjustment),
    175                     vreinterpretq_s8_u8(v_neg_adjustment));
    176       const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
    177       const int32x4_t fedc_ba98_7654_3210 =
    178           vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
    179       const int64x2_t fedcba98_76543210 = vpaddlq_s32(fedc_ba98_7654_3210);
    180 
    181       v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
    182     }
    183 
    184     // Update pointers for next iteration.
    185     sig += sig_stride;
    186     mc_running_avg_y += mc_running_avg_y_stride;
    187     running_avg_y += running_avg_y_stride;
    188   }
    189 
    190   // Too much adjustments => copy block.
    191   {
    192     int64x1_t x = vqadd_s64(vget_high_s64(v_sum_diff_total),
    193                             vget_low_s64(v_sum_diff_total));
    194     int sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
    195     int sum_diff_thresh = kSumDiffThreshold;
    196 
    197     if (increase_denoising)
    198       sum_diff_thresh = kSumDiffThresholdHigh;
    199     if (sum_diff > sum_diff_thresh) {
    200       // Before returning to copy the block (i.e., apply no denoising),
    201       // checK if we can still apply some (weaker) temporal filtering to
    202       // this block, that would otherwise not be denoised at all. Simplest
    203       // is to apply an additional adjustment to running_avg_y to bring it
    204       // closer to sig. The adjustment is capped by a maximum delta, and
    205       // chosen such that in most cases the resulting sum_diff will be
    206       // within the accceptable range given by sum_diff_thresh.
    207 
    208       // The delta is set by the excess of absolute pixel diff over the
    209       // threshold.
    210       int delta = ((sum_diff - sum_diff_thresh) >> 8) + 1;
    211       // Only apply the adjustment for max delta up to 3.
    212       if (delta < 4) {
    213         const uint8x16_t k_delta = vmovq_n_u8(delta);
    214         sig -= sig_stride * 16;
    215         mc_running_avg_y -= mc_running_avg_y_stride * 16;
    216         running_avg_y -= running_avg_y_stride * 16;
    217         for (int r = 0; r < 16; ++r) {
    218           uint8x16_t v_running_avg_y = vld1q_u8(running_avg_y);
    219           const uint8x16_t v_sig = vld1q_u8(sig);
    220           const uint8x16_t v_mc_running_avg_y = vld1q_u8(mc_running_avg_y);
    221 
    222           // Calculate absolute difference and sign masks.
    223           const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg_y);
    224           const uint8x16_t v_diff_pos_mask =
    225               vcltq_u8(v_sig, v_mc_running_avg_y);
    226           const uint8x16_t v_diff_neg_mask =
    227               vcgtq_u8(v_sig, v_mc_running_avg_y);
    228           // Clamp absolute difference to delta to get the adjustment.
    229           const uint8x16_t v_abs_adjustment = vminq_u8(v_abs_diff, (k_delta));
    230 
    231           const uint8x16_t v_pos_adjustment =
    232               vandq_u8(v_diff_pos_mask, v_abs_adjustment);
    233           const uint8x16_t v_neg_adjustment =
    234               vandq_u8(v_diff_neg_mask, v_abs_adjustment);
    235 
    236           v_running_avg_y = vqsubq_u8(v_running_avg_y, v_pos_adjustment);
    237           v_running_avg_y = vqaddq_u8(v_running_avg_y, v_neg_adjustment);
    238 
    239           // Store results.
    240           vst1q_u8(running_avg_y, v_running_avg_y);
    241 
    242           {
    243             const int8x16_t v_sum_diff =
    244                 vqsubq_s8(vreinterpretq_s8_u8(v_neg_adjustment),
    245                           vreinterpretq_s8_u8(v_pos_adjustment));
    246 
    247             const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
    248             const int32x4_t fedc_ba98_7654_3210 =
    249                 vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
    250             const int64x2_t fedcba98_76543210 =
    251                 vpaddlq_s32(fedc_ba98_7654_3210);
    252 
    253             v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
    254           }
    255           // Update pointers for next iteration.
    256           sig += sig_stride;
    257           mc_running_avg_y += mc_running_avg_y_stride;
    258           running_avg_y += running_avg_y_stride;
    259         }
    260         {
    261           // Update the sum of all pixel differences of this MB.
    262           x = vqadd_s64(vget_high_s64(v_sum_diff_total),
    263                         vget_low_s64(v_sum_diff_total));
    264           sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
    265 
    266           if (sum_diff > sum_diff_thresh) {
    267             return COPY_BLOCK;
    268           }
    269         }
    270       } else {
    271         return COPY_BLOCK;
    272       }
    273     }
    274   }
    275 
    276   // Tell above level that block was filtered.
    277   running_avg_y -= running_avg_y_stride * 16;
    278   sig -= sig_stride * 16;
    279 
    280   return FILTER_BLOCK;
    281 }
    282 
    283 }  // namespace webrtc
    284