Home | History | Annotate | Download | only in nir
      1 /*
      2  * Copyright  2016 Intel Corporation
      3  *
      4  * Permission is hereby granted, free of charge, to any person obtaining a
      5  * copy of this software and associated documentation files (the "Software"),
      6  * to deal in the Software without restriction, including without limitation
      7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
      8  * and/or sell copies of the Software, and to permit persons to whom the
      9  * Software is furnished to do so, subject to the following conditions:
     10  *
     11  * The above copyright notice and this permission notice (including the next
     12  * paragraph) shall be included in all copies or substantial portions of the
     13  * Software.
     14  *
     15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
     18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
     20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
     21  * DEALINGS IN THE SOFTWARE.
     22  */
     23 
     24 #include "nir.h"
     25 #include "nir_builder.h"
     26 #include "nir_control_flow.h"
     27 #include "nir_loop_analyze.h"
     28 
     29 /* Prepare this loop for unrolling by first converting to lcssa and then
     30  * converting the phis from the loops first block and the block that follows
     31  * the loop into regs.  Partially converting out of SSA allows us to unroll
     32  * the loop without having to keep track of and update phis along the way
     33  * which gets tricky and doesn't add much value over conveting to regs.
     34  *
     35  * The loop may have a continue instruction at the end of the loop which does
     36  * nothing.  Once we're out of SSA, we can safely delete it so we don't have
     37  * to deal with it later.
     38  */
     39 static void
     40 loop_prepare_for_unroll(nir_loop *loop)
     41 {
     42    nir_convert_loop_to_lcssa(loop);
     43 
     44    nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
     45 
     46    nir_block *block_after_loop =
     47       nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
     48 
     49    nir_lower_phis_to_regs_block(block_after_loop);
     50 
     51    nir_instr *last_instr = nir_block_last_instr(nir_loop_last_block(loop));
     52    if (last_instr && last_instr->type == nir_instr_type_jump) {
     53       assert(nir_instr_as_jump(last_instr)->type == nir_jump_continue);
     54       nir_instr_remove(last_instr);
     55    }
     56 }
     57 
     58 static void
     59 get_first_blocks_in_terminator(nir_loop_terminator *term,
     60                                nir_block **first_break_block,
     61                                nir_block **first_continue_block)
     62 {
     63    if (term->continue_from_then) {
     64       *first_continue_block = nir_if_first_then_block(term->nif);
     65       *first_break_block = nir_if_first_else_block(term->nif);
     66    } else {
     67       *first_continue_block = nir_if_first_else_block(term->nif);
     68       *first_break_block = nir_if_first_then_block(term->nif);
     69    }
     70 }
     71 
     72 /**
     73  * Unroll a loop where we know exactly how many iterations there are and there
     74  * is only a single exit point.  Note here we can unroll loops with multiple
     75  * theoretical exits that only have a single terminating exit that we always
     76  * know is the "real" exit.
     77  *
     78  *     loop {
     79  *         ...instrs...
     80  *     }
     81  *
     82  * And the iteration count is 3, the output will be:
     83  *
     84  *     ...instrs... ...instrs... ...instrs...
     85  */
     86 static void
     87 simple_unroll(nir_loop *loop)
     88 {
     89    nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
     90    assert(nir_is_trivial_loop_if(limiting_term->nif,
     91                                  limiting_term->break_block));
     92 
     93    loop_prepare_for_unroll(loop);
     94 
     95    /* Skip over loop terminator and get the loop body. */
     96    list_for_each_entry(nir_loop_terminator, terminator,
     97                        &loop->info->loop_terminator_list,
     98                        loop_terminator_link) {
     99 
    100       /* Remove all but the limiting terminator as we know the other exit
    101        * conditions can never be met. Note we need to extract any instructions
    102        * in the continue from branch and insert then into the loop body before
    103        * removing it.
    104        */
    105       if (terminator->nif != limiting_term->nif) {
    106          nir_block *first_break_block;
    107          nir_block *first_continue_block;
    108          get_first_blocks_in_terminator(terminator, &first_break_block,
    109                                         &first_continue_block);
    110 
    111          assert(nir_is_trivial_loop_if(terminator->nif,
    112                                        terminator->break_block));
    113 
    114          nir_cf_list continue_from_lst;
    115          nir_cf_extract(&continue_from_lst,
    116                         nir_before_block(first_continue_block),
    117                         nir_after_block(terminator->continue_from_block));
    118          nir_cf_reinsert(&continue_from_lst,
    119                          nir_after_cf_node(&terminator->nif->cf_node));
    120 
    121          nir_cf_node_remove(&terminator->nif->cf_node);
    122       }
    123    }
    124 
    125    nir_block *first_break_block;
    126    nir_block *first_continue_block;
    127    get_first_blocks_in_terminator(limiting_term, &first_break_block,
    128                                   &first_continue_block);
    129 
    130    /* Pluck out the loop header */
    131    nir_block *header_blk = nir_loop_first_block(loop);
    132    nir_cf_list lp_header;
    133    nir_cf_extract(&lp_header, nir_before_block(header_blk),
    134                   nir_before_cf_node(&limiting_term->nif->cf_node));
    135 
    136    /* Add the continue from block of the limiting terminator to the loop body
    137     */
    138    nir_cf_list continue_from_lst;
    139    nir_cf_extract(&continue_from_lst, nir_before_block(first_continue_block),
    140                   nir_after_block(limiting_term->continue_from_block));
    141    nir_cf_reinsert(&continue_from_lst,
    142                    nir_after_cf_node(&limiting_term->nif->cf_node));
    143 
    144    /* Pluck out the loop body */
    145    nir_cf_list loop_body;
    146    nir_cf_extract(&loop_body, nir_after_cf_node(&limiting_term->nif->cf_node),
    147                   nir_after_block(nir_loop_last_block(loop)));
    148 
    149    struct hash_table *remap_table =
    150       _mesa_hash_table_create(NULL, _mesa_hash_pointer,
    151                               _mesa_key_pointer_equal);
    152 
    153    /* Clone the loop header */
    154    nir_cf_list cloned_header;
    155    nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
    156                      remap_table);
    157 
    158    /* Insert cloned loop header before the loop */
    159    nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
    160 
    161    /* Temp list to store the cloned loop body as we unroll */
    162    nir_cf_list unrolled_lp_body;
    163 
    164    /* Clone loop header and append to the loop body */
    165    for (unsigned i = 0; i < loop->info->trip_count; i++) {
    166       /* Clone loop body */
    167       nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
    168                         remap_table);
    169 
    170       /* Insert unrolled loop body before the loop */
    171       nir_cf_reinsert(&unrolled_lp_body, nir_before_cf_node(&loop->cf_node));
    172 
    173       /* Clone loop header */
    174       nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
    175                         remap_table);
    176 
    177       /* Insert loop header after loop body */
    178       nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
    179    }
    180 
    181    /* Remove the break from the loop terminator and add instructions from
    182     * the break block after the unrolled loop.
    183     */
    184    nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
    185    nir_instr_remove(break_instr);
    186    nir_cf_list break_list;
    187    nir_cf_extract(&break_list, nir_before_block(first_break_block),
    188                   nir_after_block(limiting_term->break_block));
    189 
    190    /* Clone so things get properly remapped */
    191    nir_cf_list cloned_break_list;
    192    nir_cf_list_clone(&cloned_break_list, &break_list, loop->cf_node.parent,
    193                      remap_table);
    194 
    195    nir_cf_reinsert(&cloned_break_list, nir_before_cf_node(&loop->cf_node));
    196 
    197    /* Remove the loop */
    198    nir_cf_node_remove(&loop->cf_node);
    199 
    200    /* Delete the original loop body, break block & header */
    201    nir_cf_delete(&lp_header);
    202    nir_cf_delete(&loop_body);
    203    nir_cf_delete(&break_list);
    204 
    205    _mesa_hash_table_destroy(remap_table, NULL);
    206 }
    207 
    208 static void
    209 move_cf_list_into_loop_term(nir_cf_list *lst, nir_loop_terminator *term)
    210 {
    211    /* Move the rest of the loop inside the continue-from-block */
    212    nir_cf_reinsert(lst, nir_after_block(term->continue_from_block));
    213 
    214    /* Remove the break */
    215    nir_instr_remove(nir_block_last_instr(term->break_block));
    216 }
    217 
    218 static nir_cursor
    219 get_complex_unroll_insert_location(nir_cf_node *node, bool continue_from_then)
    220 {
    221    if (node->type == nir_cf_node_loop) {
    222       return nir_before_cf_node(node);
    223    } else {
    224       nir_if *if_stmt = nir_cf_node_as_if(node);
    225       if (continue_from_then) {
    226          return nir_after_block(nir_if_last_then_block(if_stmt));
    227       } else {
    228          return nir_after_block(nir_if_last_else_block(if_stmt));
    229       }
    230    }
    231 }
    232 
    233 /**
    234  * Unroll a loop with two exists when the trip count of one of the exits is
    235  * unknown.  If continue_from_then is true, the loop is repeated only when the
    236  * "then" branch of the if is taken; otherwise it is repeated only
    237  * when the "else" branch of the if is taken.
    238  *
    239  * For example, if the input is:
    240  *
    241  *      loop {
    242  *         ...phis/condition...
    243  *         if condition {
    244  *            ...then instructions...
    245  *         } else {
    246  *            ...continue instructions...
    247  *            break
    248  *         }
    249  *         ...body...
    250  *      }
    251  *
    252  * And the iteration count is 3, and unlimit_term->continue_from_then is true,
    253  * then the output will be:
    254  *
    255  *      ...condition...
    256  *      if condition {
    257  *         ...then instructions...
    258  *         ...body...
    259  *         if condition {
    260  *            ...then instructions...
    261  *            ...body...
    262  *            if condition {
    263  *               ...then instructions...
    264  *               ...body...
    265  *            } else {
    266  *               ...continue instructions...
    267  *            }
    268  *         } else {
    269  *            ...continue instructions...
    270  *         }
    271  *      } else {
    272  *         ...continue instructions...
    273  *      }
    274  */
    275 static void
    276 complex_unroll(nir_loop *loop, nir_loop_terminator *unlimit_term,
    277                bool limiting_term_second)
    278 {
    279    assert(nir_is_trivial_loop_if(unlimit_term->nif,
    280                                  unlimit_term->break_block));
    281 
    282    nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
    283    assert(nir_is_trivial_loop_if(limiting_term->nif,
    284                                  limiting_term->break_block));
    285 
    286    loop_prepare_for_unroll(loop);
    287 
    288    nir_block *header_blk = nir_loop_first_block(loop);
    289 
    290    nir_cf_list lp_header;
    291    nir_cf_list limit_break_list;
    292    unsigned num_times_to_clone;
    293    if (limiting_term_second) {
    294       /* Pluck out the loop header */
    295       nir_cf_extract(&lp_header, nir_before_block(header_blk),
    296                      nir_before_cf_node(&unlimit_term->nif->cf_node));
    297 
    298       /* We need some special handling when its the second terminator causing
    299        * us to exit the loop for example:
    300        *
    301        *   for (int i = 0; i < uniform_lp_count; i++) {
    302        *      colour = vec4(0.0, 1.0, 0.0, 1.0);
    303        *
    304        *      if (i == 1) {
    305        *         break;
    306        *      }
    307        *      ... any further code is unreachable after i == 1 ...
    308        *   }
    309        */
    310       nir_cf_list after_lt;
    311       nir_if *limit_if = limiting_term->nif;
    312       nir_cf_extract(&after_lt, nir_after_cf_node(&limit_if->cf_node),
    313                      nir_after_block(nir_loop_last_block(loop)));
    314       move_cf_list_into_loop_term(&after_lt, limiting_term);
    315 
    316       /* Because the trip count is the number of times we pass over the entire
    317        * loop before hitting a break when the second terminator is the
    318        * limiting terminator we can actually execute code inside the loop when
    319        * trip count == 0 e.g. the code above the break.  So we need to bump
    320        * the trip_count in order for the code below to clone anything.  When
    321        * trip count == 1 we execute the code above the break twice and the
    322        * code below it once so we need clone things twice and so on.
    323        */
    324       num_times_to_clone = loop->info->trip_count + 1;
    325    } else {
    326       /* Pluck out the loop header */
    327       nir_cf_extract(&lp_header, nir_before_block(header_blk),
    328                      nir_before_cf_node(&limiting_term->nif->cf_node));
    329 
    330       nir_block *first_break_block;
    331       nir_block *first_continue_block;
    332       get_first_blocks_in_terminator(limiting_term, &first_break_block,
    333                                      &first_continue_block);
    334 
    335       /* Remove the break then extract instructions from the break block so we
    336        * can insert them in the innermost else of the unrolled loop.
    337        */
    338       nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
    339       nir_instr_remove(break_instr);
    340       nir_cf_extract(&limit_break_list, nir_before_block(first_break_block),
    341                      nir_after_block(limiting_term->break_block));
    342 
    343       nir_cf_list continue_list;
    344       nir_cf_extract(&continue_list, nir_before_block(first_continue_block),
    345                      nir_after_block(limiting_term->continue_from_block));
    346 
    347       nir_cf_reinsert(&continue_list,
    348                       nir_after_cf_node(&limiting_term->nif->cf_node));
    349 
    350       nir_cf_node_remove(&limiting_term->nif->cf_node);
    351 
    352       num_times_to_clone = loop->info->trip_count;
    353    }
    354 
    355    /* In the terminator that we have no trip count for move everything after
    356     * the terminator into the continue from branch.
    357     */
    358    nir_cf_list loop_end;
    359    nir_cf_extract(&loop_end, nir_after_cf_node(&unlimit_term->nif->cf_node),
    360                   nir_after_block(nir_loop_last_block(loop)));
    361    move_cf_list_into_loop_term(&loop_end, unlimit_term);
    362 
    363    /* Pluck out the loop body. */
    364    nir_cf_list loop_body;
    365    nir_cf_extract(&loop_body, nir_before_block(nir_loop_first_block(loop)),
    366                   nir_after_block(nir_loop_last_block(loop)));
    367 
    368    struct hash_table *remap_table =
    369       _mesa_hash_table_create(NULL, _mesa_hash_pointer,
    370                               _mesa_key_pointer_equal);
    371 
    372    /* Set unroll_loc to the loop as we will insert the unrolled loop before it
    373     */
    374    nir_cf_node *unroll_loc = &loop->cf_node;
    375 
    376    /* Temp lists to store the cloned loop as we unroll */
    377    nir_cf_list unrolled_lp_body;
    378    nir_cf_list cloned_header;
    379 
    380    for (unsigned i = 0; i < num_times_to_clone; i++) {
    381       /* Clone loop header */
    382       nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
    383                         remap_table);
    384 
    385       nir_cursor cursor =
    386          get_complex_unroll_insert_location(unroll_loc,
    387                                             unlimit_term->continue_from_then);
    388 
    389       /* Insert cloned loop header */
    390       nir_cf_reinsert(&cloned_header, cursor);
    391 
    392       cursor =
    393          get_complex_unroll_insert_location(unroll_loc,
    394                                             unlimit_term->continue_from_then);
    395 
    396       /* Clone loop body */
    397       nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
    398                         remap_table);
    399 
    400       unroll_loc = exec_node_data(nir_cf_node,
    401                                   exec_list_get_tail(&unrolled_lp_body.list),
    402                                   node);
    403       assert(unroll_loc->type == nir_cf_node_block &&
    404              exec_list_is_empty(&nir_cf_node_as_block(unroll_loc)->instr_list));
    405 
    406       /* Get the unrolled if node */
    407       unroll_loc = nir_cf_node_prev(unroll_loc);
    408 
    409       /* Insert unrolled loop body */
    410       nir_cf_reinsert(&unrolled_lp_body, cursor);
    411    }
    412 
    413    if (!limiting_term_second) {
    414       assert(unroll_loc->type == nir_cf_node_if);
    415 
    416       nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
    417                         remap_table);
    418 
    419       nir_cursor cursor =
    420          get_complex_unroll_insert_location(unroll_loc,
    421                                             unlimit_term->continue_from_then);
    422 
    423       /* Insert cloned loop header */
    424       nir_cf_reinsert(&cloned_header, cursor);
    425 
    426       /* Clone so things get properly remapped, and insert break block from
    427        * the limiting terminator.
    428        */
    429       nir_cf_list cloned_break_blk;
    430       nir_cf_list_clone(&cloned_break_blk, &limit_break_list,
    431                         loop->cf_node.parent, remap_table);
    432 
    433       cursor =
    434          get_complex_unroll_insert_location(unroll_loc,
    435                                             unlimit_term->continue_from_then);
    436 
    437       nir_cf_reinsert(&cloned_break_blk, cursor);
    438       nir_cf_delete(&limit_break_list);
    439    }
    440 
    441    /* The loop has been unrolled so remove it. */
    442    nir_cf_node_remove(&loop->cf_node);
    443 
    444    /* Delete the original loop header and body */
    445    nir_cf_delete(&lp_header);
    446    nir_cf_delete(&loop_body);
    447 
    448    _mesa_hash_table_destroy(remap_table, NULL);
    449 }
    450 
    451 static bool
    452 is_loop_small_enough_to_unroll(nir_shader *shader, nir_loop_info *li)
    453 {
    454    unsigned max_iter = shader->options->max_unroll_iterations;
    455 
    456    if (li->trip_count > max_iter)
    457       return false;
    458 
    459    if (li->force_unroll)
    460       return true;
    461 
    462    bool loop_not_too_large =
    463       li->num_instructions * li->trip_count <= max_iter * 25;
    464 
    465    return loop_not_too_large;
    466 }
    467 
    468 static bool
    469 process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *innermost_loop)
    470 {
    471    bool progress = false;
    472    nir_loop *loop;
    473 
    474    switch (cf_node->type) {
    475    case nir_cf_node_block:
    476       return progress;
    477    case nir_cf_node_if: {
    478       nir_if *if_stmt = nir_cf_node_as_if(cf_node);
    479       foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->then_list)
    480          progress |= process_loops(sh, nested_node, innermost_loop);
    481       foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->else_list)
    482          progress |= process_loops(sh, nested_node, innermost_loop);
    483       return progress;
    484    }
    485    case nir_cf_node_loop: {
    486       loop = nir_cf_node_as_loop(cf_node);
    487       foreach_list_typed_safe(nir_cf_node, nested_node, node, &loop->body)
    488          progress |= process_loops(sh, nested_node, innermost_loop);
    489       break;
    490    }
    491    default:
    492       unreachable("unknown cf node type");
    493    }
    494 
    495    if (*innermost_loop) {
    496       /* Don't attempt to unroll outer loops or a second inner loop in
    497        * this pass wait until the next pass as we have altered the cf.
    498        */
    499       *innermost_loop = false;
    500 
    501       if (loop->info->limiting_terminator == NULL)
    502          return progress;
    503 
    504       if (!is_loop_small_enough_to_unroll(sh, loop->info))
    505          return progress;
    506 
    507       if (loop->info->is_trip_count_known) {
    508          simple_unroll(loop);
    509          progress = true;
    510       } else {
    511          /* Attempt to unroll loops with two terminators. */
    512          unsigned num_lt = list_length(&loop->info->loop_terminator_list);
    513          if (num_lt == 2) {
    514             bool limiting_term_second = true;
    515             nir_loop_terminator *terminator =
    516                list_last_entry(&loop->info->loop_terminator_list,
    517                                 nir_loop_terminator, loop_terminator_link);
    518 
    519 
    520             if (terminator->nif == loop->info->limiting_terminator->nif) {
    521                limiting_term_second = false;
    522                terminator =
    523                   list_first_entry(&loop->info->loop_terminator_list,
    524                                   nir_loop_terminator, loop_terminator_link);
    525             }
    526 
    527             /* If the first terminator has a trip count of zero and is the
    528              * limiting terminator just do a simple unroll as the second
    529              * terminator can never be reached.
    530              */
    531             if (loop->info->trip_count == 0 && !limiting_term_second) {
    532                simple_unroll(loop);
    533             } else {
    534                complex_unroll(loop, terminator, limiting_term_second);
    535             }
    536             progress = true;
    537          }
    538       }
    539    }
    540 
    541    return progress;
    542 }
    543 
    544 static bool
    545 nir_opt_loop_unroll_impl(nir_function_impl *impl,
    546                          nir_variable_mode indirect_mask)
    547 {
    548    bool progress = false;
    549    nir_metadata_require(impl, nir_metadata_loop_analysis, indirect_mask);
    550    nir_metadata_require(impl, nir_metadata_block_index);
    551 
    552    foreach_list_typed_safe(nir_cf_node, node, node, &impl->body) {
    553       bool innermost_loop = true;
    554       progress |= process_loops(impl->function->shader, node,
    555                                 &innermost_loop);
    556    }
    557 
    558    if (progress)
    559       nir_lower_regs_to_ssa_impl(impl);
    560 
    561    return progress;
    562 }
    563 
    564 bool
    565 nir_opt_loop_unroll(nir_shader *shader, nir_variable_mode indirect_mask)
    566 {
    567    bool progress = false;
    568 
    569    nir_foreach_function(function, shader) {
    570       if (function->impl) {
    571          progress |= nir_opt_loop_unroll_impl(function->impl, indirect_mask);
    572       }
    573    }
    574    return progress;
    575 }
    576