Home | History | Annotate | Download | only in arm
      1 /*
      2  *  Copyright (c) 2016 The WebM project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include <arm_neon.h>
     12 #include <assert.h>
     13 
     14 #include "./vpx_dsp_rtcd.h"
     15 #include "vpx/vpx_integer.h"
     16 #include "vpx_dsp/arm/transpose_neon.h"
     17 
     18 extern const int16_t vpx_rv[];
     19 
     20 static uint8x8_t average_k_out(const uint8x8_t a2, const uint8x8_t a1,
     21                                const uint8x8_t v0, const uint8x8_t b1,
     22                                const uint8x8_t b2) {
     23   const uint8x8_t k1 = vrhadd_u8(a2, a1);
     24   const uint8x8_t k2 = vrhadd_u8(b2, b1);
     25   const uint8x8_t k3 = vrhadd_u8(k1, k2);
     26   return vrhadd_u8(k3, v0);
     27 }
     28 
     29 static uint8x8_t generate_mask(const uint8x8_t a2, const uint8x8_t a1,
     30                                const uint8x8_t v0, const uint8x8_t b1,
     31                                const uint8x8_t b2, const uint8x8_t filter) {
     32   const uint8x8_t a2_v0 = vabd_u8(a2, v0);
     33   const uint8x8_t a1_v0 = vabd_u8(a1, v0);
     34   const uint8x8_t b1_v0 = vabd_u8(b1, v0);
     35   const uint8x8_t b2_v0 = vabd_u8(b2, v0);
     36 
     37   uint8x8_t max = vmax_u8(a2_v0, a1_v0);
     38   max = vmax_u8(b1_v0, max);
     39   max = vmax_u8(b2_v0, max);
     40   return vclt_u8(max, filter);
     41 }
     42 
     43 static uint8x8_t generate_output(const uint8x8_t a2, const uint8x8_t a1,
     44                                  const uint8x8_t v0, const uint8x8_t b1,
     45                                  const uint8x8_t b2, const uint8x8_t filter) {
     46   const uint8x8_t k_out = average_k_out(a2, a1, v0, b1, b2);
     47   const uint8x8_t mask = generate_mask(a2, a1, v0, b1, b2, filter);
     48 
     49   return vbsl_u8(mask, k_out, v0);
     50 }
     51 
     52 // Same functions but for uint8x16_t.
     53 static uint8x16_t average_k_outq(const uint8x16_t a2, const uint8x16_t a1,
     54                                  const uint8x16_t v0, const uint8x16_t b1,
     55                                  const uint8x16_t b2) {
     56   const uint8x16_t k1 = vrhaddq_u8(a2, a1);
     57   const uint8x16_t k2 = vrhaddq_u8(b2, b1);
     58   const uint8x16_t k3 = vrhaddq_u8(k1, k2);
     59   return vrhaddq_u8(k3, v0);
     60 }
     61 
     62 static uint8x16_t generate_maskq(const uint8x16_t a2, const uint8x16_t a1,
     63                                  const uint8x16_t v0, const uint8x16_t b1,
     64                                  const uint8x16_t b2, const uint8x16_t filter) {
     65   const uint8x16_t a2_v0 = vabdq_u8(a2, v0);
     66   const uint8x16_t a1_v0 = vabdq_u8(a1, v0);
     67   const uint8x16_t b1_v0 = vabdq_u8(b1, v0);
     68   const uint8x16_t b2_v0 = vabdq_u8(b2, v0);
     69 
     70   uint8x16_t max = vmaxq_u8(a2_v0, a1_v0);
     71   max = vmaxq_u8(b1_v0, max);
     72   max = vmaxq_u8(b2_v0, max);
     73   return vcltq_u8(max, filter);
     74 }
     75 
     76 static uint8x16_t generate_outputq(const uint8x16_t a2, const uint8x16_t a1,
     77                                    const uint8x16_t v0, const uint8x16_t b1,
     78                                    const uint8x16_t b2,
     79                                    const uint8x16_t filter) {
     80   const uint8x16_t k_out = average_k_outq(a2, a1, v0, b1, b2);
     81   const uint8x16_t mask = generate_maskq(a2, a1, v0, b1, b2, filter);
     82 
     83   return vbslq_u8(mask, k_out, v0);
     84 }
     85 
     86 void vpx_post_proc_down_and_across_mb_row_neon(uint8_t *src_ptr,
     87                                                uint8_t *dst_ptr, int src_stride,
     88                                                int dst_stride, int cols,
     89                                                uint8_t *f, int size) {
     90   uint8_t *src, *dst;
     91   int row;
     92   int col;
     93 
     94   // Process a stripe of macroblocks. The stripe will be a multiple of 16 (for
     95   // Y) or 8 (for U/V) wide (cols) and the height (size) will be 16 (for Y) or 8
     96   // (for U/V).
     97   assert((size == 8 || size == 16) && cols % 8 == 0);
     98 
     99   // While columns of length 16 can be processed, load them.
    100   for (col = 0; col < cols - 8; col += 16) {
    101     uint8x16_t a0, a1, a2, a3, a4, a5, a6, a7;
    102     src = src_ptr - 2 * src_stride;
    103     dst = dst_ptr;
    104 
    105     a0 = vld1q_u8(src);
    106     src += src_stride;
    107     a1 = vld1q_u8(src);
    108     src += src_stride;
    109     a2 = vld1q_u8(src);
    110     src += src_stride;
    111     a3 = vld1q_u8(src);
    112     src += src_stride;
    113 
    114     for (row = 0; row < size; row += 4) {
    115       uint8x16_t v_out_0, v_out_1, v_out_2, v_out_3;
    116       const uint8x16_t filterq = vld1q_u8(f + col);
    117 
    118       a4 = vld1q_u8(src);
    119       src += src_stride;
    120       a5 = vld1q_u8(src);
    121       src += src_stride;
    122       a6 = vld1q_u8(src);
    123       src += src_stride;
    124       a7 = vld1q_u8(src);
    125       src += src_stride;
    126 
    127       v_out_0 = generate_outputq(a0, a1, a2, a3, a4, filterq);
    128       v_out_1 = generate_outputq(a1, a2, a3, a4, a5, filterq);
    129       v_out_2 = generate_outputq(a2, a3, a4, a5, a6, filterq);
    130       v_out_3 = generate_outputq(a3, a4, a5, a6, a7, filterq);
    131 
    132       vst1q_u8(dst, v_out_0);
    133       dst += dst_stride;
    134       vst1q_u8(dst, v_out_1);
    135       dst += dst_stride;
    136       vst1q_u8(dst, v_out_2);
    137       dst += dst_stride;
    138       vst1q_u8(dst, v_out_3);
    139       dst += dst_stride;
    140 
    141       // Rotate over to the next slot.
    142       a0 = a4;
    143       a1 = a5;
    144       a2 = a6;
    145       a3 = a7;
    146     }
    147 
    148     src_ptr += 16;
    149     dst_ptr += 16;
    150   }
    151 
    152   // Clean up any left over column of length 8.
    153   if (col != cols) {
    154     uint8x8_t a0, a1, a2, a3, a4, a5, a6, a7;
    155     src = src_ptr - 2 * src_stride;
    156     dst = dst_ptr;
    157 
    158     a0 = vld1_u8(src);
    159     src += src_stride;
    160     a1 = vld1_u8(src);
    161     src += src_stride;
    162     a2 = vld1_u8(src);
    163     src += src_stride;
    164     a3 = vld1_u8(src);
    165     src += src_stride;
    166 
    167     for (row = 0; row < size; row += 4) {
    168       uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3;
    169       const uint8x8_t filter = vld1_u8(f + col);
    170 
    171       a4 = vld1_u8(src);
    172       src += src_stride;
    173       a5 = vld1_u8(src);
    174       src += src_stride;
    175       a6 = vld1_u8(src);
    176       src += src_stride;
    177       a7 = vld1_u8(src);
    178       src += src_stride;
    179 
    180       v_out_0 = generate_output(a0, a1, a2, a3, a4, filter);
    181       v_out_1 = generate_output(a1, a2, a3, a4, a5, filter);
    182       v_out_2 = generate_output(a2, a3, a4, a5, a6, filter);
    183       v_out_3 = generate_output(a3, a4, a5, a6, a7, filter);
    184 
    185       vst1_u8(dst, v_out_0);
    186       dst += dst_stride;
    187       vst1_u8(dst, v_out_1);
    188       dst += dst_stride;
    189       vst1_u8(dst, v_out_2);
    190       dst += dst_stride;
    191       vst1_u8(dst, v_out_3);
    192       dst += dst_stride;
    193 
    194       // Rotate over to the next slot.
    195       a0 = a4;
    196       a1 = a5;
    197       a2 = a6;
    198       a3 = a7;
    199     }
    200 
    201     // Not strictly necessary but makes resetting dst_ptr easier.
    202     dst_ptr += 8;
    203   }
    204 
    205   dst_ptr -= cols;
    206 
    207   for (row = 0; row < size; row += 8) {
    208     uint8x8_t a0, a1, a2, a3;
    209     uint8x8_t b0, b1, b2, b3, b4, b5, b6, b7;
    210 
    211     src = dst_ptr;
    212     dst = dst_ptr;
    213 
    214     // Load 8 values, transpose 4 of them, and discard 2 because they will be
    215     // reloaded later.
    216     load_and_transpose_u8_4x8(src, dst_stride, &a0, &a1, &a2, &a3);
    217     a3 = a1;
    218     a2 = a1 = a0;  // Extend left border.
    219 
    220     src += 2;
    221 
    222     for (col = 0; col < cols; col += 8) {
    223       uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3, v_out_4, v_out_5, v_out_6,
    224           v_out_7;
    225       // Although the filter is meant to be applied vertically and is instead
    226       // being applied horizontally here it's OK because it's set in blocks of 8
    227       // (or 16).
    228       const uint8x8_t filter = vld1_u8(f + col);
    229 
    230       load_and_transpose_u8_8x8(src, dst_stride, &b0, &b1, &b2, &b3, &b4, &b5,
    231                                 &b6, &b7);
    232 
    233       if (col + 8 == cols) {
    234         // Last row. Extend border (b5).
    235         b6 = b7 = b5;
    236       }
    237 
    238       v_out_0 = generate_output(a0, a1, a2, a3, b0, filter);
    239       v_out_1 = generate_output(a1, a2, a3, b0, b1, filter);
    240       v_out_2 = generate_output(a2, a3, b0, b1, b2, filter);
    241       v_out_3 = generate_output(a3, b0, b1, b2, b3, filter);
    242       v_out_4 = generate_output(b0, b1, b2, b3, b4, filter);
    243       v_out_5 = generate_output(b1, b2, b3, b4, b5, filter);
    244       v_out_6 = generate_output(b2, b3, b4, b5, b6, filter);
    245       v_out_7 = generate_output(b3, b4, b5, b6, b7, filter);
    246 
    247       transpose_and_store_u8_8x8(dst, dst_stride, v_out_0, v_out_1, v_out_2,
    248                                  v_out_3, v_out_4, v_out_5, v_out_6, v_out_7);
    249 
    250       a0 = b4;
    251       a1 = b5;
    252       a2 = b6;
    253       a3 = b7;
    254 
    255       src += 8;
    256       dst += 8;
    257     }
    258 
    259     dst_ptr += 8 * dst_stride;
    260   }
    261 }
    262 
    263 // sum += x;
    264 // sumsq += x * y;
    265 static void accumulate_sum_sumsq(const int16x4_t x, const int32x4_t xy,
    266                                  int16x4_t *const sum, int32x4_t *const sumsq) {
    267   const int16x4_t zero = vdup_n_s16(0);
    268   const int32x4_t zeroq = vdupq_n_s32(0);
    269 
    270   // Add in the first set because vext doesn't work with '0'.
    271   *sum = vadd_s16(*sum, x);
    272   *sumsq = vaddq_s32(*sumsq, xy);
    273 
    274   // Shift x and xy to the right and sum. vext requires an immediate.
    275   *sum = vadd_s16(*sum, vext_s16(zero, x, 1));
    276   *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 1));
    277 
    278   *sum = vadd_s16(*sum, vext_s16(zero, x, 2));
    279   *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 2));
    280 
    281   *sum = vadd_s16(*sum, vext_s16(zero, x, 3));
    282   *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 3));
    283 }
    284 
    285 // Generate mask based on (sumsq * 15 - sum * sum < flimit)
    286 static uint16x4_t calculate_mask(const int16x4_t sum, const int32x4_t sumsq,
    287                                  const int32x4_t f, const int32x4_t fifteen) {
    288   const int32x4_t a = vmulq_s32(sumsq, fifteen);
    289   const int32x4_t b = vmlsl_s16(a, sum, sum);
    290   const uint32x4_t mask32 = vcltq_s32(b, f);
    291   return vmovn_u32(mask32);
    292 }
    293 
    294 static uint8x8_t combine_mask(const int16x4_t sum_low, const int16x4_t sum_high,
    295                               const int32x4_t sumsq_low,
    296                               const int32x4_t sumsq_high, const int32x4_t f) {
    297   const int32x4_t fifteen = vdupq_n_s32(15);
    298   const uint16x4_t mask16_low = calculate_mask(sum_low, sumsq_low, f, fifteen);
    299   const uint16x4_t mask16_high =
    300       calculate_mask(sum_high, sumsq_high, f, fifteen);
    301   return vmovn_u16(vcombine_u16(mask16_low, mask16_high));
    302 }
    303 
    304 // Apply filter of (8 + sum + s[c]) >> 4.
    305 static uint8x8_t filter_pixels(const int16x8_t sum, const uint8x8_t s) {
    306   const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s));
    307   const int16x8_t sum_s = vaddq_s16(sum, s16);
    308 
    309   return vqrshrun_n_s16(sum_s, 4);
    310 }
    311 
    312 void vpx_mbpost_proc_across_ip_neon(uint8_t *src, int pitch, int rows, int cols,
    313                                     int flimit) {
    314   int row, col;
    315   const int32x4_t f = vdupq_n_s32(flimit);
    316 
    317   assert(cols % 8 == 0);
    318 
    319   for (row = 0; row < rows; ++row) {
    320     // Sum the first 8 elements, which are extended from s[0].
    321     // sumsq gets primed with +16.
    322     int sumsq = src[0] * src[0] * 9 + 16;
    323     int sum = src[0] * 9;
    324 
    325     uint8x8_t left_context, s, right_context;
    326     int16x4_t sum_low, sum_high;
    327     int32x4_t sumsq_low, sumsq_high;
    328 
    329     // Sum (+square) the next 6 elements.
    330     // Skip [0] because it's included above.
    331     for (col = 1; col <= 6; ++col) {
    332       sumsq += src[col] * src[col];
    333       sum += src[col];
    334     }
    335 
    336     // Prime the sums. Later the loop uses the _high values to prime the new
    337     // vectors.
    338     sumsq_high = vdupq_n_s32(sumsq);
    339     sum_high = vdup_n_s16(sum);
    340 
    341     // Manually extend the left border.
    342     left_context = vdup_n_u8(src[0]);
    343 
    344     for (col = 0; col < cols; col += 8) {
    345       uint8x8_t mask, output;
    346       int16x8_t x, y;
    347       int32x4_t xy_low, xy_high;
    348 
    349       s = vld1_u8(src + col);
    350 
    351       if (col + 8 == cols) {
    352         // Last row. Extend border.
    353         right_context = vdup_n_u8(src[col + 7]);
    354       } else {
    355         right_context = vld1_u8(src + col + 7);
    356       }
    357 
    358       x = vreinterpretq_s16_u16(vsubl_u8(right_context, left_context));
    359       y = vreinterpretq_s16_u16(vaddl_u8(right_context, left_context));
    360       xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
    361       xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
    362 
    363       // Catch up to the last sum'd value.
    364       sum_low = vdup_lane_s16(sum_high, 3);
    365       sumsq_low = vdupq_lane_s32(vget_high_s32(sumsq_high), 1);
    366 
    367       accumulate_sum_sumsq(vget_low_s16(x), xy_low, &sum_low, &sumsq_low);
    368 
    369       // Need to do this sequentially because we need the max value from
    370       // sum_low.
    371       sum_high = vdup_lane_s16(sum_low, 3);
    372       sumsq_high = vdupq_lane_s32(vget_high_s32(sumsq_low), 1);
    373 
    374       accumulate_sum_sumsq(vget_high_s16(x), xy_high, &sum_high, &sumsq_high);
    375 
    376       mask = combine_mask(sum_low, sum_high, sumsq_low, sumsq_high, f);
    377 
    378       output = filter_pixels(vcombine_s16(sum_low, sum_high), s);
    379       output = vbsl_u8(mask, output, s);
    380 
    381       vst1_u8(src + col, output);
    382 
    383       left_context = s;
    384     }
    385 
    386     src += pitch;
    387   }
    388 }
    389 
    390 // Apply filter of (vpx_rv + sum + s[c]) >> 4.
    391 static uint8x8_t filter_pixels_rv(const int16x8_t sum, const uint8x8_t s,
    392                                   const int16x8_t rv) {
    393   const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s));
    394   const int16x8_t sum_s = vaddq_s16(sum, s16);
    395   const int16x8_t rounded = vaddq_s16(sum_s, rv);
    396 
    397   return vqshrun_n_s16(rounded, 4);
    398 }
    399 
    400 void vpx_mbpost_proc_down_neon(uint8_t *dst, int pitch, int rows, int cols,
    401                                int flimit) {
    402   int row, col, i;
    403   const int32x4_t f = vdupq_n_s32(flimit);
    404   uint8x8_t below_context = vdup_n_u8(0);
    405 
    406   // 8 columns are processed at a time.
    407   // If rows is less than 8 the bottom border extension fails.
    408   assert(cols % 8 == 0);
    409   assert(rows >= 8);
    410 
    411   // Load and keep the first 8 values in memory. Process a vertical stripe that
    412   // is 8 wide.
    413   for (col = 0; col < cols; col += 8) {
    414     uint8x8_t s, above_context[8];
    415     int16x8_t sum, sum_tmp;
    416     int32x4_t sumsq_low, sumsq_high;
    417 
    418     // Load and extend the top border.
    419     s = vld1_u8(dst);
    420     for (i = 0; i < 8; i++) {
    421       above_context[i] = s;
    422     }
    423 
    424     sum_tmp = vreinterpretq_s16_u16(vmovl_u8(s));
    425 
    426     // sum * 9
    427     sum = vmulq_n_s16(sum_tmp, 9);
    428 
    429     // (sum * 9) * sum == sum * sum * 9
    430     sumsq_low = vmull_s16(vget_low_s16(sum), vget_low_s16(sum_tmp));
    431     sumsq_high = vmull_s16(vget_high_s16(sum), vget_high_s16(sum_tmp));
    432 
    433     // Load and discard the next 6 values to prime sum and sumsq.
    434     for (i = 1; i <= 6; ++i) {
    435       const uint8x8_t a = vld1_u8(dst + i * pitch);
    436       const int16x8_t b = vreinterpretq_s16_u16(vmovl_u8(a));
    437       sum = vaddq_s16(sum, b);
    438 
    439       sumsq_low = vmlal_s16(sumsq_low, vget_low_s16(b), vget_low_s16(b));
    440       sumsq_high = vmlal_s16(sumsq_high, vget_high_s16(b), vget_high_s16(b));
    441     }
    442 
    443     for (row = 0; row < rows; ++row) {
    444       uint8x8_t mask, output;
    445       int16x8_t x, y;
    446       int32x4_t xy_low, xy_high;
    447 
    448       s = vld1_u8(dst + row * pitch);
    449 
    450       // Extend the bottom border.
    451       if (row + 7 < rows) {
    452         below_context = vld1_u8(dst + (row + 7) * pitch);
    453       }
    454 
    455       x = vreinterpretq_s16_u16(vsubl_u8(below_context, above_context[0]));
    456       y = vreinterpretq_s16_u16(vaddl_u8(below_context, above_context[0]));
    457       xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
    458       xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
    459 
    460       sum = vaddq_s16(sum, x);
    461 
    462       sumsq_low = vaddq_s32(sumsq_low, xy_low);
    463       sumsq_high = vaddq_s32(sumsq_high, xy_high);
    464 
    465       mask = combine_mask(vget_low_s16(sum), vget_high_s16(sum), sumsq_low,
    466                           sumsq_high, f);
    467 
    468       output = filter_pixels_rv(sum, s, vld1q_s16(vpx_rv + (row & 127)));
    469       output = vbsl_u8(mask, output, s);
    470 
    471       vst1_u8(dst + row * pitch, output);
    472 
    473       above_context[0] = above_context[1];
    474       above_context[1] = above_context[2];
    475       above_context[2] = above_context[3];
    476       above_context[3] = above_context[4];
    477       above_context[4] = above_context[5];
    478       above_context[5] = above_context[6];
    479       above_context[6] = above_context[7];
    480       above_context[7] = s;
    481     }
    482 
    483     dst += 8;
    484   }
    485 }
    486