Home | History | Annotate | Download | only in nir
      1 /*
      2  * Copyright  2014 Intel Corporation
      3  *
      4  * Permission is hereby granted, free of charge, to any person obtaining a
      5  * copy of this software and associated documentation files (the "Software"),
      6  * to deal in the Software without restriction, including without limitation
      7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
      8  * and/or sell copies of the Software, and to permit persons to whom the
      9  * Software is furnished to do so, subject to the following conditions:
     10  *
     11  * The above copyright notice and this permission notice (including the next
     12  * paragraph) shall be included in all copies or substantial portions of the
     13  * Software.
     14  *
     15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
     18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
     20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
     21  * IN THE SOFTWARE.
     22  *
     23  * Authors:
     24  *    Jason Ekstrand (jason (at) jlekstrand.net)
     25  *
     26  */
     27 
     28 #include <inttypes.h>
     29 #include "nir_search.h"
     30 
     31 struct match_state {
     32    bool inexact_match;
     33    bool has_exact_alu;
     34    unsigned variables_seen;
     35    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
     36 };
     37 
     38 static bool
     39 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
     40                  unsigned num_components, const uint8_t *swizzle,
     41                  struct match_state *state);
     42 
     43 static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
     44 
     45 /**
     46  * Check if a source produces a value of the given type.
     47  *
     48  * Used for satisfying 'a@type' constraints.
     49  */
     50 static bool
     51 src_is_type(nir_src src, nir_alu_type type)
     52 {
     53    assert(type != nir_type_invalid);
     54 
     55    if (!src.is_ssa)
     56       return false;
     57 
     58    /* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */
     59    if (nir_alu_type_get_base_type(type) == nir_type_bool)
     60       type = nir_type_bool;
     61 
     62    if (src.ssa->parent_instr->type == nir_instr_type_alu) {
     63       nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
     64       nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
     65 
     66       if (type == nir_type_bool) {
     67          switch (src_alu->op) {
     68          case nir_op_iand:
     69          case nir_op_ior:
     70          case nir_op_ixor:
     71             return src_is_type(src_alu->src[0].src, nir_type_bool) &&
     72                    src_is_type(src_alu->src[1].src, nir_type_bool);
     73          case nir_op_inot:
     74             return src_is_type(src_alu->src[0].src, nir_type_bool);
     75          default:
     76             break;
     77          }
     78       }
     79 
     80       return nir_alu_type_get_base_type(output_type) == type;
     81    } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
     82       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
     83 
     84       if (type == nir_type_bool) {
     85          return intr->intrinsic == nir_intrinsic_load_front_face ||
     86                 intr->intrinsic == nir_intrinsic_load_helper_invocation;
     87       }
     88    }
     89 
     90    /* don't know */
     91    return false;
     92 }
     93 
     94 static bool
     95 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
     96             unsigned num_components, const uint8_t *swizzle,
     97             struct match_state *state)
     98 {
     99    uint8_t new_swizzle[4];
    100 
    101    /* Searching only works on SSA values because, if it's not SSA, we can't
    102     * know if the value changed between one instance of that value in the
    103     * expression and another.  Also, the replace operation will place reads of
    104     * that value right before the last instruction in the expression we're
    105     * replacing so those reads will happen after the original reads and may
    106     * not be valid if they're register reads.
    107     */
    108    if (!instr->src[src].src.is_ssa)
    109       return false;
    110 
    111    /* If the source is an explicitly sized source, then we need to reset
    112     * both the number of components and the swizzle.
    113     */
    114    if (nir_op_infos[instr->op].input_sizes[src] != 0) {
    115       num_components = nir_op_infos[instr->op].input_sizes[src];
    116       swizzle = identity_swizzle;
    117    }
    118 
    119    for (unsigned i = 0; i < num_components; ++i)
    120       new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
    121 
    122    /* If the value has a specific bit size and it doesn't match, bail */
    123    if (value->bit_size &&
    124        nir_src_bit_size(instr->src[src].src) != value->bit_size)
    125       return false;
    126 
    127    switch (value->type) {
    128    case nir_search_value_expression:
    129       if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
    130          return false;
    131 
    132       return match_expression(nir_search_value_as_expression(value),
    133                               nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
    134                               num_components, new_swizzle, state);
    135 
    136    case nir_search_value_variable: {
    137       nir_search_variable *var = nir_search_value_as_variable(value);
    138       assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
    139 
    140       if (state->variables_seen & (1 << var->variable)) {
    141          if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
    142             return false;
    143 
    144          assert(!instr->src[src].abs && !instr->src[src].negate);
    145 
    146          for (unsigned i = 0; i < num_components; ++i) {
    147             if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
    148                return false;
    149          }
    150 
    151          return true;
    152       } else {
    153          if (var->is_constant &&
    154              instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
    155             return false;
    156 
    157          if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
    158             return false;
    159 
    160          if (var->type != nir_type_invalid &&
    161              !src_is_type(instr->src[src].src, var->type))
    162             return false;
    163 
    164          state->variables_seen |= (1 << var->variable);
    165          state->variables[var->variable].src = instr->src[src].src;
    166          state->variables[var->variable].abs = false;
    167          state->variables[var->variable].negate = false;
    168 
    169          for (unsigned i = 0; i < 4; ++i) {
    170             if (i < num_components)
    171                state->variables[var->variable].swizzle[i] = new_swizzle[i];
    172             else
    173                state->variables[var->variable].swizzle[i] = 0;
    174          }
    175 
    176          return true;
    177       }
    178    }
    179 
    180    case nir_search_value_constant: {
    181       nir_search_constant *const_val = nir_search_value_as_constant(value);
    182 
    183       if (!instr->src[src].src.is_ssa)
    184          return false;
    185 
    186       if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
    187          return false;
    188 
    189       nir_load_const_instr *load =
    190          nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
    191 
    192       switch (const_val->type) {
    193       case nir_type_float:
    194          for (unsigned i = 0; i < num_components; ++i) {
    195             double val;
    196             switch (load->def.bit_size) {
    197             case 32:
    198                val = load->value.f32[new_swizzle[i]];
    199                break;
    200             case 64:
    201                val = load->value.f64[new_swizzle[i]];
    202                break;
    203             default:
    204                unreachable("unknown bit size");
    205             }
    206 
    207             if (val != const_val->data.d)
    208                return false;
    209          }
    210          return true;
    211 
    212       case nir_type_int:
    213       case nir_type_uint:
    214       case nir_type_bool32:
    215          switch (load->def.bit_size) {
    216          case 32:
    217             for (unsigned i = 0; i < num_components; ++i) {
    218                if (load->value.u32[new_swizzle[i]] !=
    219                    (uint32_t)const_val->data.u)
    220                   return false;
    221             }
    222             return true;
    223 
    224          case 64:
    225             for (unsigned i = 0; i < num_components; ++i) {
    226                if (load->value.u64[new_swizzle[i]] != const_val->data.u)
    227                   return false;
    228             }
    229             return true;
    230 
    231          default:
    232             unreachable("unknown bit size");
    233          }
    234 
    235       default:
    236          unreachable("Invalid alu source type");
    237       }
    238    }
    239 
    240    default:
    241       unreachable("Invalid search value type");
    242    }
    243 }
    244 
    245 static bool
    246 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
    247                  unsigned num_components, const uint8_t *swizzle,
    248                  struct match_state *state)
    249 {
    250    if (expr->cond && !expr->cond(instr))
    251       return false;
    252 
    253    if (instr->op != expr->opcode)
    254       return false;
    255 
    256    assert(instr->dest.dest.is_ssa);
    257 
    258    if (expr->value.bit_size &&
    259        instr->dest.dest.ssa.bit_size != expr->value.bit_size)
    260       return false;
    261 
    262    state->inexact_match = expr->inexact || state->inexact_match;
    263    state->has_exact_alu = instr->exact || state->has_exact_alu;
    264    if (state->inexact_match && state->has_exact_alu)
    265       return false;
    266 
    267    assert(!instr->dest.saturate);
    268    assert(nir_op_infos[instr->op].num_inputs > 0);
    269 
    270    /* If we have an explicitly sized destination, we can only handle the
    271     * identity swizzle.  While dot(vec3(a, b, c).zxy) is a valid
    272     * expression, we don't have the information right now to propagate that
    273     * swizzle through.  We can only properly propagate swizzles if the
    274     * instruction is vectorized.
    275     */
    276    if (nir_op_infos[instr->op].output_size != 0) {
    277       for (unsigned i = 0; i < num_components; i++) {
    278          if (swizzle[i] != i)
    279             return false;
    280       }
    281    }
    282 
    283    /* Stash off the current variables_seen bitmask.  This way we can
    284     * restore it prior to matching in the commutative case below.
    285     */
    286    unsigned variables_seen_stash = state->variables_seen;
    287 
    288    bool matched = true;
    289    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
    290       if (!match_value(expr->srcs[i], instr, i, num_components,
    291                        swizzle, state)) {
    292          matched = false;
    293          break;
    294       }
    295    }
    296 
    297    if (matched)
    298       return true;
    299 
    300    if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
    301       assert(nir_op_infos[instr->op].num_inputs == 2);
    302 
    303       /* Restore the variables_seen bitmask.  If we don't do this, then we
    304        * could end up with an erroneous failure due to variables found in the
    305        * first match attempt above not matching those in the second.
    306        */
    307       state->variables_seen = variables_seen_stash;
    308 
    309       if (!match_value(expr->srcs[0], instr, 1, num_components,
    310                        swizzle, state))
    311          return false;
    312 
    313       return match_value(expr->srcs[1], instr, 0, num_components,
    314                          swizzle, state);
    315    } else {
    316       return false;
    317    }
    318 }
    319 
    320 typedef struct bitsize_tree {
    321    unsigned num_srcs;
    322    struct bitsize_tree *srcs[4];
    323 
    324    unsigned common_size;
    325    bool is_src_sized[4];
    326    bool is_dest_sized;
    327 
    328    unsigned dest_size;
    329    unsigned src_size[4];
    330 } bitsize_tree;
    331 
    332 static bitsize_tree *
    333 build_bitsize_tree(void *mem_ctx, struct match_state *state,
    334                    const nir_search_value *value)
    335 {
    336    bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
    337 
    338    switch (value->type) {
    339    case nir_search_value_expression: {
    340       nir_search_expression *expr = nir_search_value_as_expression(value);
    341       nir_op_info info = nir_op_infos[expr->opcode];
    342       tree->num_srcs = info.num_inputs;
    343       tree->common_size = 0;
    344       for (unsigned i = 0; i < info.num_inputs; i++) {
    345          tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
    346          if (tree->is_src_sized[i])
    347             tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
    348          tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
    349       }
    350       tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
    351       if (tree->is_dest_sized)
    352          tree->dest_size = nir_alu_type_get_type_size(info.output_type);
    353       break;
    354    }
    355 
    356    case nir_search_value_variable: {
    357       nir_search_variable *var = nir_search_value_as_variable(value);
    358       tree->num_srcs = 0;
    359       tree->is_dest_sized = true;
    360       tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
    361       break;
    362    }
    363 
    364    case nir_search_value_constant: {
    365       tree->num_srcs = 0;
    366       tree->is_dest_sized = false;
    367       tree->common_size = 0;
    368       break;
    369    }
    370    }
    371 
    372    if (value->bit_size) {
    373       assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
    374       tree->common_size = value->bit_size;
    375    }
    376 
    377    return tree;
    378 }
    379 
    380 static unsigned
    381 bitsize_tree_filter_up(bitsize_tree *tree)
    382 {
    383    for (unsigned i = 0; i < tree->num_srcs; i++) {
    384       unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
    385       if (src_size == 0)
    386          continue;
    387 
    388       if (tree->is_src_sized[i]) {
    389          assert(src_size == tree->src_size[i]);
    390       } else if (tree->common_size != 0) {
    391          assert(src_size == tree->common_size);
    392          tree->src_size[i] = src_size;
    393       } else {
    394          tree->common_size = src_size;
    395          tree->src_size[i] = src_size;
    396       }
    397    }
    398 
    399    if (tree->num_srcs && tree->common_size) {
    400       if (tree->dest_size == 0)
    401          tree->dest_size = tree->common_size;
    402       else if (!tree->is_dest_sized)
    403          assert(tree->dest_size == tree->common_size);
    404 
    405       for (unsigned i = 0; i < tree->num_srcs; i++) {
    406          if (!tree->src_size[i])
    407             tree->src_size[i] = tree->common_size;
    408       }
    409    }
    410 
    411    return tree->dest_size;
    412 }
    413 
    414 static void
    415 bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
    416 {
    417    if (tree->dest_size)
    418       assert(tree->dest_size == size);
    419    else
    420       tree->dest_size = size;
    421 
    422    if (!tree->is_dest_sized) {
    423       if (tree->common_size)
    424          assert(tree->common_size == size);
    425       else
    426          tree->common_size = size;
    427    }
    428 
    429    for (unsigned i = 0; i < tree->num_srcs; i++) {
    430       if (!tree->src_size[i]) {
    431          assert(tree->common_size);
    432          tree->src_size[i] = tree->common_size;
    433       }
    434       bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
    435    }
    436 }
    437 
    438 static nir_alu_src
    439 construct_value(const nir_search_value *value,
    440                 unsigned num_components, bitsize_tree *bitsize,
    441                 struct match_state *state,
    442                 nir_instr *instr, void *mem_ctx)
    443 {
    444    switch (value->type) {
    445    case nir_search_value_expression: {
    446       const nir_search_expression *expr = nir_search_value_as_expression(value);
    447 
    448       if (nir_op_infos[expr->opcode].output_size != 0)
    449          num_components = nir_op_infos[expr->opcode].output_size;
    450 
    451       nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
    452       nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
    453                         bitsize->dest_size, NULL);
    454       alu->dest.write_mask = (1 << num_components) - 1;
    455       alu->dest.saturate = false;
    456 
    457       /* We have no way of knowing what values in a given search expression
    458        * map to a particular replacement value.  Therefore, if the
    459        * expression we are replacing has any exact values, the entire
    460        * replacement should be exact.
    461        */
    462       alu->exact = state->has_exact_alu;
    463 
    464       for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
    465          /* If the source is an explicitly sized source, then we need to reset
    466           * the number of components to match.
    467           */
    468          if (nir_op_infos[alu->op].input_sizes[i] != 0)
    469             num_components = nir_op_infos[alu->op].input_sizes[i];
    470 
    471          alu->src[i] = construct_value(expr->srcs[i],
    472                                        num_components, bitsize->srcs[i],
    473                                        state, instr, mem_ctx);
    474       }
    475 
    476       nir_instr_insert_before(instr, &alu->instr);
    477 
    478       nir_alu_src val;
    479       val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
    480       val.negate = false;
    481       val.abs = false,
    482       memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
    483 
    484       return val;
    485    }
    486 
    487    case nir_search_value_variable: {
    488       const nir_search_variable *var = nir_search_value_as_variable(value);
    489       assert(state->variables_seen & (1 << var->variable));
    490 
    491       nir_alu_src val = { NIR_SRC_INIT };
    492       nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
    493 
    494       assert(!var->is_constant);
    495 
    496       return val;
    497    }
    498 
    499    case nir_search_value_constant: {
    500       const nir_search_constant *c = nir_search_value_as_constant(value);
    501       nir_load_const_instr *load =
    502          nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size);
    503 
    504       switch (c->type) {
    505       case nir_type_float:
    506          load->def.name = ralloc_asprintf(load, "%f", c->data.d);
    507          switch (bitsize->dest_size) {
    508          case 32:
    509             load->value.f32[0] = c->data.d;
    510             break;
    511          case 64:
    512             load->value.f64[0] = c->data.d;
    513             break;
    514          default:
    515             unreachable("unknown bit size");
    516          }
    517          break;
    518 
    519       case nir_type_int:
    520          load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i);
    521          switch (bitsize->dest_size) {
    522          case 32:
    523             load->value.i32[0] = c->data.i;
    524             break;
    525          case 64:
    526             load->value.i64[0] = c->data.i;
    527             break;
    528          default:
    529             unreachable("unknown bit size");
    530          }
    531          break;
    532 
    533       case nir_type_uint:
    534          load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u);
    535          switch (bitsize->dest_size) {
    536          case 32:
    537             load->value.u32[0] = c->data.u;
    538             break;
    539          case 64:
    540             load->value.u64[0] = c->data.u;
    541             break;
    542          default:
    543             unreachable("unknown bit size");
    544          }
    545          break;
    546 
    547       case nir_type_bool32:
    548          load->value.u32[0] = c->data.u;
    549          break;
    550       default:
    551          unreachable("Invalid alu source type");
    552       }
    553 
    554       nir_instr_insert_before(instr, &load->instr);
    555 
    556       nir_alu_src val;
    557       val.src = nir_src_for_ssa(&load->def);
    558       val.negate = false;
    559       val.abs = false,
    560       memset(val.swizzle, 0, sizeof val.swizzle);
    561 
    562       return val;
    563    }
    564 
    565    default:
    566       unreachable("Invalid search value type");
    567    }
    568 }
    569 
    570 nir_alu_instr *
    571 nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
    572                   const nir_search_value *replace, void *mem_ctx)
    573 {
    574    uint8_t swizzle[4] = { 0, 0, 0, 0 };
    575 
    576    for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
    577       swizzle[i] = i;
    578 
    579    assert(instr->dest.dest.is_ssa);
    580 
    581    struct match_state state;
    582    state.inexact_match = false;
    583    state.has_exact_alu = false;
    584    state.variables_seen = 0;
    585 
    586    if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
    587                          swizzle, &state))
    588       return NULL;
    589 
    590    void *bitsize_ctx = ralloc_context(NULL);
    591    bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
    592    bitsize_tree_filter_up(tree);
    593    bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
    594 
    595    /* Inserting a mov may be unnecessary.  However, it's much easier to
    596     * simply let copy propagation clean this up than to try to go through
    597     * and rewrite swizzles ourselves.
    598     */
    599    nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
    600    mov->dest.write_mask = instr->dest.write_mask;
    601    nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
    602                      instr->dest.dest.ssa.num_components,
    603                      instr->dest.dest.ssa.bit_size, NULL);
    604 
    605    mov->src[0] = construct_value(replace,
    606                                  instr->dest.dest.ssa.num_components, tree,
    607                                  &state, &instr->instr, mem_ctx);
    608    nir_instr_insert_before(&instr->instr, &mov->instr);
    609 
    610    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
    611                             nir_src_for_ssa(&mov->dest.dest.ssa));
    612 
    613    /* We know this one has no more uses because we just rewrote them all,
    614     * so we can remove it.  The rest of the matched expression, however, we
    615     * don't know so much about.  We'll just let dead code clean them up.
    616     */
    617    nir_instr_remove(&instr->instr);
    618 
    619    ralloc_free(bitsize_ctx);
    620 
    621    return mov;
    622 }
    623