Home | History | Annotate | Download | only in Utils
      1 //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This file defines a utility class to perform loop versioning.  The versioned
     11 // loop speculates that otherwise may-aliasing memory accesses don't overlap and
     12 // emits checks to prove this.
     13 //
     14 //===----------------------------------------------------------------------===//
     15 
     16 #include "llvm/Transforms/Utils/LoopVersioning.h"
     17 #include "llvm/Analysis/LoopAccessAnalysis.h"
     18 #include "llvm/Analysis/LoopInfo.h"
     19 #include "llvm/Analysis/ScalarEvolutionExpander.h"
     20 #include "llvm/IR/Dominators.h"
     21 #include "llvm/IR/MDBuilder.h"
     22 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     23 #include "llvm/Transforms/Utils/Cloning.h"
     24 
     25 using namespace llvm;
     26 
     27 static cl::opt<bool>
     28     AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
     29                     cl::Hidden,
     30                     cl::desc("Add no-alias annotation for instructions that "
     31                              "are disambiguated by memchecks"));
     32 
     33 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
     34                                DominatorTree *DT, ScalarEvolution *SE,
     35                                bool UseLAIChecks)
     36     : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT),
     37       SE(SE) {
     38   assert(L->getExitBlock() && "No single exit block");
     39   assert(L->getLoopPreheader() && "No preheader");
     40   if (UseLAIChecks) {
     41     setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
     42     setSCEVChecks(LAI.getPSE().getUnionPredicate());
     43   }
     44 }
     45 
     46 void LoopVersioning::setAliasChecks(
     47     SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) {
     48   AliasChecks = std::move(Checks);
     49 }
     50 
     51 void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) {
     52   Preds = std::move(Check);
     53 }
     54 
     55 void LoopVersioning::versionLoop(
     56     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
     57   Instruction *FirstCheckInst;
     58   Instruction *MemRuntimeCheck;
     59   Value *SCEVRuntimeCheck;
     60   Value *RuntimeCheck = nullptr;
     61 
     62   // Add the memcheck in the original preheader (this is empty initially).
     63   BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
     64   std::tie(FirstCheckInst, MemRuntimeCheck) =
     65       LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
     66 
     67   const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate();
     68   SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
     69                    "scev.check");
     70   SCEVRuntimeCheck =
     71       Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator());
     72   auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
     73 
     74   // Discard the SCEV runtime check if it is always true.
     75   if (CI && CI->isZero())
     76     SCEVRuntimeCheck = nullptr;
     77 
     78   if (MemRuntimeCheck && SCEVRuntimeCheck) {
     79     RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
     80                                           SCEVRuntimeCheck, "lver.safe");
     81     if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
     82       I->insertBefore(RuntimeCheckBB->getTerminator());
     83   } else
     84     RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
     85 
     86   assert(RuntimeCheck && "called even though we don't need "
     87                          "any runtime checks");
     88 
     89   // Rename the block to make the IR more readable.
     90   RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
     91                           ".lver.check");
     92 
     93   // Create empty preheader for the loop (and after cloning for the
     94   // non-versioned loop).
     95   BasicBlock *PH =
     96       SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI);
     97   PH->setName(VersionedLoop->getHeader()->getName() + ".ph");
     98 
     99   // Clone the loop including the preheader.
    100   //
    101   // FIXME: This does not currently preserve SimplifyLoop because the exit
    102   // block is a join between the two loops.
    103   SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
    104   NonVersionedLoop =
    105       cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
    106                              ".lver.orig", LI, DT, NonVersionedLoopBlocks);
    107   remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
    108 
    109   // Insert the conditional branch based on the result of the memchecks.
    110   Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
    111   BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
    112                      VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
    113   OrigTerm->eraseFromParent();
    114 
    115   // The loops merge in the original exit block.  This is now dominated by the
    116   // memchecking block.
    117   DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
    118 
    119   // Adds the necessary PHI nodes for the versioned loops based on the
    120   // loop-defined values used outside of the loop.
    121   addPHINodes(DefsUsedOutside);
    122 }
    123 
    124 void LoopVersioning::addPHINodes(
    125     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
    126   BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
    127   assert(PHIBlock && "No single successor to loop exit block");
    128   PHINode *PN;
    129 
    130   // First add a single-operand PHI for each DefsUsedOutside if one does not
    131   // exists yet.
    132   for (auto *Inst : DefsUsedOutside) {
    133     // See if we have a single-operand PHI with the value defined by the
    134     // original loop.
    135     for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
    136       if (PN->getIncomingValue(0) == Inst)
    137         break;
    138     }
    139     // If not create it.
    140     if (!PN) {
    141       PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
    142                            &PHIBlock->front());
    143       for (auto *User : Inst->users())
    144         if (!VersionedLoop->contains(cast<Instruction>(User)->getParent()))
    145           User->replaceUsesOfWith(Inst, PN);
    146       PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
    147     }
    148   }
    149 
    150   // Then for each PHI add the operand for the edge from the cloned loop.
    151   for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
    152     assert(PN->getNumOperands() == 1 &&
    153            "Exit block should only have on predecessor");
    154 
    155     // If the definition was cloned used that otherwise use the same value.
    156     Value *ClonedValue = PN->getIncomingValue(0);
    157     auto Mapped = VMap.find(ClonedValue);
    158     if (Mapped != VMap.end())
    159       ClonedValue = Mapped->second;
    160 
    161     PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
    162   }
    163 }
    164 
    165 void LoopVersioning::prepareNoAliasMetadata() {
    166   // We need to turn the no-alias relation between pointer checking groups into
    167   // no-aliasing annotations between instructions.
    168   //
    169   // We accomplish this by mapping each pointer checking group (a set of
    170   // pointers memchecked together) to an alias scope and then also mapping each
    171   // group to the list of scopes it can't alias.
    172 
    173   const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
    174   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
    175 
    176   // First allocate an aliasing scope for each pointer checking group.
    177   //
    178   // While traversing through the checking groups in the loop, also create a
    179   // reverse map from pointers to the pointer checking group they were assigned
    180   // to.
    181   MDBuilder MDB(Context);
    182   MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
    183 
    184   for (const auto &Group : RtPtrChecking->CheckingGroups) {
    185     GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
    186 
    187     for (unsigned PtrIdx : Group.Members)
    188       PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
    189   }
    190 
    191   // Go through the checks and for each pointer group, collect the scopes for
    192   // each non-aliasing pointer group.
    193   DenseMap<const RuntimePointerChecking::CheckingPtrGroup *,
    194            SmallVector<Metadata *, 4>>
    195       GroupToNonAliasingScopes;
    196 
    197   for (const auto &Check : AliasChecks)
    198     GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
    199 
    200   // Finally, transform the above to actually map to scope list which is what
    201   // the metadata uses.
    202 
    203   for (auto Pair : GroupToNonAliasingScopes)
    204     GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
    205 }
    206 
    207 void LoopVersioning::annotateLoopWithNoAlias() {
    208   if (!AnnotateNoAlias)
    209     return;
    210 
    211   // First prepare the maps.
    212   prepareNoAliasMetadata();
    213 
    214   // Add the scope and no-alias metadata to the instructions.
    215   for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
    216     annotateInstWithNoAlias(I);
    217   }
    218 }
    219 
    220 void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
    221                                              const Instruction *OrigInst) {
    222   if (!AnnotateNoAlias)
    223     return;
    224 
    225   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
    226   const Value *Ptr = isa<LoadInst>(OrigInst)
    227                          ? cast<LoadInst>(OrigInst)->getPointerOperand()
    228                          : cast<StoreInst>(OrigInst)->getPointerOperand();
    229 
    230   // Find the group for the pointer and then add the scope metadata.
    231   auto Group = PtrToGroup.find(Ptr);
    232   if (Group != PtrToGroup.end()) {
    233     VersionedInst->setMetadata(
    234         LLVMContext::MD_alias_scope,
    235         MDNode::concatenate(
    236             VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
    237             MDNode::get(Context, GroupToScope[Group->second])));
    238 
    239     // Add the no-alias metadata.
    240     auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
    241     if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
    242       VersionedInst->setMetadata(
    243           LLVMContext::MD_noalias,
    244           MDNode::concatenate(
    245               VersionedInst->getMetadata(LLVMContext::MD_noalias),
    246               NonAliasingScopeList->second));
    247   }
    248 }
    249 
    250 namespace {
    251 /// \brief Also expose this is a pass.  Currently this is only used for
    252 /// unit-testing.  It adds all memchecks necessary to remove all may-aliasing
    253 /// array accesses from the loop.
    254 class LoopVersioningPass : public FunctionPass {
    255 public:
    256   LoopVersioningPass() : FunctionPass(ID) {
    257     initializeLoopVersioningPassPass(*PassRegistry::getPassRegistry());
    258   }
    259 
    260   bool runOnFunction(Function &F) override {
    261     auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
    262     auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>();
    263     auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
    264     auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
    265 
    266     // Build up a worklist of inner-loops to version. This is necessary as the
    267     // act of versioning a loop creates new loops and can invalidate iterators
    268     // across the loops.
    269     SmallVector<Loop *, 8> Worklist;
    270 
    271     for (Loop *TopLevelLoop : *LI)
    272       for (Loop *L : depth_first(TopLevelLoop))
    273         // We only handle inner-most loops.
    274         if (L->empty())
    275           Worklist.push_back(L);
    276 
    277     // Now walk the identified inner loops.
    278     bool Changed = false;
    279     for (Loop *L : Worklist) {
    280       const LoopAccessInfo &LAI = LAA->getInfo(L);
    281       if (LAI.getNumRuntimePointerChecks() ||
    282           !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) {
    283         LoopVersioning LVer(LAI, L, LI, DT, SE);
    284         LVer.versionLoop();
    285         LVer.annotateLoopWithNoAlias();
    286         Changed = true;
    287       }
    288     }
    289 
    290     return Changed;
    291   }
    292 
    293   void getAnalysisUsage(AnalysisUsage &AU) const override {
    294     AU.addRequired<LoopInfoWrapperPass>();
    295     AU.addPreserved<LoopInfoWrapperPass>();
    296     AU.addRequired<LoopAccessLegacyAnalysis>();
    297     AU.addRequired<DominatorTreeWrapperPass>();
    298     AU.addPreserved<DominatorTreeWrapperPass>();
    299     AU.addRequired<ScalarEvolutionWrapperPass>();
    300   }
    301 
    302   static char ID;
    303 };
    304 }
    305 
    306 #define LVER_OPTION "loop-versioning"
    307 #define DEBUG_TYPE LVER_OPTION
    308 
    309 char LoopVersioningPass::ID;
    310 static const char LVer_name[] = "Loop Versioning";
    311 
    312 INITIALIZE_PASS_BEGIN(LoopVersioningPass, LVER_OPTION, LVer_name, false, false)
    313 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
    314 INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
    315 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
    316 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
    317 INITIALIZE_PASS_END(LoopVersioningPass, LVER_OPTION, LVer_name, false, false)
    318 
    319 namespace llvm {
    320 FunctionPass *createLoopVersioningPass() {
    321   return new LoopVersioningPass();
    322 }
    323 }
    324