Home | History | Annotate | Download | only in nir
      1 /*
      2  * Copyright  2016 Broadcom
      3  *
      4  * Permission is hereby granted, free of charge, to any person obtaining a
      5  * copy of this software and associated documentation files (the "Software"),
      6  * to deal in the Software without restriction, including without limitation
      7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
      8  * and/or sell copies of the Software, and to permit persons to whom the
      9  * Software is furnished to do so, subject to the following conditions:
     10  *
     11  * The above copyright notice and this permission notice (including the next
     12  * paragraph) shall be included in all copies or substantial portions of the
     13  * Software.
     14  *
     15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
     18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
     20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
     21  * IN THE SOFTWARE.
     22  */
     23 
     24 #include "nir.h"
     25 #include "nir_builder.h"
     26 
     27 /** @file nir_lower_io_to_scalar.c
     28  *
     29  * Replaces nir_load_input/nir_store_output operations with num_components !=
     30  * 1 with individual per-channel operations.
     31  */
     32 
     33 static void
     34 lower_load_input_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
     35 {
     36    b->cursor = nir_before_instr(&intr->instr);
     37 
     38    assert(intr->dest.is_ssa);
     39 
     40    nir_ssa_def *loads[4];
     41 
     42    for (unsigned i = 0; i < intr->num_components; i++) {
     43       nir_intrinsic_instr *chan_intr =
     44          nir_intrinsic_instr_create(b->shader, intr->intrinsic);
     45       nir_ssa_dest_init(&chan_intr->instr, &chan_intr->dest,
     46                         1, intr->dest.ssa.bit_size, NULL);
     47       chan_intr->num_components = 1;
     48 
     49       nir_intrinsic_set_base(chan_intr, nir_intrinsic_base(intr));
     50       nir_intrinsic_set_component(chan_intr, nir_intrinsic_component(intr) + i);
     51       /* offset */
     52       nir_src_copy(&chan_intr->src[0], &intr->src[0], chan_intr);
     53 
     54       nir_builder_instr_insert(b, &chan_intr->instr);
     55 
     56       loads[i] = &chan_intr->dest.ssa;
     57    }
     58 
     59    nir_ssa_def_rewrite_uses(&intr->dest.ssa,
     60                             nir_src_for_ssa(nir_vec(b, loads,
     61                                                     intr->num_components)));
     62    nir_instr_remove(&intr->instr);
     63 }
     64 
     65 static void
     66 lower_store_output_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
     67 {
     68    b->cursor = nir_before_instr(&intr->instr);
     69 
     70    nir_ssa_def *value = nir_ssa_for_src(b, intr->src[0], intr->num_components);
     71 
     72    for (unsigned i = 0; i < intr->num_components; i++) {
     73       if (!(nir_intrinsic_write_mask(intr) & (1 << i)))
     74          continue;
     75 
     76       nir_intrinsic_instr *chan_intr =
     77          nir_intrinsic_instr_create(b->shader, intr->intrinsic);
     78       chan_intr->num_components = 1;
     79 
     80       nir_intrinsic_set_base(chan_intr, nir_intrinsic_base(intr));
     81       nir_intrinsic_set_write_mask(chan_intr, 0x1);
     82       nir_intrinsic_set_component(chan_intr, nir_intrinsic_component(intr) + i);
     83 
     84       /* value */
     85       chan_intr->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
     86       /* offset */
     87       nir_src_copy(&chan_intr->src[1], &intr->src[1], chan_intr);
     88 
     89       nir_builder_instr_insert(b, &chan_intr->instr);
     90    }
     91 
     92    nir_instr_remove(&intr->instr);
     93 }
     94 
     95 void
     96 nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask)
     97 {
     98    nir_foreach_function(function, shader) {
     99       if (function->impl) {
    100          nir_builder b;
    101          nir_builder_init(&b, function->impl);
    102 
    103          nir_foreach_block(block, function->impl) {
    104             nir_foreach_instr_safe(instr, block) {
    105                if (instr->type != nir_instr_type_intrinsic)
    106                   continue;
    107 
    108                nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
    109 
    110                if (intr->num_components == 1)
    111                   continue;
    112 
    113                switch (intr->intrinsic) {
    114                case nir_intrinsic_load_input:
    115                   if (mask & nir_var_shader_in)
    116                      lower_load_input_to_scalar(&b, intr);
    117                   break;
    118                case nir_intrinsic_store_output:
    119                   if (mask & nir_var_shader_out)
    120                      lower_store_output_to_scalar(&b, intr);
    121                   break;
    122                default:
    123                   break;
    124                }
    125             }
    126          }
    127       }
    128    }
    129 }
    130 
    131 static nir_variable **
    132 get_channel_variables(struct hash_table *ht, nir_variable *var)
    133 {
    134    nir_variable **chan_vars;
    135    struct hash_entry *entry = _mesa_hash_table_search(ht, var);
    136    if (!entry) {
    137       chan_vars = (nir_variable **) calloc(4, sizeof(nir_variable *));
    138       _mesa_hash_table_insert(ht, var, chan_vars);
    139    } else {
    140       chan_vars = (nir_variable **) entry->data;
    141    }
    142 
    143    return chan_vars;
    144 }
    145 
    146 /*
    147  * This function differs from nir_deref_clone() in that it gets its type from
    148  * the parent deref rather than our source deref. This is useful when splitting
    149  * vectors because we want to use the scalar type of the new parent rather than
    150  * then the old vector type.
    151  */
    152 static nir_deref_array *
    153 clone_deref_array(const nir_deref_array *darr, nir_deref *parent)
    154 {
    155    nir_deref_array *ndarr = nir_deref_array_create(parent);
    156 
    157    ndarr->deref.type = glsl_get_array_element(parent->type);
    158    if (darr->deref.child)
    159       ndarr->deref.child =
    160          &clone_deref_array(nir_deref_as_array(darr->deref.child),
    161                             &ndarr->deref)->deref;
    162 
    163    ndarr->deref_array_type = darr->deref_array_type;
    164    ndarr->base_offset = darr->base_offset;
    165    if (ndarr->deref_array_type == nir_deref_array_type_indirect)
    166      nir_src_copy(&ndarr->indirect, &darr->indirect, parent);
    167 
    168    return ndarr;
    169 }
    170 
    171 static void
    172 lower_load_to_scalar_early(nir_builder *b, nir_intrinsic_instr *intr,
    173                            nir_variable *var, struct hash_table *split_inputs,
    174                            struct hash_table *split_outputs)
    175 {
    176    b->cursor = nir_before_instr(&intr->instr);
    177 
    178    assert(intr->dest.is_ssa);
    179 
    180    nir_ssa_def *loads[4];
    181 
    182    nir_variable **chan_vars;
    183    if (var->data.mode == nir_var_shader_in) {
    184       chan_vars = get_channel_variables(split_inputs, var);
    185    } else {
    186       chan_vars = get_channel_variables(split_outputs, var);
    187    }
    188 
    189    for (unsigned i = 0; i < intr->num_components; i++) {
    190       nir_variable *chan_var = chan_vars[var->data.location_frac + i];
    191       if (!chan_vars[var->data.location_frac + i]) {
    192          chan_var = nir_variable_clone(var, b->shader);
    193          chan_var->data.location_frac =  var->data.location_frac + i;
    194          chan_var->type = glsl_channel_type(chan_var->type);
    195 
    196          chan_vars[var->data.location_frac + i] = chan_var;
    197 
    198          nir_shader_add_variable(b->shader, chan_var);
    199       }
    200 
    201       nir_intrinsic_instr *chan_intr =
    202          nir_intrinsic_instr_create(b->shader, intr->intrinsic);
    203       nir_ssa_dest_init(&chan_intr->instr, &chan_intr->dest,
    204                         1, intr->dest.ssa.bit_size, NULL);
    205       chan_intr->num_components = 1;
    206       chan_intr->variables[0] = nir_deref_var_create(chan_intr, chan_var);
    207 
    208       if (intr->variables[0]->deref.child) {
    209          chan_intr->variables[0]->deref.child =
    210             &clone_deref_array(nir_deref_as_array(intr->variables[0]->deref.child),
    211                                &chan_intr->variables[0]->deref)->deref;
    212       }
    213 
    214       if (intr->intrinsic == nir_intrinsic_interp_var_at_offset ||
    215           intr->intrinsic == nir_intrinsic_interp_var_at_sample)
    216          nir_src_copy(chan_intr->src, intr->src, &chan_intr->instr);
    217 
    218       nir_builder_instr_insert(b, &chan_intr->instr);
    219 
    220       loads[i] = &chan_intr->dest.ssa;
    221    }
    222 
    223    nir_ssa_def_rewrite_uses(&intr->dest.ssa,
    224                             nir_src_for_ssa(nir_vec(b, loads,
    225                                                     intr->num_components)));
    226 
    227    /* Remove the old load intrinsic */
    228    nir_instr_remove(&intr->instr);
    229 }
    230 
    231 static void
    232 lower_store_output_to_scalar_early(nir_builder *b, nir_intrinsic_instr *intr,
    233                                    nir_variable *var,
    234                                    struct hash_table *split_outputs)
    235 {
    236    b->cursor = nir_before_instr(&intr->instr);
    237 
    238    nir_ssa_def *value = nir_ssa_for_src(b, intr->src[0], intr->num_components);
    239 
    240    nir_variable **chan_vars = get_channel_variables(split_outputs, var);
    241    for (unsigned i = 0; i < intr->num_components; i++) {
    242       if (!(nir_intrinsic_write_mask(intr) & (1 << i)))
    243          continue;
    244 
    245       nir_variable *chan_var = chan_vars[var->data.location_frac + i];
    246       if (!chan_vars[var->data.location_frac + i]) {
    247          chan_var = nir_variable_clone(var, b->shader);
    248          chan_var->data.location_frac =  var->data.location_frac + i;
    249          chan_var->type = glsl_channel_type(chan_var->type);
    250 
    251          chan_vars[var->data.location_frac + i] = chan_var;
    252 
    253          nir_shader_add_variable(b->shader, chan_var);
    254       }
    255 
    256       nir_intrinsic_instr *chan_intr =
    257          nir_intrinsic_instr_create(b->shader, intr->intrinsic);
    258       chan_intr->num_components = 1;
    259 
    260       nir_intrinsic_set_write_mask(chan_intr, 0x1);
    261 
    262       chan_intr->variables[0] = nir_deref_var_create(chan_intr, chan_var);
    263       chan_intr->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
    264 
    265       if (intr->variables[0]->deref.child) {
    266          chan_intr->variables[0]->deref.child =
    267             &clone_deref_array(nir_deref_as_array(intr->variables[0]->deref.child),
    268                                &chan_intr->variables[0]->deref)->deref;
    269       }
    270 
    271       nir_builder_instr_insert(b, &chan_intr->instr);
    272    }
    273 
    274    /* Remove the old store intrinsic */
    275    nir_instr_remove(&intr->instr);
    276 }
    277 
    278 /*
    279  * This function is intended to be called earlier than nir_lower_io_to_scalar()
    280  * i.e. before nir_lower_io() is called.
    281  */
    282 void
    283 nir_lower_io_to_scalar_early(nir_shader *shader, nir_variable_mode mask)
    284 {
    285    struct hash_table *split_inputs =
    286       _mesa_hash_table_create(NULL, _mesa_hash_pointer,
    287                               _mesa_key_pointer_equal);
    288    struct hash_table *split_outputs =
    289       _mesa_hash_table_create(NULL, _mesa_hash_pointer,
    290                               _mesa_key_pointer_equal);
    291 
    292    nir_foreach_function(function, shader) {
    293       if (function->impl) {
    294          nir_builder b;
    295          nir_builder_init(&b, function->impl);
    296 
    297          nir_foreach_block(block, function->impl) {
    298             nir_foreach_instr_safe(instr, block) {
    299                if (instr->type != nir_instr_type_intrinsic)
    300                   continue;
    301 
    302                nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
    303 
    304                if (intr->num_components == 1)
    305                   continue;
    306 
    307                if (intr->intrinsic != nir_intrinsic_load_var &&
    308                    intr->intrinsic != nir_intrinsic_store_var &&
    309                    intr->intrinsic != nir_intrinsic_interp_var_at_centroid &&
    310                    intr->intrinsic != nir_intrinsic_interp_var_at_sample &&
    311                    intr->intrinsic != nir_intrinsic_interp_var_at_offset)
    312                   continue;
    313 
    314                nir_variable *var = intr->variables[0]->var;
    315                nir_variable_mode mode = var->data.mode;
    316 
    317                /* TODO: add patch support */
    318                if (var->data.patch)
    319                   continue;
    320 
    321                /* TODO: add doubles support */
    322                if (glsl_type_is_64bit(glsl_without_array(var->type)))
    323                   continue;
    324 
    325                if (var->data.location < VARYING_SLOT_VAR0 &&
    326                    var->data.location >= 0)
    327                   continue;
    328 
    329                /* Don't bother splitting if we can't opt away any unused
    330                 * components.
    331                 */
    332                if (var->data.always_active_io)
    333                   continue;
    334 
    335               /* Skip types we cannot split */
    336               if (glsl_type_is_matrix(glsl_without_array(var->type)) ||
    337                   glsl_type_is_struct(glsl_without_array(var->type)))
    338                  continue;
    339 
    340                switch (intr->intrinsic) {
    341                case nir_intrinsic_interp_var_at_centroid:
    342                case nir_intrinsic_interp_var_at_sample:
    343                case nir_intrinsic_interp_var_at_offset:
    344                case nir_intrinsic_load_var:
    345                   if ((mask & nir_var_shader_in && mode == nir_var_shader_in) ||
    346                       (mask & nir_var_shader_out && mode == nir_var_shader_out))
    347                      lower_load_to_scalar_early(&b, intr, var, split_inputs,
    348                                                 split_outputs);
    349                   break;
    350                case nir_intrinsic_store_var:
    351                   if (mask & nir_var_shader_out &&
    352                       mode == nir_var_shader_out)
    353                      lower_store_output_to_scalar_early(&b, intr, var,
    354                                                         split_outputs);
    355                   break;
    356                default:
    357                   break;
    358                }
    359             }
    360          }
    361       }
    362    }
    363 
    364    /* Remove old input from the shaders inputs list */
    365    struct hash_entry *entry;
    366    hash_table_foreach(split_inputs, entry) {
    367       nir_variable *var = (nir_variable *) entry->key;
    368       exec_node_remove(&var->node);
    369 
    370       free(entry->data);
    371    }
    372 
    373    /* Remove old output from the shaders outputs list */
    374    hash_table_foreach(split_outputs, entry) {
    375       nir_variable *var = (nir_variable *) entry->key;
    376       exec_node_remove(&var->node);
    377 
    378       free(entry->data);
    379    }
    380 
    381    _mesa_hash_table_destroy(split_inputs, NULL);
    382    _mesa_hash_table_destroy(split_outputs, NULL);
    383 }
    384