Home | History | Annotate | Download | only in nir
      1 /*
      2  * Copyright  2014-2015 Broadcom
      3  *
      4  * Permission is hereby granted, free of charge, to any person obtaining a
      5  * copy of this software and associated documentation files (the "Software"),
      6  * to deal in the Software without restriction, including without limitation
      7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
      8  * and/or sell copies of the Software, and to permit persons to whom the
      9  * Software is furnished to do so, subject to the following conditions:
     10  *
     11  * The above copyright notice and this permission notice (including the next
     12  * paragraph) shall be included in all copies or substantial portions of the
     13  * Software.
     14  *
     15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
     18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
     20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
     21  * IN THE SOFTWARE.
     22  */
     23 
     24 #include "nir.h"
     25 #include "nir_builder.h"
     26 
     27 /** @file nir_lower_alu_to_scalar.c
     28  *
     29  * Replaces nir_alu_instr operations with more than one channel used in the
     30  * arguments with individual per-channel operations.
     31  */
     32 
     33 static void
     34 nir_alu_ssa_dest_init(nir_alu_instr *instr, unsigned num_components,
     35                       unsigned bit_size)
     36 {
     37    nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_components,
     38                      bit_size, NULL);
     39    instr->dest.write_mask = (1 << num_components) - 1;
     40 }
     41 
     42 static void
     43 lower_reduction(nir_alu_instr *instr, nir_op chan_op, nir_op merge_op,
     44                 nir_builder *builder)
     45 {
     46    unsigned num_components = nir_op_infos[instr->op].input_sizes[0];
     47 
     48    nir_ssa_def *last = NULL;
     49    for (unsigned i = 0; i < num_components; i++) {
     50       nir_alu_instr *chan = nir_alu_instr_create(builder->shader, chan_op);
     51       nir_alu_ssa_dest_init(chan, 1, instr->dest.dest.ssa.bit_size);
     52       nir_alu_src_copy(&chan->src[0], &instr->src[0], chan);
     53       chan->src[0].swizzle[0] = chan->src[0].swizzle[i];
     54       if (nir_op_infos[chan_op].num_inputs > 1) {
     55          assert(nir_op_infos[chan_op].num_inputs == 2);
     56          nir_alu_src_copy(&chan->src[1], &instr->src[1], chan);
     57          chan->src[1].swizzle[0] = chan->src[1].swizzle[i];
     58       }
     59       chan->exact = instr->exact;
     60 
     61       nir_builder_instr_insert(builder, &chan->instr);
     62 
     63       if (i == 0) {
     64          last = &chan->dest.dest.ssa;
     65       } else {
     66          last = nir_build_alu(builder, merge_op,
     67                               last, &chan->dest.dest.ssa, NULL, NULL);
     68       }
     69    }
     70 
     71    assert(instr->dest.write_mask == 1);
     72    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(last));
     73    nir_instr_remove(&instr->instr);
     74 }
     75 
     76 static bool
     77 lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
     78 {
     79    unsigned num_src = nir_op_infos[instr->op].num_inputs;
     80    unsigned i, chan;
     81 
     82    assert(instr->dest.dest.is_ssa);
     83    assert(instr->dest.write_mask != 0);
     84 
     85    b->cursor = nir_before_instr(&instr->instr);
     86    b->exact = instr->exact;
     87 
     88 #define LOWER_REDUCTION(name, chan, merge) \
     89    case name##2: \
     90    case name##3: \
     91    case name##4: \
     92       lower_reduction(instr, chan, merge, b); \
     93       return true;
     94 
     95    switch (instr->op) {
     96    case nir_op_vec4:
     97    case nir_op_vec3:
     98    case nir_op_vec2:
     99       /* We don't need to scalarize these ops, they're the ones generated to
    100        * group up outputs into a value that can be SSAed.
    101        */
    102       return false;
    103 
    104    case nir_op_pack_half_2x16:
    105       if (!b->shader->options->lower_pack_half_2x16)
    106          return false;
    107 
    108       nir_ssa_def *val =
    109          nir_pack_half_2x16_split(b, nir_channel(b, instr->src[0].src.ssa,
    110                                                  instr->src[0].swizzle[0]),
    111                                      nir_channel(b, instr->src[0].src.ssa,
    112                                                  instr->src[0].swizzle[1]));
    113 
    114       nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val));
    115       nir_instr_remove(&instr->instr);
    116       return true;
    117 
    118    case nir_op_unpack_unorm_4x8:
    119    case nir_op_unpack_snorm_4x8:
    120    case nir_op_unpack_unorm_2x16:
    121    case nir_op_unpack_snorm_2x16:
    122       /* There is no scalar version of these ops, unless we were to break it
    123        * down to bitshifts and math (which is definitely not intended).
    124        */
    125       return false;
    126 
    127    case nir_op_unpack_half_2x16: {
    128       if (!b->shader->options->lower_unpack_half_2x16)
    129          return false;
    130 
    131       nir_ssa_def *comps[2];
    132       comps[0] = nir_unpack_half_2x16_split_x(b, instr->src[0].src.ssa);
    133       comps[1] = nir_unpack_half_2x16_split_y(b, instr->src[0].src.ssa);
    134       nir_ssa_def *vec = nir_vec(b, comps, 2);
    135 
    136       nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(vec));
    137       nir_instr_remove(&instr->instr);
    138       return true;
    139    }
    140 
    141    case nir_op_pack_uvec2_to_uint: {
    142       assert(b->shader->options->lower_pack_snorm_2x16 ||
    143              b->shader->options->lower_pack_unorm_2x16);
    144 
    145       nir_ssa_def *word =
    146          nir_extract_u16(b, instr->src[0].src.ssa, nir_imm_int(b, 0));
    147       nir_ssa_def *val =
    148          nir_ior(b, nir_ishl(b, nir_channel(b, word, 1), nir_imm_int(b, 16)),
    149                                 nir_channel(b, word, 0));
    150 
    151       nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val));
    152       nir_instr_remove(&instr->instr);
    153       break;
    154    }
    155 
    156    case nir_op_pack_uvec4_to_uint: {
    157       assert(b->shader->options->lower_pack_snorm_4x8 ||
    158              b->shader->options->lower_pack_unorm_4x8);
    159 
    160       nir_ssa_def *byte =
    161          nir_extract_u8(b, instr->src[0].src.ssa, nir_imm_int(b, 0));
    162       nir_ssa_def *val =
    163          nir_ior(b, nir_ior(b, nir_ishl(b, nir_channel(b, byte, 3), nir_imm_int(b, 24)),
    164                                nir_ishl(b, nir_channel(b, byte, 2), nir_imm_int(b, 16))),
    165                     nir_ior(b, nir_ishl(b, nir_channel(b, byte, 1), nir_imm_int(b, 8)),
    166                                nir_channel(b, byte, 0)));
    167 
    168       nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val));
    169       nir_instr_remove(&instr->instr);
    170       break;
    171    }
    172 
    173    case nir_op_fdph: {
    174       nir_ssa_def *sum[4];
    175       for (unsigned i = 0; i < 3; i++) {
    176          sum[i] = nir_fmul(b, nir_channel(b, instr->src[0].src.ssa,
    177                                           instr->src[0].swizzle[i]),
    178                               nir_channel(b, instr->src[1].src.ssa,
    179                                           instr->src[1].swizzle[i]));
    180       }
    181       sum[3] = nir_channel(b, instr->src[1].src.ssa, instr->src[1].swizzle[3]);
    182 
    183       nir_ssa_def *val = nir_fadd(b, nir_fadd(b, sum[0], sum[1]),
    184                                      nir_fadd(b, sum[2], sum[3]));
    185 
    186       nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val));
    187       nir_instr_remove(&instr->instr);
    188       return true;
    189    }
    190 
    191    case nir_op_unpack_double_2x32:
    192       return false;
    193 
    194       LOWER_REDUCTION(nir_op_fdot, nir_op_fmul, nir_op_fadd);
    195       LOWER_REDUCTION(nir_op_ball_fequal, nir_op_feq, nir_op_iand);
    196       LOWER_REDUCTION(nir_op_ball_iequal, nir_op_ieq, nir_op_iand);
    197       LOWER_REDUCTION(nir_op_bany_fnequal, nir_op_fne, nir_op_ior);
    198       LOWER_REDUCTION(nir_op_bany_inequal, nir_op_ine, nir_op_ior);
    199       LOWER_REDUCTION(nir_op_fall_equal, nir_op_seq, nir_op_fand);
    200       LOWER_REDUCTION(nir_op_fany_nequal, nir_op_sne, nir_op_for);
    201 
    202    default:
    203       break;
    204    }
    205 
    206    if (instr->dest.dest.ssa.num_components == 1)
    207       return false;
    208 
    209    unsigned num_components = instr->dest.dest.ssa.num_components;
    210    nir_ssa_def *comps[] = { NULL, NULL, NULL, NULL };
    211 
    212    for (chan = 0; chan < 4; chan++) {
    213       if (!(instr->dest.write_mask & (1 << chan)))
    214          continue;
    215 
    216       nir_alu_instr *lower = nir_alu_instr_create(b->shader, instr->op);
    217       for (i = 0; i < num_src; i++) {
    218          /* We only handle same-size-as-dest (input_sizes[] == 0) or scalar
    219           * args (input_sizes[] == 1).
    220           */
    221          assert(nir_op_infos[instr->op].input_sizes[i] < 2);
    222          unsigned src_chan = (nir_op_infos[instr->op].input_sizes[i] == 1 ?
    223                               0 : chan);
    224 
    225          nir_alu_src_copy(&lower->src[i], &instr->src[i], lower);
    226          for (int j = 0; j < 4; j++)
    227             lower->src[i].swizzle[j] = instr->src[i].swizzle[src_chan];
    228       }
    229 
    230       nir_alu_ssa_dest_init(lower, 1, instr->dest.dest.ssa.bit_size);
    231       lower->dest.saturate = instr->dest.saturate;
    232       comps[chan] = &lower->dest.dest.ssa;
    233       lower->exact = instr->exact;
    234 
    235       nir_builder_instr_insert(b, &lower->instr);
    236    }
    237 
    238    nir_ssa_def *vec = nir_vec(b, comps, num_components);
    239 
    240    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(vec));
    241 
    242    nir_instr_remove(&instr->instr);
    243    return true;
    244 }
    245 
    246 static bool
    247 nir_lower_alu_to_scalar_impl(nir_function_impl *impl)
    248 {
    249    nir_builder builder;
    250    nir_builder_init(&builder, impl);
    251    bool progress = false;
    252 
    253    nir_foreach_block(block, impl) {
    254       nir_foreach_instr_safe(instr, block) {
    255          if (instr->type == nir_instr_type_alu) {
    256             progress = lower_alu_instr_scalar(nir_instr_as_alu(instr),
    257                                               &builder) || progress;
    258          }
    259       }
    260    }
    261 
    262    nir_metadata_preserve(impl, nir_metadata_block_index |
    263                                nir_metadata_dominance);
    264 
    265    return progress;
    266 }
    267 
    268 bool
    269 nir_lower_alu_to_scalar(nir_shader *shader)
    270 {
    271    bool progress = false;
    272 
    273    nir_foreach_function(function, shader) {
    274       if (function->impl)
    275          progress = nir_lower_alu_to_scalar_impl(function->impl) || progress;
    276    }
    277 
    278    return progress;
    279 }
    280