1 #define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS 2 3 #include "test.h" 4 5 #include "../internal/fixedpoint.h" 6 7 using namespace gemmlowp; 8 9 template <int tIntegerBits> 10 void test_convert(FixedPoint<int32_t, tIntegerBits> x) { 11 typedef FixedPoint<int32_t, tIntegerBits> F; 12 F y = ToFixedPoint<int32_t, tIntegerBits>(ToDouble(x)); 13 Check(y == x); 14 } 15 16 template <int tIntegerBits_a, int tIntegerBits_b> 17 void test_Rescale(FixedPoint<int32_t, tIntegerBits_a> a) { 18 FixedPoint<int32_t, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a); 19 FixedPoint<int32_t, tIntegerBits_b> expected = 20 ToFixedPoint<int32_t, tIntegerBits_b>(ToDouble(a)); 21 Check(actual == expected); 22 } 23 24 template <int tIntegerBits_a, int tIntegerBits_b> 25 void test_Rescale(const std::vector<int32_t>& testvals_int32) { 26 for (auto a : testvals_int32) { 27 FixedPoint<int32_t, tIntegerBits_a> aq; 28 aq.raw() = a; 29 test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq); 30 } 31 } 32 33 template <int tIntegerBits_a, int tIntegerBits_b> 34 void test_mul(FixedPoint<int32_t, tIntegerBits_a> a, 35 FixedPoint<int32_t, tIntegerBits_b> b) { 36 static const int IntegerBits_ab = tIntegerBits_a + tIntegerBits_b; 37 FixedPoint<int32_t, IntegerBits_ab> ab; 38 ab = a * b; 39 double a_double = ToDouble(a); 40 double b_double = ToDouble(b); 41 double ab_double = a_double * b_double; 42 FixedPoint<int32_t, IntegerBits_ab> expected = 43 ToFixedPoint<int32_t, IntegerBits_ab>(ab_double); 44 int64_t diff = int64_t(ab.raw()) - int64_t(expected.raw()); 45 Check(std::abs(diff) <= 1); 46 } 47 48 template <int tIntegerBits_a, int tIntegerBits_b> 49 void test_mul(const std::vector<int32_t>& testvals_int32) { 50 for (auto a : testvals_int32) { 51 for (auto b : testvals_int32) { 52 FixedPoint<int32_t, tIntegerBits_a> aq; 53 FixedPoint<int32_t, tIntegerBits_b> bq; 54 aq.raw() = a; 55 bq.raw() = b; 56 test_mul(aq, bq); 57 } 58 } 59 } 60 61 template <int tExponent, int tIntegerBits_a> 62 void test_ExactMulByPot(FixedPoint<int32_t, tIntegerBits_a> a) { 63 double x = ToDouble(a) * std::pow(2.0, tExponent); 64 double y = ToDouble(ExactMulByPot<tExponent>(a)); 65 Check(x == y); 66 } 67 68 template <int tExponent, int tIntegerBits_a> 69 void test_ExactMulByPot(const std::vector<int32_t>& testvals_int32) { 70 for (auto a : testvals_int32) { 71 FixedPoint<int32_t, tIntegerBits_a> aq; 72 aq.raw() = a; 73 test_ExactMulByPot<tExponent, tIntegerBits_a>(aq); 74 } 75 } 76 77 void test_exp_on_interval_between_negative_one_quarter_and_0_excl( 78 FixedPoint<int32_t, 0> a) { 79 double a_double = ToDouble(a); 80 double expected = std::exp(a_double); 81 double actual = 82 ToDouble(exp_on_interval_between_negative_one_quarter_and_0_excl(a)); 83 double error = expected - actual; 84 Check(std::abs(error) < 3e-7); 85 } 86 87 void test_exp_on_interval_between_negative_one_quarter_and_0_excl( 88 const std::vector<int32_t>& testvals_int32) { 89 for (auto a : testvals_int32) { 90 typedef FixedPoint<int32_t, 0> F; 91 F aq = SaturatingRoundingMultiplyByPOT<-3>(F::FromRaw(a)) - 92 F::ConstantPOT<-3>(); 93 test_exp_on_interval_between_negative_one_quarter_and_0_excl(aq); 94 } 95 } 96 97 template <int tIntegerBits> 98 void test_exp_on_negative_values(FixedPoint<int32_t, tIntegerBits> a) { 99 double a_double = ToDouble(a); 100 double expected = std::exp(a_double); 101 double actual = ToDouble(exp_on_negative_values(a)); 102 double error = expected - actual; 103 Check(std::abs(error) < 3e-7); 104 } 105 106 template <int tIntegerBits> 107 void test_exp_on_negative_values(const std::vector<int32_t>& testvals_int32) { 108 for (auto a : testvals_int32) { 109 if (a < 0) { 110 FixedPoint<int32_t, tIntegerBits> aq; 111 aq.raw() = a; 112 test_exp_on_negative_values(aq); 113 } 114 } 115 } 116 117 void test_one_minus_x_over_one_plus_x_for_x_in_0_1(FixedPoint<int32_t, 0> a) { 118 double a_double = ToDouble(a); 119 double expected = (1 - a_double) / (1 + a_double); 120 FixedPoint<int32_t, 0> retval = one_minus_x_over_one_plus_x_for_x_in_0_1(a); 121 double actual = ToDouble(retval); 122 double error = expected - actual; 123 Check(std::abs(error) < 6e-9); 124 } 125 126 void test_one_minus_x_over_one_plus_x_for_x_in_0_1( 127 const std::vector<int32_t>& testvals_int32) { 128 for (auto a : testvals_int32) { 129 if (a > 0) { 130 FixedPoint<int32_t, 0> aq; 131 aq.raw() = a; 132 test_one_minus_x_over_one_plus_x_for_x_in_0_1(aq); 133 } 134 } 135 } 136 137 template <int tIntegerBits> 138 void test_tanh(FixedPoint<int32_t, tIntegerBits> a) { 139 double a_double = ToDouble(a); 140 double expected = std::tanh(a_double); 141 double actual = ToDouble(tanh(a)); 142 double error = expected - actual; 143 Check(std::abs(error) < 1.5e-7); 144 } 145 146 template <int tIntegerBits> 147 void test_tanh(const std::vector<int32_t>& testvals_int32) { 148 for (auto a : testvals_int32) { 149 FixedPoint<int32_t, tIntegerBits> aq; 150 aq.raw() = a; 151 test_tanh(aq); 152 } 153 } 154 155 #ifdef GEMMLOWP_NEON 156 void test_int32x4(const std::vector<int32_t>& testvals_int32) { 157 size_t n = testvals_int32.size(); 158 size_t n4 = n - (n % 4); 159 std::vector<int32_t> results_int32(n4); 160 std::vector<int32_t> results_int32x4(n4); 161 162 for (size_t i = 0; i < n4; i++) { 163 results_int32[i] = 164 tanh(FixedPoint<int32_t, 4>::FromRaw(testvals_int32[i])).raw(); 165 } 166 for (size_t i = 0; i < n4; i++) { 167 vst1q_s32( 168 &results_int32x4[i], 169 tanh(FixedPoint<int32x4_t, 4>::FromRaw(vld1q_s32(&testvals_int32[i]))) 170 .raw()); 171 } 172 173 for (size_t i = 0; i < n4; i++) { 174 Check(results_int32[i] == results_int32x4[i]); 175 } 176 } 177 #endif // GEMMLOWP_NEON 178 179 int main() { 180 std::vector<int32_t> testvals_int32; 181 182 for (int i = 0; i < 31; i++) { 183 testvals_int32.push_back((1 << i) - 2); 184 testvals_int32.push_back((1 << i) - 1); 185 testvals_int32.push_back((1 << i)); 186 testvals_int32.push_back((1 << i) + 1); 187 testvals_int32.push_back((1 << i) + 2); 188 testvals_int32.push_back(-(1 << i) - 2); 189 testvals_int32.push_back(-(1 << i) - 1); 190 testvals_int32.push_back(-(1 << i)); 191 testvals_int32.push_back(-(1 << i) + 1); 192 testvals_int32.push_back(-(1 << i) + 2); 193 } 194 testvals_int32.push_back(std::numeric_limits<int32_t>::min()); 195 testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 1); 196 testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 2); 197 testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 2); 198 testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 1); 199 testvals_int32.push_back(std::numeric_limits<int32_t>::max()); 200 201 uint32_t random = 1; 202 for (int i = 0; i < 1000; i++) { 203 random = random * 1664525 + 1013904223; 204 testvals_int32.push_back(static_cast<int32_t>(random)); 205 } 206 207 std::sort(testvals_int32.begin(), testvals_int32.end()); 208 209 for (auto a : testvals_int32) { 210 FixedPoint<int32_t, 4> x; 211 x.raw() = a; 212 test_convert(x); 213 } 214 215 test_mul<0, 0>(testvals_int32); 216 test_mul<0, 1>(testvals_int32); 217 test_mul<2, 0>(testvals_int32); 218 test_mul<1, 1>(testvals_int32); 219 test_mul<4, 4>(testvals_int32); 220 test_mul<3, 5>(testvals_int32); 221 test_mul<7, 2>(testvals_int32); 222 test_mul<14, 15>(testvals_int32); 223 224 test_Rescale<0, 0>(testvals_int32); 225 test_Rescale<0, 1>(testvals_int32); 226 test_Rescale<2, 0>(testvals_int32); 227 test_Rescale<4, 4>(testvals_int32); 228 test_Rescale<4, 5>(testvals_int32); 229 test_Rescale<6, 3>(testvals_int32); 230 test_Rescale<13, 9>(testvals_int32); 231 232 test_ExactMulByPot<0, 0>(testvals_int32); 233 test_ExactMulByPot<0, 4>(testvals_int32); 234 test_ExactMulByPot<1, 4>(testvals_int32); 235 test_ExactMulByPot<3, 2>(testvals_int32); 236 test_ExactMulByPot<-4, 5>(testvals_int32); 237 test_ExactMulByPot<-2, 6>(testvals_int32); 238 239 test_exp_on_interval_between_negative_one_quarter_and_0_excl(testvals_int32); 240 241 test_exp_on_negative_values<1>(testvals_int32); 242 test_exp_on_negative_values<2>(testvals_int32); 243 test_exp_on_negative_values<3>(testvals_int32); 244 test_exp_on_negative_values<4>(testvals_int32); 245 test_exp_on_negative_values<5>(testvals_int32); 246 test_exp_on_negative_values<6>(testvals_int32); 247 248 test_one_minus_x_over_one_plus_x_for_x_in_0_1(testvals_int32); 249 250 test_tanh<1>(testvals_int32); 251 test_tanh<2>(testvals_int32); 252 test_tanh<3>(testvals_int32); 253 test_tanh<4>(testvals_int32); 254 test_tanh<5>(testvals_int32); 255 test_tanh<6>(testvals_int32); 256 257 #ifdef GEMMLOWP_NEON 258 test_int32x4(testvals_int32); 259 #endif // GEMMLOWP_NEON 260 261 std::cerr << "All tests passed." << std::endl; 262 } 263