Home | History | Annotate | Download | only in opt
      1 // Copyright (c) 2018 Google LLC.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/opt/loop_unswitch_pass.h"
     16 
     17 #include <functional>
     18 #include <list>
     19 #include <memory>
     20 #include <type_traits>
     21 #include <unordered_map>
     22 #include <unordered_set>
     23 #include <utility>
     24 #include <vector>
     25 
     26 #include "source/opt/basic_block.h"
     27 #include "source/opt/dominator_tree.h"
     28 #include "source/opt/fold.h"
     29 #include "source/opt/function.h"
     30 #include "source/opt/instruction.h"
     31 #include "source/opt/ir_builder.h"
     32 #include "source/opt/ir_context.h"
     33 #include "source/opt/loop_descriptor.h"
     34 
     35 #include "source/opt/loop_utils.h"
     36 
     37 namespace spvtools {
     38 namespace opt {
     39 namespace {
     40 
     41 static const uint32_t kTypePointerStorageClassInIdx = 0;
     42 static const uint32_t kBranchCondTrueLabIdInIdx = 1;
     43 static const uint32_t kBranchCondFalseLabIdInIdx = 2;
     44 
     45 }  // anonymous namespace
     46 
     47 namespace {
     48 
     49 // This class handle the unswitch procedure for a given loop.
     50 // The unswitch will not happen if:
     51 //  - The loop has any instruction that will prevent it;
     52 //  - The loop invariant condition is not uniform.
     53 class LoopUnswitch {
     54  public:
     55   LoopUnswitch(IRContext* context, Function* function, Loop* loop,
     56                LoopDescriptor* loop_desc)
     57       : function_(function),
     58         loop_(loop),
     59         loop_desc_(*loop_desc),
     60         context_(context),
     61         switch_block_(nullptr) {}
     62 
     63   // Returns true if the loop can be unswitched.
     64   // Can be unswitch if:
     65   //  - The loop has no instructions that prevents it (such as barrier);
     66   //  - The loop has one conditional branch or switch that do not depends on the
     67   //  loop;
     68   //  - The loop invariant condition is uniform;
     69   bool CanUnswitchLoop() {
     70     if (switch_block_) return true;
     71     if (loop_->IsSafeToClone()) return false;
     72 
     73     CFG& cfg = *context_->cfg();
     74 
     75     for (uint32_t bb_id : loop_->GetBlocks()) {
     76       BasicBlock* bb = cfg.block(bb_id);
     77       if (bb->terminator()->IsBranch() &&
     78           bb->terminator()->opcode() != SpvOpBranch) {
     79         if (IsConditionLoopInvariant(bb->terminator())) {
     80           switch_block_ = bb;
     81           break;
     82         }
     83       }
     84     }
     85 
     86     return switch_block_;
     87   }
     88 
     89   // Return the iterator to the basic block |bb|.
     90   Function::iterator FindBasicBlockPosition(BasicBlock* bb_to_find) {
     91     Function::iterator it = function_->FindBlock(bb_to_find->id());
     92     assert(it != function_->end() && "Basic Block not found");
     93     return it;
     94   }
     95 
     96   // Creates a new basic block and insert it into the function |fn| at the
     97   // position |ip|. This function preserves the def/use and instr to block
     98   // managers.
     99   BasicBlock* CreateBasicBlock(Function::iterator ip) {
    100     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
    101 
    102     // TODO(1841): Handle id overflow.
    103     BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<BasicBlock>(
    104         new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
    105             context_, SpvOpLabel, 0, context_->TakeNextId(), {})))));
    106     bb->SetParent(function_);
    107     def_use_mgr->AnalyzeInstDef(bb->GetLabelInst());
    108     context_->set_instr_block(bb->GetLabelInst(), bb);
    109 
    110     return bb;
    111   }
    112 
    113   // Unswitches |loop_|.
    114   void PerformUnswitch() {
    115     assert(CanUnswitchLoop() &&
    116            "Cannot unswitch if there is not constant condition");
    117     assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block");
    118     assert(loop_->IsLCSSA() && "This loop is not in LCSSA form");
    119 
    120     CFG& cfg = *context_->cfg();
    121     DominatorTree* dom_tree =
    122         &context_->GetDominatorAnalysis(function_)->GetDomTree();
    123     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
    124     LoopUtils loop_utils(context_, loop_);
    125 
    126     //////////////////////////////////////////////////////////////////////////////
    127     // Step 1: Create the if merge block for structured modules.
    128     //    To do so, the |loop_| merge block will become the if's one and we
    129     //    create a merge for the loop. This will limit the amount of duplicated
    130     //    code the structured control flow imposes.
    131     //    For non structured program, the new loop will be connected to
    132     //    the old loop's exit blocks.
    133     //////////////////////////////////////////////////////////////////////////////
    134 
    135     // Get the merge block if it exists.
    136     BasicBlock* if_merge_block = loop_->GetMergeBlock();
    137     // The merge block is only created if the loop has a unique exit block. We
    138     // have this guarantee for structured loops, for compute loop it will
    139     // trivially help maintain both a structured-like form and LCSAA.
    140     BasicBlock* loop_merge_block =
    141         if_merge_block
    142             ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block))
    143             : nullptr;
    144     if (loop_merge_block) {
    145       // Add the instruction and update managers.
    146       InstructionBuilder builder(
    147           context_, loop_merge_block,
    148           IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
    149       builder.AddBranch(if_merge_block->id());
    150       builder.SetInsertPoint(&*loop_merge_block->begin());
    151       cfg.RegisterBlock(loop_merge_block);
    152       def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst());
    153       // Update CFG.
    154       if_merge_block->ForEachPhiInst(
    155           [loop_merge_block, &builder, this](Instruction* phi) {
    156             Instruction* cloned = phi->Clone(context_);
    157             builder.AddInstruction(std::unique_ptr<Instruction>(cloned));
    158             phi->SetInOperand(0, {cloned->result_id()});
    159             phi->SetInOperand(1, {loop_merge_block->id()});
    160             for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--)
    161               phi->RemoveInOperand(j);
    162           });
    163       // Copy the predecessor list (will get invalidated otherwise).
    164       std::vector<uint32_t> preds = cfg.preds(if_merge_block->id());
    165       for (uint32_t pid : preds) {
    166         if (pid == loop_merge_block->id()) continue;
    167         BasicBlock* p_bb = cfg.block(pid);
    168         p_bb->ForEachSuccessorLabel(
    169             [if_merge_block, loop_merge_block](uint32_t* id) {
    170               if (*id == if_merge_block->id()) *id = loop_merge_block->id();
    171             });
    172         cfg.AddEdge(pid, loop_merge_block->id());
    173       }
    174       cfg.RemoveNonExistingEdges(if_merge_block->id());
    175       // Update loop descriptor.
    176       if (Loop* ploop = loop_->GetParent()) {
    177         ploop->AddBasicBlock(loop_merge_block);
    178         loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop);
    179       }
    180 
    181       // Update the dominator tree.
    182       DominatorTreeNode* loop_merge_dtn =
    183           dom_tree->GetOrInsertNode(loop_merge_block);
    184       DominatorTreeNode* if_merge_block_dtn =
    185           dom_tree->GetOrInsertNode(if_merge_block);
    186       loop_merge_dtn->parent_ = if_merge_block_dtn->parent_;
    187       loop_merge_dtn->children_.push_back(if_merge_block_dtn);
    188       loop_merge_dtn->parent_->children_.push_back(loop_merge_dtn);
    189       if_merge_block_dtn->parent_->children_.erase(std::find(
    190           if_merge_block_dtn->parent_->children_.begin(),
    191           if_merge_block_dtn->parent_->children_.end(), if_merge_block_dtn));
    192 
    193       loop_->SetMergeBlock(loop_merge_block);
    194     }
    195 
    196     ////////////////////////////////////////////////////////////////////////////
    197     // Step 2: Build a new preheader for |loop_|, use the old one
    198     //         for the constant branch.
    199     ////////////////////////////////////////////////////////////////////////////
    200 
    201     BasicBlock* if_block = loop_->GetPreHeaderBlock();
    202     // If this preheader is the parent loop header,
    203     // we need to create a dedicated block for the if.
    204     BasicBlock* loop_pre_header =
    205         CreateBasicBlock(++FindBasicBlockPosition(if_block));
    206     InstructionBuilder(
    207         context_, loop_pre_header,
    208         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping)
    209         .AddBranch(loop_->GetHeaderBlock()->id());
    210 
    211     if_block->tail()->SetInOperand(0, {loop_pre_header->id()});
    212 
    213     // Update loop descriptor.
    214     if (Loop* ploop = loop_desc_[if_block]) {
    215       ploop->AddBasicBlock(loop_pre_header);
    216       loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop);
    217     }
    218 
    219     // Update the CFG.
    220     cfg.RegisterBlock(loop_pre_header);
    221     def_use_mgr->AnalyzeInstDef(loop_pre_header->GetLabelInst());
    222     cfg.AddEdge(if_block->id(), loop_pre_header->id());
    223     cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id());
    224 
    225     loop_->GetHeaderBlock()->ForEachPhiInst(
    226         [loop_pre_header, if_block](Instruction* phi) {
    227           phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) {
    228             if (*id == if_block->id()) {
    229               *id = loop_pre_header->id();
    230             }
    231           });
    232         });
    233     loop_->SetPreHeaderBlock(loop_pre_header);
    234 
    235     // Update the dominator tree.
    236     DominatorTreeNode* loop_pre_header_dtn =
    237         dom_tree->GetOrInsertNode(loop_pre_header);
    238     DominatorTreeNode* if_block_dtn = dom_tree->GetTreeNode(if_block);
    239     loop_pre_header_dtn->parent_ = if_block_dtn;
    240     assert(
    241         if_block_dtn->children_.size() == 1 &&
    242         "A loop preheader should only have the header block as a child in the "
    243         "dominator tree");
    244     loop_pre_header_dtn->children_.push_back(if_block_dtn->children_[0]);
    245     if_block_dtn->children_.clear();
    246     if_block_dtn->children_.push_back(loop_pre_header_dtn);
    247 
    248     // Make domination queries valid.
    249     dom_tree->ResetDFNumbering();
    250 
    251     // Compute an ordered list of basic block to clone: loop blocks + pre-header
    252     // + merge block.
    253     loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks_, true, true);
    254 
    255     /////////////////////////////
    256     // Do the actual unswitch: //
    257     //   - Clone the loop      //
    258     //   - Connect exits       //
    259     //   - Specialize the loop //
    260     /////////////////////////////
    261 
    262     Instruction* iv_condition = &*switch_block_->tail();
    263     SpvOp iv_opcode = iv_condition->opcode();
    264     Instruction* condition =
    265         def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]);
    266 
    267     analysis::ConstantManager* cst_mgr = context_->get_constant_mgr();
    268     const analysis::Type* cond_type =
    269         context_->get_type_mgr()->GetType(condition->type_id());
    270 
    271     // Build the list of value for which we need to clone and specialize the
    272     // loop.
    273     std::vector<std::pair<Instruction*, BasicBlock*>> constant_branch;
    274     // Special case for the original loop
    275     Instruction* original_loop_constant_value;
    276     BasicBlock* original_loop_target;
    277     if (iv_opcode == SpvOpBranchConditional) {
    278       constant_branch.emplace_back(
    279           cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})),
    280           nullptr);
    281       original_loop_constant_value =
    282           cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {1}));
    283     } else {
    284       // We are looking to take the default branch, so we can't provide a
    285       // specific value.
    286       original_loop_constant_value = nullptr;
    287       for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) {
    288         constant_branch.emplace_back(
    289             cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(
    290                 cond_type, iv_condition->GetInOperand(i).words)),
    291             nullptr);
    292       }
    293     }
    294 
    295     // Get the loop landing pads.
    296     std::unordered_set<uint32_t> if_merging_blocks;
    297     std::function<bool(uint32_t)> is_from_original_loop;
    298     if (loop_->GetHeaderBlock()->GetLoopMergeInst()) {
    299       if_merging_blocks.insert(if_merge_block->id());
    300       is_from_original_loop = [this](uint32_t id) {
    301         return loop_->IsInsideLoop(id) || loop_->GetMergeBlock()->id() == id;
    302       };
    303     } else {
    304       loop_->GetExitBlocks(&if_merging_blocks);
    305       is_from_original_loop = [this](uint32_t id) {
    306         return loop_->IsInsideLoop(id);
    307       };
    308     }
    309 
    310     for (auto& specialisation_pair : constant_branch) {
    311       Instruction* specialisation_value = specialisation_pair.first;
    312       //////////////////////////////////////////////////////////
    313       // Step 3: Duplicate |loop_|.
    314       //////////////////////////////////////////////////////////
    315       LoopUtils::LoopCloningResult clone_result;
    316 
    317       Loop* cloned_loop =
    318           loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_);
    319       specialisation_pair.second = cloned_loop->GetPreHeaderBlock();
    320 
    321       ////////////////////////////////////
    322       // Step 4: Specialize the loop.   //
    323       ////////////////////////////////////
    324 
    325       {
    326         std::unordered_set<uint32_t> dead_blocks;
    327         std::unordered_set<uint32_t> unreachable_merges;
    328         SimplifyLoop(
    329             make_range(
    330                 UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
    331                                                clone_result.cloned_bb_.begin()),
    332                 UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
    333                                                clone_result.cloned_bb_.end())),
    334             cloned_loop, condition, specialisation_value, &dead_blocks);
    335 
    336         // We tagged dead blocks, create the loop before we invalidate any basic
    337         // block.
    338         cloned_loop =
    339             CleanLoopNest(cloned_loop, dead_blocks, &unreachable_merges);
    340         CleanUpCFG(
    341             UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
    342                                            clone_result.cloned_bb_.begin()),
    343             dead_blocks, unreachable_merges);
    344 
    345         ///////////////////////////////////////////////////////////
    346         // Step 5: Connect convergent edges to the landing pads. //
    347         ///////////////////////////////////////////////////////////
    348 
    349         for (uint32_t merge_bb_id : if_merging_blocks) {
    350           BasicBlock* merge = context_->cfg()->block(merge_bb_id);
    351           // We are in LCSSA so we only care about phi instructions.
    352           merge->ForEachPhiInst([is_from_original_loop, &dead_blocks,
    353                                  &clone_result](Instruction* phi) {
    354             uint32_t num_in_operands = phi->NumInOperands();
    355             for (uint32_t i = 0; i < num_in_operands; i += 2) {
    356               uint32_t pred = phi->GetSingleWordInOperand(i + 1);
    357               if (is_from_original_loop(pred)) {
    358                 pred = clone_result.value_map_.at(pred);
    359                 if (!dead_blocks.count(pred)) {
    360                   uint32_t incoming_value_id = phi->GetSingleWordInOperand(i);
    361                   // Not all the incoming value are coming from the loop.
    362                   ValueMapTy::iterator new_value =
    363                       clone_result.value_map_.find(incoming_value_id);
    364                   if (new_value != clone_result.value_map_.end()) {
    365                     incoming_value_id = new_value->second;
    366                   }
    367                   phi->AddOperand({SPV_OPERAND_TYPE_ID, {incoming_value_id}});
    368                   phi->AddOperand({SPV_OPERAND_TYPE_ID, {pred}});
    369                 }
    370               }
    371             }
    372           });
    373         }
    374       }
    375       function_->AddBasicBlocks(clone_result.cloned_bb_.begin(),
    376                                 clone_result.cloned_bb_.end(),
    377                                 ++FindBasicBlockPosition(if_block));
    378     }
    379 
    380     // Same as above but specialize the existing loop
    381     {
    382       std::unordered_set<uint32_t> dead_blocks;
    383       std::unordered_set<uint32_t> unreachable_merges;
    384       SimplifyLoop(make_range(function_->begin(), function_->end()), loop_,
    385                    condition, original_loop_constant_value, &dead_blocks);
    386 
    387       for (uint32_t merge_bb_id : if_merging_blocks) {
    388         BasicBlock* merge = context_->cfg()->block(merge_bb_id);
    389         // LCSSA, so we only care about phi instructions.
    390         // If we the phi is reduced to a single incoming branch, do not
    391         // propagate it to preserve LCSSA.
    392         PatchPhis(merge, dead_blocks, true);
    393       }
    394       if (if_merge_block) {
    395         bool has_live_pred = false;
    396         for (uint32_t pid : cfg.preds(if_merge_block->id())) {
    397           if (!dead_blocks.count(pid)) {
    398             has_live_pred = true;
    399             break;
    400           }
    401         }
    402         if (!has_live_pred) unreachable_merges.insert(if_merge_block->id());
    403       }
    404       original_loop_target = loop_->GetPreHeaderBlock();
    405       // We tagged dead blocks, prune the loop descriptor from any dead loops.
    406       // After this call, |loop_| can be nullptr (i.e. the unswitch killed this
    407       // loop).
    408       loop_ = CleanLoopNest(loop_, dead_blocks, &unreachable_merges);
    409 
    410       CleanUpCFG(function_->begin(), dead_blocks, unreachable_merges);
    411     }
    412 
    413     /////////////////////////////////////
    414     // Finally: connect the new loops. //
    415     /////////////////////////////////////
    416 
    417     // Delete the old jump
    418     context_->KillInst(&*if_block->tail());
    419     InstructionBuilder builder(context_, if_block);
    420     if (iv_opcode == SpvOpBranchConditional) {
    421       assert(constant_branch.size() == 1);
    422       builder.AddConditionalBranch(
    423           condition->result_id(), original_loop_target->id(),
    424           constant_branch[0].second->id(),
    425           if_merge_block ? if_merge_block->id() : kInvalidId);
    426     } else {
    427       std::vector<std::pair<Operand::OperandData, uint32_t>> targets;
    428       for (auto& t : constant_branch) {
    429         targets.emplace_back(t.first->GetInOperand(0).words, t.second->id());
    430       }
    431 
    432       builder.AddSwitch(condition->result_id(), original_loop_target->id(),
    433                         targets,
    434                         if_merge_block ? if_merge_block->id() : kInvalidId);
    435     }
    436 
    437     switch_block_ = nullptr;
    438     ordered_loop_blocks_.clear();
    439 
    440     context_->InvalidateAnalysesExceptFor(
    441         IRContext::Analysis::kAnalysisLoopAnalysis);
    442   }
    443 
    444   // Returns true if the unswitch killed the original |loop_|.
    445   bool WasLoopKilled() const { return loop_ == nullptr; }
    446 
    447  private:
    448   using ValueMapTy = std::unordered_map<uint32_t, uint32_t>;
    449   using BlockMapTy = std::unordered_map<uint32_t, BasicBlock*>;
    450 
    451   Function* function_;
    452   Loop* loop_;
    453   LoopDescriptor& loop_desc_;
    454   IRContext* context_;
    455 
    456   BasicBlock* switch_block_;
    457   // Map between instructions and if they are dynamically uniform.
    458   std::unordered_map<uint32_t, bool> dynamically_uniform_;
    459   // The loop basic blocks in structured order.
    460   std::vector<BasicBlock*> ordered_loop_blocks_;
    461 
    462   // Returns the next usable id for the context.
    463   uint32_t TakeNextId() {
    464     // TODO(1841): Handle id overflow.
    465     return context_->TakeNextId();
    466   }
    467 
    468   // Patches |bb|'s phi instruction by removing incoming value from unexisting
    469   // or tagged as dead branches.
    470   void PatchPhis(BasicBlock* bb,
    471                  const std::unordered_set<uint32_t>& dead_blocks,
    472                  bool preserve_phi) {
    473     CFG& cfg = *context_->cfg();
    474 
    475     std::vector<Instruction*> phi_to_kill;
    476     const std::vector<uint32_t>& bb_preds = cfg.preds(bb->id());
    477     auto is_branch_dead = [&bb_preds, &dead_blocks](uint32_t id) {
    478       return dead_blocks.count(id) ||
    479              std::find(bb_preds.begin(), bb_preds.end(), id) == bb_preds.end();
    480     };
    481     bb->ForEachPhiInst(
    482         [&phi_to_kill, &is_branch_dead, preserve_phi, this](Instruction* insn) {
    483           uint32_t i = 0;
    484           while (i < insn->NumInOperands()) {
    485             uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1);
    486             if (is_branch_dead(incoming_id)) {
    487               // Remove the incoming block id operand.
    488               insn->RemoveInOperand(i + 1);
    489               // Remove the definition id operand.
    490               insn->RemoveInOperand(i);
    491               continue;
    492             }
    493             i += 2;
    494           }
    495           // If there is only 1 remaining edge, propagate the value and
    496           // kill the instruction.
    497           if (insn->NumInOperands() == 2 && !preserve_phi) {
    498             phi_to_kill.push_back(insn);
    499             context_->ReplaceAllUsesWith(insn->result_id(),
    500                                          insn->GetSingleWordInOperand(0));
    501           }
    502         });
    503     for (Instruction* insn : phi_to_kill) {
    504       context_->KillInst(insn);
    505     }
    506   }
    507 
    508   // Removes any block that is tagged as dead, if the block is in
    509   // |unreachable_merges| then all block's instructions are replaced by a
    510   // OpUnreachable.
    511   void CleanUpCFG(UptrVectorIterator<BasicBlock> bb_it,
    512                   const std::unordered_set<uint32_t>& dead_blocks,
    513                   const std::unordered_set<uint32_t>& unreachable_merges) {
    514     CFG& cfg = *context_->cfg();
    515 
    516     while (bb_it != bb_it.End()) {
    517       BasicBlock& bb = *bb_it;
    518 
    519       if (unreachable_merges.count(bb.id())) {
    520         if (bb.begin() != bb.tail() ||
    521             bb.terminator()->opcode() != SpvOpUnreachable) {
    522           // Make unreachable, but leave the label.
    523           bb.KillAllInsts(false);
    524           InstructionBuilder(context_, &bb).AddUnreachable();
    525           cfg.RemoveNonExistingEdges(bb.id());
    526         }
    527         ++bb_it;
    528       } else if (dead_blocks.count(bb.id())) {
    529         cfg.ForgetBlock(&bb);
    530         // Kill this block.
    531         bb.KillAllInsts(true);
    532         bb_it = bb_it.Erase();
    533       } else {
    534         cfg.RemoveNonExistingEdges(bb.id());
    535         ++bb_it;
    536       }
    537     }
    538   }
    539 
    540   // Return true if |c_inst| is a Boolean constant and set |cond_val| with the
    541   // value that |c_inst|
    542   bool GetConstCondition(const Instruction* c_inst, bool* cond_val) {
    543     bool cond_is_const;
    544     switch (c_inst->opcode()) {
    545       case SpvOpConstantFalse: {
    546         *cond_val = false;
    547         cond_is_const = true;
    548       } break;
    549       case SpvOpConstantTrue: {
    550         *cond_val = true;
    551         cond_is_const = true;
    552       } break;
    553       default: { cond_is_const = false; } break;
    554     }
    555     return cond_is_const;
    556   }
    557 
    558   // Simplifies |loop| assuming the instruction |to_version_insn| takes the
    559   // value |cst_value|. |block_range| is an iterator range returning the loop
    560   // basic blocks in a structured order (dominator first).
    561   // The function will ignore basic blocks returned by |block_range| if they
    562   // does not belong to the loop.
    563   // The set |dead_blocks| will contain all the dead basic blocks.
    564   //
    565   // Requirements:
    566   //   - |loop| must be in the LCSSA form;
    567   //   - |cst_value| must be constant or null (to represent the default target
    568   //   of an OpSwitch).
    569   void SimplifyLoop(IteratorRange<UptrVectorIterator<BasicBlock>> block_range,
    570                     Loop* loop, Instruction* to_version_insn,
    571                     Instruction* cst_value,
    572                     std::unordered_set<uint32_t>* dead_blocks) {
    573     CFG& cfg = *context_->cfg();
    574     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
    575 
    576     std::function<bool(uint32_t)> ignore_node;
    577     ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); };
    578 
    579     std::vector<std::pair<Instruction*, uint32_t>> use_list;
    580     def_use_mgr->ForEachUse(to_version_insn,
    581                             [&use_list, &ignore_node, this](
    582                                 Instruction* inst, uint32_t operand_index) {
    583                               BasicBlock* bb = context_->get_instr_block(inst);
    584 
    585                               if (!bb || ignore_node(bb->id())) {
    586                                 // Out of the loop, the specialization does not
    587                                 // apply any more.
    588                                 return;
    589                               }
    590                               use_list.emplace_back(inst, operand_index);
    591                             });
    592 
    593     // First pass: inject the specialized value into the loop (and only the
    594     // loop).
    595     for (auto use : use_list) {
    596       Instruction* inst = use.first;
    597       uint32_t operand_index = use.second;
    598       BasicBlock* bb = context_->get_instr_block(inst);
    599 
    600       // If it is not a branch, simply inject the value.
    601       if (!inst->IsBranch()) {
    602         // To also handle switch, cst_value can be nullptr: this case
    603         // means that we are looking to branch to the default target of
    604         // the switch. We don't actually know its value so we don't touch
    605         // it if it not a switch.
    606         if (cst_value) {
    607           inst->SetOperand(operand_index, {cst_value->result_id()});
    608           def_use_mgr->AnalyzeInstUse(inst);
    609         }
    610       }
    611 
    612       // The user is a branch, kill dead branches.
    613       uint32_t live_target = 0;
    614       std::unordered_set<uint32_t> dead_branches;
    615       switch (inst->opcode()) {
    616         case SpvOpBranchConditional: {
    617           assert(cst_value && "No constant value to specialize !");
    618           bool branch_cond = false;
    619           if (GetConstCondition(cst_value, &branch_cond)) {
    620             uint32_t true_label =
    621                 inst->GetSingleWordInOperand(kBranchCondTrueLabIdInIdx);
    622             uint32_t false_label =
    623                 inst->GetSingleWordInOperand(kBranchCondFalseLabIdInIdx);
    624             live_target = branch_cond ? true_label : false_label;
    625             uint32_t dead_target = !branch_cond ? true_label : false_label;
    626             cfg.RemoveEdge(bb->id(), dead_target);
    627           }
    628           break;
    629         }
    630         case SpvOpSwitch: {
    631           live_target = inst->GetSingleWordInOperand(1);
    632           if (cst_value) {
    633             if (!cst_value->IsConstant()) break;
    634             const Operand& cst = cst_value->GetInOperand(0);
    635             for (uint32_t i = 2; i < inst->NumInOperands(); i += 2) {
    636               const Operand& literal = inst->GetInOperand(i);
    637               if (literal == cst) {
    638                 live_target = inst->GetSingleWordInOperand(i + 1);
    639                 break;
    640               }
    641             }
    642           }
    643           for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) {
    644             uint32_t id = inst->GetSingleWordInOperand(i);
    645             if (id != live_target) {
    646               cfg.RemoveEdge(bb->id(), id);
    647             }
    648           }
    649         }
    650         default:
    651           break;
    652       }
    653       if (live_target != 0) {
    654         // Check for the presence of the merge block.
    655         if (Instruction* merge = bb->GetMergeInst()) context_->KillInst(merge);
    656         context_->KillInst(&*bb->tail());
    657         InstructionBuilder builder(context_, bb,
    658                                    IRContext::kAnalysisDefUse |
    659                                        IRContext::kAnalysisInstrToBlockMapping);
    660         builder.AddBranch(live_target);
    661       }
    662     }
    663 
    664     // Go through the loop basic block and tag all blocks that are obviously
    665     // dead.
    666     std::unordered_set<uint32_t> visited;
    667     for (BasicBlock& bb : block_range) {
    668       if (ignore_node(bb.id())) continue;
    669       visited.insert(bb.id());
    670 
    671       // Check if this block is dead, if so tag it as dead otherwise patch phi
    672       // instructions.
    673       bool has_live_pred = false;
    674       for (uint32_t pid : cfg.preds(bb.id())) {
    675         if (!dead_blocks->count(pid)) {
    676           has_live_pred = true;
    677           break;
    678         }
    679       }
    680       if (!has_live_pred) {
    681         dead_blocks->insert(bb.id());
    682         const BasicBlock& cbb = bb;
    683         // Patch the phis for any back-edge.
    684         cbb.ForEachSuccessorLabel(
    685             [dead_blocks, &visited, &cfg, this](uint32_t id) {
    686               if (!visited.count(id) || dead_blocks->count(id)) return;
    687               BasicBlock* succ = cfg.block(id);
    688               PatchPhis(succ, *dead_blocks, false);
    689             });
    690         continue;
    691       }
    692       // Update the phi instructions, some incoming branch have/will disappear.
    693       PatchPhis(&bb, *dead_blocks, /* preserve_phi = */ false);
    694     }
    695   }
    696 
    697   // Returns true if the header is not reachable or tagged as dead or if we
    698   // never loop back.
    699   bool IsLoopDead(BasicBlock* header, BasicBlock* latch,
    700                   const std::unordered_set<uint32_t>& dead_blocks) {
    701     if (!header || dead_blocks.count(header->id())) return true;
    702     if (!latch || dead_blocks.count(latch->id())) return true;
    703     for (uint32_t pid : context_->cfg()->preds(header->id())) {
    704       if (!dead_blocks.count(pid)) {
    705         // Seems reachable.
    706         return false;
    707       }
    708     }
    709     return true;
    710   }
    711 
    712   // Cleans the loop nest under |loop| and reflect changes to the loop
    713   // descriptor. This will kill all descriptors that represent dead loops.
    714   // If |loop_| is killed, it will be set to nullptr.
    715   // Any merge blocks that become unreachable will be added to
    716   // |unreachable_merges|.
    717   // The function returns the pointer to |loop| or nullptr if the loop was
    718   // killed.
    719   Loop* CleanLoopNest(Loop* loop,
    720                       const std::unordered_set<uint32_t>& dead_blocks,
    721                       std::unordered_set<uint32_t>* unreachable_merges) {
    722     // This represent the pair of dead loop and nearest alive parent (nullptr if
    723     // no parent).
    724     std::unordered_map<Loop*, Loop*> dead_loops;
    725     auto get_parent = [&dead_loops](Loop* l) -> Loop* {
    726       std::unordered_map<Loop*, Loop*>::iterator it = dead_loops.find(l);
    727       if (it != dead_loops.end()) return it->second;
    728       return nullptr;
    729     };
    730 
    731     bool is_main_loop_dead =
    732         IsLoopDead(loop->GetHeaderBlock(), loop->GetLatchBlock(), dead_blocks);
    733     if (is_main_loop_dead) {
    734       if (Instruction* merge = loop->GetHeaderBlock()->GetLoopMergeInst()) {
    735         context_->KillInst(merge);
    736       }
    737       dead_loops[loop] = loop->GetParent();
    738     } else {
    739       dead_loops[loop] = loop;
    740     }
    741 
    742     // For each loop, check if we killed it. If we did, find a suitable parent
    743     // for its children.
    744     for (Loop& sub_loop :
    745          make_range(++TreeDFIterator<Loop>(loop), TreeDFIterator<Loop>())) {
    746       if (IsLoopDead(sub_loop.GetHeaderBlock(), sub_loop.GetLatchBlock(),
    747                      dead_blocks)) {
    748         if (Instruction* merge =
    749                 sub_loop.GetHeaderBlock()->GetLoopMergeInst()) {
    750           context_->KillInst(merge);
    751         }
    752         dead_loops[&sub_loop] = get_parent(&sub_loop);
    753       } else {
    754         // The loop is alive, check if its merge block is dead, if it is, tag it
    755         // as required.
    756         if (sub_loop.GetMergeBlock()) {
    757           uint32_t merge_id = sub_loop.GetMergeBlock()->id();
    758           if (dead_blocks.count(merge_id)) {
    759             unreachable_merges->insert(sub_loop.GetMergeBlock()->id());
    760           }
    761         }
    762       }
    763     }
    764     if (!is_main_loop_dead) dead_loops.erase(loop);
    765 
    766     // Remove dead blocks from live loops.
    767     for (uint32_t bb_id : dead_blocks) {
    768       Loop* l = loop_desc_[bb_id];
    769       if (l) {
    770         l->RemoveBasicBlock(bb_id);
    771         loop_desc_.ForgetBasicBlock(bb_id);
    772       }
    773     }
    774 
    775     std::for_each(
    776         dead_loops.begin(), dead_loops.end(),
    777         [&loop,
    778          this](std::unordered_map<Loop*, Loop*>::iterator::reference it) {
    779           if (it.first == loop) loop = nullptr;
    780           loop_desc_.RemoveLoop(it.first);
    781         });
    782 
    783     return loop;
    784   }
    785 
    786   // Returns true if |var| is dynamically uniform.
    787   // Note: this is currently approximated as uniform.
    788   bool IsDynamicallyUniform(Instruction* var, const BasicBlock* entry,
    789                             const DominatorTree& post_dom_tree) {
    790     assert(post_dom_tree.IsPostDominator());
    791     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
    792 
    793     auto it = dynamically_uniform_.find(var->result_id());
    794 
    795     if (it != dynamically_uniform_.end()) return it->second;
    796 
    797     analysis::DecorationManager* dec_mgr = context_->get_decoration_mgr();
    798 
    799     bool& is_uniform = dynamically_uniform_[var->result_id()];
    800     is_uniform = false;
    801 
    802     dec_mgr->WhileEachDecoration(var->result_id(), SpvDecorationUniform,
    803                                  [&is_uniform](const Instruction&) {
    804                                    is_uniform = true;
    805                                    return false;
    806                                  });
    807     if (is_uniform) {
    808       return is_uniform;
    809     }
    810 
    811     BasicBlock* parent = context_->get_instr_block(var);
    812     if (!parent) {
    813       return is_uniform = true;
    814     }
    815 
    816     if (!post_dom_tree.Dominates(parent->id(), entry->id())) {
    817       return is_uniform = false;
    818     }
    819     if (var->opcode() == SpvOpLoad) {
    820       const uint32_t PtrTypeId =
    821           def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id();
    822       const Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId);
    823       uint32_t storage_class =
    824           PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx);
    825       if (storage_class != SpvStorageClassUniform &&
    826           storage_class != SpvStorageClassUniformConstant) {
    827         return is_uniform = false;
    828       }
    829     } else {
    830       if (!context_->IsCombinatorInstruction(var)) {
    831         return is_uniform = false;
    832       }
    833     }
    834 
    835     return is_uniform = var->WhileEachInId([entry, &post_dom_tree,
    836                                             this](const uint32_t* id) {
    837       return IsDynamicallyUniform(context_->get_def_use_mgr()->GetDef(*id),
    838                                   entry, post_dom_tree);
    839     });
    840   }
    841 
    842   // Returns true if |insn| is constant and dynamically uniform within the loop.
    843   bool IsConditionLoopInvariant(Instruction* insn) {
    844     assert(insn->IsBranch());
    845     assert(insn->opcode() != SpvOpBranch);
    846     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
    847 
    848     Instruction* condition = def_use_mgr->GetDef(insn->GetOperand(0).words[0]);
    849     return !loop_->IsInsideLoop(condition) &&
    850            IsDynamicallyUniform(
    851                condition, function_->entry().get(),
    852                context_->GetPostDominatorAnalysis(function_)->GetDomTree());
    853   }
    854 };
    855 
    856 }  // namespace
    857 
    858 Pass::Status LoopUnswitchPass::Process() {
    859   bool modified = false;
    860   Module* module = context()->module();
    861 
    862   // Process each function in the module
    863   for (Function& f : *module) {
    864     modified |= ProcessFunction(&f);
    865   }
    866 
    867   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
    868 }
    869 
    870 bool LoopUnswitchPass::ProcessFunction(Function* f) {
    871   bool modified = false;
    872   std::unordered_set<Loop*> processed_loop;
    873 
    874   LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f);
    875 
    876   bool loop_changed = true;
    877   while (loop_changed) {
    878     loop_changed = false;
    879     for (Loop& loop :
    880          make_range(++TreeDFIterator<Loop>(loop_descriptor.GetDummyRootLoop()),
    881                     TreeDFIterator<Loop>())) {
    882       if (processed_loop.count(&loop)) continue;
    883       processed_loop.insert(&loop);
    884 
    885       LoopUnswitch unswitcher(context(), f, &loop, &loop_descriptor);
    886       while (!unswitcher.WasLoopKilled() && unswitcher.CanUnswitchLoop()) {
    887         if (!loop.IsLCSSA()) {
    888           LoopUtils(context(), &loop).MakeLoopClosedSSA();
    889         }
    890         modified = true;
    891         loop_changed = true;
    892         unswitcher.PerformUnswitch();
    893       }
    894       if (loop_changed) break;
    895     }
    896   }
    897 
    898   return modified;
    899 }
    900 
    901 }  // namespace opt
    902 }  // namespace spvtools
    903