Home | History | Annotate | Download | only in x86
      1 /*
      2  *  Copyright (c) 2012 The WebM project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include <immintrin.h>  // AVX2
     12 
     13 void vp9_get16x16var_avx2(const unsigned char *src_ptr,
     14                           int source_stride,
     15                           const unsigned char *ref_ptr,
     16                           int recon_stride,
     17                           unsigned int *SSE,
     18                           int *Sum) {
     19     __m256i src, src_expand_low, src_expand_high, ref, ref_expand_low;
     20     __m256i ref_expand_high, madd_low, madd_high;
     21     unsigned int i, src_2strides, ref_2strides;
     22     __m256i zero_reg = _mm256_set1_epi16(0);
     23     __m256i sum_ref_src = _mm256_set1_epi16(0);
     24     __m256i madd_ref_src = _mm256_set1_epi16(0);
     25 
     26     // processing two strides in a 256 bit register reducing the number
     27     // of loop stride by half (comparing to the sse2 code)
     28     src_2strides = source_stride << 1;
     29     ref_2strides = recon_stride << 1;
     30     for (i = 0; i < 8; i++) {
     31         src = _mm256_castsi128_si256(
     32               _mm_loadu_si128((__m128i const *) (src_ptr)));
     33         src = _mm256_inserti128_si256(src,
     34               _mm_loadu_si128((__m128i const *)(src_ptr+source_stride)), 1);
     35 
     36         ref =_mm256_castsi128_si256(
     37              _mm_loadu_si128((__m128i const *) (ref_ptr)));
     38         ref = _mm256_inserti128_si256(ref,
     39               _mm_loadu_si128((__m128i const *)(ref_ptr+recon_stride)), 1);
     40 
     41         // expanding to 16 bit each lane
     42         src_expand_low = _mm256_unpacklo_epi8(src, zero_reg);
     43         src_expand_high = _mm256_unpackhi_epi8(src, zero_reg);
     44 
     45         ref_expand_low = _mm256_unpacklo_epi8(ref, zero_reg);
     46         ref_expand_high = _mm256_unpackhi_epi8(ref, zero_reg);
     47 
     48         // src-ref
     49         src_expand_low = _mm256_sub_epi16(src_expand_low, ref_expand_low);
     50         src_expand_high = _mm256_sub_epi16(src_expand_high, ref_expand_high);
     51 
     52         // madd low (src - ref)
     53         madd_low = _mm256_madd_epi16(src_expand_low, src_expand_low);
     54 
     55         // add high to low
     56         src_expand_low = _mm256_add_epi16(src_expand_low, src_expand_high);
     57 
     58         // madd high (src - ref)
     59         madd_high = _mm256_madd_epi16(src_expand_high, src_expand_high);
     60 
     61         sum_ref_src = _mm256_add_epi16(sum_ref_src, src_expand_low);
     62 
     63         // add high to low
     64         madd_ref_src = _mm256_add_epi32(madd_ref_src,
     65                        _mm256_add_epi32(madd_low, madd_high));
     66 
     67         src_ptr+= src_2strides;
     68         ref_ptr+= ref_2strides;
     69     }
     70 
     71     {
     72         __m128i sum_res, madd_res;
     73         __m128i expand_sum_low, expand_sum_high, expand_sum;
     74         __m128i expand_madd_low, expand_madd_high, expand_madd;
     75         __m128i ex_expand_sum_low, ex_expand_sum_high, ex_expand_sum;
     76 
     77         // extract the low lane and add it to the high lane
     78         sum_res = _mm_add_epi16(_mm256_castsi256_si128(sum_ref_src),
     79                                 _mm256_extractf128_si256(sum_ref_src, 1));
     80 
     81         madd_res = _mm_add_epi32(_mm256_castsi256_si128(madd_ref_src),
     82                                  _mm256_extractf128_si256(madd_ref_src, 1));
     83 
     84         // padding each 2 bytes with another 2 zeroed bytes
     85         expand_sum_low = _mm_unpacklo_epi16(_mm256_castsi256_si128(zero_reg),
     86                                             sum_res);
     87         expand_sum_high = _mm_unpackhi_epi16(_mm256_castsi256_si128(zero_reg),
     88                                              sum_res);
     89 
     90         // shifting the sign 16 bits right
     91         expand_sum_low = _mm_srai_epi32(expand_sum_low, 16);
     92         expand_sum_high = _mm_srai_epi32(expand_sum_high, 16);
     93 
     94         expand_sum = _mm_add_epi32(expand_sum_low, expand_sum_high);
     95 
     96         // expand each 32 bits of the madd result to 64 bits
     97         expand_madd_low = _mm_unpacklo_epi32(madd_res,
     98                           _mm256_castsi256_si128(zero_reg));
     99         expand_madd_high = _mm_unpackhi_epi32(madd_res,
    100                            _mm256_castsi256_si128(zero_reg));
    101 
    102         expand_madd = _mm_add_epi32(expand_madd_low, expand_madd_high);
    103 
    104         ex_expand_sum_low = _mm_unpacklo_epi32(expand_sum,
    105                             _mm256_castsi256_si128(zero_reg));
    106         ex_expand_sum_high = _mm_unpackhi_epi32(expand_sum,
    107                              _mm256_castsi256_si128(zero_reg));
    108 
    109         ex_expand_sum = _mm_add_epi32(ex_expand_sum_low, ex_expand_sum_high);
    110 
    111         // shift 8 bytes eight
    112         madd_res = _mm_srli_si128(expand_madd, 8);
    113         sum_res = _mm_srli_si128(ex_expand_sum, 8);
    114 
    115         madd_res = _mm_add_epi32(madd_res, expand_madd);
    116         sum_res = _mm_add_epi32(sum_res, ex_expand_sum);
    117 
    118         *((int*)SSE)= _mm_cvtsi128_si32(madd_res);
    119 
    120         *((int*)Sum)= _mm_cvtsi128_si32(sum_res);
    121     }
    122 }
    123 
    124 void vp9_get32x32var_avx2(const unsigned char *src_ptr,
    125                           int source_stride,
    126                           const unsigned char *ref_ptr,
    127                           int recon_stride,
    128                           unsigned int *SSE,
    129                           int *Sum) {
    130     __m256i src, src_expand_low, src_expand_high, ref, ref_expand_low;
    131     __m256i ref_expand_high, madd_low, madd_high;
    132     unsigned int i;
    133     __m256i zero_reg = _mm256_set1_epi16(0);
    134     __m256i sum_ref_src = _mm256_set1_epi16(0);
    135     __m256i madd_ref_src = _mm256_set1_epi16(0);
    136 
    137     // processing 32 elements in parallel
    138     for (i = 0; i < 16; i++) {
    139        src = _mm256_loadu_si256((__m256i const *) (src_ptr));
    140 
    141        ref = _mm256_loadu_si256((__m256i const *) (ref_ptr));
    142 
    143        // expanding to 16 bit each lane
    144        src_expand_low = _mm256_unpacklo_epi8(src, zero_reg);
    145        src_expand_high = _mm256_unpackhi_epi8(src, zero_reg);
    146 
    147        ref_expand_low = _mm256_unpacklo_epi8(ref, zero_reg);
    148        ref_expand_high = _mm256_unpackhi_epi8(ref, zero_reg);
    149 
    150        // src-ref
    151        src_expand_low = _mm256_sub_epi16(src_expand_low, ref_expand_low);
    152        src_expand_high = _mm256_sub_epi16(src_expand_high, ref_expand_high);
    153 
    154        // madd low (src - ref)
    155        madd_low = _mm256_madd_epi16(src_expand_low, src_expand_low);
    156 
    157        // add high to low
    158        src_expand_low = _mm256_add_epi16(src_expand_low, src_expand_high);
    159 
    160        // madd high (src - ref)
    161        madd_high = _mm256_madd_epi16(src_expand_high, src_expand_high);
    162 
    163        sum_ref_src = _mm256_add_epi16(sum_ref_src, src_expand_low);
    164 
    165        // add high to low
    166        madd_ref_src = _mm256_add_epi32(madd_ref_src,
    167                       _mm256_add_epi32(madd_low, madd_high));
    168 
    169        src_ptr+= source_stride;
    170        ref_ptr+= recon_stride;
    171     }
    172 
    173     {
    174       __m256i expand_sum_low, expand_sum_high, expand_sum;
    175       __m256i expand_madd_low, expand_madd_high, expand_madd;
    176       __m256i ex_expand_sum_low, ex_expand_sum_high, ex_expand_sum;
    177 
    178       // padding each 2 bytes with another 2 zeroed bytes
    179       expand_sum_low = _mm256_unpacklo_epi16(zero_reg, sum_ref_src);
    180       expand_sum_high = _mm256_unpackhi_epi16(zero_reg, sum_ref_src);
    181 
    182       // shifting the sign 16 bits right
    183       expand_sum_low = _mm256_srai_epi32(expand_sum_low, 16);
    184       expand_sum_high = _mm256_srai_epi32(expand_sum_high, 16);
    185 
    186       expand_sum = _mm256_add_epi32(expand_sum_low, expand_sum_high);
    187 
    188       // expand each 32 bits of the madd result to 64 bits
    189       expand_madd_low = _mm256_unpacklo_epi32(madd_ref_src, zero_reg);
    190       expand_madd_high = _mm256_unpackhi_epi32(madd_ref_src, zero_reg);
    191 
    192       expand_madd = _mm256_add_epi32(expand_madd_low, expand_madd_high);
    193 
    194       ex_expand_sum_low = _mm256_unpacklo_epi32(expand_sum, zero_reg);
    195       ex_expand_sum_high = _mm256_unpackhi_epi32(expand_sum, zero_reg);
    196 
    197       ex_expand_sum = _mm256_add_epi32(ex_expand_sum_low, ex_expand_sum_high);
    198 
    199       // shift 8 bytes eight
    200       madd_ref_src = _mm256_srli_si256(expand_madd, 8);
    201       sum_ref_src = _mm256_srli_si256(ex_expand_sum, 8);
    202 
    203       madd_ref_src = _mm256_add_epi32(madd_ref_src, expand_madd);
    204       sum_ref_src = _mm256_add_epi32(sum_ref_src, ex_expand_sum);
    205 
    206       // extract the low lane and the high lane and add the results
    207       *((int*)SSE)= _mm_cvtsi128_si32(_mm256_castsi256_si128(madd_ref_src)) +
    208       _mm_cvtsi128_si32(_mm256_extractf128_si256(madd_ref_src, 1));
    209 
    210       *((int*)Sum)= _mm_cvtsi128_si32(_mm256_castsi256_si128(sum_ref_src)) +
    211       _mm_cvtsi128_si32(_mm256_extractf128_si256(sum_ref_src, 1));
    212     }
    213 }
    214