1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <unistd.h> 16 #ifdef __APPLE__ 17 #include <sys/time.h> 18 #endif 19 20 #include <cstdint> 21 #include <cstdlib> 22 #include <ctime> 23 #include <iomanip> 24 #include <iostream> 25 #include <map> 26 #include <memory> 27 #include <vector> 28 29 #include "multi_thread_gemm.h" 30 #include "quantized_mul_kernels.h" 31 #include "single_thread_gemm.h" 32 #include "streams.h" 33 34 #define LHS_OFFSET (-127) 35 #define RHS_OFFSET (-127) 36 #define SUM_OFFSET (127) 37 #define MUL_OFFSET (1) 38 #define SHIFT (7) 39 #define FLOAT_SCALE (0.333f) 40 41 using namespace gemmlowp::meta; 42 43 // Input, output & kernel setups. 44 45 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, ColumnMajorWithSum, 46 QuantizedStaticPreprocessed, RowMajor> 47 ParamsColumnMajor; 48 49 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, RowMajorWithSum, 50 QuantizedStaticPreprocessed, RowMajor> 51 ParamsRowMajor; 52 53 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, ColumnMajorWithSum, 54 QuantizedStaticPreprocessedAsFloat, RowMajor> 55 ParamsColumnMajorAsFloat; 56 57 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum, 58 QuantizedStaticPreprocessedAsFloat, RowMajor> 59 ParamsRowMajorAsFloat; 60 61 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, ColumnMajorWithSum, 62 QuantizedStaticPreprocessedAsInt32, RowMajor> 63 ParamsColumnMajorAsInt32; 64 65 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, RowMajorWithSum, 66 QuantizedStaticPreprocessedAsInt32, RowMajor> 67 ParamsRowMajorAsInt32; 68 69 typedef gemmlowp::WorkersPool Pool; 70 typedef SimpleContext<gemmlowp::WorkersPool> Context; 71 72 #ifdef LHS_PACK 73 typedef GemmExecutorPackLHSCacheFriendly<> Executor; 74 #else 75 typedef GemmExecutorPackRHSCacheFriendly<> Executor; 76 #endif 77 78 // Testing helper functions. 79 80 void prepare_test_data(std::uint8_t* data, std::int32_t rows, std::int32_t cols, 81 std::int32_t seed, std::int32_t seed_2) { 82 std::int32_t value = seed; 83 for (int i = 0; i < rows * cols; ++i) { 84 data[i] = static_cast<std::uint8_t>(value); 85 value = ((value * seed_2) + seed) % 256; 86 } 87 } 88 89 template <typename CLEAR_TYPE> 90 void clear(int rows, int cols, CLEAR_TYPE* data) { 91 for (int i = 0; i < rows * cols; ++i) { 92 data[i] = 0; 93 } 94 } 95 96 bool check_row_row(std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* results, int rows, 97 int cols, int depth) { 98 int wrong = 0; 99 int rounding = (1 << (SHIFT - 1)); 100 for (int i = 0; i < rows; ++i) { 101 for (int j = 0; j < cols; ++j) { 102 int expected = 0; 103 for (int k = 0; k < depth; ++k) { 104 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) * 105 (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET); 106 } 107 expected += SUM_OFFSET * depth; 108 expected *= MUL_OFFSET; 109 expected += rounding; 110 expected = (expected >> SHIFT); 111 if (expected < 0) { 112 expected = 0; 113 } else if (expected > 255) { 114 expected = 255; 115 } 116 expected = static_cast<int>(static_cast<std::uint8_t>(expected)); 117 int actual = static_cast<int>(results[i * cols + j]); 118 if (actual != expected) { 119 std::cout << "Wrong @" << i << "x" << j << " : " << actual 120 << " != " << expected << std::endl; 121 wrong++; 122 } 123 } 124 } 125 if (wrong != 0) { 126 std::cout << wrong << "/" << (rows * cols) << std::endl; 127 } 128 return wrong == 0; 129 } 130 131 bool check_row_col(std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* results, int rows, 132 int cols, int depth) { 133 int wrong = 0; 134 int rounding = (1 << (SHIFT - 1)); 135 for (int i = 0; i < rows; ++i) { 136 for (int j = 0; j < cols; ++j) { 137 int expected = 0; 138 for (int k = 0; k < depth; ++k) { 139 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) * 140 (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET); 141 } 142 expected += SUM_OFFSET * depth; 143 expected *= MUL_OFFSET; 144 expected += rounding; 145 expected = (expected >> SHIFT); 146 if (expected < 0) { 147 expected = 0; 148 } else if (expected > 255) { 149 expected = 255; 150 } 151 expected = static_cast<int>(static_cast<std::uint8_t>(expected)); 152 int actual = static_cast<int>(results[i * cols + j]); 153 if (actual != expected) { 154 wrong++; 155 } 156 } 157 } 158 return wrong == 0; 159 } 160 161 bool check_row_row_f(std::uint8_t* lhs, std::uint8_t* rhs, float* results, int rows, 162 int cols, int depth) { 163 int wrong = 0; 164 for (int i = 0; i < rows; ++i) { 165 for (int j = 0; j < cols; ++j) { 166 int expected = 0; 167 for (int k = 0; k < depth; ++k) { 168 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) * 169 (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET); 170 } 171 float expected_float = static_cast<float>(expected) * FLOAT_SCALE; 172 float actual = results[i * cols + j]; 173 if (actual != expected_float) { 174 wrong++; 175 } 176 } 177 } 178 return wrong == 0; 179 } 180 181 bool check_row_col_f(std::uint8_t* lhs, std::uint8_t* rhs, float* results, int rows, 182 int cols, int depth) { 183 int wrong = 0; 184 for (int i = 0; i < rows; ++i) { 185 for (int j = 0; j < cols; ++j) { 186 int expected = 0; 187 for (int k = 0; k < depth; ++k) { 188 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) * 189 (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET); 190 } 191 float expected_float = static_cast<float>(expected) * FLOAT_SCALE; 192 float actual = results[i * cols + j]; 193 if (actual != expected_float) { 194 wrong++; 195 } 196 } 197 } 198 return wrong == 0; 199 } 200 201 bool check_row_row_i32(std::uint8_t* lhs, std::uint8_t* rhs, std::int32_t* results, int rows, 202 int cols, int depth) { 203 int wrong = 0; 204 for (int i = 0; i < rows; ++i) { 205 for (int j = 0; j < cols; ++j) { 206 int expected = 0; 207 for (int k = 0; k < depth; ++k) { 208 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) * 209 (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET); 210 } 211 int actual = results[i * cols + j]; 212 if (actual != expected) { 213 wrong++; 214 } 215 } 216 } 217 return wrong == 0; 218 } 219 220 bool check_row_col_i32(std::uint8_t* lhs, std::uint8_t* rhs, std::int32_t* results, int rows, 221 int cols, int depth) { 222 int wrong = 0; 223 for (int i = 0; i < rows; ++i) { 224 for (int j = 0; j < cols; ++j) { 225 int expected = 0; 226 for (int k = 0; k < depth; ++k) { 227 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) * 228 (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET); 229 } 230 int actual = results[i * cols + j]; 231 if (actual != expected) { 232 wrong++; 233 } 234 } 235 } 236 return wrong == 0; 237 } 238 239 template <typename PARAMS, typename RESULT_TYPE> 240 void setup_params(std::uint8_t* lhs, std::uint8_t* rhs, RESULT_TYPE* result, 241 std::uint8_t* scratch, PARAMS* params) { 242 params->lhs = lhs; 243 params->rhs = rhs; 244 params->result = result; 245 params->scratch = scratch; 246 247 params->left_stream.multiplicative_sum_offset = RHS_OFFSET; 248 params->left_stream.additive_sum_offset = 0; 249 250 params->right_stream.multiplicative_sum_offset = LHS_OFFSET; 251 params->right_stream.additive_sum_offset = 0; 252 } 253 254 void setup_row_row(int m, int n, int k, ParamsRowMajor* params) { 255 params->m = m; 256 params->n = n; 257 params->k = k; 258 params->left_stream.count = k; 259 params->left_stream.stride = k; 260 params->left_stream.additive_sum_offset = 261 SUM_OFFSET * k + k * LHS_OFFSET * RHS_OFFSET; 262 params->right_stream.count = k; 263 params->right_stream.stride = k; 264 params->fused_kernel.kernel.count = k; 265 params->fused_kernel.kernel.multiplicative_offset = MUL_OFFSET; 266 params->fused_kernel.kernel.rounding_offset = (1 << (SHIFT - 1)); 267 params->fused_kernel.kernel.shift = -SHIFT; 268 params->fused_kernel.output_stream.stride = n; 269 } 270 271 void setup_row_col(int m, int n, int k, ParamsColumnMajor* params) { 272 params->m = m; 273 params->n = n; 274 params->k = k; 275 params->left_stream.count = k; 276 params->left_stream.stride = k; 277 params->left_stream.additive_sum_offset = 278 SUM_OFFSET * k + k * LHS_OFFSET * RHS_OFFSET; 279 params->right_stream.count = k; 280 params->right_stream.stride = n; 281 params->fused_kernel.kernel.count = k; 282 params->fused_kernel.kernel.multiplicative_offset = MUL_OFFSET; 283 params->fused_kernel.kernel.rounding_offset = (1 << (SHIFT - 1)); 284 params->fused_kernel.kernel.shift = -SHIFT; 285 params->fused_kernel.output_stream.stride = n; 286 } 287 288 void setup_row_row_f(int m, int n, int k, ParamsRowMajorAsFloat* params) { 289 params->m = m; 290 params->n = n; 291 params->k = k; 292 params->left_stream.count = k; 293 params->left_stream.stride = k; 294 params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET; 295 params->right_stream.count = k; 296 params->right_stream.stride = k; 297 params->fused_kernel.kernel.count = k; 298 params->fused_kernel.kernel.scale = FLOAT_SCALE; 299 params->fused_kernel.output_stream.stride = n * sizeof(float); 300 } 301 302 void setup_row_col_f(int m, int n, int k, ParamsColumnMajorAsFloat* params) { 303 params->m = m; 304 params->n = n; 305 params->k = k; 306 params->left_stream.count = k; 307 params->left_stream.stride = k; 308 params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET; 309 params->right_stream.count = k; 310 params->right_stream.stride = n; 311 params->fused_kernel.kernel.count = k; 312 params->fused_kernel.kernel.scale = FLOAT_SCALE; 313 params->fused_kernel.output_stream.stride = n * sizeof(float); 314 } 315 316 void setup_row_row_i32(int m, int n, int k, ParamsRowMajorAsInt32* params) { 317 params->m = m; 318 params->n = n; 319 params->k = k; 320 params->left_stream.count = k; 321 params->left_stream.stride = k; 322 params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET; 323 params->right_stream.count = k; 324 params->right_stream.stride = k; 325 params->fused_kernel.kernel.count = k; 326 params->fused_kernel.output_stream.stride = n * sizeof(std::int32_t); 327 } 328 329 void setup_row_col_i32(int m, int n, int k, ParamsColumnMajorAsInt32* params) { 330 params->m = m; 331 params->n = n; 332 params->k = k; 333 params->left_stream.count = k; 334 params->left_stream.stride = k; 335 params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET; 336 params->right_stream.count = k; 337 params->right_stream.stride = n; 338 params->fused_kernel.kernel.count = k; 339 params->fused_kernel.output_stream.stride = n * sizeof(std::int32_t); 340 } 341 342 int main() { 343 ParamsRowMajor params_row; 344 ParamsColumnMajor params_col; 345 ParamsRowMajorAsFloat params_row_f; 346 ParamsColumnMajorAsFloat params_col_f; 347 ParamsRowMajorAsInt32 params_row_i32; 348 ParamsColumnMajorAsInt32 params_col_i32; 349 350 std::unique_ptr<std::uint8_t> lhs(new std::uint8_t[1024 * 1024]); 351 std::unique_ptr<std::uint8_t> rhs(new std::uint8_t[1024 * 1024]); 352 std::unique_ptr<std::uint8_t> result(new std::uint8_t[1024 * 1024]); 353 std::unique_ptr<float> result_f(new float[1024 * 1024]); 354 std::unique_ptr<std::int32_t> result_i32(new std::int32_t[1024 * 1024]); 355 std::unique_ptr<std::uint8_t> scratch(new std::uint8_t[4048 * 1024]); 356 357 setup_params(lhs.get(), rhs.get(), result.get(), scratch.get(), ¶ms_row); 358 setup_params(lhs.get(), rhs.get(), result.get(), scratch.get(), ¶ms_col); 359 setup_params(lhs.get(), rhs.get(), result_f.get(), scratch.get(), 360 ¶ms_row_f); 361 setup_params(lhs.get(), rhs.get(), result_f.get(), scratch.get(), 362 ¶ms_col_f); 363 setup_params(lhs.get(), rhs.get(), result_i32.get(), scratch.get(), 364 ¶ms_row_i32); 365 setup_params(lhs.get(), rhs.get(), result_i32.get(), scratch.get(), 366 ¶ms_col_i32); 367 368 Pool pool; 369 Context context(4, &pool); 370 371 for (int i = 1; i < 16; ++i) { 372 for (int j = 1; j < 16; ++j) { 373 for (int k = 1; k < 24; ++k) { 374 prepare_test_data(lhs.get(), i, k, 11, 13); 375 prepare_test_data(rhs.get(), j, k, 13, 17); 376 377 clear(i, j, result.get()); 378 setup_row_row(i, j, k, ¶ms_row); 379 Gemm<Executor, ParamsRowMajor, 2, 4, 8>(params_row); 380 if (!check_row_row(lhs.get(), rhs.get(), result.get(), i, j, k)) { 381 std::cout << "Row: " << i << "x" << j << "x" << k << " : ERROR" 382 << std::endl; 383 std::cout << "Exiting." << std::endl; 384 std::exit(1); 385 } 386 387 clear(i, j, result.get()); 388 setup_row_col(i, j, k, ¶ms_col); 389 Gemm<Executor, ParamsColumnMajor, 2, 4, 8>(params_col); 390 if (!check_row_col(lhs.get(), rhs.get(), result.get(), i, j, k)) { 391 std::cout << "Column: " << i << "x" << j << "x" << k << " : ERROR" 392 << std::endl; 393 std::cout << "Exiting." << std::endl; 394 std::exit(1); 395 } 396 397 clear(i, j, result_f.get()); 398 setup_row_row_f(i, j, k, ¶ms_row_f); 399 Gemm<Executor, ParamsRowMajorAsFloat, 2, 4, 8>(params_row_f); 400 if (!check_row_row_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) { 401 std::cout << "RowAsFloat: " << i << "x" << j << "x" << k << " : ERROR" 402 << std::endl; 403 std::cout << "Exiting." << std::endl; 404 std::exit(1); 405 } 406 407 clear(i, j, result_f.get()); 408 setup_row_col_f(i, j, k, ¶ms_col_f); 409 Gemm<Executor, ParamsColumnMajorAsFloat, 2, 4, 8>(params_col_f); 410 if (!check_row_col_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) { 411 std::cout << "ColumnAsFloat: " << i << "x" << j << "x" << k 412 << " : ERROR" << std::endl; 413 std::cout << "Exiting." << std::endl; 414 std::exit(1); 415 } 416 417 clear(i, j, result_i32.get()); 418 setup_row_row_i32(i, j, k, ¶ms_row_i32); 419 Gemm<Executor, ParamsRowMajorAsInt32, 2, 4, 8>(params_row_i32); 420 if (!check_row_row_i32(lhs.get(), rhs.get(), result_i32.get(), i, j, 421 k)) { 422 std::cout << "RowAsInt32: " << i << "x" << j << "x" << k << " : ERROR" 423 << std::endl; 424 std::cout << "Exiting." << std::endl; 425 std::exit(1); 426 } 427 428 clear(i, j, result_i32.get()); 429 setup_row_col_i32(i, j, k, ¶ms_col_i32); 430 Gemm<Executor, ParamsColumnMajorAsInt32, 2, 4, 8>(params_col_i32); 431 if (!check_row_col_i32(lhs.get(), rhs.get(), result_i32.get(), i, j, 432 k)) { 433 std::cout << "ColumnAsInt32: " << i << "x" << j << "x" << k 434 << " : ERROR" << std::endl; 435 std::cout << "Exiting." << std::endl; 436 std::exit(1); 437 } 438 } 439 } 440 } 441 442 for (int i = 1; i < 1024; i += 211) { 443 for (int j = 1; j < 1024; j += 211) { 444 for (int k = 8; k < 1024; k += 111) { 445 prepare_test_data(lhs.get(), i, k, 11, 13); 446 prepare_test_data(rhs.get(), j, k, 13, 17); 447 448 clear(i, j, result.get()); 449 setup_row_row(i, j, k, ¶ms_row); 450 MultiThreadGemm<Context, Executor, ParamsRowMajor, 2, 4, 8>(&context, 451 params_row); 452 if (!check_row_row(lhs.get(), rhs.get(), result.get(), i, j, k)) { 453 std::cout << "Row(MT): " << i << "x" << j << "x" << k << " : ERROR" 454 << std::endl; 455 std::cout << "Exiting." << std::endl; 456 std::exit(1); 457 } 458 459 clear(i, j, result.get()); 460 setup_row_col(i, j, k, ¶ms_col); 461 MultiThreadGemm<Context, Executor, ParamsColumnMajor, 2, 4, 8>( 462 &context, params_col); 463 if (!check_row_col(lhs.get(), rhs.get(), result.get(), i, j, k)) { 464 std::cout << "Column(MT): " << i << "x" << j << "x" << k << " : ERROR" 465 << std::endl; 466 std::cout << "Exiting." << std::endl; 467 std::exit(1); 468 } 469 470 clear(i, j, result_f.get()); 471 setup_row_row_f(i, j, k, ¶ms_row_f); 472 MultiThreadGemm<Context, Executor, ParamsRowMajorAsFloat, 2, 4, 8>( 473 &context, params_row_f); 474 if (!check_row_row_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) { 475 std::cout << "RowAsFloat(MT): " << i << "x" << j << "x" << k 476 << " : ERROR" << std::endl; 477 std::cout << "Exiting." << std::endl; 478 std::exit(1); 479 } 480 481 clear(i, j, result_f.get()); 482 setup_row_col_f(i, j, k, ¶ms_col_f); 483 MultiThreadGemm<Context, Executor, ParamsColumnMajorAsFloat, 2, 4, 8>( 484 &context, params_col_f); 485 if (!check_row_col_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) { 486 std::cout << "ColumnAsFloat(MT): " << i << "x" << j << "x" << k 487 << " : ERROR" << std::endl; 488 std::cout << "Exiting." << std::endl; 489 std::exit(1); 490 } 491 492 clear(i, j, result_i32.get()); 493 setup_row_row_i32(i, j, k, ¶ms_row_i32); 494 MultiThreadGemm<Context, Executor, ParamsRowMajorAsInt32, 2, 4, 8>( 495 &context, params_row_i32); 496 if (!check_row_row_i32(lhs.get(), rhs.get(), result_i32.get(), i, j, 497 k)) { 498 std::cout << "RowAsInt32(MT): " << i << "x" << j << "x" << k 499 << " : ERROR" << std::endl; 500 std::cout << "Exiting." << std::endl; 501 std::exit(1); 502 } 503 504 clear(i, j, result_i32.get()); 505 setup_row_col_i32(i, j, k, ¶ms_col_i32); 506 MultiThreadGemm<Context, Executor, ParamsColumnMajorAsInt32, 2, 4, 8>( 507 &context, params_col_i32); 508 if (!check_row_col_i32(lhs.get(), rhs.get(), result_i32.get(), i, j, 509 k)) { 510 std::cout << "ColumnAsInt32(MT): " << i << "x" << j << "x" << k 511 << " : ERROR" << std::endl; 512 std::cout << "Exiting." << std::endl; 513 std::exit(1); 514 } 515 } 516 } 517 } 518 519 std::cout << "OK." << std::endl; 520 return 0; 521 } 522