1 // Copyright (c) 2018 Google LLC. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "source/opt/licm_pass.h" 16 17 #include <queue> 18 #include <utility> 19 20 #include "source/opt/module.h" 21 #include "source/opt/pass.h" 22 23 namespace spvtools { 24 namespace opt { 25 26 Pass::Status LICMPass::Process() { return ProcessIRContext(); } 27 28 Pass::Status LICMPass::ProcessIRContext() { 29 Status status = Status::SuccessWithoutChange; 30 Module* module = get_module(); 31 32 // Process each function in the module 33 for (auto func = module->begin(); 34 func != module->end() && status != Status::Failure; ++func) { 35 status = CombineStatus(status, ProcessFunction(&*func)); 36 } 37 return status; 38 } 39 40 Pass::Status LICMPass::ProcessFunction(Function* f) { 41 Status status = Status::SuccessWithoutChange; 42 LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f); 43 44 // Process each loop in the function 45 for (auto it = loop_descriptor->begin(); 46 it != loop_descriptor->end() && status != Status::Failure; ++it) { 47 Loop& loop = *it; 48 // Ignore nested loops, as we will process them in order in ProcessLoop 49 if (loop.IsNested()) { 50 continue; 51 } 52 status = CombineStatus(status, ProcessLoop(&loop, f)); 53 } 54 return status; 55 } 56 57 Pass::Status LICMPass::ProcessLoop(Loop* loop, Function* f) { 58 Status status = Status::SuccessWithoutChange; 59 60 // Process all nested loops first 61 for (auto nl = loop->begin(); nl != loop->end() && status != Status::Failure; 62 ++nl) { 63 Loop* nested_loop = *nl; 64 status = CombineStatus(status, ProcessLoop(nested_loop, f)); 65 } 66 67 std::vector<BasicBlock*> loop_bbs{}; 68 status = CombineStatus( 69 status, 70 AnalyseAndHoistFromBB(loop, f, loop->GetHeaderBlock(), &loop_bbs)); 71 72 for (size_t i = 0; i < loop_bbs.size() && status != Status::Failure; ++i) { 73 BasicBlock* bb = loop_bbs[i]; 74 // do not delete the element 75 status = 76 CombineStatus(status, AnalyseAndHoistFromBB(loop, f, bb, &loop_bbs)); 77 } 78 79 return status; 80 } 81 82 Pass::Status LICMPass::AnalyseAndHoistFromBB( 83 Loop* loop, Function* f, BasicBlock* bb, 84 std::vector<BasicBlock*>* loop_bbs) { 85 bool modified = false; 86 std::function<bool(Instruction*)> hoist_inst = 87 [this, &loop, &modified](Instruction* inst) { 88 if (loop->ShouldHoistInstruction(this->context(), inst)) { 89 if (!HoistInstruction(loop, inst)) { 90 return false; 91 } 92 modified = true; 93 } 94 return true; 95 }; 96 97 if (IsImmediatelyContainedInLoop(loop, f, bb)) { 98 if (!bb->WhileEachInst(hoist_inst, false)) { 99 return Status::Failure; 100 } 101 } 102 103 DominatorAnalysis* dom_analysis = context()->GetDominatorAnalysis(f); 104 DominatorTree& dom_tree = dom_analysis->GetDomTree(); 105 106 for (DominatorTreeNode* child_dom_tree_node : *dom_tree.GetTreeNode(bb)) { 107 if (loop->IsInsideLoop(child_dom_tree_node->bb_)) { 108 loop_bbs->push_back(child_dom_tree_node->bb_); 109 } 110 } 111 112 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); 113 } 114 115 bool LICMPass::IsImmediatelyContainedInLoop(Loop* loop, Function* f, 116 BasicBlock* bb) { 117 LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f); 118 return loop == (*loop_descriptor)[bb->id()]; 119 } 120 121 bool LICMPass::HoistInstruction(Loop* loop, Instruction* inst) { 122 // TODO(1841): Handle failure to create pre-header. 123 BasicBlock* pre_header_bb = loop->GetOrCreatePreHeaderBlock(); 124 if (!pre_header_bb) { 125 return false; 126 } 127 Instruction* insertion_point = &*pre_header_bb->tail(); 128 Instruction* previous_node = insertion_point->PreviousNode(); 129 if (previous_node && (previous_node->opcode() == SpvOpLoopMerge || 130 previous_node->opcode() == SpvOpSelectionMerge)) { 131 insertion_point = previous_node; 132 } 133 134 inst->InsertBefore(insertion_point); 135 context()->set_instr_block(inst, pre_header_bb); 136 return true; 137 } 138 139 } // namespace opt 140 } // namespace spvtools 141