1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "test.h" 16 17 #include <unistd.h> 18 #include <cstdint> 19 #include <cstdlib> 20 #include <ctime> 21 #include <iostream> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 #ifdef __APPLE__ 26 #include <TargetConditionals.h> 27 #endif 28 29 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h" 30 #include "../internal/kernel_reference.h" 31 #include "test_data.h" 32 33 namespace gemmlowp { 34 35 void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b, 36 bool transpose_c, int m, int n, int k, 37 const uint8_t* a, int32_t a_offset, int lda, 38 const uint8_t* b, int32_t b_offset, int ldb, 39 uint8_t* c, int32_t c_offset, int32_t c_mult_int, 40 int32_t c_shift, int ldc) { 41 assert((c_shift >= 0) && (c_shift <= 32)); 42 43 assert(a != nullptr); 44 assert(b != nullptr); 45 assert(c != nullptr); 46 47 int a_i_stride; 48 int a_l_stride; 49 if (transpose_a) { 50 a_i_stride = lda; 51 a_l_stride = 1; 52 } else { 53 a_i_stride = 1; 54 a_l_stride = lda; 55 } 56 int b_j_stride; 57 int b_l_stride; 58 if (transpose_b) { 59 b_j_stride = 1; 60 b_l_stride = ldb; 61 } else { 62 b_j_stride = ldb; 63 b_l_stride = 1; 64 } 65 int c_i_stride; 66 int c_j_stride; 67 if (transpose_c) { 68 c_i_stride = ldc; 69 c_j_stride = 1; 70 } else { 71 c_i_stride = 1; 72 c_j_stride = ldc; 73 } 74 int i, j, l; 75 76 const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1)); 77 78 for (j = 0; j < n; j++) { 79 for (i = 0; i < m; i++) { 80 int32_t total = 0; 81 for (l = 0; l < k; l++) { 82 const int a_index = i * a_i_stride + l * a_l_stride; 83 const uint8_t a_as_byte = a[a_index]; 84 const int32_t a_as_int = static_cast<int32_t>(a_as_byte) + a_offset; 85 const int b_index = j * b_j_stride + l * b_l_stride; 86 const uint8_t b_as_byte = b[b_index]; 87 const int32_t b_as_int = static_cast<int32_t>(b_as_byte) + b_offset; 88 const int32_t mult_as_int = a_as_int * b_as_int; 89 total += mult_as_int; 90 } 91 int32_t output = 92 (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift; 93 if (output > 255) { 94 output = 255; 95 } 96 if (output < 0) { 97 output = 0; 98 } 99 const int c_index = i * c_i_stride + j * c_j_stride; 100 c[c_index] = static_cast<uint8_t>(output); 101 } 102 } 103 } 104 105 // *GemmWrapper's allow to wrap various Gemm functions in a uniform 106 // interface, so we can use the same testing code to test all of them 107 108 template <typename Kernel, typename Scalar, typename tBitDepthParams> 109 struct SingleThreadGemmWrapper { 110 typedef tBitDepthParams BitDepthParams; 111 112 static const char* Name() { 113 static char buf[256]; 114 snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name()); 115 return buf; 116 } 117 118 typedef SingleThreadGemmContext Context; 119 120 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> 121 static void Gemm(Context* context, 122 const MatrixMap<const Scalar, LhsOrder>& lhs, 123 const MatrixMap<const Scalar, RhsOrder>& rhs, 124 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, 125 int rhs_offset, int result_offset, int result_mult_int, 126 int result_shift) { 127 const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows()); 128 const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols()); 129 SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams, 130 LhsOrder, RhsOrder, ResultOrder, 131 OffsetColDup, OffsetRowDup>( 132 context, Kernel(), lhs, rhs, result, lhs_offset_vector, 133 rhs_offset_vector, 134 MakeStandardOutputPipeline(result_offset, result_mult_int, 135 result_shift)); 136 } 137 }; 138 139 template <typename Kernel, typename Scalar, typename tBitDepthParams> 140 struct MultiThreadGemmWrapper { 141 typedef tBitDepthParams BitDepthParams; 142 143 static const char* Name() { 144 static char buf[256]; 145 snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name()); 146 return buf; 147 } 148 149 typedef MultiThreadGemmContext Context; 150 151 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> 152 static void Gemm(Context* context, 153 const MatrixMap<const Scalar, LhsOrder>& lhs, 154 const MatrixMap<const Scalar, RhsOrder>& rhs, 155 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, 156 int rhs_offset, int result_offset, int result_mult_int, 157 int result_shift) { 158 const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows()); 159 const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols()); 160 MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams, 161 LhsOrder, RhsOrder, ResultOrder, 162 OffsetColDup, OffsetRowDup>( 163 context, Kernel(), lhs, rhs, result, lhs_offset_vector, 164 rhs_offset_vector, 165 MakeStandardOutputPipeline(result_offset, result_mult_int, 166 result_shift)); 167 } 168 }; 169 170 template <typename Scalar, typename tBitDepthParams> 171 struct PublicGemmWrapper { 172 typedef tBitDepthParams BitDepthParams; 173 174 static const char* Name() { return "public Gemm"; } 175 176 typedef GemmContext Context; 177 178 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> 179 static void Gemm(Context* context, 180 const MatrixMap<const Scalar, LhsOrder>& lhs, 181 const MatrixMap<const Scalar, RhsOrder>& rhs, 182 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, 183 int rhs_offset, int result_offset, int result_mult_int, 184 int result_shift) { 185 gemmlowp::Gemm<uint8_t, BitDepthParams, LhsOrder, RhsOrder, ResultOrder>( 186 context, lhs, rhs, result, lhs_offset, rhs_offset, result_offset, 187 result_mult_int, result_shift); 188 } 189 }; 190 191 template <eight_bit_int_gemm::BitDepthSetting BitDepth> 192 struct BitDepthParamsForSettings {}; 193 194 template <> 195 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8> 196 : DefaultL8R8BitDepthParams {}; 197 198 template <> 199 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7> 200 : DefaultL7R5BitDepthParams {}; 201 202 template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth> 203 struct EightBitIntGemmWrapper { 204 typedef BitDepthParamsForSettings<BitDepth> BitDepthParams; 205 206 static const char* Name() { return "EightBitIntGemm"; } 207 208 typedef void Context; 209 210 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> 211 static void Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs, 212 const MatrixMap<const Scalar, RhsOrder>& rhs, 213 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, 214 int rhs_offset, int result_offset, int result_mult_int, 215 int result_shift) { 216 const bool transpose_c = ResultOrder == MapOrder::RowMajor; 217 const bool transpose_a = LhsOrder == MapOrder::RowMajor; 218 const bool transpose_b = RhsOrder == MapOrder::RowMajor; 219 eight_bit_int_gemm::EightBitIntGemm( 220 transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(), 221 lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(), 222 rhs_offset, rhs.stride(), result->data(), result_offset, 223 result_mult_int, result_shift, result->stride(), BitDepth); 224 } 225 }; 226 227 template <typename Scalar> 228 struct ReferenceEightBitIntGemmWrapper { 229 typedef DefaultL8R8BitDepthParams BitDepthParams; 230 231 static const char* Name() { return "ReferenceEightBitIntGemm"; } 232 233 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder> 234 static void Gemm(bool transpose_a, bool transpose_b, bool transpose_c, 235 const MatrixMap<const Scalar, LhsOrder>& lhs, 236 const MatrixMap<const Scalar, RhsOrder>& rhs, 237 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset, 238 int rhs_offset, int result_offset, int result_mult_int, 239 int result_shift) { 240 ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(), 241 rhs.cols(), lhs.cols(), lhs.data(), lhs_offset, 242 lhs.stride(), rhs.data(), rhs_offset, rhs.stride(), 243 result->data(), result_offset, result_mult_int, 244 result_shift, result->stride()); 245 } 246 }; 247 248 const char* OrderName(MapOrder order) { 249 return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor"; 250 } 251 252 struct ResultStats { 253 ResultStats() 254 : count(0), 255 med_val(0), 256 mean_signed_diff(0), 257 med_signed_diff(0), 258 med_unsigned_diff(0), 259 max_unsigned_diff(0) {} 260 261 int count; 262 int med_val; 263 float mean_signed_diff; 264 int med_signed_diff; 265 int med_unsigned_diff; 266 int max_unsigned_diff; 267 268 std::vector<int> count_diff_by_pot_slice; 269 }; 270 271 void GetResultStats(const uint8_t* actual, const uint8_t* expected, 272 size_t count, ResultStats* stats) { 273 std::vector<uint8_t> results; 274 std::vector<int16_t> signed_diffs; 275 std::vector<uint8_t> unsigned_diffs; 276 int64_t signed_diffs_sum = 0; 277 for (size_t i = 0; i < count; i++) { 278 results.push_back(actual[i]); 279 int16_t signed_diff = actual[i] - expected[i]; 280 signed_diffs.push_back(signed_diff); 281 unsigned_diffs.push_back(std::abs(signed_diff)); 282 signed_diffs_sum += signed_diff; 283 } 284 285 std::sort(results.begin(), results.end()); 286 std::sort(signed_diffs.begin(), signed_diffs.end()); 287 std::sort(unsigned_diffs.begin(), unsigned_diffs.end()); 288 289 const size_t middle = count / 2; 290 291 stats->count = count; 292 stats->med_val = results[middle]; 293 stats->mean_signed_diff = float(signed_diffs_sum) / count; 294 stats->med_signed_diff = signed_diffs[middle]; 295 stats->med_unsigned_diff = unsigned_diffs[middle]; 296 stats->max_unsigned_diff = unsigned_diffs.back(); 297 298 // Size 9 for 9 different POT values: 2^0, ..., 2^8 299 stats->count_diff_by_pot_slice.resize(9); 300 auto cur = unsigned_diffs.begin(); 301 size_t checksum = 0; 302 for (int exponent = 0; exponent < 9; exponent++) { 303 int pot = 1 << exponent; 304 auto next = std::lower_bound(cur, unsigned_diffs.end(), pot); 305 checksum += stats->count_diff_by_pot_slice[exponent] = next - cur; 306 cur = next; 307 } 308 assert(checksum == count); 309 } 310 311 struct ResultStatsBounds { 312 ResultStatsBounds() 313 : mean_signed_diff(0), 314 med_signed_diff(0), 315 med_unsigned_diff(0), 316 max_unsigned_diff(0) {} 317 318 float mean_signed_diff; 319 int med_signed_diff; 320 int med_unsigned_diff; 321 int max_unsigned_diff; 322 }; 323 324 bool CheckResultStatsBounds(const ResultStats& stats, 325 const ResultStatsBounds& bounds) { 326 return stats.max_unsigned_diff <= bounds.max_unsigned_diff && 327 stats.med_unsigned_diff <= bounds.med_unsigned_diff && 328 std::abs(stats.med_signed_diff) <= bounds.med_signed_diff && 329 std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff; 330 } 331 332 void ReportResultStats(const ResultStats& stats, 333 const ResultStatsBounds& bounds) { 334 printf(" number of matrix entries: %d\n", stats.count); 335 printf(" median value: %d\n", stats.med_val); 336 printf(" median unsigned diff: %d (tolerating %d)\n", 337 stats.med_unsigned_diff, bounds.med_unsigned_diff); 338 printf(" max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff, 339 bounds.max_unsigned_diff); 340 printf(" median signed diff: %d (tolerating %d)\n", stats.med_signed_diff, 341 bounds.med_signed_diff); 342 printf(" mean signed diff: %.3g (tolerating %.3g)\n", 343 stats.mean_signed_diff, bounds.mean_signed_diff); 344 345 printf("No error: %.2f %% of entries\n", 346 100.f * stats.count_diff_by_pot_slice[0] / stats.count); 347 for (int exponent = 1; exponent < 9; exponent++) { 348 printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1), 349 (1 << exponent) - 1, 350 100.f * stats.count_diff_by_pot_slice[exponent] / stats.count); 351 } 352 } 353 354 // Our approach to choosing result_shift values for testing, is bisection. 355 // This function takes an interval, [result_shift_min .. result_shift_max]. 356 // If too much saturation occurred in either direction, it bisects accordingly, 357 // recursing until the interval contains only one value. 358 // The primary reason why we prefer this over computing optimal shift values, 359 // is that we actually want to exercise some saturation, as there is nontrivial 360 // code handling that in gemmlowp. 361 // Secondarily, this is faster than computing optimal shifts, since in 90% of 362 // cases the first-tried shift value 16 turns out to be good enough. 363 template <typename GemmWrapper, typename LhsType, typename RhsType, 364 typename ResultType> 365 void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs, 366 const RhsType& rhs, ResultType* result, int lhs_offset, 367 int rhs_offset, int result_offset, int result_mult_int, 368 int result_shift_min, int result_shift_max) { 369 const int rows = lhs.rows(); 370 const int cols = rhs.cols(); 371 Check(lhs.cols() == rhs.rows()); 372 const int depth = lhs.cols(); 373 374 const int result_shift = (result_shift_min + result_shift_max) / 2; 375 376 GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(), &result->map(), 377 lhs_offset, rhs_offset, result_offset, result_mult_int, 378 result_shift); 379 380 typedef typename ResultType::Scalar Scalar; 381 static const MapOrder kLhsOrder = LhsType::kOrder; 382 static const MapOrder kRhsOrder = RhsType::kOrder; 383 static const MapOrder kResultOrder = ResultType::kOrder; 384 ResultType ref_result(rows, cols); 385 const bool transpose_c = kResultOrder == MapOrder::RowMajor; 386 const bool transpose_a = kLhsOrder == MapOrder::RowMajor; 387 const bool transpose_b = kRhsOrder == MapOrder::RowMajor; 388 ReferenceEightBitIntGemmWrapper<Scalar>::Gemm( 389 transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(), 390 &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int, 391 result_shift); 392 393 typedef typename GemmWrapper::BitDepthParams BitDepthParams; 394 395 ResultStats stats; 396 GetResultStats(result->data(), ref_result.data(), rows * cols, &stats); 397 398 // Adjust shifts until we get meaningful results 399 int new_result_shift_min = result_shift_min; 400 int new_result_shift_max = result_shift_max; 401 bool retry = false; 402 403 if (stats.med_val < 32) { 404 new_result_shift_max = (result_shift_min + result_shift_max) / 2; 405 retry = true; 406 } 407 408 if (stats.med_val > 224) { 409 new_result_shift_min = (result_shift_min + result_shift_max) / 2; 410 retry = true; 411 } 412 413 if (retry) { 414 if (result_shift_min != result_shift_max) { 415 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, 416 rhs_offset, result_offset, result_mult_int, 417 new_result_shift_min, new_result_shift_max); 418 } 419 return; 420 } 421 422 ResultStatsBounds bounds; 423 424 if (BitDepthParams::LhsBitDepth::kBits < 8 || 425 BitDepthParams::RhsBitDepth::kBits < 8) { 426 // We have very lax requirements on unsigned diff. 427 // We have tighter requirements on signed diff (bias), but only 428 // if the matrix is large enough for things to average out. 429 // For very small sizes, we... basically don't test anything. 430 // The problem is that this test uses unrealistic combinations of 431 // result_mult_int 432 // and result_shift, resulting in potentially wild requantization artifacts 433 // on small GEMMs. 434 int adjust_for_small_sizes = 1000 / (rows * cols); 435 bounds.max_unsigned_diff = 436 std::max(stats.med_val / 2, adjust_for_small_sizes); 437 bounds.med_unsigned_diff = 438 std::max(stats.med_val / 8, adjust_for_small_sizes); 439 bounds.med_signed_diff = std::max(2, adjust_for_small_sizes); 440 bounds.mean_signed_diff = std::max(2, adjust_for_small_sizes); 441 } 442 443 // Check results 444 const bool good = CheckResultStatsBounds(stats, bounds); 445 446 printf( 447 "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n", 448 good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder), 449 OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(), 450 lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift); 451 452 if (!good) { 453 ReportResultStats(stats, bounds); 454 455 int bad_coeffs_printed = 0; 456 for (int c = 0; c < result->cols() && bad_coeffs_printed < 20; c++) { 457 for (int r = 0; r < result->rows() && bad_coeffs_printed < 20; r++) { 458 if (ref_result(r, c) != (*result)(r, c)) { 459 printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c, 460 ref_result(r, c), (*result)(r, c)); 461 bad_coeffs_printed++; 462 } 463 } 464 } 465 } 466 467 Check(good); 468 } 469 470 template <typename GemmWrapper, typename LhsType, typename RhsType, 471 typename ResultType> 472 void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs, 473 const RhsType& rhs, ResultType* result, int lhs_offset, 474 int rhs_offset, int result_offset, int result_mult_int) { 475 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset, 476 result_offset, result_mult_int, 0, 32); 477 } 478 479 enum class WhatParamsToTest { 480 All, 481 OnlyGenericCase, 482 }; 483 484 template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder, 485 MapOrder ResultOrder> 486 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth, 487 int cols, WhatParamsToTest params_to_test) { 488 typedef std::uint8_t Scalar; 489 typedef Matrix<Scalar, LhsOrder> LhsType; 490 LhsType lhs(rows, depth); 491 MakeRandom(&lhs, 8); 492 typedef Matrix<Scalar, RhsOrder> RhsType; 493 RhsType rhs(depth, cols); 494 MakeRandom(&rhs, 8); 495 typedef Matrix<Scalar, ResultOrder> ResultType; 496 ResultType result(rows, cols); 497 MakeZero(&result); 498 499 if (params_to_test == WhatParamsToTest::All) { 500 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1); 501 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1); 502 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1); 503 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1); 504 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10); 505 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10); 506 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4); 507 } 508 test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123); 509 } 510 511 enum class WhatOrdersToTest { All, OnlyRCC }; 512 513 template <typename GemmWrapper> 514 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth, 515 int cols, WhatParamsToTest params_to_test, 516 WhatOrdersToTest orders_to_test) { 517 #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder) \ 518 do { \ 519 test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \ 520 MapOrder::ResultOrder>(context, rows, depth, cols, \ 521 params_to_test); \ 522 } while (false) 523 524 if (orders_to_test == WhatOrdersToTest::All) { 525 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor); 526 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor); 527 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor); 528 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor); 529 530 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor); 531 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor); 532 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor); 533 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor); 534 } else { 535 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor); 536 } 537 538 #undef GEMMLOWP_ONE_TEST 539 } 540 541 template <typename Kernel> 542 void test_gemm_kernel(MultiThreadGemmContext* context) { 543 typedef MultiThreadGemmWrapper<Kernel, std::uint8_t, 544 DefaultL8R8BitDepthParams> 545 GemmWrapper; 546 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase, 547 WhatOrdersToTest::OnlyRCC); 548 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase, 549 WhatOrdersToTest::OnlyRCC); 550 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase, 551 WhatOrdersToTest::OnlyRCC); 552 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase, 553 WhatOrdersToTest::OnlyRCC); 554 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase, 555 WhatOrdersToTest::OnlyRCC); 556 test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase, 557 WhatOrdersToTest::OnlyRCC); 558 test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All, 559 WhatOrdersToTest::OnlyRCC); 560 test_gemm<GemmWrapper>(context, 200, 200, 200, 561 WhatParamsToTest::OnlyGenericCase, 562 WhatOrdersToTest::All); 563 test_gemm<GemmWrapper>(context, 50, 5000, 50, 564 WhatParamsToTest::OnlyGenericCase, 565 WhatOrdersToTest::OnlyRCC); 566 } 567 568 template <typename GemmWrapper> 569 void test_gemm(typename GemmWrapper::Context* context) { 570 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All, 571 WhatOrdersToTest::OnlyRCC); 572 test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All, 573 WhatOrdersToTest::OnlyRCC); 574 test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All, 575 WhatOrdersToTest::OnlyRCC); 576 test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All, 577 WhatOrdersToTest::OnlyRCC); 578 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All, 579 WhatOrdersToTest::OnlyRCC); 580 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All, 581 WhatOrdersToTest::OnlyRCC); 582 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All, 583 WhatOrdersToTest::OnlyRCC); 584 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All, 585 WhatOrdersToTest::OnlyRCC); 586 test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All, 587 WhatOrdersToTest::OnlyRCC); 588 test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All, 589 WhatOrdersToTest::OnlyRCC); 590 test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All, 591 WhatOrdersToTest::OnlyRCC); 592 test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All, 593 WhatOrdersToTest::OnlyRCC); 594 test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All, 595 WhatOrdersToTest::OnlyRCC); 596 test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All, 597 WhatOrdersToTest::OnlyRCC); 598 test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All, 599 WhatOrdersToTest::OnlyRCC); 600 test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All, 601 WhatOrdersToTest::OnlyRCC); 602 test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All, 603 WhatOrdersToTest::OnlyRCC); 604 605 test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All, 606 WhatOrdersToTest::OnlyRCC); 607 test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All, 608 WhatOrdersToTest::OnlyRCC); 609 test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All, 610 WhatOrdersToTest::OnlyRCC); 611 test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All, 612 WhatOrdersToTest::OnlyRCC); 613 test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All, 614 WhatOrdersToTest::OnlyRCC); 615 test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All, 616 WhatOrdersToTest::OnlyRCC); 617 618 test_gemm<GemmWrapper>(context, 512, 512, 512, 619 WhatParamsToTest::OnlyGenericCase, 620 WhatOrdersToTest::OnlyRCC); 621 test_gemm<GemmWrapper>(context, 1024, 1024, 1024, 622 WhatParamsToTest::OnlyGenericCase, 623 WhatOrdersToTest::OnlyRCC); 624 test_gemm<GemmWrapper>(context, 567, 2345, 123, 625 WhatParamsToTest::OnlyGenericCase, 626 WhatOrdersToTest::OnlyRCC); 627 test_gemm<GemmWrapper>(context, 100, 5000, 100, 628 WhatParamsToTest::OnlyGenericCase, 629 WhatOrdersToTest::OnlyRCC); 630 test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase, 631 WhatOrdersToTest::OnlyRCC); 632 test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase, 633 WhatOrdersToTest::OnlyRCC); 634 test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase, 635 WhatOrdersToTest::OnlyRCC); 636 test_gemm<GemmWrapper>(context, 1, 1000, 1000, 637 WhatParamsToTest::OnlyGenericCase, 638 WhatOrdersToTest::OnlyRCC); 639 test_gemm<GemmWrapper>(context, 1000, 1, 1000, 640 WhatParamsToTest::OnlyGenericCase, 641 WhatOrdersToTest::OnlyRCC); 642 test_gemm<GemmWrapper>(context, 1000, 1000, 1, 643 WhatParamsToTest::OnlyGenericCase, 644 WhatOrdersToTest::OnlyRCC); 645 test_gemm<GemmWrapper>(context, 777, 3456, 1, 646 WhatParamsToTest::OnlyGenericCase, 647 WhatOrdersToTest::OnlyRCC); 648 test_gemm<GemmWrapper>(context, 4567, 555, 1, 649 WhatParamsToTest::OnlyGenericCase, 650 WhatOrdersToTest::OnlyRCC); 651 652 // Test all storage orders 653 test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All, 654 WhatOrdersToTest::All); 655 test_gemm<GemmWrapper>(context, 300, 400, 500, 656 WhatParamsToTest::OnlyGenericCase, 657 WhatOrdersToTest::All); 658 } 659 660 template <typename GemmWrapper> 661 void test_gemv(typename GemmWrapper::Context* context) { 662 test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All, 663 WhatOrdersToTest::OnlyRCC); 664 test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All, 665 WhatOrdersToTest::OnlyRCC); 666 test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All, 667 WhatOrdersToTest::OnlyRCC); 668 test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All, 669 WhatOrdersToTest::OnlyRCC); 670 test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All, 671 WhatOrdersToTest::OnlyRCC); 672 test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All, 673 WhatOrdersToTest::OnlyRCC); 674 test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All, 675 WhatOrdersToTest::OnlyRCC); 676 test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All, 677 WhatOrdersToTest::OnlyRCC); 678 test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All, 679 WhatOrdersToTest::OnlyRCC); 680 test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All, 681 WhatOrdersToTest::OnlyRCC); 682 test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All, 683 WhatOrdersToTest::OnlyRCC); 684 test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All, 685 WhatOrdersToTest::OnlyRCC); 686 687 // Test all storage orders 688 test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All, 689 WhatOrdersToTest::All); 690 test_gemm<GemmWrapper>(context, 300, 400, 1, 691 WhatParamsToTest::OnlyGenericCase, 692 WhatOrdersToTest::All); 693 } 694 695 const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) { 696 switch (b) { 697 case eight_bit_int_gemm::BitDepthSetting::A8B8: 698 return "Lhs: 8 bit, Rhs: 8 bit"; 699 case eight_bit_int_gemm::BitDepthSetting::A5B7: 700 return "Lhs: 7 bit, Rhs: 5 bit"; 701 default: 702 abort(); 703 return nullptr; 704 } 705 } 706 707 // Runs a small set of hand-picked data for per-channel quantized data. 708 // This test case comes from a set of 2 2x2 convolution filters run over a 3x3 709 // image. 710 void TestWithSmallDataPerChannelQuantization() { 711 const int m = 2; 712 const int n = 9; 713 const int k = 12; 714 715 // 12 x 2, columnwise. 716 const uint8_t a_data[] = { 717 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 718 64, 64, 64, 64, 64, 64, 0, 0, 0, 255, 255, 255 719 }; 720 const int lda = k; 721 int a_offset[] = {0, -64}; 722 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda); 723 const OffsetColMap lhs_offset(a_offset, m); 724 725 // 12 x 9, columnwise. 726 const uint8_t b_data[] = { 727 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 728 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 729 0, 0, 0, 127, 127, 127, 0, 0, 0, 127, 127, 127, 730 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 731 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 732 0, 0, 0, 127, 127, 127, 0, 0, 0, 127, 127, 127, 733 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127, 734 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127, 735 0, 0, 0, 127, 127, 127, 127, 127, 127, 127, 127, 127 736 }; 737 const int ldb = k; 738 int b_offset = -127; 739 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb); 740 const OffsetRowDup rhs_offset(b_offset, rhs.cols()); 741 742 // 2 x 9, columnwise. 743 const uint8_t expected_c_data[] = { 744 255, 255, 745 0, 0, 746 127, 159, 747 0, 64, 748 0, 64, 749 127, 159, 750 127, 127, 751 127, 127, 752 127, 127 753 }; 754 const int ldc = m; 755 int c_offset[] = {97155, 97346}; 756 int c_mult_int[] = {2741, 2741}; 757 const int c_shift = 21; 758 759 const int c_count = m * n; 760 std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]); 761 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n, 762 ldc); 763 const OffsetColMap result_offset(c_offset, m); 764 const OffsetColMap result_mult_int(c_mult_int, m); 765 const int result_shift = c_shift; 766 767 GemmContext gemm_context; 768 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>( 769 result_offset, result_mult_int, result_shift); 770 GemmWithOutputPipelinePC<uint8_t, uint8_t, DefaultL8R8BitDepthParams>( 771 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset, 772 output_pipeline); 773 774 ResultStats stats; 775 GetResultStats(output_data.get(), expected_c_data, c_count, &stats); 776 777 ResultStatsBounds bounds; 778 const bool good = CheckResultStatsBounds(stats, bounds); 779 printf("TestWithSmallDataPerChannelQuantization: %s\n", 780 good ? "PASS" : "FAIL"); 781 ReportResultStats(stats, bounds); 782 Check(good); 783 } 784 785 // Runs a larger set of hand-picked data for per-channel quantized data. 786 // This test case comes from a set of 22 3x3 convolution filters run over a 5x5 787 // image. Right now, I have 7 different filters and 15 copies of the first 788 // filter to make sure NEON code path that processes 16 rows at a time is 789 // covered. 790 void TestWithLargeDataPerChannelQuantization() { 791 const int m = 22; 792 const int n = 25; 793 const int k = 27; 794 795 // 27 x 22, column-wise. 796 const uint8_t a_data[] = { 797 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 798 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 799 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 127, 127, 255, 255, 255, 800 127, 127, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 801 0, 0, 0, 127, 127, 127, 0, 0, 0, 0, 0, 0, 255, 255, 255, 802 0, 0, 0, 0, 0, 0, 127, 127, 127, 0, 0, 0, 803 51, 51, 51, 51, 51, 51, 51, 51, 51, 0, 0, 0, 255, 255, 255, 804 0, 0, 0, 51, 51, 51, 51, 51, 51, 51, 51, 51, 805 51, 51, 51, 0, 0, 0, 51, 51, 51, 51, 51, 51, 255, 255, 255, 806 51, 51, 51, 51, 51, 51, 0, 0, 0, 51, 51, 51, 807 0, 0, 0, 64, 64, 64, 0, 0, 0, 64, 64, 64, 255, 255, 255, 808 64, 64, 64, 0, 0, 0, 64, 64, 64, 0, 0, 0, 809 36, 36, 36, 0, 0, 0, 36, 36, 36, 0, 0, 0, 255, 255, 255, 810 0, 0, 0, 36, 36, 36, 0, 0, 0, 36, 36, 36, 811 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 812 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 813 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 814 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 815 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 816 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 817 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 818 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 819 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 820 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 821 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 822 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 823 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 824 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 825 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 826 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 827 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 828 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 829 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 830 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 831 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 832 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 833 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 834 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 835 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 836 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 837 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 838 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 839 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 840 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 841 }; 842 const int lda = k; 843 int a_offset[] = { 844 0, 0, 0, -51, -51, 0, -36, 0, 0, 0, 845 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 846 0, 0 847 }; 848 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda); 849 const OffsetColMap lhs_offset(a_offset, m); 850 851 // 27 x 25, column-wise. 852 const uint8_t b_data[] = { 853 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 854 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 855 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 856 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 857 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 858 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 859 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 860 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 861 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 862 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 863 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 864 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 865 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 866 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 867 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 868 119, 119, 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 869 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 870 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 871 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 872 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 873 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 874 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 875 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 876 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 877 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 878 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 879 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 880 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 881 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 882 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 883 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 884 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 885 119, 119, 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 886 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 887 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 888 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 889 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 890 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 891 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 892 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 893 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 894 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127, 895 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 896 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127, 897 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 898 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127, 899 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 900 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127, 901 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 902 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127 903 }; 904 const int ldb = k; 905 int b_offset = -127; 906 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb); 907 const OffsetRowDup rhs_offset(b_offset, rhs.cols()); 908 909 // 22 x 25, column-wise. 910 const uint8_t expected_c_data[] = { 911 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7, 912 7, 7, 7, 7, 7, 7, 7, 913 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 914 7, 7, 7, 7, 7, 7, 7, 915 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 916 7, 7, 7, 7, 7, 7, 7, 917 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 918 7, 7, 7, 7, 7, 7, 7, 919 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7, 920 7, 7, 7, 7, 7, 7, 7, 921 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 922 7, 7, 7, 7, 7, 7, 7, 923 7, 7, 7, 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 924 7, 7, 7, 7, 7, 7, 7, 925 7, 7, 71, 87, 45, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7, 926 7, 7, 7, 7, 7, 7, 7, 927 7, 7, 7, 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 928 7, 7, 7, 7, 7, 7, 7, 929 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 930 7, 7, 7, 7, 7, 7, 7, 931 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 932 7, 7, 7, 7, 7, 7, 7, 933 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7, 934 7, 7, 7, 7, 7, 7, 7, 935 255, 135, 135, 255, 255, 143, 255, 255, 255, 255, 255, 255, 255, 255, 255, 936 255, 255, 255, 255, 255, 255, 255, 937 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7, 938 7, 7, 7, 7, 7, 7, 7, 939 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 940 7, 7, 7, 7, 7, 7, 7, 941 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 942 7, 7, 7, 7, 7, 7, 7, 943 7, 7, 7, 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 944 7, 7, 7, 7, 7, 7, 7, 945 7, 7, 71, 87, 45, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7, 946 7, 7, 7, 7, 7, 7, 7, 947 7, 7, 7, 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 948 7, 7, 7, 7, 7, 7, 7, 949 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 950 7, 7, 7, 7, 7, 7, 7, 951 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7, 952 7, 7, 7, 7, 7, 7, 7, 953 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 954 7, 7, 7, 7, 7, 7, 7, 955 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 956 7, 7, 7, 7, 7, 7, 7, 957 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 958 7, 7, 7, 7, 7, 7, 7, 959 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7, 960 7, 7, 7, 7, 7, 7, 7, 961 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 962 99, 99, 99, 99, 99, 99, 99, 963 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 964 111, 111, 111, 111, 111, 111, 111, 965 }; 966 const int ldc = m; 967 int c_offset[] = { 968 6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477, 969 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 970 6477, 6477, 971 }; 972 int c_mult_int[] = { 973 41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121, 41121, 41121, 974 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121, 975 41121, 41121, 976 }; 977 const int c_shift = 21; 978 979 const int c_count = m * n; 980 std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]); 981 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n, 982 ldc); 983 const OffsetColMap result_offset(c_offset, m); 984 const OffsetColMap result_mult_int(c_mult_int, m); 985 const int result_shift = c_shift; 986 987 GemmContext gemm_context; 988 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>( 989 result_offset, result_mult_int, result_shift); 990 GemmWithOutputPipelinePC<uint8_t, uint8_t, DefaultL8R8BitDepthParams>( 991 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset, 992 output_pipeline); 993 994 ResultStats stats; 995 GetResultStats(output_data.get(), expected_c_data, c_count, &stats); 996 997 ResultStatsBounds bounds; 998 const bool good = CheckResultStatsBounds(stats, bounds); 999 printf("TestWithLargeDataPerChannelQuantization: %s\n", 1000 good ? "PASS" : "FAIL"); 1001 ReportResultStats(stats, bounds); 1002 Check(good); 1003 } 1004 1005 // Runs a small set of hand-calculated data through the implementation. 1006 void TestWithSmallData() { 1007 const int m = 4; 1008 const int n = 2; 1009 const int k = 3; 1010 // Matrix A (LHS) is: 1011 // | 7 | 10 | 13 | 16 | 1012 // | 8 | 11 | 14 | 17 | 1013 // | 9 | 12 | 15 | 18 | 1014 const uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}; 1015 // Matrix B (RHS) is: 1016 // | 1 | 3 | 5 | 1017 // | 2 | 4 | 6 | 1018 const uint8_t b_data[] = {1, 2, 3, 4, 5, 6}; 1019 // Here are the results we expect, from hand calculations: 1020 // (1 * 7) + (3 * 8) + (5 * 9) = 76 1021 // (2 * 7) + (4 * 8) + (6 * 9) = 100 1022 // (1 * 10) + (3 * 11) + (5 * 12) = 103 1023 // (2 * 10) + (4 * 11) + (6 * 12) = 136 1024 // (1 * 13) + (3 * 14) + (5 * 15) = 130 1025 // (2 * 13) + (4 * 14) + (6 * 15) = 172 1026 // (1 * 16) + (3 * 17) + (5 * 18) = 157 1027 // (2 * 16) + (4 * 17) + (6 * 18) = 208 1028 // That means matrix C should be: 1029 // | 76 | 103 | 130 | 157 | 1030 // | 100 | 136 | 172 | 208 | 1031 const uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208}; 1032 1033 const int c_count = m * n; 1034 std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]); 1035 1036 const bool is_a_transposed = true; 1037 const bool is_b_transposed = true; 1038 const bool is_c_transposed = true; 1039 const int lda = k; 1040 const int ldb = n; 1041 const int ldc = n; 1042 1043 const int a_offset = 0; 1044 const int b_offset = 0; 1045 const int c_offset = 0; 1046 const int c_mult = 1; 1047 const int c_shift = 0; 1048 1049 gemmlowp::eight_bit_int_gemm::EightBitIntGemm( 1050 is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data, 1051 a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult, 1052 c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8); 1053 1054 ResultStats stats; 1055 GetResultStats(output_data.get(), expected_data, c_count, &stats); 1056 1057 ResultStatsBounds bounds; 1058 const bool good = CheckResultStatsBounds(stats, bounds); 1059 printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL"); 1060 ReportResultStats(stats, bounds); 1061 Check(good); 1062 } 1063 1064 // This is the most realistic test of how we'll be using the low-precision GEMM 1065 // function in applications. It takes in large input matrices that have been 1066 // captured from an actual neural network run. 1067 void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth, 1068 int tolerance_median, int tolerance_max) { 1069 std::unique_ptr<uint8_t[]> output_data(new uint8_t[test_data::c_count]); 1070 gemmlowp::eight_bit_int_gemm::EightBitIntGemm( 1071 test_data::is_a_transposed, test_data::is_b_transposed, 1072 test_data::is_c_transposed, test_data::m, test_data::n, test_data::k, 1073 test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data, 1074 test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset, 1075 test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth); 1076 1077 ResultStats stats; 1078 GetResultStats(output_data.get(), test_data::expected_c_data, 1079 test_data::c_count, &stats); 1080 1081 ResultStatsBounds bounds; 1082 if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) { 1083 bounds.med_unsigned_diff = tolerance_median; 1084 bounds.max_unsigned_diff = tolerance_max; 1085 bounds.med_signed_diff = 0; 1086 bounds.mean_signed_diff = 0.2f; 1087 } 1088 1089 const bool good = CheckResultStatsBounds(stats, bounds); 1090 printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL", 1091 GetBitDepthName(BitDepth)); 1092 ReportResultStats(stats, bounds); 1093 Check(good); 1094 } 1095 1096 template <MapOrder ResultOrder> 1097 void TestOutputStages(int rows, int depth, int cols, int result_offset, 1098 int result_mult_int, int result_shift) { 1099 Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth); 1100 Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols); 1101 Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols); 1102 MakeRandom(&lhs, 8); 1103 MakeRandom(&rhs, 8); 1104 const int lhs_offset = 12; 1105 const int rhs_offset = -34; 1106 1107 // Test an empty pipeline, i.e. returning raw int32 accumulators. 1108 auto empty_pipeline = std::make_tuple(); 1109 GemmContext context; 1110 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>( 1111 &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset, 1112 rhs_offset, empty_pipeline); 1113 1114 for (int r = 0; r < rows; r++) { 1115 for (int c = 0; c < cols; c++) { 1116 std::int32_t expected = 0; 1117 for (int d = 0; d < depth; d++) { 1118 std::int32_t lhs_val = 1119 static_cast<std::int32_t>(lhs(r, d)) + lhs_offset; 1120 std::int32_t rhs_val = 1121 static_cast<std::int32_t>(rhs(d, c)) + rhs_offset; 1122 expected += lhs_val * rhs_val; 1123 } 1124 Check(expected == result_raw_int32(r, c)); 1125 } 1126 } 1127 1128 // Test a pipeline with only the quantize-down stage, still returning 1129 // unclamped (but scaled) int32's 1130 OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage; 1131 quantize_down_stage.result_offset = result_offset; 1132 quantize_down_stage.result_mult_int = result_mult_int; 1133 quantize_down_stage.result_shift = result_shift; 1134 auto quantize_down_pipeline = std::make_tuple(quantize_down_stage); 1135 Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols); 1136 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>( 1137 &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32, 1138 lhs_offset, rhs_offset, quantize_down_pipeline); 1139 1140 std::uint64_t sum = 0; 1141 for (int r = 0; r < rows; r++) { 1142 for (int c = 0; c < cols; c++) { 1143 std::int32_t raw = result_raw_int32(r, c); 1144 const std::int32_t rounding = 1145 (result_shift < 1) ? 0 : (1 << (result_shift - 1)); 1146 std::int32_t expected = 1147 ((raw + result_offset) * result_mult_int + rounding) >> result_shift; 1148 Check(expected == result_quantized_down_int32(r, c)); 1149 sum += expected; 1150 } 1151 } 1152 std::uint64_t avg = sum / (rows * cols); 1153 // Test that the average quantized-down value falls reasonably in the 1154 // middle of the [0..255] range. Otherwise, the multiplier / shift need to be 1155 // adjusted. 1156 Check(avg >= 64 && avg <= 192); 1157 1158 // Test the familiar default pipeline consisting of quantize-down and 1159 // clamp-and-cast-to-uint8. 1160 OutputStageSaturatingCastToUint8 saturating_cast_stage; 1161 auto quantize_down_and_saturating_cast_pipeline = 1162 std::make_tuple(quantize_down_stage, saturating_cast_stage); 1163 Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows, 1164 cols); 1165 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>( 1166 &context, lhs.const_map(), rhs.const_map(), 1167 &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset, 1168 quantize_down_and_saturating_cast_pipeline); 1169 1170 for (int r = 0; r < rows; r++) { 1171 for (int c = 0; c < cols; c++) { 1172 std::int32_t quantized = result_quantized_down_int32(r, c); 1173 std::uint8_t expected = std::min(std::max(quantized, 0), 255); 1174 Check(expected == result_quantized_down_saturated_uint8(r, c)); 1175 } 1176 } 1177 1178 // Test a bias-addition with row-vector 1179 std::vector<std::int32_t> row_vector_data(cols); 1180 for (int i = 0; i < cols; i++) { 1181 row_vector_data[i] = (Random() % 1000) - 500; 1182 } 1183 typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap; 1184 RowVectorMap row_vector_map(row_vector_data.data(), cols); 1185 OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage; 1186 row_bias_addition_stage.bias_vector = row_vector_map; 1187 auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage); 1188 Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols); 1189 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>( 1190 &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition, 1191 lhs_offset, rhs_offset, row_bias_addition_pipeline); 1192 for (int r = 0; r < rows; r++) { 1193 for (int c = 0; c < cols; c++) { 1194 std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c]; 1195 Check(expected == result_of_row_bias_addition(r, c)); 1196 } 1197 } 1198 1199 // Test a bias-addition with column-vector 1200 std::vector<std::int32_t> col_vector_data(rows); 1201 for (int i = 0; i < rows; i++) { 1202 col_vector_data[i] = (Random() % 1000) - 500; 1203 } 1204 typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap; 1205 ColVectorMap col_vector_map(col_vector_data.data(), rows); 1206 OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage; 1207 col_bias_addition_stage.bias_vector = col_vector_map; 1208 auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage); 1209 Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols); 1210 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>( 1211 &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition, 1212 lhs_offset, rhs_offset, col_bias_addition_pipeline); 1213 for (int r = 0; r < rows; r++) { 1214 for (int c = 0; c < cols; c++) { 1215 std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r]; 1216 Check(expected == result_of_col_bias_addition(r, c)); 1217 } 1218 } 1219 1220 // Test a clamp 1221 OutputStageClamp clamp_stage; 1222 // Determine min and max of raw int32 accumulators 1223 std::int32_t raw_min = std::numeric_limits<std::int32_t>::max(); 1224 std::int32_t raw_max = std::numeric_limits<std::int32_t>::min(); 1225 for (int r = 0; r < rows; r++) { 1226 for (int c = 0; c < cols; c++) { 1227 raw_min = std::min(raw_min, result_raw_int32(r, c)); 1228 raw_max = std::max(raw_max, result_raw_int32(r, c)); 1229 } 1230 } 1231 // Pick some interesting clamp min/max bounds 1232 clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3); 1233 clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7); 1234 assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max && 1235 clamp_stage.max <= raw_max); 1236 auto clamp_pipeline = std::make_tuple(clamp_stage); 1237 Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols); 1238 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>( 1239 &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset, 1240 rhs_offset, clamp_pipeline); 1241 for (int r = 0; r < rows; r++) { 1242 for (int c = 0; c < cols; c++) { 1243 std::int32_t raw = result_raw_int32(r, c); 1244 std::int32_t expected = 1245 std::min(std::max(raw, clamp_stage.min), clamp_stage.max); 1246 Check(expected == result_clamped(r, c)); 1247 } 1248 } 1249 1250 // Test tanh 1251 OutputStageTanh tanh_stage; 1252 const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2; 1253 const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16; 1254 tanh_stage.real_zero_as_int32 = real_zero_as_int32; 1255 tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32; 1256 auto tanh_pipeline = std::make_tuple(tanh_stage); 1257 Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols); 1258 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>( 1259 &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset, 1260 rhs_offset, tanh_pipeline); 1261 for (int r = 0; r < rows; r++) { 1262 for (int c = 0; c < cols; c++) { 1263 std::int32_t raw = result_raw_int32(r, c); 1264 double real_input = 1265 double(raw - real_zero_as_int32) / real_amplitude_as_int32; 1266 double expected = std::tanh(real_input); 1267 std::int32_t actual_int32 = result_tanh(r, c); 1268 double actual = 1269 double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32; 1270 Check(std::abs(expected - actual) < 2e-4); 1271 } 1272 } 1273 1274 // Test a pipeline with bias and clamp 1275 auto bias_clamp_pipeline = 1276 std::make_tuple(col_bias_addition_stage, clamp_stage); 1277 Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols); 1278 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>( 1279 &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped, 1280 lhs_offset, rhs_offset, bias_clamp_pipeline); 1281 for (int r = 0; r < rows; r++) { 1282 for (int c = 0; c < cols; c++) { 1283 std::int32_t raw = result_raw_int32(r, c); 1284 std::int32_t biased = raw + col_vector_data[r]; 1285 std::int32_t expected = 1286 std::min(std::max(biased, clamp_stage.min), clamp_stage.max); 1287 Check(expected == result_biased_clamped(r, c)); 1288 } 1289 } 1290 1291 // Test a full pipeline with bias and clamp and quantization down to 8bit 1292 // result 1293 auto bias_clamp_quantize_cast_pipeline = 1294 std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage, 1295 saturating_cast_stage); 1296 Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted( 1297 rows, cols); 1298 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>( 1299 &context, lhs.const_map(), rhs.const_map(), 1300 &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset, 1301 bias_clamp_quantize_cast_pipeline); 1302 for (int r = 0; r < rows; r++) { 1303 for (int c = 0; c < cols; c++) { 1304 const std::int32_t rounding = 1305 (result_shift < 1) ? 0 : (1 << (result_shift - 1)); 1306 std::int32_t quantized = 1307 ((result_biased_clamped(r, c) + result_offset) * result_mult_int + 1308 rounding) >> 1309 result_shift; 1310 std::uint8_t expected = std::min(std::max(quantized, 0), 255); 1311 Check(expected == result_biased_clamped_quantized_casted(r, c)); 1312 } 1313 } 1314 1315 printf("TestOutputStages: PASS with ResultOrder=%s\n", 1316 OrderName(ResultOrder)); 1317 } 1318 1319 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS 1320 void TestExhaustively() { 1321 GemmContext context; 1322 1323 // Test the internal GEMM interfaces 1324 test_gemm<SingleThreadGemmWrapper< 1325 DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>, 1326 std::uint8_t, DefaultL8R8BitDepthParams>>(&context); 1327 1328 test_gemm<MultiThreadGemmWrapper< 1329 DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>, 1330 std::uint8_t, DefaultL8R8BitDepthParams>>(&context); 1331 1332 // Test the public GEMM interfaces 1333 test_gemm<PublicGemmWrapper<uint8_t, DefaultL8R8BitDepthParams>>(&context); 1334 1335 test_gemm<EightBitIntGemmWrapper<uint8_t, 1336 eight_bit_int_gemm::BitDepthSetting::A8B8>>( 1337 &context); 1338 1339 // Test GEMV cases (internal interfaces) 1340 test_gemv<SingleThreadGemmWrapper< 1341 DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>, 1342 std::uint8_t, DefaultL8R8BitDepthParams>>(&context); 1343 1344 test_gemv<MultiThreadGemmWrapper< 1345 DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>, 1346 std::uint8_t, DefaultL8R8BitDepthParams>>(&context); 1347 1348 // Test GEMV cases (public interfaces) 1349 test_gemv<PublicGemmWrapper<uint8_t, DefaultL8R8BitDepthParams>>(&context); 1350 1351 test_gemv<EightBitIntGemmWrapper<uint8_t, 1352 eight_bit_int_gemm::BitDepthSetting::A8B8>>( 1353 &context); 1354 1355 // Test other bit depths 1356 // L7R5 1357 test_gemm<SingleThreadGemmWrapper< 1358 DefaultKernel<KernelFamily::Gemm, DefaultL7R5BitDepthParams>, 1359 std::uint8_t, DefaultL7R5BitDepthParams>>(&context); 1360 1361 test_gemv<SingleThreadGemmWrapper< 1362 DefaultKernel<KernelFamily::Gemv, DefaultL7R5BitDepthParams>, 1363 std::uint8_t, DefaultL7R5BitDepthParams>>(&context); 1364 1365 test_gemm<EightBitIntGemmWrapper<std::uint8_t, 1366 eight_bit_int_gemm::BitDepthSetting::A5B7>>( 1367 &context); 1368 1369 // Test specific kernels with various different formats, 1370 // to exercises corner cases especially in the packing code. 1371 test_gemm_kernel< 1372 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>, 1373 KernelSideFormat<CellFormat<1, 1>, 1>>>>( 1374 &context); 1375 1376 test_gemm_kernel< 1377 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>, 1378 KernelSideFormat<CellFormat<4, 2>, 2>>>>( 1379 &context); 1380 1381 test_gemm_kernel< 1382 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>, 1383 KernelSideFormat<CellFormat<4, 2>, 5>>>>( 1384 &context); 1385 1386 test_gemm_kernel<ReferenceKernel<KernelFormat< 1387 KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>, 1388 KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context); 1389 1390 test_gemm_kernel<ReferenceKernel<KernelFormat< 1391 KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>, 1392 KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context); 1393 1394 test_gemm_kernel<ReferenceKernel<KernelFormat< 1395 KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>, 1396 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context); 1397 1398 test_gemm_kernel<ReferenceKernel<KernelFormat< 1399 KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>, 1400 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context); 1401 1402 test_gemm_kernel<ReferenceKernel<KernelFormat< 1403 KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>, 1404 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context); 1405 1406 test_gemm_kernel<ReferenceKernel<KernelFormat< 1407 KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>, 1408 KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context); 1409 } 1410 #endif // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS 1411 1412 void test() { 1413 #ifdef GEMMLOWP_TEST_PROFILE 1414 RegisterCurrentThreadForProfiling(); 1415 StartProfiling(); 1416 #endif 1417 1418 // Run a first quick test against hand-calculated data. 1419 TestWithSmallData(); 1420 1421 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS 1422 TestExhaustively(); 1423 #endif 1424 1425 // Run against actual data from a network evaluation. 1426 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0); 1427 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10); 1428 1429 // Test non-default output pipelines with various combinations of 1430 // output stages. 1431 TestOutputStages<MapOrder::RowMajor>(63, 10, 127, 5, 17, 14); 1432 TestOutputStages<MapOrder::ColMajor>(63, 10, 127, 5, 17, 14); 1433 1434 // Test per channel quantization. 1435 TestWithSmallDataPerChannelQuantization(); 1436 TestWithLargeDataPerChannelQuantization(); 1437 #ifdef GEMMLOWP_TEST_PROFILE 1438 FinishProfiling(); 1439 #endif 1440 1441 std::cerr << "All tests passed." << std::endl; 1442 1443 // We have been testing the eight_bit_int_gemm, so we should free its 1444 // persistent 1445 // resources now to avoid having leak-checking tools report leaks. 1446 eight_bit_int_gemm::FreePersistentResources(); 1447 } 1448 1449 } // end namespace gemmlowp 1450 1451 // For iOS, we need to define our own main(), so skip it here. 1452 #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR)) 1453 int main() { gemmlowp::test(); } 1454 #endif 1455