Home | History | Annotate | Download | only in test
      1 #define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
      2 
      3 #include "test.h"
      4 
      5 #include "../internal/fixedpoint.h"
      6 
      7 using namespace gemmlowp;
      8 
      9 template <int tIntegerBits>
     10 void test_convert(FixedPoint<int32_t, tIntegerBits> x) {
     11   typedef FixedPoint<int32_t, tIntegerBits> F;
     12   F y = ToFixedPoint<int32_t, tIntegerBits>(ToDouble(x));
     13   Check(y == x);
     14 }
     15 
     16 template <int tIntegerBits_a, int tIntegerBits_b>
     17 void test_Rescale(FixedPoint<int32_t, tIntegerBits_a> a) {
     18   FixedPoint<int32_t, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
     19   FixedPoint<int32_t, tIntegerBits_b> expected =
     20       ToFixedPoint<int32_t, tIntegerBits_b>(ToDouble(a));
     21   Check(actual == expected);
     22 }
     23 
     24 template <int tIntegerBits_a, int tIntegerBits_b>
     25 void test_Rescale(const std::vector<int32_t>& testvals_int32) {
     26   for (auto a : testvals_int32) {
     27     FixedPoint<int32_t, tIntegerBits_a> aq;
     28     aq.raw() = a;
     29     test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
     30   }
     31 }
     32 
     33 template <int tIntegerBits_a, int tIntegerBits_b>
     34 void test_mul(FixedPoint<int32_t, tIntegerBits_a> a,
     35               FixedPoint<int32_t, tIntegerBits_b> b) {
     36   static const int IntegerBits_ab = tIntegerBits_a + tIntegerBits_b;
     37   FixedPoint<int32_t, IntegerBits_ab> ab;
     38   ab = a * b;
     39   double a_double = ToDouble(a);
     40   double b_double = ToDouble(b);
     41   double ab_double = a_double * b_double;
     42   FixedPoint<int32_t, IntegerBits_ab> expected =
     43       ToFixedPoint<int32_t, IntegerBits_ab>(ab_double);
     44   int64_t diff = int64_t(ab.raw()) - int64_t(expected.raw());
     45   Check(std::abs(diff) <= 1);
     46 }
     47 
     48 template <int tIntegerBits_a, int tIntegerBits_b>
     49 void test_mul(const std::vector<int32_t>& testvals_int32) {
     50   for (auto a : testvals_int32) {
     51     for (auto b : testvals_int32) {
     52       FixedPoint<int32_t, tIntegerBits_a> aq;
     53       FixedPoint<int32_t, tIntegerBits_b> bq;
     54       aq.raw() = a;
     55       bq.raw() = b;
     56       test_mul(aq, bq);
     57     }
     58   }
     59 }
     60 
     61 template <int tExponent, int tIntegerBits_a>
     62 void test_ExactMulByPot(FixedPoint<int32_t, tIntegerBits_a> a) {
     63   double x = ToDouble(a) * std::pow(2.0, tExponent);
     64   double y = ToDouble(ExactMulByPot<tExponent>(a));
     65   Check(x == y);
     66 }
     67 
     68 template <int tExponent, int tIntegerBits_a>
     69 void test_ExactMulByPot(const std::vector<int32_t>& testvals_int32) {
     70   for (auto a : testvals_int32) {
     71     FixedPoint<int32_t, tIntegerBits_a> aq;
     72     aq.raw() = a;
     73     test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
     74   }
     75 }
     76 
     77 void test_exp_on_interval_between_negative_one_quarter_and_0_excl(
     78     FixedPoint<int32_t, 0> a) {
     79   double a_double = ToDouble(a);
     80   double expected = std::exp(a_double);
     81   double actual =
     82       ToDouble(exp_on_interval_between_negative_one_quarter_and_0_excl(a));
     83   double error = expected - actual;
     84   Check(std::abs(error) < 3e-7);
     85 }
     86 
     87 void test_exp_on_interval_between_negative_one_quarter_and_0_excl(
     88     const std::vector<int32_t>& testvals_int32) {
     89   for (auto a : testvals_int32) {
     90     typedef FixedPoint<int32_t, 0> F;
     91     F aq = SaturatingRoundingMultiplyByPOT<-3>(F::FromRaw(a)) -
     92            F::ConstantPOT<-3>();
     93     test_exp_on_interval_between_negative_one_quarter_and_0_excl(aq);
     94   }
     95 }
     96 
     97 template <int tIntegerBits>
     98 void test_exp_on_negative_values(FixedPoint<int32_t, tIntegerBits> a) {
     99   double a_double = ToDouble(a);
    100   double expected = std::exp(a_double);
    101   double actual = ToDouble(exp_on_negative_values(a));
    102   double error = expected - actual;
    103   Check(std::abs(error) < 3e-7);
    104 }
    105 
    106 template <int tIntegerBits>
    107 void test_exp_on_negative_values(const std::vector<int32_t>& testvals_int32) {
    108   for (auto a : testvals_int32) {
    109     if (a < 0) {
    110       FixedPoint<int32_t, tIntegerBits> aq;
    111       aq.raw() = a;
    112       test_exp_on_negative_values(aq);
    113     }
    114   }
    115 }
    116 
    117 void test_one_minus_x_over_one_plus_x_for_x_in_0_1(FixedPoint<int32_t, 0> a) {
    118   double a_double = ToDouble(a);
    119   double expected = (1 - a_double) / (1 + a_double);
    120   FixedPoint<int32_t, 0> retval = one_minus_x_over_one_plus_x_for_x_in_0_1(a);
    121   double actual = ToDouble(retval);
    122   double error = expected - actual;
    123   Check(std::abs(error) < 6e-9);
    124 }
    125 
    126 void test_one_minus_x_over_one_plus_x_for_x_in_0_1(
    127     const std::vector<int32_t>& testvals_int32) {
    128   for (auto a : testvals_int32) {
    129     if (a > 0) {
    130       FixedPoint<int32_t, 0> aq;
    131       aq.raw() = a;
    132       test_one_minus_x_over_one_plus_x_for_x_in_0_1(aq);
    133     }
    134   }
    135 }
    136 
    137 template <int tIntegerBits>
    138 void test_tanh(FixedPoint<int32_t, tIntegerBits> a) {
    139   double a_double = ToDouble(a);
    140   double expected = std::tanh(a_double);
    141   double actual = ToDouble(tanh(a));
    142   double error = expected - actual;
    143   Check(std::abs(error) < 1.5e-7);
    144 }
    145 
    146 template <int tIntegerBits>
    147 void test_tanh(const std::vector<int32_t>& testvals_int32) {
    148   for (auto a : testvals_int32) {
    149     FixedPoint<int32_t, tIntegerBits> aq;
    150     aq.raw() = a;
    151     test_tanh(aq);
    152   }
    153 }
    154 
    155 #ifdef GEMMLOWP_NEON
    156 void test_int32x4(const std::vector<int32_t>& testvals_int32) {
    157   size_t n = testvals_int32.size();
    158   size_t n4 = n - (n % 4);
    159   std::vector<int32_t> results_int32(n4);
    160   std::vector<int32_t> results_int32x4(n4);
    161 
    162   for (size_t i = 0; i < n4; i++) {
    163     results_int32[i] =
    164         tanh(FixedPoint<int32_t, 4>::FromRaw(testvals_int32[i])).raw();
    165   }
    166   for (size_t i = 0; i < n4; i++) {
    167     vst1q_s32(
    168         &results_int32x4[i],
    169         tanh(FixedPoint<int32x4_t, 4>::FromRaw(vld1q_s32(&testvals_int32[i])))
    170             .raw());
    171   }
    172 
    173   for (size_t i = 0; i < n4; i++) {
    174     Check(results_int32[i] == results_int32x4[i]);
    175   }
    176 }
    177 #endif  // GEMMLOWP_NEON
    178 
    179 int main() {
    180   std::vector<int32_t> testvals_int32;
    181 
    182   for (int i = 0; i < 31; i++) {
    183     testvals_int32.push_back((1 << i) - 2);
    184     testvals_int32.push_back((1 << i) - 1);
    185     testvals_int32.push_back((1 << i));
    186     testvals_int32.push_back((1 << i) + 1);
    187     testvals_int32.push_back((1 << i) + 2);
    188     testvals_int32.push_back(-(1 << i) - 2);
    189     testvals_int32.push_back(-(1 << i) - 1);
    190     testvals_int32.push_back(-(1 << i));
    191     testvals_int32.push_back(-(1 << i) + 1);
    192     testvals_int32.push_back(-(1 << i) + 2);
    193   }
    194   testvals_int32.push_back(std::numeric_limits<int32_t>::min());
    195   testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 1);
    196   testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 2);
    197   testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 2);
    198   testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 1);
    199   testvals_int32.push_back(std::numeric_limits<int32_t>::max());
    200 
    201   uint32_t random = 1;
    202   for (int i = 0; i < 1000; i++) {
    203     random = random * 1664525 + 1013904223;
    204     testvals_int32.push_back(static_cast<int32_t>(random));
    205   }
    206 
    207   std::sort(testvals_int32.begin(), testvals_int32.end());
    208 
    209   for (auto a : testvals_int32) {
    210     FixedPoint<int32_t, 4> x;
    211     x.raw() = a;
    212     test_convert(x);
    213   }
    214 
    215   test_mul<0, 0>(testvals_int32);
    216   test_mul<0, 1>(testvals_int32);
    217   test_mul<2, 0>(testvals_int32);
    218   test_mul<1, 1>(testvals_int32);
    219   test_mul<4, 4>(testvals_int32);
    220   test_mul<3, 5>(testvals_int32);
    221   test_mul<7, 2>(testvals_int32);
    222   test_mul<14, 15>(testvals_int32);
    223 
    224   test_Rescale<0, 0>(testvals_int32);
    225   test_Rescale<0, 1>(testvals_int32);
    226   test_Rescale<2, 0>(testvals_int32);
    227   test_Rescale<4, 4>(testvals_int32);
    228   test_Rescale<4, 5>(testvals_int32);
    229   test_Rescale<6, 3>(testvals_int32);
    230   test_Rescale<13, 9>(testvals_int32);
    231 
    232   test_ExactMulByPot<0, 0>(testvals_int32);
    233   test_ExactMulByPot<0, 4>(testvals_int32);
    234   test_ExactMulByPot<1, 4>(testvals_int32);
    235   test_ExactMulByPot<3, 2>(testvals_int32);
    236   test_ExactMulByPot<-4, 5>(testvals_int32);
    237   test_ExactMulByPot<-2, 6>(testvals_int32);
    238 
    239   test_exp_on_interval_between_negative_one_quarter_and_0_excl(testvals_int32);
    240 
    241   test_exp_on_negative_values<1>(testvals_int32);
    242   test_exp_on_negative_values<2>(testvals_int32);
    243   test_exp_on_negative_values<3>(testvals_int32);
    244   test_exp_on_negative_values<4>(testvals_int32);
    245   test_exp_on_negative_values<5>(testvals_int32);
    246   test_exp_on_negative_values<6>(testvals_int32);
    247 
    248   test_one_minus_x_over_one_plus_x_for_x_in_0_1(testvals_int32);
    249 
    250   test_tanh<1>(testvals_int32);
    251   test_tanh<2>(testvals_int32);
    252   test_tanh<3>(testvals_int32);
    253   test_tanh<4>(testvals_int32);
    254   test_tanh<5>(testvals_int32);
    255   test_tanh<6>(testvals_int32);
    256 
    257 #ifdef GEMMLOWP_NEON
    258   test_int32x4(testvals_int32);
    259 #endif  // GEMMLOWP_NEON
    260 
    261   std::cerr << "All tests passed." << std::endl;
    262 }
    263