1 /* 2 * Copyright (c) 2014 The WebM project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include <arm_neon.h> 12 #include <assert.h> 13 14 #include "./vpx_dsp_rtcd.h" 15 #include "./vpx_config.h" 16 17 #include "vpx/vpx_integer.h" 18 #include "vpx_dsp/arm/mem_neon.h" 19 #include "vpx_ports/mem.h" 20 21 static INLINE int horizontal_add_s16x8(const int16x8_t v_16x8) { 22 const int32x4_t a = vpaddlq_s16(v_16x8); 23 const int64x2_t b = vpaddlq_s32(a); 24 const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)), 25 vreinterpret_s32_s64(vget_high_s64(b))); 26 return vget_lane_s32(c, 0); 27 } 28 29 static INLINE int horizontal_add_s32x4(const int32x4_t v_32x4) { 30 const int64x2_t b = vpaddlq_s32(v_32x4); 31 const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)), 32 vreinterpret_s32_s64(vget_high_s64(b))); 33 return vget_lane_s32(c, 0); 34 } 35 36 // The variance helper functions use int16_t for sum. 8 values are accumulated 37 // and then added (at which point they expand up to int32_t). To avoid overflow, 38 // there can be no more than 32767 / 255 ~= 128 values accumulated in each 39 // column. For a 32x32 buffer, this results in 32 / 8 = 4 values per row * 32 40 // rows = 128. Asserts have been added to each function to warn against reaching 41 // this limit. 42 43 // Process a block of width 4 four rows at a time. 44 static void variance_neon_w4x4(const uint8_t *a, int a_stride, const uint8_t *b, 45 int b_stride, int h, uint32_t *sse, int *sum) { 46 int i; 47 int16x8_t sum_s16 = vdupq_n_s16(0); 48 int32x4_t sse_lo_s32 = vdupq_n_s32(0); 49 int32x4_t sse_hi_s32 = vdupq_n_s32(0); 50 51 // Since width is only 4, sum_s16 only loads a half row per loop. 52 assert(h <= 256); 53 54 for (i = 0; i < h; i += 4) { 55 const uint8x16_t a_u8 = load_unaligned_u8q(a, a_stride); 56 const uint8x16_t b_u8 = load_unaligned_u8q(b, b_stride); 57 const uint16x8_t diff_lo_u16 = 58 vsubl_u8(vget_low_u8(a_u8), vget_low_u8(b_u8)); 59 const uint16x8_t diff_hi_u16 = 60 vsubl_u8(vget_high_u8(a_u8), vget_high_u8(b_u8)); 61 62 const int16x8_t diff_lo_s16 = vreinterpretq_s16_u16(diff_lo_u16); 63 const int16x8_t diff_hi_s16 = vreinterpretq_s16_u16(diff_hi_u16); 64 65 sum_s16 = vaddq_s16(sum_s16, diff_lo_s16); 66 sum_s16 = vaddq_s16(sum_s16, diff_hi_s16); 67 68 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_low_s16(diff_lo_s16), 69 vget_low_s16(diff_lo_s16)); 70 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_high_s16(diff_lo_s16), 71 vget_high_s16(diff_lo_s16)); 72 73 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_low_s16(diff_hi_s16), 74 vget_low_s16(diff_hi_s16)); 75 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_high_s16(diff_hi_s16), 76 vget_high_s16(diff_hi_s16)); 77 78 a += 4 * a_stride; 79 b += 4 * b_stride; 80 } 81 82 *sum = horizontal_add_s16x8(sum_s16); 83 *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_lo_s32, sse_hi_s32)); 84 } 85 86 // Process a block of any size where the width is divisible by 16. 87 static void variance_neon_w16(const uint8_t *a, int a_stride, const uint8_t *b, 88 int b_stride, int w, int h, uint32_t *sse, 89 int *sum) { 90 int i, j; 91 int16x8_t sum_s16 = vdupq_n_s16(0); 92 int32x4_t sse_lo_s32 = vdupq_n_s32(0); 93 int32x4_t sse_hi_s32 = vdupq_n_s32(0); 94 95 // The loop loads 16 values at a time but doubles them up when accumulating 96 // into sum_s16. 97 assert(w / 8 * h <= 128); 98 99 for (i = 0; i < h; ++i) { 100 for (j = 0; j < w; j += 16) { 101 const uint8x16_t a_u8 = vld1q_u8(a + j); 102 const uint8x16_t b_u8 = vld1q_u8(b + j); 103 104 const uint16x8_t diff_lo_u16 = 105 vsubl_u8(vget_low_u8(a_u8), vget_low_u8(b_u8)); 106 const uint16x8_t diff_hi_u16 = 107 vsubl_u8(vget_high_u8(a_u8), vget_high_u8(b_u8)); 108 109 const int16x8_t diff_lo_s16 = vreinterpretq_s16_u16(diff_lo_u16); 110 const int16x8_t diff_hi_s16 = vreinterpretq_s16_u16(diff_hi_u16); 111 112 sum_s16 = vaddq_s16(sum_s16, diff_lo_s16); 113 sum_s16 = vaddq_s16(sum_s16, diff_hi_s16); 114 115 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_low_s16(diff_lo_s16), 116 vget_low_s16(diff_lo_s16)); 117 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_high_s16(diff_lo_s16), 118 vget_high_s16(diff_lo_s16)); 119 120 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_low_s16(diff_hi_s16), 121 vget_low_s16(diff_hi_s16)); 122 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_high_s16(diff_hi_s16), 123 vget_high_s16(diff_hi_s16)); 124 } 125 a += a_stride; 126 b += b_stride; 127 } 128 129 *sum = horizontal_add_s16x8(sum_s16); 130 *sse = (unsigned int)horizontal_add_s32x4(vaddq_s32(sse_lo_s32, sse_hi_s32)); 131 } 132 133 // Process a block of width 8 two rows at a time. 134 static void variance_neon_w8x2(const uint8_t *a, int a_stride, const uint8_t *b, 135 int b_stride, int h, uint32_t *sse, int *sum) { 136 int i = 0; 137 int16x8_t sum_s16 = vdupq_n_s16(0); 138 int32x4_t sse_lo_s32 = vdupq_n_s32(0); 139 int32x4_t sse_hi_s32 = vdupq_n_s32(0); 140 141 // Each column has it's own accumulator entry in sum_s16. 142 assert(h <= 128); 143 144 do { 145 const uint8x8_t a_0_u8 = vld1_u8(a); 146 const uint8x8_t a_1_u8 = vld1_u8(a + a_stride); 147 const uint8x8_t b_0_u8 = vld1_u8(b); 148 const uint8x8_t b_1_u8 = vld1_u8(b + b_stride); 149 const uint16x8_t diff_0_u16 = vsubl_u8(a_0_u8, b_0_u8); 150 const uint16x8_t diff_1_u16 = vsubl_u8(a_1_u8, b_1_u8); 151 const int16x8_t diff_0_s16 = vreinterpretq_s16_u16(diff_0_u16); 152 const int16x8_t diff_1_s16 = vreinterpretq_s16_u16(diff_1_u16); 153 sum_s16 = vaddq_s16(sum_s16, diff_0_s16); 154 sum_s16 = vaddq_s16(sum_s16, diff_1_s16); 155 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_low_s16(diff_0_s16), 156 vget_low_s16(diff_0_s16)); 157 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_low_s16(diff_1_s16), 158 vget_low_s16(diff_1_s16)); 159 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_high_s16(diff_0_s16), 160 vget_high_s16(diff_0_s16)); 161 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_high_s16(diff_1_s16), 162 vget_high_s16(diff_1_s16)); 163 a += a_stride + a_stride; 164 b += b_stride + b_stride; 165 i += 2; 166 } while (i < h); 167 168 *sum = horizontal_add_s16x8(sum_s16); 169 *sse = (uint32_t)horizontal_add_s32x4(vaddq_s32(sse_lo_s32, sse_hi_s32)); 170 } 171 172 void vpx_get8x8var_neon(const uint8_t *a, int a_stride, const uint8_t *b, 173 int b_stride, unsigned int *sse, int *sum) { 174 variance_neon_w8x2(a, a_stride, b, b_stride, 8, sse, sum); 175 } 176 177 void vpx_get16x16var_neon(const uint8_t *a, int a_stride, const uint8_t *b, 178 int b_stride, unsigned int *sse, int *sum) { 179 variance_neon_w16(a, a_stride, b, b_stride, 16, 16, sse, sum); 180 } 181 182 #define varianceNxM(n, m, shift) \ 183 unsigned int vpx_variance##n##x##m##_neon(const uint8_t *a, int a_stride, \ 184 const uint8_t *b, int b_stride, \ 185 unsigned int *sse) { \ 186 int sum; \ 187 if (n == 4) \ 188 variance_neon_w4x4(a, a_stride, b, b_stride, m, sse, &sum); \ 189 else if (n == 8) \ 190 variance_neon_w8x2(a, a_stride, b, b_stride, m, sse, &sum); \ 191 else \ 192 variance_neon_w16(a, a_stride, b, b_stride, n, m, sse, &sum); \ 193 if (n * m < 16 * 16) \ 194 return *sse - ((sum * sum) >> shift); \ 195 else \ 196 return *sse - (uint32_t)(((int64_t)sum * sum) >> shift); \ 197 } 198 199 varianceNxM(4, 4, 4); 200 varianceNxM(4, 8, 5); 201 varianceNxM(8, 4, 5); 202 varianceNxM(8, 8, 6); 203 varianceNxM(8, 16, 7); 204 varianceNxM(16, 8, 7); 205 varianceNxM(16, 16, 8); 206 varianceNxM(16, 32, 9); 207 varianceNxM(32, 16, 9); 208 varianceNxM(32, 32, 10); 209 210 unsigned int vpx_variance32x64_neon(const uint8_t *a, int a_stride, 211 const uint8_t *b, int b_stride, 212 unsigned int *sse) { 213 int sum1, sum2; 214 uint32_t sse1, sse2; 215 variance_neon_w16(a, a_stride, b, b_stride, 32, 32, &sse1, &sum1); 216 variance_neon_w16(a + (32 * a_stride), a_stride, b + (32 * b_stride), 217 b_stride, 32, 32, &sse2, &sum2); 218 *sse = sse1 + sse2; 219 sum1 += sum2; 220 return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11); 221 } 222 223 unsigned int vpx_variance64x32_neon(const uint8_t *a, int a_stride, 224 const uint8_t *b, int b_stride, 225 unsigned int *sse) { 226 int sum1, sum2; 227 uint32_t sse1, sse2; 228 variance_neon_w16(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1); 229 variance_neon_w16(a + (16 * a_stride), a_stride, b + (16 * b_stride), 230 b_stride, 64, 16, &sse2, &sum2); 231 *sse = sse1 + sse2; 232 sum1 += sum2; 233 return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11); 234 } 235 236 unsigned int vpx_variance64x64_neon(const uint8_t *a, int a_stride, 237 const uint8_t *b, int b_stride, 238 unsigned int *sse) { 239 int sum1, sum2; 240 uint32_t sse1, sse2; 241 242 variance_neon_w16(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1); 243 variance_neon_w16(a + (16 * a_stride), a_stride, b + (16 * b_stride), 244 b_stride, 64, 16, &sse2, &sum2); 245 sse1 += sse2; 246 sum1 += sum2; 247 248 variance_neon_w16(a + (16 * 2 * a_stride), a_stride, b + (16 * 2 * b_stride), 249 b_stride, 64, 16, &sse2, &sum2); 250 sse1 += sse2; 251 sum1 += sum2; 252 253 variance_neon_w16(a + (16 * 3 * a_stride), a_stride, b + (16 * 3 * b_stride), 254 b_stride, 64, 16, &sse2, &sum2); 255 *sse = sse1 + sse2; 256 sum1 += sum2; 257 return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 12); 258 } 259 260 unsigned int vpx_mse16x16_neon(const unsigned char *src_ptr, int source_stride, 261 const unsigned char *ref_ptr, int recon_stride, 262 unsigned int *sse) { 263 int i; 264 int16x4_t d22s16, d23s16, d24s16, d25s16, d26s16, d27s16, d28s16, d29s16; 265 int64x1_t d0s64; 266 uint8x16_t q0u8, q1u8, q2u8, q3u8; 267 int32x4_t q7s32, q8s32, q9s32, q10s32; 268 uint16x8_t q11u16, q12u16, q13u16, q14u16; 269 int64x2_t q1s64; 270 271 q7s32 = vdupq_n_s32(0); 272 q8s32 = vdupq_n_s32(0); 273 q9s32 = vdupq_n_s32(0); 274 q10s32 = vdupq_n_s32(0); 275 276 for (i = 0; i < 8; i++) { // mse16x16_neon_loop 277 q0u8 = vld1q_u8(src_ptr); 278 src_ptr += source_stride; 279 q1u8 = vld1q_u8(src_ptr); 280 src_ptr += source_stride; 281 q2u8 = vld1q_u8(ref_ptr); 282 ref_ptr += recon_stride; 283 q3u8 = vld1q_u8(ref_ptr); 284 ref_ptr += recon_stride; 285 286 q11u16 = vsubl_u8(vget_low_u8(q0u8), vget_low_u8(q2u8)); 287 q12u16 = vsubl_u8(vget_high_u8(q0u8), vget_high_u8(q2u8)); 288 q13u16 = vsubl_u8(vget_low_u8(q1u8), vget_low_u8(q3u8)); 289 q14u16 = vsubl_u8(vget_high_u8(q1u8), vget_high_u8(q3u8)); 290 291 d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16)); 292 d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16)); 293 q7s32 = vmlal_s16(q7s32, d22s16, d22s16); 294 q8s32 = vmlal_s16(q8s32, d23s16, d23s16); 295 296 d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16)); 297 d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16)); 298 q9s32 = vmlal_s16(q9s32, d24s16, d24s16); 299 q10s32 = vmlal_s16(q10s32, d25s16, d25s16); 300 301 d26s16 = vreinterpret_s16_u16(vget_low_u16(q13u16)); 302 d27s16 = vreinterpret_s16_u16(vget_high_u16(q13u16)); 303 q7s32 = vmlal_s16(q7s32, d26s16, d26s16); 304 q8s32 = vmlal_s16(q8s32, d27s16, d27s16); 305 306 d28s16 = vreinterpret_s16_u16(vget_low_u16(q14u16)); 307 d29s16 = vreinterpret_s16_u16(vget_high_u16(q14u16)); 308 q9s32 = vmlal_s16(q9s32, d28s16, d28s16); 309 q10s32 = vmlal_s16(q10s32, d29s16, d29s16); 310 } 311 312 q7s32 = vaddq_s32(q7s32, q8s32); 313 q9s32 = vaddq_s32(q9s32, q10s32); 314 q10s32 = vaddq_s32(q7s32, q9s32); 315 316 q1s64 = vpaddlq_s32(q10s32); 317 d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64)); 318 319 vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d0s64), 0); 320 return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0); 321 } 322 323 unsigned int vpx_get4x4sse_cs_neon(const unsigned char *src_ptr, 324 int source_stride, 325 const unsigned char *ref_ptr, 326 int recon_stride) { 327 int16x4_t d22s16, d24s16, d26s16, d28s16; 328 int64x1_t d0s64; 329 uint8x8_t d0u8, d1u8, d2u8, d3u8, d4u8, d5u8, d6u8, d7u8; 330 int32x4_t q7s32, q8s32, q9s32, q10s32; 331 uint16x8_t q11u16, q12u16, q13u16, q14u16; 332 int64x2_t q1s64; 333 334 d0u8 = vld1_u8(src_ptr); 335 src_ptr += source_stride; 336 d4u8 = vld1_u8(ref_ptr); 337 ref_ptr += recon_stride; 338 d1u8 = vld1_u8(src_ptr); 339 src_ptr += source_stride; 340 d5u8 = vld1_u8(ref_ptr); 341 ref_ptr += recon_stride; 342 d2u8 = vld1_u8(src_ptr); 343 src_ptr += source_stride; 344 d6u8 = vld1_u8(ref_ptr); 345 ref_ptr += recon_stride; 346 d3u8 = vld1_u8(src_ptr); 347 src_ptr += source_stride; 348 d7u8 = vld1_u8(ref_ptr); 349 ref_ptr += recon_stride; 350 351 q11u16 = vsubl_u8(d0u8, d4u8); 352 q12u16 = vsubl_u8(d1u8, d5u8); 353 q13u16 = vsubl_u8(d2u8, d6u8); 354 q14u16 = vsubl_u8(d3u8, d7u8); 355 356 d22s16 = vget_low_s16(vreinterpretq_s16_u16(q11u16)); 357 d24s16 = vget_low_s16(vreinterpretq_s16_u16(q12u16)); 358 d26s16 = vget_low_s16(vreinterpretq_s16_u16(q13u16)); 359 d28s16 = vget_low_s16(vreinterpretq_s16_u16(q14u16)); 360 361 q7s32 = vmull_s16(d22s16, d22s16); 362 q8s32 = vmull_s16(d24s16, d24s16); 363 q9s32 = vmull_s16(d26s16, d26s16); 364 q10s32 = vmull_s16(d28s16, d28s16); 365 366 q7s32 = vaddq_s32(q7s32, q8s32); 367 q9s32 = vaddq_s32(q9s32, q10s32); 368 q9s32 = vaddq_s32(q7s32, q9s32); 369 370 q1s64 = vpaddlq_s32(q9s32); 371 d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64)); 372 373 return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0); 374 } 375