Home | History | Annotate | Download | only in arm
      1 /*
      2  *  Copyright (c) 2018, Alliance for Open Media. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #ifndef AOM_AV1_COMMON_ARM_CONVOLVE_NEON_H_
     12 #define AOM_AV1_COMMON_ARM_CONVOLVE_NEON_H_
     13 
     14 #include <arm_neon.h>
     15 
     16 #define HORIZ_EXTRA_ROWS ((SUBPEL_TAPS + 7) & ~0x07)
     17 
     18 static INLINE uint8x8_t wiener_convolve8_vert_4x8(
     19     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
     20     const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
     21     const int16x8_t s6, int16_t *filter_y, const int bd,
     22     const int round1_bits) {
     23   int16x8_t ss0, ss1, ss2;
     24   int32x4_t sum0, sum1;
     25   uint16x4_t tmp0, tmp1;
     26   uint16x8_t tmp;
     27   uint8x8_t res;
     28 
     29   const int32_t round_const = (1 << (bd + round1_bits - 1));
     30   const int32x4_t round_bits = vdupq_n_s32(-round1_bits);
     31   const int32x4_t zero = vdupq_n_s32(0);
     32   const int32x4_t round_vec = vdupq_n_s32(round_const);
     33 
     34   ss0 = vaddq_s16(s0, s6);
     35   ss1 = vaddq_s16(s1, s5);
     36   ss2 = vaddq_s16(s2, s4);
     37 
     38   sum0 = vmull_n_s16(vget_low_s16(ss0), filter_y[0]);
     39   sum0 = vmlal_n_s16(sum0, vget_low_s16(ss1), filter_y[1]);
     40   sum0 = vmlal_n_s16(sum0, vget_low_s16(ss2), filter_y[2]);
     41   sum0 = vmlal_n_s16(sum0, vget_low_s16(s3), filter_y[3]);
     42 
     43   sum1 = vmull_n_s16(vget_high_s16(ss0), filter_y[0]);
     44   sum1 = vmlal_n_s16(sum1, vget_high_s16(ss1), filter_y[1]);
     45   sum1 = vmlal_n_s16(sum1, vget_high_s16(ss2), filter_y[2]);
     46   sum1 = vmlal_n_s16(sum1, vget_high_s16(s3), filter_y[3]);
     47 
     48   sum0 = vsubq_s32(sum0, round_vec);
     49   sum1 = vsubq_s32(sum1, round_vec);
     50 
     51   /* right shift & rounding */
     52   sum0 = vrshlq_s32(sum0, round_bits);
     53   sum1 = vrshlq_s32(sum1, round_bits);
     54 
     55   sum0 = vmaxq_s32(sum0, zero);
     56   sum1 = vmaxq_s32(sum1, zero);
     57 
     58   /* from int32x4_t to uint8x8_t */
     59   tmp0 = vqmovn_u32(vreinterpretq_u32_s32(sum0));
     60   tmp1 = vqmovn_u32(vreinterpretq_u32_s32(sum1));
     61   tmp = vcombine_u16(tmp0, tmp1);
     62   res = vqmovn_u16(tmp);
     63 
     64   return res;
     65 }
     66 
     67 static INLINE uint16x8_t wiener_convolve8_horiz_8x8(
     68     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
     69     const int16x8_t s3, int16_t *filter_x, const int bd,
     70     const int round0_bits) {
     71   int16x8_t sum;
     72   uint16x8_t res;
     73   int32x4_t sum_0, sum_1;
     74   int32x4_t s3_0, s3_1;
     75   const int32_t round_const_0 = (1 << (bd + FILTER_BITS - 1));
     76   const int32_t round_const_1 = (1 << ((bd) + 1 + FILTER_BITS - round0_bits));
     77 
     78   /* for the purpose of right shift by { conv_params->round_0 } */
     79   const int32x4_t round_bits = vdupq_n_s32(-round0_bits);
     80 
     81   const int32x4_t round_vec_0 = vdupq_n_s32(round_const_0);
     82   const int32x4_t round_vec_1 = vdupq_n_s32(round_const_1);
     83 
     84   sum = vmulq_n_s16(s0, filter_x[0]);
     85   sum = vmlaq_n_s16(sum, s1, filter_x[1]);
     86   sum = vmlaq_n_s16(sum, s2, filter_x[2]);
     87 
     88   /* sum from 16x8 to 2 32x4 registers */
     89   sum_0 = vmovl_s16(vget_low_s16(sum));
     90   sum_1 = vmovl_s16(vget_high_s16(sum));
     91 
     92   /* s[3]*128 -- and filter coef max can be 128
     93    *  then max value possible = 128*128*255 exceeding 16 bit
     94    */
     95 
     96   s3_0 = vmull_n_s16(vget_low_s16(s3), filter_x[3]);
     97   s3_1 = vmull_n_s16(vget_high_s16(s3), filter_x[3]);
     98   sum_0 = vaddq_s32(sum_0, s3_0);
     99   sum_1 = vaddq_s32(sum_1, s3_1);
    100 
    101   /* Add the constant value */
    102   sum_0 = vaddq_s32(sum_0, round_vec_0);
    103   sum_1 = vaddq_s32(sum_1, round_vec_0);
    104 
    105   /* right shift & rounding & saturating */
    106   sum_0 = vqrshlq_s32(sum_0, round_bits);
    107   sum_1 = vqrshlq_s32(sum_1, round_bits);
    108 
    109   /* Clipping to max value */
    110   sum_0 = vminq_s32(sum_0, round_vec_1);
    111   sum_1 = vminq_s32(sum_1, round_vec_1);
    112 
    113   res = vcombine_u16(vqmovun_s32(sum_0), vqmovun_s32(sum_1));
    114   return res;
    115 }
    116 
    117 static INLINE uint16x4_t wiener_convolve8_horiz_4x8(
    118     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
    119     const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
    120     const int16x4_t s6, int16_t *filter_x, const int bd,
    121     const int round0_bits) {
    122   uint16x4_t res;
    123   int32x4_t sum_0, s3_0;
    124   int16x4_t sum, temp0, temp1, temp2;
    125 
    126   const int32_t round_const_0 = (1 << (bd + FILTER_BITS - 1));
    127   const int32_t round_const_1 = (1 << ((bd) + 1 + FILTER_BITS - round0_bits));
    128   const int32x4_t round_bits = vdupq_n_s32(-round0_bits);
    129   const int32x4_t zero = vdupq_n_s32(0);
    130   const int32x4_t round_vec_0 = vdupq_n_s32(round_const_0);
    131   const int32x4_t round_vec_1 = vdupq_n_s32(round_const_1);
    132 
    133   temp0 = vadd_s16(s0, s6);
    134   temp1 = vadd_s16(s1, s5);
    135   temp2 = vadd_s16(s2, s4);
    136 
    137   sum = vmul_n_s16(temp0, filter_x[0]);
    138   sum = vmla_n_s16(sum, temp1, filter_x[1]);
    139   sum = vmla_n_s16(sum, temp2, filter_x[2]);
    140   sum_0 = vmovl_s16(sum);
    141 
    142   /* s[3]*128 -- and filter coff max can be 128.
    143    * then max value possible = 128*128*255 Therefore, 32 bits are required to
    144    * hold the result.
    145    */
    146   s3_0 = vmull_n_s16(s3, filter_x[3]);
    147   sum_0 = vaddq_s32(sum_0, s3_0);
    148 
    149   sum_0 = vaddq_s32(sum_0, round_vec_0);
    150   sum_0 = vrshlq_s32(sum_0, round_bits);
    151 
    152   sum_0 = vmaxq_s32(sum_0, zero);
    153   sum_0 = vminq_s32(sum_0, round_vec_1);
    154   res = vqmovun_s32(sum_0);
    155   return res;
    156 }
    157 
    158 static INLINE int16x8_t
    159 convolve8_8x8_s16(const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
    160                   const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
    161                   const int16x8_t s6, const int16x8_t s7, const int16_t *filter,
    162                   const int16x8_t horiz_const, const int16x8_t shift_round_0) {
    163   int16x8_t sum;
    164   int16x8_t res;
    165 
    166   sum = horiz_const;
    167   sum = vmlaq_n_s16(sum, s0, filter[0]);
    168   sum = vmlaq_n_s16(sum, s1, filter[1]);
    169   sum = vmlaq_n_s16(sum, s2, filter[2]);
    170   sum = vmlaq_n_s16(sum, s3, filter[3]);
    171   sum = vmlaq_n_s16(sum, s4, filter[4]);
    172   sum = vmlaq_n_s16(sum, s5, filter[5]);
    173   sum = vmlaq_n_s16(sum, s6, filter[6]);
    174   sum = vmlaq_n_s16(sum, s7, filter[7]);
    175 
    176   res = vqrshlq_s16(sum, shift_round_0);
    177 
    178   return res;
    179 }
    180 
    181 static INLINE int16x4_t
    182 convolve8_4x4_s16(const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
    183                   const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
    184                   const int16x4_t s6, const int16x4_t s7, const int16_t *filter,
    185                   const int16x4_t horiz_const, const int16x4_t shift_round_0) {
    186   int16x4_t sum;
    187   sum = horiz_const;
    188   sum = vmla_n_s16(sum, s0, filter[0]);
    189   sum = vmla_n_s16(sum, s1, filter[1]);
    190   sum = vmla_n_s16(sum, s2, filter[2]);
    191   sum = vmla_n_s16(sum, s3, filter[3]);
    192   sum = vmla_n_s16(sum, s4, filter[4]);
    193   sum = vmla_n_s16(sum, s5, filter[5]);
    194   sum = vmla_n_s16(sum, s6, filter[6]);
    195   sum = vmla_n_s16(sum, s7, filter[7]);
    196 
    197   sum = vqrshl_s16(sum, shift_round_0);
    198 
    199   return sum;
    200 }
    201 
    202 static INLINE uint16x4_t convolve8_4x4_s32(
    203     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
    204     const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
    205     const int16x4_t s6, const int16x4_t s7, const int16_t *y_filter,
    206     const int32x4_t round_shift_vec, const int32x4_t offset_const) {
    207   int32x4_t sum0;
    208   uint16x4_t res;
    209   const int32x4_t zero = vdupq_n_s32(0);
    210 
    211   sum0 = vmull_n_s16(s0, y_filter[0]);
    212   sum0 = vmlal_n_s16(sum0, s1, y_filter[1]);
    213   sum0 = vmlal_n_s16(sum0, s2, y_filter[2]);
    214   sum0 = vmlal_n_s16(sum0, s3, y_filter[3]);
    215   sum0 = vmlal_n_s16(sum0, s4, y_filter[4]);
    216   sum0 = vmlal_n_s16(sum0, s5, y_filter[5]);
    217   sum0 = vmlal_n_s16(sum0, s6, y_filter[6]);
    218   sum0 = vmlal_n_s16(sum0, s7, y_filter[7]);
    219 
    220   sum0 = vaddq_s32(sum0, offset_const);
    221   sum0 = vqrshlq_s32(sum0, round_shift_vec);
    222   sum0 = vmaxq_s32(sum0, zero);
    223   res = vmovn_u32(vreinterpretq_u32_s32(sum0));
    224 
    225   return res;
    226 }
    227 
    228 #endif  // AOM_AV1_COMMON_ARM_CONVOLVE_NEON_H_
    229