Home | History | Annotate | Download | only in Scalar
      1 //===- LoopUnrollAndJam.cpp - Loop unroll and jam pass --------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This pass implements an unroll and jam pass. Most of the work is done by
     11 // Utils/UnrollLoopAndJam.cpp.
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h"
     15 #include "llvm/ADT/None.h"
     16 #include "llvm/ADT/STLExtras.h"
     17 #include "llvm/ADT/SmallPtrSet.h"
     18 #include "llvm/ADT/StringRef.h"
     19 #include "llvm/Analysis/AssumptionCache.h"
     20 #include "llvm/Analysis/CodeMetrics.h"
     21 #include "llvm/Analysis/DependenceAnalysis.h"
     22 #include "llvm/Analysis/LoopAnalysisManager.h"
     23 #include "llvm/Analysis/LoopInfo.h"
     24 #include "llvm/Analysis/LoopPass.h"
     25 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
     26 #include "llvm/Analysis/ScalarEvolution.h"
     27 #include "llvm/Analysis/TargetTransformInfo.h"
     28 #include "llvm/IR/BasicBlock.h"
     29 #include "llvm/IR/CFG.h"
     30 #include "llvm/IR/Constant.h"
     31 #include "llvm/IR/Constants.h"
     32 #include "llvm/IR/Dominators.h"
     33 #include "llvm/IR/Function.h"
     34 #include "llvm/IR/Instruction.h"
     35 #include "llvm/IR/Instructions.h"
     36 #include "llvm/IR/IntrinsicInst.h"
     37 #include "llvm/IR/Metadata.h"
     38 #include "llvm/IR/PassManager.h"
     39 #include "llvm/Pass.h"
     40 #include "llvm/Support/Casting.h"
     41 #include "llvm/Support/CommandLine.h"
     42 #include "llvm/Support/Debug.h"
     43 #include "llvm/Support/ErrorHandling.h"
     44 #include "llvm/Support/raw_ostream.h"
     45 #include "llvm/Transforms/Scalar.h"
     46 #include "llvm/Transforms/Scalar/LoopPassManager.h"
     47 #include "llvm/Transforms/Utils.h"
     48 #include "llvm/Transforms/Utils/LoopUtils.h"
     49 #include "llvm/Transforms/Utils/UnrollLoop.h"
     50 #include <algorithm>
     51 #include <cassert>
     52 #include <cstdint>
     53 #include <string>
     54 
     55 using namespace llvm;
     56 
     57 #define DEBUG_TYPE "loop-unroll-and-jam"
     58 
     59 static cl::opt<bool>
     60     AllowUnrollAndJam("allow-unroll-and-jam", cl::Hidden,
     61                       cl::desc("Allows loops to be unroll-and-jammed."));
     62 
     63 static cl::opt<unsigned> UnrollAndJamCount(
     64     "unroll-and-jam-count", cl::Hidden,
     65     cl::desc("Use this unroll count for all loops including those with "
     66              "unroll_and_jam_count pragma values, for testing purposes"));
     67 
     68 static cl::opt<unsigned> UnrollAndJamThreshold(
     69     "unroll-and-jam-threshold", cl::init(60), cl::Hidden,
     70     cl::desc("Threshold to use for inner loop when doing unroll and jam."));
     71 
     72 static cl::opt<unsigned> PragmaUnrollAndJamThreshold(
     73     "pragma-unroll-and-jam-threshold", cl::init(1024), cl::Hidden,
     74     cl::desc("Unrolled size limit for loops with an unroll_and_jam(full) or "
     75              "unroll_count pragma."));
     76 
     77 // Returns the loop hint metadata node with the given name (for example,
     78 // "llvm.loop.unroll.count").  If no such metadata node exists, then nullptr is
     79 // returned.
     80 static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) {
     81   if (MDNode *LoopID = L->getLoopID())
     82     return GetUnrollMetadata(LoopID, Name);
     83   return nullptr;
     84 }
     85 
     86 // Returns true if the loop has any metadata starting with Prefix. For example a
     87 // Prefix of "llvm.loop.unroll." returns true if we have any unroll metadata.
     88 static bool HasAnyUnrollPragma(const Loop *L, StringRef Prefix) {
     89   if (MDNode *LoopID = L->getLoopID()) {
     90     // First operand should refer to the loop id itself.
     91     assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
     92     assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
     93 
     94     for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) {
     95       MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i));
     96       if (!MD)
     97         continue;
     98 
     99       MDString *S = dyn_cast<MDString>(MD->getOperand(0));
    100       if (!S)
    101         continue;
    102 
    103       if (S->getString().startswith(Prefix))
    104         return true;
    105     }
    106   }
    107   return false;
    108 }
    109 
    110 // Returns true if the loop has an unroll_and_jam(enable) pragma.
    111 static bool HasUnrollAndJamEnablePragma(const Loop *L) {
    112   return GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable");
    113 }
    114 
    115 // Returns true if the loop has an unroll_and_jam(disable) pragma.
    116 static bool HasUnrollAndJamDisablePragma(const Loop *L) {
    117   return GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.disable");
    118 }
    119 
    120 // If loop has an unroll_and_jam_count pragma return the (necessarily
    121 // positive) value from the pragma.  Otherwise return 0.
    122 static unsigned UnrollAndJamCountPragmaValue(const Loop *L) {
    123   MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.count");
    124   if (MD) {
    125     assert(MD->getNumOperands() == 2 &&
    126            "Unroll count hint metadata should have two operands.");
    127     unsigned Count =
    128         mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue();
    129     assert(Count >= 1 && "Unroll count must be positive.");
    130     return Count;
    131   }
    132   return 0;
    133 }
    134 
    135 // Returns loop size estimation for unrolled loop.
    136 static uint64_t
    137 getUnrollAndJammedLoopSize(unsigned LoopSize,
    138                            TargetTransformInfo::UnrollingPreferences &UP) {
    139   assert(LoopSize >= UP.BEInsns && "LoopSize should not be less than BEInsns!");
    140   return static_cast<uint64_t>(LoopSize - UP.BEInsns) * UP.Count + UP.BEInsns;
    141 }
    142 
    143 // Calculates unroll and jam count and writes it to UP.Count. Returns true if
    144 // unroll count was set explicitly.
    145 static bool computeUnrollAndJamCount(
    146     Loop *L, Loop *SubLoop, const TargetTransformInfo &TTI, DominatorTree &DT,
    147     LoopInfo *LI, ScalarEvolution &SE,
    148     const SmallPtrSetImpl<const Value *> &EphValues,
    149     OptimizationRemarkEmitter *ORE, unsigned OuterTripCount,
    150     unsigned OuterTripMultiple, unsigned OuterLoopSize, unsigned InnerTripCount,
    151     unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP) {
    152   // Check for explicit Count from the "unroll-and-jam-count" option.
    153   bool UserUnrollCount = UnrollAndJamCount.getNumOccurrences() > 0;
    154   if (UserUnrollCount) {
    155     UP.Count = UnrollAndJamCount;
    156     UP.Force = true;
    157     if (UP.AllowRemainder &&
    158         getUnrollAndJammedLoopSize(OuterLoopSize, UP) < UP.Threshold &&
    159         getUnrollAndJammedLoopSize(InnerLoopSize, UP) <
    160             UP.UnrollAndJamInnerLoopThreshold)
    161       return true;
    162   }
    163 
    164   // Check for unroll_and_jam pragmas
    165   unsigned PragmaCount = UnrollAndJamCountPragmaValue(L);
    166   if (PragmaCount > 0) {
    167     UP.Count = PragmaCount;
    168     UP.Runtime = true;
    169     UP.Force = true;
    170     if ((UP.AllowRemainder || (OuterTripMultiple % PragmaCount == 0)) &&
    171         getUnrollAndJammedLoopSize(OuterLoopSize, UP) < UP.Threshold &&
    172         getUnrollAndJammedLoopSize(InnerLoopSize, UP) <
    173             UP.UnrollAndJamInnerLoopThreshold)
    174       return true;
    175   }
    176 
    177   // Use computeUnrollCount from the loop unroller to get a sensible count
    178   // for the unrolling the outer loop. This uses UP.Threshold /
    179   // UP.PartialThreshold / UP.MaxCount to come up with sensible loop values.
    180   // We have already checked that the loop has no unroll.* pragmas.
    181   unsigned MaxTripCount = 0;
    182   bool UseUpperBound = false;
    183   bool ExplicitUnroll = computeUnrollCount(
    184       L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount,
    185       OuterTripMultiple, OuterLoopSize, UP, UseUpperBound);
    186   if (ExplicitUnroll || UseUpperBound) {
    187     // If the user explicitly set the loop as unrolled, dont UnJ it. Leave it
    188     // for the unroller instead.
    189     UP.Count = 0;
    190     return false;
    191   }
    192 
    193   bool PragmaEnableUnroll = HasUnrollAndJamEnablePragma(L);
    194   ExplicitUnroll = PragmaCount > 0 || PragmaEnableUnroll || UserUnrollCount;
    195 
    196   // If the loop has an unrolling pragma, we want to be more aggressive with
    197   // unrolling limits.
    198   if (ExplicitUnroll && OuterTripCount != 0)
    199     UP.UnrollAndJamInnerLoopThreshold = PragmaUnrollAndJamThreshold;
    200 
    201   if (!UP.AllowRemainder && getUnrollAndJammedLoopSize(InnerLoopSize, UP) >=
    202                                 UP.UnrollAndJamInnerLoopThreshold) {
    203     UP.Count = 0;
    204     return false;
    205   }
    206 
    207   // If the inner loop count is known and small, leave the entire loop nest to
    208   // be the unroller
    209   if (!ExplicitUnroll && InnerTripCount &&
    210       InnerLoopSize * InnerTripCount < UP.Threshold) {
    211     UP.Count = 0;
    212     return false;
    213   }
    214 
    215   // We have a sensible limit for the outer loop, now adjust it for the inner
    216   // loop and UP.UnrollAndJamInnerLoopThreshold.
    217   while (UP.Count != 0 && UP.AllowRemainder &&
    218          getUnrollAndJammedLoopSize(InnerLoopSize, UP) >=
    219              UP.UnrollAndJamInnerLoopThreshold)
    220     UP.Count--;
    221 
    222   if (!ExplicitUnroll) {
    223     // Check for situations where UnJ is likely to be unprofitable. Including
    224     // subloops with more than 1 block.
    225     if (SubLoop->getBlocks().size() != 1) {
    226       UP.Count = 0;
    227       return false;
    228     }
    229 
    230     // Limit to loops where there is something to gain from unrolling and
    231     // jamming the loop. In this case, look for loads that are invariant in the
    232     // outer loop and can become shared.
    233     unsigned NumInvariant = 0;
    234     for (BasicBlock *BB : SubLoop->getBlocks()) {
    235       for (Instruction &I : *BB) {
    236         if (auto *Ld = dyn_cast<LoadInst>(&I)) {
    237           Value *V = Ld->getPointerOperand();
    238           const SCEV *LSCEV = SE.getSCEVAtScope(V, L);
    239           if (SE.isLoopInvariant(LSCEV, L))
    240             NumInvariant++;
    241         }
    242       }
    243     }
    244     if (NumInvariant == 0) {
    245       UP.Count = 0;
    246       return false;
    247     }
    248   }
    249 
    250   return ExplicitUnroll;
    251 }
    252 
    253 static LoopUnrollResult
    254 tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
    255                       ScalarEvolution &SE, const TargetTransformInfo &TTI,
    256                       AssumptionCache &AC, DependenceInfo &DI,
    257                       OptimizationRemarkEmitter &ORE, int OptLevel) {
    258   // Quick checks of the correct loop form
    259   if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1)
    260     return LoopUnrollResult::Unmodified;
    261   Loop *SubLoop = L->getSubLoops()[0];
    262   if (!SubLoop->isLoopSimplifyForm())
    263     return LoopUnrollResult::Unmodified;
    264 
    265   BasicBlock *Latch = L->getLoopLatch();
    266   BasicBlock *Exit = L->getExitingBlock();
    267   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
    268   BasicBlock *SubLoopExit = SubLoop->getExitingBlock();
    269 
    270   if (Latch != Exit || SubLoopLatch != SubLoopExit)
    271     return LoopUnrollResult::Unmodified;
    272 
    273   TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
    274       L, SE, TTI, OptLevel, None, None, None, None, None, None);
    275   if (AllowUnrollAndJam.getNumOccurrences() > 0)
    276     UP.UnrollAndJam = AllowUnrollAndJam;
    277   if (UnrollAndJamThreshold.getNumOccurrences() > 0)
    278     UP.UnrollAndJamInnerLoopThreshold = UnrollAndJamThreshold;
    279   // Exit early if unrolling is disabled.
    280   if (!UP.UnrollAndJam || UP.UnrollAndJamInnerLoopThreshold == 0)
    281     return LoopUnrollResult::Unmodified;
    282 
    283   LLVM_DEBUG(dbgs() << "Loop Unroll and Jam: F["
    284                     << L->getHeader()->getParent()->getName() << "] Loop %"
    285                     << L->getHeader()->getName() << "\n");
    286 
    287   // A loop with any unroll pragma (enabling/disabling/count/etc) is left for
    288   // the unroller, so long as it does not explicitly have unroll_and_jam
    289   // metadata. This means #pragma nounroll will disable unroll and jam as well
    290   // as unrolling
    291   if (HasUnrollAndJamDisablePragma(L) ||
    292       (HasAnyUnrollPragma(L, "llvm.loop.unroll.") &&
    293        !HasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam."))) {
    294     LLVM_DEBUG(dbgs() << "  Disabled due to pragma.\n");
    295     return LoopUnrollResult::Unmodified;
    296   }
    297 
    298   if (!isSafeToUnrollAndJam(L, SE, DT, DI)) {
    299     LLVM_DEBUG(dbgs() << "  Disabled due to not being safe.\n");
    300     return LoopUnrollResult::Unmodified;
    301   }
    302 
    303   // Approximate the loop size and collect useful info
    304   unsigned NumInlineCandidates;
    305   bool NotDuplicatable;
    306   bool Convergent;
    307   SmallPtrSet<const Value *, 32> EphValues;
    308   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
    309   unsigned InnerLoopSize =
    310       ApproximateLoopSize(SubLoop, NumInlineCandidates, NotDuplicatable,
    311                           Convergent, TTI, EphValues, UP.BEInsns);
    312   unsigned OuterLoopSize =
    313       ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent,
    314                           TTI, EphValues, UP.BEInsns);
    315   LLVM_DEBUG(dbgs() << "  Outer Loop Size: " << OuterLoopSize << "\n");
    316   LLVM_DEBUG(dbgs() << "  Inner Loop Size: " << InnerLoopSize << "\n");
    317   if (NotDuplicatable) {
    318     LLVM_DEBUG(dbgs() << "  Not unrolling loop which contains non-duplicatable "
    319                          "instructions.\n");
    320     return LoopUnrollResult::Unmodified;
    321   }
    322   if (NumInlineCandidates != 0) {
    323     LLVM_DEBUG(dbgs() << "  Not unrolling loop with inlinable calls.\n");
    324     return LoopUnrollResult::Unmodified;
    325   }
    326   if (Convergent) {
    327     LLVM_DEBUG(
    328         dbgs() << "  Not unrolling loop with convergent instructions.\n");
    329     return LoopUnrollResult::Unmodified;
    330   }
    331 
    332   // Find trip count and trip multiple
    333   unsigned OuterTripCount = SE.getSmallConstantTripCount(L, Latch);
    334   unsigned OuterTripMultiple = SE.getSmallConstantTripMultiple(L, Latch);
    335   unsigned InnerTripCount = SE.getSmallConstantTripCount(SubLoop, SubLoopLatch);
    336 
    337   // Decide if, and by how much, to unroll
    338   bool IsCountSetExplicitly = computeUnrollAndJamCount(
    339       L, SubLoop, TTI, DT, LI, SE, EphValues, &ORE, OuterTripCount,
    340       OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP);
    341   if (UP.Count <= 1)
    342     return LoopUnrollResult::Unmodified;
    343   // Unroll factor (Count) must be less or equal to TripCount.
    344   if (OuterTripCount && UP.Count > OuterTripCount)
    345     UP.Count = OuterTripCount;
    346 
    347   LoopUnrollResult UnrollResult =
    348       UnrollAndJamLoop(L, UP.Count, OuterTripCount, OuterTripMultiple,
    349                        UP.UnrollRemainder, LI, &SE, &DT, &AC, &ORE);
    350 
    351   // If loop has an unroll count pragma or unrolled by explicitly set count
    352   // mark loop as unrolled to prevent unrolling beyond that requested.
    353   if (UnrollResult != LoopUnrollResult::FullyUnrolled && IsCountSetExplicitly)
    354     L->setLoopAlreadyUnrolled();
    355 
    356   return UnrollResult;
    357 }
    358 
    359 namespace {
    360 
    361 class LoopUnrollAndJam : public LoopPass {
    362 public:
    363   static char ID; // Pass ID, replacement for typeid
    364   unsigned OptLevel;
    365 
    366   LoopUnrollAndJam(int OptLevel = 2) : LoopPass(ID), OptLevel(OptLevel) {
    367     initializeLoopUnrollAndJamPass(*PassRegistry::getPassRegistry());
    368   }
    369 
    370   bool runOnLoop(Loop *L, LPPassManager &LPM) override {
    371     if (skipLoop(L))
    372       return false;
    373 
    374     Function &F = *L->getHeader()->getParent();
    375 
    376     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
    377     LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
    378     ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
    379     const TargetTransformInfo &TTI =
    380         getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
    381     auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
    382     auto &DI = getAnalysis<DependenceAnalysisWrapperPass>().getDI();
    383     // For the old PM, we can't use OptimizationRemarkEmitter as an analysis
    384     // pass.  Function analyses need to be preserved across loop transformations
    385     // but ORE cannot be preserved (see comment before the pass definition).
    386     OptimizationRemarkEmitter ORE(&F);
    387 
    388     LoopUnrollResult Result =
    389         tryToUnrollAndJamLoop(L, DT, LI, SE, TTI, AC, DI, ORE, OptLevel);
    390 
    391     if (Result == LoopUnrollResult::FullyUnrolled)
    392       LPM.markLoopAsDeleted(*L);
    393 
    394     return Result != LoopUnrollResult::Unmodified;
    395   }
    396 
    397   /// This transformation requires natural loop information & requires that
    398   /// loop preheaders be inserted into the CFG...
    399   void getAnalysisUsage(AnalysisUsage &AU) const override {
    400     AU.addRequired<AssumptionCacheTracker>();
    401     AU.addRequired<TargetTransformInfoWrapperPass>();
    402     AU.addRequired<DependenceAnalysisWrapperPass>();
    403     getLoopAnalysisUsage(AU);
    404   }
    405 };
    406 
    407 } // end anonymous namespace
    408 
    409 char LoopUnrollAndJam::ID = 0;
    410 
    411 INITIALIZE_PASS_BEGIN(LoopUnrollAndJam, "loop-unroll-and-jam",
    412                       "Unroll and Jam loops", false, false)
    413 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
    414 INITIALIZE_PASS_DEPENDENCY(LoopPass)
    415 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
    416 INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass)
    417 INITIALIZE_PASS_END(LoopUnrollAndJam, "loop-unroll-and-jam",
    418                     "Unroll and Jam loops", false, false)
    419 
    420 Pass *llvm::createLoopUnrollAndJamPass(int OptLevel) {
    421   return new LoopUnrollAndJam(OptLevel);
    422 }
    423 
    424 PreservedAnalyses LoopUnrollAndJamPass::run(Loop &L, LoopAnalysisManager &AM,
    425                                             LoopStandardAnalysisResults &AR,
    426                                             LPMUpdater &) {
    427   const auto &FAM =
    428       AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
    429   Function *F = L.getHeader()->getParent();
    430 
    431   auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F);
    432   // FIXME: This should probably be optional rather than required.
    433   if (!ORE)
    434     report_fatal_error(
    435         "LoopUnrollAndJamPass: OptimizationRemarkEmitterAnalysis not cached at "
    436         "a higher level");
    437 
    438   DependenceInfo DI(F, &AR.AA, &AR.SE, &AR.LI);
    439 
    440   LoopUnrollResult Result = tryToUnrollAndJamLoop(
    441       &L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, DI, *ORE, OptLevel);
    442 
    443   if (Result == LoopUnrollResult::Unmodified)
    444     return PreservedAnalyses::all();
    445 
    446   return getLoopPassPreservedAnalyses();
    447 }
    448