Home | History | Annotate | Download | only in opt
      1 // Copyright (c) 2019 Google LLC
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "code_sink.h"
     16 
     17 #include <set>
     18 #include <vector>
     19 
     20 #include "source/opt/instruction.h"
     21 #include "source/opt/ir_builder.h"
     22 #include "source/opt/ir_context.h"
     23 #include "source/util/bit_vector.h"
     24 
     25 namespace spvtools {
     26 namespace opt {
     27 
     28 Pass::Status CodeSinkingPass::Process() {
     29   bool modified = false;
     30   for (Function& function : *get_module()) {
     31     cfg()->ForEachBlockInPostOrder(function.entry().get(),
     32                                    [&modified, this](BasicBlock* bb) {
     33                                      if (SinkInstructionsInBB(bb)) {
     34                                        modified = true;
     35                                      }
     36                                    });
     37   }
     38   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
     39 }
     40 
     41 bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) {
     42   bool modified = false;
     43   for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) {
     44     if (SinkInstruction(&*inst)) {
     45       inst = bb->rbegin();
     46       modified = true;
     47     }
     48   }
     49   return modified;
     50 }
     51 
     52 bool CodeSinkingPass::SinkInstruction(Instruction* inst) {
     53   if (inst->opcode() != SpvOpLoad && inst->opcode() != SpvOpAccessChain) {
     54     return false;
     55   }
     56 
     57   if (ReferencesMutableMemory(inst)) {
     58     return false;
     59   }
     60 
     61   if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) {
     62     Instruction* pos = &*target_bb->begin();
     63     while (pos->opcode() == SpvOpPhi) {
     64       pos = pos->NextNode();
     65     }
     66 
     67     inst->InsertBefore(pos);
     68     context()->set_instr_block(inst, target_bb);
     69     return true;
     70   }
     71   return false;
     72 }
     73 
     74 BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) {
     75   assert(inst->result_id() != 0 && "Instruction should have a result.");
     76   BasicBlock* original_bb = context()->get_instr_block(inst);
     77   BasicBlock* bb = original_bb;
     78 
     79   std::unordered_set<uint32_t> bbs_with_uses;
     80   get_def_use_mgr()->ForEachUse(
     81       inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) {
     82         if (use->opcode() != SpvOpPhi) {
     83           bbs_with_uses.insert(context()->get_instr_block(use)->id());
     84         } else {
     85           bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1));
     86         }
     87       });
     88 
     89   while (true) {
     90     // If |inst| is used in |bb|, then |inst| cannot be moved any further.
     91     if (bbs_with_uses.count(bb->id())) {
     92       break;
     93     }
     94 
     95     // If |bb| has one successor (succ_bb), and |bb| is the only predecessor
     96     // of succ_bb, then |inst| can be moved to succ_bb.  If succ_bb, has move
     97     // then one predecessor, then moving |inst| into succ_bb could cause it to
     98     // be executed more often, so the search has to stop.
     99     if (bb->terminator()->opcode() == SpvOpBranch) {
    100       uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0);
    101       if (cfg()->preds(succ_bb_id).size() == 1) {
    102         bb = context()->get_instr_block(succ_bb_id);
    103         continue;
    104       } else {
    105         break;
    106       }
    107     }
    108 
    109     // The remaining checks need to know the merge node.  If there is no merge
    110     // instruction or an OpLoopMerge, then it is a break or continue.  We could
    111     // figure it out, but not worth doing it now.
    112     Instruction* merge_inst = bb->GetMergeInst();
    113     if (merge_inst == nullptr || merge_inst->opcode() != SpvOpSelectionMerge) {
    114       break;
    115     }
    116 
    117     // Check all of the successors of |bb| it see which lead to a use of |inst|
    118     // before reaching the merge node.
    119     bool used_in_multiple_blocks = false;
    120     uint32_t bb_used_in = 0;
    121     bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks,
    122                                &bbs_with_uses](uint32_t* succ_bb_id) {
    123       if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) {
    124         if (bb_used_in == 0) {
    125           bb_used_in = *succ_bb_id;
    126         } else {
    127           used_in_multiple_blocks = true;
    128         }
    129       }
    130     });
    131 
    132     // If more than one successor, which is not the merge block, uses |inst|
    133     // then we have to leave |inst| in bb because there is none of the
    134     // successors dominate all uses of |inst|.
    135     if (used_in_multiple_blocks) {
    136       break;
    137     }
    138 
    139     if (bb_used_in == 0) {
    140       // If |inst| is not used before reaching the merge node, then we can move
    141       // |inst| to the merge node.
    142       bb = context()->get_instr_block(bb->MergeBlockIdIfAny());
    143     } else {
    144       // If the only successor that leads to a used of |inst| has more than 1
    145       // predecessor, then moving |inst| could cause it to be executed more
    146       // often, so we cannot move it.
    147       if (cfg()->preds(bb_used_in).size() != 1) {
    148         break;
    149       }
    150 
    151       // If |inst| is used after the merge block, then |bb_used_in| does not
    152       // dominate all of the uses.  So we cannot move |inst| any further.
    153       if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(),
    154                          bbs_with_uses)) {
    155         break;
    156       }
    157 
    158       // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that
    159       // block.
    160       bb = context()->get_instr_block(bb_used_in);
    161     }
    162     continue;
    163   }
    164   return (bb != original_bb ? bb : nullptr);
    165 }
    166 
    167 bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) {
    168   if (!inst->IsLoad()) {
    169     return false;
    170   }
    171 
    172   Instruction* base_ptr = inst->GetBaseAddress();
    173   if (base_ptr->opcode() != SpvOpVariable) {
    174     return true;
    175   }
    176 
    177   if (base_ptr->IsReadOnlyVariable()) {
    178     return false;
    179   }
    180 
    181   if (HasUniformMemorySync()) {
    182     return true;
    183   }
    184 
    185   if (base_ptr->GetSingleWordInOperand(0) != SpvStorageClassUniform) {
    186     return true;
    187   }
    188 
    189   return HasPossibleStore(base_ptr);
    190 }
    191 
    192 bool CodeSinkingPass::HasUniformMemorySync() {
    193   if (checked_for_uniform_sync_) {
    194     return has_uniform_sync_;
    195   }
    196 
    197   bool has_sync = false;
    198   get_module()->ForEachInst([this, &has_sync](Instruction* inst) {
    199     switch (inst->opcode()) {
    200       case SpvOpMemoryBarrier: {
    201         uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1);
    202         if (IsSyncOnUniform(mem_semantics_id)) {
    203           has_sync = true;
    204         }
    205         break;
    206       }
    207       case SpvOpControlBarrier:
    208       case SpvOpAtomicLoad:
    209       case SpvOpAtomicStore:
    210       case SpvOpAtomicExchange:
    211       case SpvOpAtomicIIncrement:
    212       case SpvOpAtomicIDecrement:
    213       case SpvOpAtomicIAdd:
    214       case SpvOpAtomicISub:
    215       case SpvOpAtomicSMin:
    216       case SpvOpAtomicUMin:
    217       case SpvOpAtomicSMax:
    218       case SpvOpAtomicUMax:
    219       case SpvOpAtomicAnd:
    220       case SpvOpAtomicOr:
    221       case SpvOpAtomicXor:
    222       case SpvOpAtomicFlagTestAndSet:
    223       case SpvOpAtomicFlagClear: {
    224         uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2);
    225         if (IsSyncOnUniform(mem_semantics_id)) {
    226           has_sync = true;
    227         }
    228         break;
    229       }
    230       case SpvOpAtomicCompareExchange:
    231       case SpvOpAtomicCompareExchangeWeak:
    232         if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) ||
    233             IsSyncOnUniform(inst->GetSingleWordInOperand(3))) {
    234           has_sync = true;
    235         }
    236         break;
    237       default:
    238         break;
    239     }
    240   });
    241   has_uniform_sync_ = has_sync;
    242   return has_sync;
    243 }
    244 
    245 bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const {
    246   const analysis::Constant* mem_semantics_const =
    247       context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id);
    248   assert(mem_semantics_const != nullptr &&
    249          "Expecting memory semantics id to be a constant.");
    250   assert(mem_semantics_const->AsIntConstant() &&
    251          "Memory semantics should be an integer.");
    252   uint32_t mem_semantics_int = mem_semantics_const->GetU32();
    253 
    254   // If it does not affect uniform memory, then it is does not apply to uniform
    255   // memory.
    256   if ((mem_semantics_int & SpvMemorySemanticsUniformMemoryMask) == 0) {
    257     return false;
    258   }
    259 
    260   // Check if there is an acquire or release.  If so not, this it does not add
    261   // any memory constraints.
    262   return (mem_semantics_int & (SpvMemorySemanticsAcquireMask |
    263                                SpvMemorySemanticsAcquireReleaseMask |
    264                                SpvMemorySemanticsReleaseMask)) != 0;
    265 }
    266 
    267 bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) {
    268   assert(var_inst->opcode() == SpvOpVariable ||
    269          var_inst->opcode() == SpvOpAccessChain ||
    270          var_inst->opcode() == SpvOpPtrAccessChain);
    271 
    272   return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) {
    273     switch (use->opcode()) {
    274       case SpvOpStore:
    275         return true;
    276       case SpvOpAccessChain:
    277       case SpvOpPtrAccessChain:
    278         return HasPossibleStore(use);
    279       default:
    280         return false;
    281     }
    282   });
    283 }
    284 
    285 bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end,
    286                                      const std::unordered_set<uint32_t>& set) {
    287   std::vector<uint32_t> worklist;
    288   worklist.push_back(start);
    289   std::unordered_set<uint32_t> already_done;
    290   already_done.insert(start);
    291 
    292   while (!worklist.empty()) {
    293     BasicBlock* bb = context()->get_instr_block(worklist.back());
    294     worklist.pop_back();
    295 
    296     if (bb->id() == end) {
    297       continue;
    298     }
    299 
    300     if (set.count(bb->id())) {
    301       return true;
    302     }
    303 
    304     bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) {
    305       if (already_done.insert(*succ_bb_id).second) {
    306         worklist.push_back(*succ_bb_id);
    307       }
    308     });
    309   }
    310   return false;
    311 }
    312 
    313 // namespace opt
    314 
    315 }  // namespace opt
    316 }  // namespace spvtools
    317