1 /* 2 * Copyright 2014 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 * 23 * Authors: 24 * Jason Ekstrand (jason (at) jlekstrand.net) 25 * 26 */ 27 28 #include <inttypes.h> 29 #include "nir_search.h" 30 31 struct match_state { 32 bool inexact_match; 33 bool has_exact_alu; 34 unsigned variables_seen; 35 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES]; 36 }; 37 38 static bool 39 match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 40 unsigned num_components, const uint8_t *swizzle, 41 struct match_state *state); 42 43 static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 }; 44 45 /** 46 * Check if a source produces a value of the given type. 47 * 48 * Used for satisfying 'a@type' constraints. 49 */ 50 static bool 51 src_is_type(nir_src src, nir_alu_type type) 52 { 53 assert(type != nir_type_invalid); 54 55 if (!src.is_ssa) 56 return false; 57 58 /* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */ 59 if (nir_alu_type_get_base_type(type) == nir_type_bool) 60 type = nir_type_bool; 61 62 if (src.ssa->parent_instr->type == nir_instr_type_alu) { 63 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr); 64 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type; 65 66 if (type == nir_type_bool) { 67 switch (src_alu->op) { 68 case nir_op_iand: 69 case nir_op_ior: 70 case nir_op_ixor: 71 return src_is_type(src_alu->src[0].src, nir_type_bool) && 72 src_is_type(src_alu->src[1].src, nir_type_bool); 73 case nir_op_inot: 74 return src_is_type(src_alu->src[0].src, nir_type_bool); 75 default: 76 break; 77 } 78 } 79 80 return nir_alu_type_get_base_type(output_type) == type; 81 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) { 82 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr); 83 84 if (type == nir_type_bool) { 85 return intr->intrinsic == nir_intrinsic_load_front_face || 86 intr->intrinsic == nir_intrinsic_load_helper_invocation; 87 } 88 } 89 90 /* don't know */ 91 return false; 92 } 93 94 static bool 95 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src, 96 unsigned num_components, const uint8_t *swizzle, 97 struct match_state *state) 98 { 99 uint8_t new_swizzle[4]; 100 101 /* Searching only works on SSA values because, if it's not SSA, we can't 102 * know if the value changed between one instance of that value in the 103 * expression and another. Also, the replace operation will place reads of 104 * that value right before the last instruction in the expression we're 105 * replacing so those reads will happen after the original reads and may 106 * not be valid if they're register reads. 107 */ 108 if (!instr->src[src].src.is_ssa) 109 return false; 110 111 /* If the source is an explicitly sized source, then we need to reset 112 * both the number of components and the swizzle. 113 */ 114 if (nir_op_infos[instr->op].input_sizes[src] != 0) { 115 num_components = nir_op_infos[instr->op].input_sizes[src]; 116 swizzle = identity_swizzle; 117 } 118 119 for (unsigned i = 0; i < num_components; ++i) 120 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]]; 121 122 /* If the value has a specific bit size and it doesn't match, bail */ 123 if (value->bit_size && 124 nir_src_bit_size(instr->src[src].src) != value->bit_size) 125 return false; 126 127 switch (value->type) { 128 case nir_search_value_expression: 129 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu) 130 return false; 131 132 return match_expression(nir_search_value_as_expression(value), 133 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr), 134 num_components, new_swizzle, state); 135 136 case nir_search_value_variable: { 137 nir_search_variable *var = nir_search_value_as_variable(value); 138 assert(var->variable < NIR_SEARCH_MAX_VARIABLES); 139 140 if (state->variables_seen & (1 << var->variable)) { 141 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa) 142 return false; 143 144 assert(!instr->src[src].abs && !instr->src[src].negate); 145 146 for (unsigned i = 0; i < num_components; ++i) { 147 if (state->variables[var->variable].swizzle[i] != new_swizzle[i]) 148 return false; 149 } 150 151 return true; 152 } else { 153 if (var->is_constant && 154 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) 155 return false; 156 157 if (var->cond && !var->cond(instr, src, num_components, new_swizzle)) 158 return false; 159 160 if (var->type != nir_type_invalid && 161 !src_is_type(instr->src[src].src, var->type)) 162 return false; 163 164 state->variables_seen |= (1 << var->variable); 165 state->variables[var->variable].src = instr->src[src].src; 166 state->variables[var->variable].abs = false; 167 state->variables[var->variable].negate = false; 168 169 for (unsigned i = 0; i < 4; ++i) { 170 if (i < num_components) 171 state->variables[var->variable].swizzle[i] = new_swizzle[i]; 172 else 173 state->variables[var->variable].swizzle[i] = 0; 174 } 175 176 return true; 177 } 178 } 179 180 case nir_search_value_constant: { 181 nir_search_constant *const_val = nir_search_value_as_constant(value); 182 183 if (!instr->src[src].src.is_ssa) 184 return false; 185 186 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const) 187 return false; 188 189 nir_load_const_instr *load = 190 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr); 191 192 switch (const_val->type) { 193 case nir_type_float: 194 for (unsigned i = 0; i < num_components; ++i) { 195 double val; 196 switch (load->def.bit_size) { 197 case 32: 198 val = load->value.f32[new_swizzle[i]]; 199 break; 200 case 64: 201 val = load->value.f64[new_swizzle[i]]; 202 break; 203 default: 204 unreachable("unknown bit size"); 205 } 206 207 if (val != const_val->data.d) 208 return false; 209 } 210 return true; 211 212 case nir_type_int: 213 case nir_type_uint: 214 case nir_type_bool32: 215 switch (load->def.bit_size) { 216 case 32: 217 for (unsigned i = 0; i < num_components; ++i) { 218 if (load->value.u32[new_swizzle[i]] != 219 (uint32_t)const_val->data.u) 220 return false; 221 } 222 return true; 223 224 case 64: 225 for (unsigned i = 0; i < num_components; ++i) { 226 if (load->value.u64[new_swizzle[i]] != const_val->data.u) 227 return false; 228 } 229 return true; 230 231 default: 232 unreachable("unknown bit size"); 233 } 234 235 default: 236 unreachable("Invalid alu source type"); 237 } 238 } 239 240 default: 241 unreachable("Invalid search value type"); 242 } 243 } 244 245 static bool 246 match_expression(const nir_search_expression *expr, nir_alu_instr *instr, 247 unsigned num_components, const uint8_t *swizzle, 248 struct match_state *state) 249 { 250 if (expr->cond && !expr->cond(instr)) 251 return false; 252 253 if (instr->op != expr->opcode) 254 return false; 255 256 assert(instr->dest.dest.is_ssa); 257 258 if (expr->value.bit_size && 259 instr->dest.dest.ssa.bit_size != expr->value.bit_size) 260 return false; 261 262 state->inexact_match = expr->inexact || state->inexact_match; 263 state->has_exact_alu = instr->exact || state->has_exact_alu; 264 if (state->inexact_match && state->has_exact_alu) 265 return false; 266 267 assert(!instr->dest.saturate); 268 assert(nir_op_infos[instr->op].num_inputs > 0); 269 270 /* If we have an explicitly sized destination, we can only handle the 271 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid 272 * expression, we don't have the information right now to propagate that 273 * swizzle through. We can only properly propagate swizzles if the 274 * instruction is vectorized. 275 */ 276 if (nir_op_infos[instr->op].output_size != 0) { 277 for (unsigned i = 0; i < num_components; i++) { 278 if (swizzle[i] != i) 279 return false; 280 } 281 } 282 283 /* Stash off the current variables_seen bitmask. This way we can 284 * restore it prior to matching in the commutative case below. 285 */ 286 unsigned variables_seen_stash = state->variables_seen; 287 288 bool matched = true; 289 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { 290 if (!match_value(expr->srcs[i], instr, i, num_components, 291 swizzle, state)) { 292 matched = false; 293 break; 294 } 295 } 296 297 if (matched) 298 return true; 299 300 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) { 301 assert(nir_op_infos[instr->op].num_inputs == 2); 302 303 /* Restore the variables_seen bitmask. If we don't do this, then we 304 * could end up with an erroneous failure due to variables found in the 305 * first match attempt above not matching those in the second. 306 */ 307 state->variables_seen = variables_seen_stash; 308 309 if (!match_value(expr->srcs[0], instr, 1, num_components, 310 swizzle, state)) 311 return false; 312 313 return match_value(expr->srcs[1], instr, 0, num_components, 314 swizzle, state); 315 } else { 316 return false; 317 } 318 } 319 320 typedef struct bitsize_tree { 321 unsigned num_srcs; 322 struct bitsize_tree *srcs[4]; 323 324 unsigned common_size; 325 bool is_src_sized[4]; 326 bool is_dest_sized; 327 328 unsigned dest_size; 329 unsigned src_size[4]; 330 } bitsize_tree; 331 332 static bitsize_tree * 333 build_bitsize_tree(void *mem_ctx, struct match_state *state, 334 const nir_search_value *value) 335 { 336 bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree); 337 338 switch (value->type) { 339 case nir_search_value_expression: { 340 nir_search_expression *expr = nir_search_value_as_expression(value); 341 nir_op_info info = nir_op_infos[expr->opcode]; 342 tree->num_srcs = info.num_inputs; 343 tree->common_size = 0; 344 for (unsigned i = 0; i < info.num_inputs; i++) { 345 tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]); 346 if (tree->is_src_sized[i]) 347 tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]); 348 tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]); 349 } 350 tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type); 351 if (tree->is_dest_sized) 352 tree->dest_size = nir_alu_type_get_type_size(info.output_type); 353 break; 354 } 355 356 case nir_search_value_variable: { 357 nir_search_variable *var = nir_search_value_as_variable(value); 358 tree->num_srcs = 0; 359 tree->is_dest_sized = true; 360 tree->dest_size = nir_src_bit_size(state->variables[var->variable].src); 361 break; 362 } 363 364 case nir_search_value_constant: { 365 tree->num_srcs = 0; 366 tree->is_dest_sized = false; 367 tree->common_size = 0; 368 break; 369 } 370 } 371 372 if (value->bit_size) { 373 assert(!tree->is_dest_sized || tree->dest_size == value->bit_size); 374 tree->common_size = value->bit_size; 375 } 376 377 return tree; 378 } 379 380 static unsigned 381 bitsize_tree_filter_up(bitsize_tree *tree) 382 { 383 for (unsigned i = 0; i < tree->num_srcs; i++) { 384 unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]); 385 if (src_size == 0) 386 continue; 387 388 if (tree->is_src_sized[i]) { 389 assert(src_size == tree->src_size[i]); 390 } else if (tree->common_size != 0) { 391 assert(src_size == tree->common_size); 392 tree->src_size[i] = src_size; 393 } else { 394 tree->common_size = src_size; 395 tree->src_size[i] = src_size; 396 } 397 } 398 399 if (tree->num_srcs && tree->common_size) { 400 if (tree->dest_size == 0) 401 tree->dest_size = tree->common_size; 402 else if (!tree->is_dest_sized) 403 assert(tree->dest_size == tree->common_size); 404 405 for (unsigned i = 0; i < tree->num_srcs; i++) { 406 if (!tree->src_size[i]) 407 tree->src_size[i] = tree->common_size; 408 } 409 } 410 411 return tree->dest_size; 412 } 413 414 static void 415 bitsize_tree_filter_down(bitsize_tree *tree, unsigned size) 416 { 417 if (tree->dest_size) 418 assert(tree->dest_size == size); 419 else 420 tree->dest_size = size; 421 422 if (!tree->is_dest_sized) { 423 if (tree->common_size) 424 assert(tree->common_size == size); 425 else 426 tree->common_size = size; 427 } 428 429 for (unsigned i = 0; i < tree->num_srcs; i++) { 430 if (!tree->src_size[i]) { 431 assert(tree->common_size); 432 tree->src_size[i] = tree->common_size; 433 } 434 bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]); 435 } 436 } 437 438 static nir_alu_src 439 construct_value(const nir_search_value *value, 440 unsigned num_components, bitsize_tree *bitsize, 441 struct match_state *state, 442 nir_instr *instr, void *mem_ctx) 443 { 444 switch (value->type) { 445 case nir_search_value_expression: { 446 const nir_search_expression *expr = nir_search_value_as_expression(value); 447 448 if (nir_op_infos[expr->opcode].output_size != 0) 449 num_components = nir_op_infos[expr->opcode].output_size; 450 451 nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode); 452 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, 453 bitsize->dest_size, NULL); 454 alu->dest.write_mask = (1 << num_components) - 1; 455 alu->dest.saturate = false; 456 457 /* We have no way of knowing what values in a given search expression 458 * map to a particular replacement value. Therefore, if the 459 * expression we are replacing has any exact values, the entire 460 * replacement should be exact. 461 */ 462 alu->exact = state->has_exact_alu; 463 464 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) { 465 /* If the source is an explicitly sized source, then we need to reset 466 * the number of components to match. 467 */ 468 if (nir_op_infos[alu->op].input_sizes[i] != 0) 469 num_components = nir_op_infos[alu->op].input_sizes[i]; 470 471 alu->src[i] = construct_value(expr->srcs[i], 472 num_components, bitsize->srcs[i], 473 state, instr, mem_ctx); 474 } 475 476 nir_instr_insert_before(instr, &alu->instr); 477 478 nir_alu_src val; 479 val.src = nir_src_for_ssa(&alu->dest.dest.ssa); 480 val.negate = false; 481 val.abs = false, 482 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle); 483 484 return val; 485 } 486 487 case nir_search_value_variable: { 488 const nir_search_variable *var = nir_search_value_as_variable(value); 489 assert(state->variables_seen & (1 << var->variable)); 490 491 nir_alu_src val = { NIR_SRC_INIT }; 492 nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx); 493 494 assert(!var->is_constant); 495 496 return val; 497 } 498 499 case nir_search_value_constant: { 500 const nir_search_constant *c = nir_search_value_as_constant(value); 501 nir_load_const_instr *load = 502 nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size); 503 504 switch (c->type) { 505 case nir_type_float: 506 load->def.name = ralloc_asprintf(load, "%f", c->data.d); 507 switch (bitsize->dest_size) { 508 case 32: 509 load->value.f32[0] = c->data.d; 510 break; 511 case 64: 512 load->value.f64[0] = c->data.d; 513 break; 514 default: 515 unreachable("unknown bit size"); 516 } 517 break; 518 519 case nir_type_int: 520 load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i); 521 switch (bitsize->dest_size) { 522 case 32: 523 load->value.i32[0] = c->data.i; 524 break; 525 case 64: 526 load->value.i64[0] = c->data.i; 527 break; 528 default: 529 unreachable("unknown bit size"); 530 } 531 break; 532 533 case nir_type_uint: 534 load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u); 535 switch (bitsize->dest_size) { 536 case 32: 537 load->value.u32[0] = c->data.u; 538 break; 539 case 64: 540 load->value.u64[0] = c->data.u; 541 break; 542 default: 543 unreachable("unknown bit size"); 544 } 545 break; 546 547 case nir_type_bool32: 548 load->value.u32[0] = c->data.u; 549 break; 550 default: 551 unreachable("Invalid alu source type"); 552 } 553 554 nir_instr_insert_before(instr, &load->instr); 555 556 nir_alu_src val; 557 val.src = nir_src_for_ssa(&load->def); 558 val.negate = false; 559 val.abs = false, 560 memset(val.swizzle, 0, sizeof val.swizzle); 561 562 return val; 563 } 564 565 default: 566 unreachable("Invalid search value type"); 567 } 568 } 569 570 nir_alu_instr * 571 nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search, 572 const nir_search_value *replace, void *mem_ctx) 573 { 574 uint8_t swizzle[4] = { 0, 0, 0, 0 }; 575 576 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i) 577 swizzle[i] = i; 578 579 assert(instr->dest.dest.is_ssa); 580 581 struct match_state state; 582 state.inexact_match = false; 583 state.has_exact_alu = false; 584 state.variables_seen = 0; 585 586 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components, 587 swizzle, &state)) 588 return NULL; 589 590 void *bitsize_ctx = ralloc_context(NULL); 591 bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace); 592 bitsize_tree_filter_up(tree); 593 bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size); 594 595 /* Inserting a mov may be unnecessary. However, it's much easier to 596 * simply let copy propagation clean this up than to try to go through 597 * and rewrite swizzles ourselves. 598 */ 599 nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov); 600 mov->dest.write_mask = instr->dest.write_mask; 601 nir_ssa_dest_init(&mov->instr, &mov->dest.dest, 602 instr->dest.dest.ssa.num_components, 603 instr->dest.dest.ssa.bit_size, NULL); 604 605 mov->src[0] = construct_value(replace, 606 instr->dest.dest.ssa.num_components, tree, 607 &state, &instr->instr, mem_ctx); 608 nir_instr_insert_before(&instr->instr, &mov->instr); 609 610 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, 611 nir_src_for_ssa(&mov->dest.dest.ssa)); 612 613 /* We know this one has no more uses because we just rewrote them all, 614 * so we can remove it. The rest of the matched expression, however, we 615 * don't know so much about. We'll just let dead code clean them up. 616 */ 617 nir_instr_remove(&instr->instr); 618 619 ralloc_free(bitsize_ctx); 620 621 return mov; 622 } 623