Home | History | Annotate | Download | only in test
      1 /*
      2  *  Copyright (c) 2016 The WebM project authors. All Rights Reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include <limits>
     12 
     13 #include "third_party/googletest/src/include/gtest/gtest.h"
     14 
     15 #include "./vp9_rtcd.h"
     16 #include "test/acm_random.h"
     17 #include "test/buffer.h"
     18 #include "test/register_state_check.h"
     19 #include "vpx_ports/vpx_timer.h"
     20 
     21 namespace {
     22 
     23 using ::libvpx_test::ACMRandom;
     24 using ::libvpx_test::Buffer;
     25 
     26 typedef void (*TemporalFilterFunc)(const uint8_t *a, unsigned int stride,
     27                                    const uint8_t *b, unsigned int w,
     28                                    unsigned int h, int filter_strength,
     29                                    int filter_weight, unsigned int *accumulator,
     30                                    uint16_t *count);
     31 
     32 // Calculate the difference between 'a' and 'b', sum in blocks of 9, and apply
     33 // filter based on strength and weight. Store the resulting filter amount in
     34 // 'count' and apply it to 'b' and store it in 'accumulator'.
     35 void reference_filter(const Buffer<uint8_t> &a, const Buffer<uint8_t> &b, int w,
     36                       int h, int filter_strength, int filter_weight,
     37                       Buffer<unsigned int> *accumulator,
     38                       Buffer<uint16_t> *count) {
     39   Buffer<int> diff_sq = Buffer<int>(w, h, 0);
     40   ASSERT_TRUE(diff_sq.Init());
     41   diff_sq.Set(0);
     42 
     43   int rounding = 0;
     44   if (filter_strength > 0) {
     45     rounding = 1 << (filter_strength - 1);
     46   }
     47 
     48   // Calculate all the differences. Avoids re-calculating a bunch of extra
     49   // values.
     50   for (int height = 0; height < h; ++height) {
     51     for (int width = 0; width < w; ++width) {
     52       int diff = a.TopLeftPixel()[height * a.stride() + width] -
     53                  b.TopLeftPixel()[height * b.stride() + width];
     54       diff_sq.TopLeftPixel()[height * diff_sq.stride() + width] = diff * diff;
     55     }
     56   }
     57 
     58   // For any given point, sum the neighboring values and calculate the
     59   // modifier.
     60   for (int height = 0; height < h; ++height) {
     61     for (int width = 0; width < w; ++width) {
     62       // Determine how many values are being summed.
     63       int summed_values = 9;
     64 
     65       if (height == 0 || height == (h - 1)) {
     66         summed_values -= 3;
     67       }
     68 
     69       if (width == 0 || width == (w - 1)) {
     70         if (summed_values == 6) {  // corner
     71           summed_values -= 2;
     72         } else {
     73           summed_values -= 3;
     74         }
     75       }
     76 
     77       // Sum the diff_sq of the surrounding values.
     78       int sum = 0;
     79       for (int idy = -1; idy <= 1; ++idy) {
     80         for (int idx = -1; idx <= 1; ++idx) {
     81           const int y = height + idy;
     82           const int x = width + idx;
     83 
     84           // If inside the border.
     85           if (y >= 0 && y < h && x >= 0 && x < w) {
     86             sum += diff_sq.TopLeftPixel()[y * diff_sq.stride() + x];
     87           }
     88         }
     89       }
     90 
     91       sum *= 3;
     92       sum /= summed_values;
     93       sum += rounding;
     94       sum >>= filter_strength;
     95 
     96       // Clamp the value and invert it.
     97       if (sum > 16) sum = 16;
     98       sum = 16 - sum;
     99 
    100       sum *= filter_weight;
    101 
    102       count->TopLeftPixel()[height * count->stride() + width] += sum;
    103       accumulator->TopLeftPixel()[height * accumulator->stride() + width] +=
    104           sum * b.TopLeftPixel()[height * b.stride() + width];
    105     }
    106   }
    107 }
    108 
    109 class TemporalFilterTest : public ::testing::TestWithParam<TemporalFilterFunc> {
    110  public:
    111   virtual void SetUp() {
    112     filter_func_ = GetParam();
    113     rnd_.Reset(ACMRandom::DeterministicSeed());
    114   }
    115 
    116  protected:
    117   TemporalFilterFunc filter_func_;
    118   ACMRandom rnd_;
    119 };
    120 
    121 TEST_P(TemporalFilterTest, SizeCombinations) {
    122   // Depending on subsampling this function may be called with values of 8 or 16
    123   // for width and height, in any combination.
    124   Buffer<uint8_t> a = Buffer<uint8_t>(16, 16, 8);
    125   ASSERT_TRUE(a.Init());
    126 
    127   const int filter_weight = 2;
    128   const int filter_strength = 6;
    129 
    130   for (int width = 8; width <= 16; width += 8) {
    131     for (int height = 8; height <= 16; height += 8) {
    132       // The second buffer must not have any border.
    133       Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0);
    134       ASSERT_TRUE(b.Init());
    135       Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0);
    136       ASSERT_TRUE(accum_ref.Init());
    137       Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0);
    138       ASSERT_TRUE(accum_chk.Init());
    139       Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0);
    140       ASSERT_TRUE(count_ref.Init());
    141       Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0);
    142       ASSERT_TRUE(count_chk.Init());
    143 
    144       // The difference between the buffers must be small to pass the threshold
    145       // to apply the filter.
    146       a.Set(&rnd_, 0, 7);
    147       b.Set(&rnd_, 0, 7);
    148 
    149       accum_ref.Set(rnd_.Rand8());
    150       accum_chk.CopyFrom(accum_ref);
    151       count_ref.Set(rnd_.Rand8());
    152       count_chk.CopyFrom(count_ref);
    153       reference_filter(a, b, width, height, filter_strength, filter_weight,
    154                        &accum_ref, &count_ref);
    155       ASM_REGISTER_STATE_CHECK(
    156           filter_func_(a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width,
    157                        height, filter_strength, filter_weight,
    158                        accum_chk.TopLeftPixel(), count_chk.TopLeftPixel()));
    159       EXPECT_TRUE(accum_chk.CheckValues(accum_ref));
    160       EXPECT_TRUE(count_chk.CheckValues(count_ref));
    161       if (HasFailure()) {
    162         printf("Width: %d Height: %d\n", width, height);
    163         count_chk.PrintDifference(count_ref);
    164         accum_chk.PrintDifference(accum_ref);
    165         return;
    166       }
    167     }
    168   }
    169 }
    170 
    171 TEST_P(TemporalFilterTest, CompareReferenceRandom) {
    172   for (int width = 8; width <= 16; width += 8) {
    173     for (int height = 8; height <= 16; height += 8) {
    174       Buffer<uint8_t> a = Buffer<uint8_t>(width, height, 8);
    175       ASSERT_TRUE(a.Init());
    176       // The second buffer must not have any border.
    177       Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0);
    178       ASSERT_TRUE(b.Init());
    179       Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0);
    180       ASSERT_TRUE(accum_ref.Init());
    181       Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0);
    182       ASSERT_TRUE(accum_chk.Init());
    183       Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0);
    184       ASSERT_TRUE(count_ref.Init());
    185       Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0);
    186       ASSERT_TRUE(count_chk.Init());
    187 
    188       for (int filter_strength = 0; filter_strength <= 6; ++filter_strength) {
    189         for (int filter_weight = 0; filter_weight <= 2; ++filter_weight) {
    190           for (int repeat = 0; repeat < 100; ++repeat) {
    191             if (repeat < 50) {
    192               a.Set(&rnd_, 0, 7);
    193               b.Set(&rnd_, 0, 7);
    194             } else {
    195               // Check large (but close) values as well.
    196               a.Set(&rnd_, std::numeric_limits<uint8_t>::max() - 7,
    197                     std::numeric_limits<uint8_t>::max());
    198               b.Set(&rnd_, std::numeric_limits<uint8_t>::max() - 7,
    199                     std::numeric_limits<uint8_t>::max());
    200             }
    201 
    202             accum_ref.Set(rnd_.Rand8());
    203             accum_chk.CopyFrom(accum_ref);
    204             count_ref.Set(rnd_.Rand8());
    205             count_chk.CopyFrom(count_ref);
    206             reference_filter(a, b, width, height, filter_strength,
    207                              filter_weight, &accum_ref, &count_ref);
    208             ASM_REGISTER_STATE_CHECK(filter_func_(
    209                 a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width, height,
    210                 filter_strength, filter_weight, accum_chk.TopLeftPixel(),
    211                 count_chk.TopLeftPixel()));
    212             EXPECT_TRUE(accum_chk.CheckValues(accum_ref));
    213             EXPECT_TRUE(count_chk.CheckValues(count_ref));
    214             if (HasFailure()) {
    215               printf("Weight: %d Strength: %d\n", filter_weight,
    216                      filter_strength);
    217               count_chk.PrintDifference(count_ref);
    218               accum_chk.PrintDifference(accum_ref);
    219               return;
    220             }
    221           }
    222         }
    223       }
    224     }
    225   }
    226 }
    227 
    228 TEST_P(TemporalFilterTest, DISABLED_Speed) {
    229   Buffer<uint8_t> a = Buffer<uint8_t>(16, 16, 8);
    230   ASSERT_TRUE(a.Init());
    231 
    232   const int filter_weight = 2;
    233   const int filter_strength = 6;
    234 
    235   for (int width = 8; width <= 16; width += 8) {
    236     for (int height = 8; height <= 16; height += 8) {
    237       // The second buffer must not have any border.
    238       Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0);
    239       ASSERT_TRUE(b.Init());
    240       Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0);
    241       ASSERT_TRUE(accum_ref.Init());
    242       Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0);
    243       ASSERT_TRUE(accum_chk.Init());
    244       Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0);
    245       ASSERT_TRUE(count_ref.Init());
    246       Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0);
    247       ASSERT_TRUE(count_chk.Init());
    248 
    249       a.Set(&rnd_, 0, 7);
    250       b.Set(&rnd_, 0, 7);
    251 
    252       accum_chk.Set(0);
    253       count_chk.Set(0);
    254 
    255       vpx_usec_timer timer;
    256       vpx_usec_timer_start(&timer);
    257       for (int i = 0; i < 10000; ++i) {
    258         filter_func_(a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width,
    259                      height, filter_strength, filter_weight,
    260                      accum_chk.TopLeftPixel(), count_chk.TopLeftPixel());
    261       }
    262       vpx_usec_timer_mark(&timer);
    263       const int elapsed_time = static_cast<int>(vpx_usec_timer_elapsed(&timer));
    264       printf("Temporal filter %dx%d time: %5d us\n", width, height,
    265              elapsed_time);
    266     }
    267   }
    268 }
    269 
    270 INSTANTIATE_TEST_CASE_P(C, TemporalFilterTest,
    271                         ::testing::Values(&vp9_temporal_filter_apply_c));
    272 
    273 #if HAVE_SSE4_1
    274 INSTANTIATE_TEST_CASE_P(SSE4_1, TemporalFilterTest,
    275                         ::testing::Values(&vp9_temporal_filter_apply_sse4_1));
    276 #endif  // HAVE_SSE4_1
    277 }  // namespace
    278