1 /* 2 * Copyright 2016 Google Inc. 3 * 4 * Use of this source code is governed by a BSD-style license that can 5 * be found in the LICENSE file. 6 * 7 */ 8 9 #include <stdio.h> 10 #include <stdlib.h> 11 12 // 13 // 14 // 15 16 #include "gen.h" 17 #include "transpose.h" 18 19 #include "common/util.h" 20 #include "common/macros.h" 21 22 // 23 // 24 // 25 26 struct hsg_transpose_state 27 { 28 FILE * header; 29 struct hsg_config const * config; 30 }; 31 32 static 33 char 34 hsg_transpose_reg_prefix(uint32_t const cols_log2) 35 { 36 return 'a' + (('r' + cols_log2 - 'a') % 26); 37 } 38 39 static 40 void 41 hsg_transpose_blend(uint32_t const cols_log2, 42 uint32_t const row_ll, // lower-left 43 uint32_t const row_ur, // upper-right 44 void * blend) 45 { 46 struct hsg_transpose_state * const state = blend; 47 48 // we're starting register names at '1' for now 49 fprintf(state->header, 50 " HS_TRANSPOSE_BLEND( %c, %c, %2u, %3u, %3u ) \\\n", 51 hsg_transpose_reg_prefix(cols_log2-1), 52 hsg_transpose_reg_prefix(cols_log2), 53 cols_log2,row_ll+1,row_ur+1); 54 } 55 56 static 57 void 58 hsg_transpose_remap(uint32_t const row_from, 59 uint32_t const row_to, 60 void * remap) 61 { 62 struct hsg_transpose_state * const state = remap; 63 64 // we're starting register names at '1' for now 65 fprintf(state->header, 66 " HS_TRANSPOSE_REMAP( %c, %3u, %3u ) \\\n", 67 hsg_transpose_reg_prefix(state->config->warp.lanes_log2), 68 row_from+1,row_to+1); 69 } 70 71 // 72 // 73 // 74 75 static 76 void 77 hsg_copyright(FILE * file) 78 { 79 fprintf(file, 80 "// \n" 81 "// Copyright 2016 Google Inc. \n" 82 "// \n" 83 "// Use of this source code is governed by a BSD-style \n" 84 "// license that can be found in the LICENSE file. \n" 85 "// \n" 86 "\n"); 87 } 88 89 static 90 void 91 hsg_macros(FILE * file) 92 { 93 fprintf(file, 94 "// target-specific config \n" 95 "#include \"hs_config.h\" \n" 96 " \n" 97 "// GLSL preamble \n" 98 "#include \"hs_glsl_preamble.h\"\n" 99 " \n" 100 "// arch/target-specific macros \n" 101 "#include \"hs_glsl_macros.h\" \n" 102 " \n" 103 "// \n" 104 "// \n" 105 "// \n" 106 "\n"); 107 } 108 109 // 110 // 111 // 112 113 struct hsg_target_state 114 { 115 FILE * header; 116 FILE * modules; 117 FILE * source; 118 }; 119 120 // 121 // 122 // 123 124 void 125 hsg_target_glsl(struct hsg_target * const target, 126 struct hsg_config const * const config, 127 struct hsg_merge const * const merge, 128 struct hsg_op const * const ops, 129 uint32_t const depth) 130 { 131 switch (ops->type) 132 { 133 case HSG_OP_TYPE_END: 134 fprintf(target->state->source, 135 "}\n"); 136 137 if (depth == 0) { 138 fclose(target->state->source); 139 target->state->source = NULL; 140 } 141 break; 142 143 case HSG_OP_TYPE_BEGIN: 144 fprintf(target->state->source, 145 "{\n"); 146 break; 147 148 case HSG_OP_TYPE_ELSE: 149 fprintf(target->state->source, 150 "else\n"); 151 break; 152 153 case HSG_OP_TYPE_TARGET_BEGIN: 154 { 155 // allocate state 156 target->state = malloc(sizeof(*target->state)); 157 158 // allocate files 159 target->state->header = fopen("hs_config.h", "wb"); 160 target->state->modules = fopen("hs_modules.h","wb"); 161 162 hsg_copyright(target->state->header); 163 hsg_copyright(target->state->modules); 164 165 // initialize header 166 uint32_t const bc_max = msb_idx_u32(pow2_rd_u32(merge->warps)); 167 168 fprintf(target->state->header, 169 "#ifndef HS_GLSL_ONCE \n" 170 "#define HS_GLSL_ONCE \n" 171 " \n" 172 "#define HS_SLAB_THREADS_LOG2 %u \n" 173 "#define HS_SLAB_THREADS (1 << HS_SLAB_THREADS_LOG2) \n" 174 "#define HS_SLAB_WIDTH_LOG2 %u \n" 175 "#define HS_SLAB_WIDTH (1 << HS_SLAB_WIDTH_LOG2) \n" 176 "#define HS_SLAB_HEIGHT %u \n" 177 "#define HS_SLAB_KEYS (HS_SLAB_WIDTH * HS_SLAB_HEIGHT)\n" 178 "#define HS_REG_LAST(c) c##%u \n" 179 "#define HS_KEY_WORDS %u \n" 180 "#define HS_VAL_WORDS 0 \n" 181 "#define HS_BS_SLABS %u \n" 182 "#define HS_BS_SLABS_LOG2_RU %u \n" 183 "#define HS_BC_SLABS_LOG2_MAX %u \n" 184 "#define HS_FM_BLOCK_HEIGHT %u \n" 185 "#define HS_FM_SCALE_MIN %u \n" 186 "#define HS_FM_SCALE_MAX %u \n" 187 "#define HS_HM_BLOCK_HEIGHT %u \n" 188 "#define HS_HM_SCALE_MIN %u \n" 189 "#define HS_HM_SCALE_MAX %u \n" 190 "#define HS_EMPTY \n" 191 " \n", 192 config->warp.lanes_log2, // FIXME -- this matters for SIMD 193 config->warp.lanes_log2, 194 config->thread.regs, 195 config->thread.regs, 196 config->type.words, 197 merge->warps, 198 msb_idx_u32(pow2_ru_u32(merge->warps)), 199 bc_max, 200 config->merge.flip.warps, 201 config->merge.flip.lo, 202 config->merge.flip.hi, 203 config->merge.half.warps, 204 config->merge.half.lo, 205 config->merge.half.hi); 206 207 if (target->define != NULL) 208 fprintf(target->state->header,"#define %s\n\n",target->define); 209 210 fprintf(target->state->header, 211 "#define HS_SLAB_ROWS() \\\n"); 212 213 for (uint32_t ii=1; ii<=config->thread.regs; ii++) 214 fprintf(target->state->header, 215 " HS_SLAB_ROW( %3u, %3u ) \\\n",ii,ii-1); 216 217 fprintf(target->state->header, 218 " HS_EMPTY\n" 219 " \n"); 220 221 fprintf(target->state->header, 222 "#define HS_TRANSPOSE_SLAB() \\\n"); 223 224 for (uint32_t ii=1; ii<=config->warp.lanes_log2; ii++) 225 fprintf(target->state->header, 226 " HS_TRANSPOSE_STAGE( %u ) \\\n",ii); 227 228 struct hsg_transpose_state state[1] = 229 { 230 { .header = target->state->header, 231 .config = config 232 } 233 }; 234 235 hsg_transpose(config->warp.lanes_log2, 236 config->thread.regs, 237 hsg_transpose_blend,state, 238 hsg_transpose_remap,state); 239 240 fprintf(target->state->header, 241 " HS_EMPTY\n" 242 " \n"); 243 } 244 break; 245 246 case HSG_OP_TYPE_TARGET_END: 247 // decorate the files 248 fprintf(target->state->header, 249 "#endif \n" 250 " \n" 251 "// \n" 252 "// \n" 253 "// \n" 254 " \n"); 255 256 // close files 257 fclose(target->state->header); 258 fclose(target->state->modules); 259 260 // free state 261 free(target->state); 262 break; 263 264 case HSG_OP_TYPE_TRANSPOSE_KERNEL_PROTO: 265 { 266 fprintf(target->state->modules, 267 "#include \"hs_transpose.len.xxd\"\n,\n" 268 "#include \"hs_transpose.spv.xxd\"\n,\n"); 269 270 target->state->source = fopen("hs_transpose.comp","w+"); 271 272 hsg_copyright(target->state->source); 273 274 hsg_macros(target->state->source); 275 276 fprintf(target->state->source, 277 "HS_TRANSPOSE_KERNEL_PROTO()\n"); 278 } 279 break; 280 281 case HSG_OP_TYPE_TRANSPOSE_KERNEL_PREAMBLE: 282 { 283 fprintf(target->state->source, 284 "HS_SUBGROUP_PREAMBLE();\n"); 285 286 fprintf(target->state->source, 287 "HS_SLAB_GLOBAL_PREAMBLE();\n"); 288 } 289 break; 290 291 case HSG_OP_TYPE_TRANSPOSE_KERNEL_BODY: 292 { 293 fprintf(target->state->source, 294 "HS_TRANSPOSE_SLAB()\n"); 295 } 296 break; 297 298 case HSG_OP_TYPE_BS_KERNEL_PROTO: 299 { 300 struct hsg_merge const * const m = merge + ops->a; 301 302 uint32_t const bs = pow2_ru_u32(m->warps); 303 uint32_t const msb = msb_idx_u32(bs); 304 305 fprintf(target->state->modules, 306 "#include \"hs_bs_%u.len.xxd\"\n,\n" 307 "#include \"hs_bs_%u.spv.xxd\"\n,\n", 308 msb, 309 msb); 310 311 char filename[] = { "hs_bs_XX.comp" }; 312 sprintf(filename,"hs_bs_%u.comp",msb); 313 314 target->state->source = fopen(filename,"w+"); 315 316 hsg_copyright(target->state->source); 317 318 hsg_macros(target->state->source); 319 320 if (m->warps > 1) 321 { 322 fprintf(target->state->source, 323 "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n", 324 m->warps * config->warp.lanes, 325 m->rows_bs); 326 } 327 328 fprintf(target->state->source, 329 "HS_BS_KERNEL_PROTO(%u,%u)\n", 330 m->warps,msb); 331 } 332 break; 333 334 case HSG_OP_TYPE_BS_KERNEL_PREAMBLE: 335 { 336 fprintf(target->state->source, 337 "HS_SUBGROUP_PREAMBLE();\n"); 338 339 fprintf(target->state->source, 340 "HS_SLAB_GLOBAL_PREAMBLE();\n"); 341 } 342 break; 343 344 case HSG_OP_TYPE_BC_KERNEL_PROTO: 345 { 346 struct hsg_merge const * const m = merge + ops->a; 347 348 uint32_t const msb = msb_idx_u32(m->warps); 349 350 fprintf(target->state->modules, 351 "#include \"hs_bc_%u.len.xxd\"\n,\n" 352 "#include \"hs_bc_%u.spv.xxd\"\n,\n", 353 msb, 354 msb); 355 356 char filename[] = { "hs_bc_XX.comp" }; 357 sprintf(filename,"hs_bc_%u.comp",msb); 358 359 target->state->source = fopen(filename,"w+"); 360 361 hsg_copyright(target->state->source); 362 363 hsg_macros(target->state->source); 364 365 if (m->warps > 1) 366 { 367 fprintf(target->state->source, 368 "HS_BLOCK_LOCAL_MEM_DECL(%u,%u);\n\n", 369 m->warps * config->warp.lanes, 370 m->rows_bc); 371 } 372 373 fprintf(target->state->source, 374 "HS_BC_KERNEL_PROTO(%u,%u)\n", 375 m->warps,msb); 376 } 377 break; 378 379 case HSG_OP_TYPE_BC_KERNEL_PREAMBLE: 380 { 381 fprintf(target->state->source, 382 "HS_SUBGROUP_PREAMBLE()\n"); 383 384 fprintf(target->state->source, 385 "HS_SLAB_GLOBAL_PREAMBLE();\n"); 386 } 387 break; 388 389 case HSG_OP_TYPE_FM_KERNEL_PROTO: 390 { 391 fprintf(target->state->modules, 392 "#include \"hs_fm_%u_%u.len.xxd\"\n,\n" 393 "#include \"hs_fm_%u_%u.spv.xxd\"\n,\n", 394 ops->a,ops->b, 395 ops->a,ops->b); 396 397 char filename[] = { "hs_fm_X_XX.comp" }; 398 sprintf(filename,"hs_fm_%u_%u.comp",ops->a,ops->b); 399 400 target->state->source = fopen(filename,"w+"); 401 402 hsg_copyright(target->state->source); 403 404 hsg_macros(target->state->source); 405 406 fprintf(target->state->source, 407 "HS_FM_KERNEL_PROTO(%u,%u)\n", 408 ops->a,ops->b); 409 } 410 break; 411 412 case HSG_OP_TYPE_FM_KERNEL_PREAMBLE: 413 { 414 fprintf(target->state->source, 415 "HS_SUBGROUP_PREAMBLE()\n"); 416 417 fprintf(target->state->source, 418 "HS_FM_PREAMBLE(%u);\n", 419 ops->a); 420 } 421 break; 422 423 case HSG_OP_TYPE_HM_KERNEL_PROTO: 424 { 425 fprintf(target->state->modules, 426 "#include \"hs_hm_%u.len.xxd\"\n,\n" 427 "#include \"hs_hm_%u.spv.xxd\"\n,\n", 428 ops->a, 429 ops->a); 430 431 char filename[] = { "hs_hm_X.comp" }; 432 sprintf(filename,"hs_hm_%u.comp",ops->a); 433 434 target->state->source = fopen(filename,"w+"); 435 436 hsg_copyright(target->state->source); 437 438 hsg_macros(target->state->source); 439 440 fprintf(target->state->source, 441 "HS_HM_KERNEL_PROTO(%u)\n", 442 ops->a); 443 } 444 break; 445 446 case HSG_OP_TYPE_HM_KERNEL_PREAMBLE: 447 { 448 fprintf(target->state->source, 449 "HS_SUBGROUP_PREAMBLE()\n"); 450 451 fprintf(target->state->source, 452 "HS_HM_PREAMBLE(%u);\n", 453 ops->a); 454 } 455 break; 456 457 case HSG_OP_TYPE_BX_REG_GLOBAL_LOAD: 458 { 459 static char const * const vstr[] = { "vin", "vout" }; 460 461 fprintf(target->state->source, 462 "HS_KEY_TYPE r%-3u = HS_SLAB_GLOBAL_LOAD(%s,%u);\n", 463 ops->n,vstr[ops->v],ops->n-1); 464 } 465 break; 466 467 case HSG_OP_TYPE_BX_REG_GLOBAL_STORE: 468 fprintf(target->state->source, 469 "HS_SLAB_GLOBAL_STORE(%u,r%u);\n", 470 ops->n-1,ops->n); 471 break; 472 473 case HSG_OP_TYPE_HM_REG_GLOBAL_LOAD: 474 fprintf(target->state->source, 475 "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n", 476 ops->a,ops->b); 477 break; 478 479 case HSG_OP_TYPE_HM_REG_GLOBAL_STORE: 480 fprintf(target->state->source, 481 "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n", 482 ops->b,ops->a); 483 break; 484 485 case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_LEFT: 486 fprintf(target->state->source, 487 "HS_KEY_TYPE r%-3u = HS_XM_GLOBAL_LOAD_L(%u);\n", 488 ops->a,ops->b); 489 break; 490 491 case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_LEFT: 492 fprintf(target->state->source, 493 "HS_XM_GLOBAL_STORE_L(%-3u,r%u);\n", 494 ops->b,ops->a); 495 break; 496 497 case HSG_OP_TYPE_FM_REG_GLOBAL_LOAD_RIGHT: 498 fprintf(target->state->source, 499 "HS_KEY_TYPE r%-3u = HS_FM_GLOBAL_LOAD_R(%u);\n", 500 ops->b,ops->a); 501 break; 502 503 case HSG_OP_TYPE_FM_REG_GLOBAL_STORE_RIGHT: 504 fprintf(target->state->source, 505 "HS_FM_GLOBAL_STORE_R(%-3u,r%u);\n", 506 ops->a,ops->b); 507 break; 508 509 case HSG_OP_TYPE_FM_MERGE_RIGHT_PRED: 510 { 511 if (ops->a <= ops->b) 512 { 513 fprintf(target->state->source, 514 "if (HS_FM_IS_NOT_LAST_SPAN() || (fm_frac == 0))\n"); 515 } 516 else if (ops->b > 1) 517 { 518 fprintf(target->state->source, 519 "else if (fm_frac == %u)\n", 520 ops->b); 521 } 522 else 523 { 524 fprintf(target->state->source, 525 "else\n"); 526 } 527 } 528 break; 529 530 case HSG_OP_TYPE_SLAB_FLIP: 531 fprintf(target->state->source, 532 "HS_SLAB_FLIP_PREAMBLE(%u);\n", 533 ops->n-1); 534 break; 535 536 case HSG_OP_TYPE_SLAB_HALF: 537 fprintf(target->state->source, 538 "HS_SLAB_HALF_PREAMBLE(%u);\n", 539 ops->n / 2); 540 break; 541 542 case HSG_OP_TYPE_CMP_FLIP: 543 fprintf(target->state->source, 544 "HS_CMP_FLIP(%-3u,r%-3u,r%-3u);\n",ops->a,ops->b,ops->c); 545 break; 546 547 case HSG_OP_TYPE_CMP_HALF: 548 fprintf(target->state->source, 549 "HS_CMP_HALF(%-3u,r%-3u);\n",ops->a,ops->b); 550 break; 551 552 case HSG_OP_TYPE_CMP_XCHG: 553 if (ops->c == UINT32_MAX) 554 { 555 fprintf(target->state->source, 556 "HS_CMP_XCHG(r%-3u,r%-3u);\n", 557 ops->a,ops->b); 558 } 559 else 560 { 561 fprintf(target->state->source, 562 "HS_CMP_XCHG(r%u_%u,r%u_%u);\n", 563 ops->c,ops->a,ops->c,ops->b); 564 } 565 break; 566 567 case HSG_OP_TYPE_BS_REG_SHARED_STORE_V: 568 fprintf(target->state->source, 569 "HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u) = r%u;\n", 570 merge[ops->a].warps,ops->c,ops->b); 571 break; 572 573 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_V: 574 fprintf(target->state->source, 575 "r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n", 576 ops->b,merge[ops->a].warps,ops->c); 577 break; 578 579 case HSG_OP_TYPE_BC_REG_SHARED_LOAD_V: 580 fprintf(target->state->source, 581 "HS_KEY_TYPE r%-3u = HS_BX_LOCAL_V(%-3u * HS_SLAB_THREADS * %-3u);\n", 582 ops->b,ops->a,ops->c); 583 break; 584 585 case HSG_OP_TYPE_BX_REG_SHARED_STORE_LEFT: 586 fprintf(target->state->source, 587 "HS_SLAB_LOCAL_L(%5u) = r%u_%u;\n", 588 ops->b * config->warp.lanes, 589 ops->c, 590 ops->a); 591 break; 592 593 case HSG_OP_TYPE_BS_REG_SHARED_STORE_RIGHT: 594 fprintf(target->state->source, 595 "HS_SLAB_LOCAL_R(%5u) = r%u_%u;\n", 596 ops->b * config->warp.lanes, 597 ops->c, 598 ops->a); 599 break; 600 601 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_LEFT: 602 fprintf(target->state->source, 603 "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_L(%u);\n", 604 ops->c, 605 ops->a, 606 ops->b * config->warp.lanes); 607 break; 608 609 case HSG_OP_TYPE_BS_REG_SHARED_LOAD_RIGHT: 610 fprintf(target->state->source, 611 "HS_KEY_TYPE r%u_%-3u = HS_SLAB_LOCAL_R(%u);\n", 612 ops->c, 613 ops->a, 614 ops->b * config->warp.lanes); 615 break; 616 617 case HSG_OP_TYPE_BC_REG_GLOBAL_LOAD_LEFT: 618 fprintf(target->state->source, 619 "HS_KEY_TYPE r%u_%-3u = HS_BC_GLOBAL_LOAD_L(%u);\n", 620 ops->c, 621 ops->a, 622 ops->b); 623 break; 624 625 case HSG_OP_TYPE_BLOCK_SYNC: 626 fprintf(target->state->source, 627 "HS_BLOCK_BARRIER();\n"); 628 // 629 // FIXME - Named barriers to allow coordinating warps to proceed? 630 // 631 break; 632 633 case HSG_OP_TYPE_BS_FRAC_PRED: 634 { 635 if (ops->m == 0) 636 { 637 fprintf(target->state->source, 638 "if (warp_idx < bs_full)\n"); 639 } 640 else 641 { 642 fprintf(target->state->source, 643 "else if (bs_frac == %u)\n", 644 ops->w); 645 } 646 } 647 break; 648 649 case HSG_OP_TYPE_BS_MERGE_H_PREAMBLE: 650 { 651 struct hsg_merge const * const m = merge + ops->a; 652 653 fprintf(target->state->source, 654 "HS_BS_MERGE_H_PREAMBLE(%u);\n", 655 m->warps); 656 } 657 break; 658 659 case HSG_OP_TYPE_BC_MERGE_H_PREAMBLE: 660 { 661 struct hsg_merge const * const m = merge + ops->a; 662 663 fprintf(target->state->source, 664 "HS_BC_MERGE_H_PREAMBLE(%u);\n", 665 m->warps); 666 } 667 break; 668 669 case HSG_OP_TYPE_BX_MERGE_H_PRED: 670 fprintf(target->state->source, 671 "if (HS_SUBGROUP_ID() < %u)\n", 672 ops->a); 673 break; 674 675 case HSG_OP_TYPE_BS_ACTIVE_PRED: 676 { 677 struct hsg_merge const * const m = merge + ops->a; 678 679 if (m->warps <= 32) 680 { 681 fprintf(target->state->source, 682 "if (((1u << HS_SUBGROUP_ID()) & 0x%08X) != 0)\n", 683 m->levels[ops->b].active.b32a2[0]); 684 } 685 else 686 { 687 fprintf(target->state->source, 688 "if (((1UL << HS_SUBGROUP_ID()) & 0x%08X%08XL) != 0L)\n", 689 m->levels[ops->b].active.b32a2[1], 690 m->levels[ops->b].active.b32a2[0]); 691 } 692 } 693 break; 694 695 default: 696 fprintf(stderr,"type not found: %s\n",hsg_op_type_string[ops->type]); 697 exit(EXIT_FAILURE); 698 break; 699 } 700 } 701 702 // 703 // 704 // 705