1 // Copyright (c) 2019 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "code_sink.h" 16 17 #include <set> 18 #include <vector> 19 20 #include "source/opt/instruction.h" 21 #include "source/opt/ir_builder.h" 22 #include "source/opt/ir_context.h" 23 #include "source/util/bit_vector.h" 24 25 namespace spvtools { 26 namespace opt { 27 28 Pass::Status CodeSinkingPass::Process() { 29 bool modified = false; 30 for (Function& function : *get_module()) { 31 cfg()->ForEachBlockInPostOrder(function.entry().get(), 32 [&modified, this](BasicBlock* bb) { 33 if (SinkInstructionsInBB(bb)) { 34 modified = true; 35 } 36 }); 37 } 38 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; 39 } 40 41 bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) { 42 bool modified = false; 43 for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) { 44 if (SinkInstruction(&*inst)) { 45 inst = bb->rbegin(); 46 modified = true; 47 } 48 } 49 return modified; 50 } 51 52 bool CodeSinkingPass::SinkInstruction(Instruction* inst) { 53 if (inst->opcode() != SpvOpLoad && inst->opcode() != SpvOpAccessChain) { 54 return false; 55 } 56 57 if (ReferencesMutableMemory(inst)) { 58 return false; 59 } 60 61 if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) { 62 Instruction* pos = &*target_bb->begin(); 63 while (pos->opcode() == SpvOpPhi) { 64 pos = pos->NextNode(); 65 } 66 67 inst->InsertBefore(pos); 68 context()->set_instr_block(inst, target_bb); 69 return true; 70 } 71 return false; 72 } 73 74 BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) { 75 assert(inst->result_id() != 0 && "Instruction should have a result."); 76 BasicBlock* original_bb = context()->get_instr_block(inst); 77 BasicBlock* bb = original_bb; 78 79 std::unordered_set<uint32_t> bbs_with_uses; 80 get_def_use_mgr()->ForEachUse( 81 inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) { 82 if (use->opcode() != SpvOpPhi) { 83 bbs_with_uses.insert(context()->get_instr_block(use)->id()); 84 } else { 85 bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1)); 86 } 87 }); 88 89 while (true) { 90 // If |inst| is used in |bb|, then |inst| cannot be moved any further. 91 if (bbs_with_uses.count(bb->id())) { 92 break; 93 } 94 95 // If |bb| has one successor (succ_bb), and |bb| is the only predecessor 96 // of succ_bb, then |inst| can be moved to succ_bb. If succ_bb, has move 97 // then one predecessor, then moving |inst| into succ_bb could cause it to 98 // be executed more often, so the search has to stop. 99 if (bb->terminator()->opcode() == SpvOpBranch) { 100 uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0); 101 if (cfg()->preds(succ_bb_id).size() == 1) { 102 bb = context()->get_instr_block(succ_bb_id); 103 continue; 104 } else { 105 break; 106 } 107 } 108 109 // The remaining checks need to know the merge node. If there is no merge 110 // instruction or an OpLoopMerge, then it is a break or continue. We could 111 // figure it out, but not worth doing it now. 112 Instruction* merge_inst = bb->GetMergeInst(); 113 if (merge_inst == nullptr || merge_inst->opcode() != SpvOpSelectionMerge) { 114 break; 115 } 116 117 // Check all of the successors of |bb| it see which lead to a use of |inst| 118 // before reaching the merge node. 119 bool used_in_multiple_blocks = false; 120 uint32_t bb_used_in = 0; 121 bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks, 122 &bbs_with_uses](uint32_t* succ_bb_id) { 123 if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) { 124 if (bb_used_in == 0) { 125 bb_used_in = *succ_bb_id; 126 } else { 127 used_in_multiple_blocks = true; 128 } 129 } 130 }); 131 132 // If more than one successor, which is not the merge block, uses |inst| 133 // then we have to leave |inst| in bb because there is none of the 134 // successors dominate all uses of |inst|. 135 if (used_in_multiple_blocks) { 136 break; 137 } 138 139 if (bb_used_in == 0) { 140 // If |inst| is not used before reaching the merge node, then we can move 141 // |inst| to the merge node. 142 bb = context()->get_instr_block(bb->MergeBlockIdIfAny()); 143 } else { 144 // If the only successor that leads to a used of |inst| has more than 1 145 // predecessor, then moving |inst| could cause it to be executed more 146 // often, so we cannot move it. 147 if (cfg()->preds(bb_used_in).size() != 1) { 148 break; 149 } 150 151 // If |inst| is used after the merge block, then |bb_used_in| does not 152 // dominate all of the uses. So we cannot move |inst| any further. 153 if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(), 154 bbs_with_uses)) { 155 break; 156 } 157 158 // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that 159 // block. 160 bb = context()->get_instr_block(bb_used_in); 161 } 162 continue; 163 } 164 return (bb != original_bb ? bb : nullptr); 165 } 166 167 bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) { 168 if (!inst->IsLoad()) { 169 return false; 170 } 171 172 Instruction* base_ptr = inst->GetBaseAddress(); 173 if (base_ptr->opcode() != SpvOpVariable) { 174 return true; 175 } 176 177 if (base_ptr->IsReadOnlyVariable()) { 178 return false; 179 } 180 181 if (HasUniformMemorySync()) { 182 return true; 183 } 184 185 if (base_ptr->GetSingleWordInOperand(0) != SpvStorageClassUniform) { 186 return true; 187 } 188 189 return HasPossibleStore(base_ptr); 190 } 191 192 bool CodeSinkingPass::HasUniformMemorySync() { 193 if (checked_for_uniform_sync_) { 194 return has_uniform_sync_; 195 } 196 197 bool has_sync = false; 198 get_module()->ForEachInst([this, &has_sync](Instruction* inst) { 199 switch (inst->opcode()) { 200 case SpvOpMemoryBarrier: { 201 uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1); 202 if (IsSyncOnUniform(mem_semantics_id)) { 203 has_sync = true; 204 } 205 break; 206 } 207 case SpvOpControlBarrier: 208 case SpvOpAtomicLoad: 209 case SpvOpAtomicStore: 210 case SpvOpAtomicExchange: 211 case SpvOpAtomicIIncrement: 212 case SpvOpAtomicIDecrement: 213 case SpvOpAtomicIAdd: 214 case SpvOpAtomicISub: 215 case SpvOpAtomicSMin: 216 case SpvOpAtomicUMin: 217 case SpvOpAtomicSMax: 218 case SpvOpAtomicUMax: 219 case SpvOpAtomicAnd: 220 case SpvOpAtomicOr: 221 case SpvOpAtomicXor: 222 case SpvOpAtomicFlagTestAndSet: 223 case SpvOpAtomicFlagClear: { 224 uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2); 225 if (IsSyncOnUniform(mem_semantics_id)) { 226 has_sync = true; 227 } 228 break; 229 } 230 case SpvOpAtomicCompareExchange: 231 case SpvOpAtomicCompareExchangeWeak: 232 if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) || 233 IsSyncOnUniform(inst->GetSingleWordInOperand(3))) { 234 has_sync = true; 235 } 236 break; 237 default: 238 break; 239 } 240 }); 241 has_uniform_sync_ = has_sync; 242 return has_sync; 243 } 244 245 bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const { 246 const analysis::Constant* mem_semantics_const = 247 context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id); 248 assert(mem_semantics_const != nullptr && 249 "Expecting memory semantics id to be a constant."); 250 assert(mem_semantics_const->AsIntConstant() && 251 "Memory semantics should be an integer."); 252 uint32_t mem_semantics_int = mem_semantics_const->GetU32(); 253 254 // If it does not affect uniform memory, then it is does not apply to uniform 255 // memory. 256 if ((mem_semantics_int & SpvMemorySemanticsUniformMemoryMask) == 0) { 257 return false; 258 } 259 260 // Check if there is an acquire or release. If so not, this it does not add 261 // any memory constraints. 262 return (mem_semantics_int & (SpvMemorySemanticsAcquireMask | 263 SpvMemorySemanticsAcquireReleaseMask | 264 SpvMemorySemanticsReleaseMask)) != 0; 265 } 266 267 bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) { 268 assert(var_inst->opcode() == SpvOpVariable || 269 var_inst->opcode() == SpvOpAccessChain || 270 var_inst->opcode() == SpvOpPtrAccessChain); 271 272 return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) { 273 switch (use->opcode()) { 274 case SpvOpStore: 275 return true; 276 case SpvOpAccessChain: 277 case SpvOpPtrAccessChain: 278 return HasPossibleStore(use); 279 default: 280 return false; 281 } 282 }); 283 } 284 285 bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end, 286 const std::unordered_set<uint32_t>& set) { 287 std::vector<uint32_t> worklist; 288 worklist.push_back(start); 289 std::unordered_set<uint32_t> already_done; 290 already_done.insert(start); 291 292 while (!worklist.empty()) { 293 BasicBlock* bb = context()->get_instr_block(worklist.back()); 294 worklist.pop_back(); 295 296 if (bb->id() == end) { 297 continue; 298 } 299 300 if (set.count(bb->id())) { 301 return true; 302 } 303 304 bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) { 305 if (already_done.insert(*succ_bb_id).second) { 306 worklist.push_back(*succ_bb_id); 307 } 308 }); 309 } 310 return false; 311 } 312 313 // namespace opt 314 315 } // namespace opt 316 } // namespace spvtools 317