Home | History | Annotate | Download | only in Scalar
      1 //===-- InductiveRangeCheckElimination.cpp - ------------------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 // The InductiveRangeCheckElimination pass splits a loop's iteration space into
     10 // three disjoint ranges.  It does that in a way such that the loop running in
     11 // the middle loop provably does not need range checks. As an example, it will
     12 // convert
     13 //
     14 //   len = < known positive >
     15 //   for (i = 0; i < n; i++) {
     16 //     if (0 <= i && i < len) {
     17 //       do_something();
     18 //     } else {
     19 //       throw_out_of_bounds();
     20 //     }
     21 //   }
     22 //
     23 // to
     24 //
     25 //   len = < known positive >
     26 //   limit = smin(n, len)
     27 //   // no first segment
     28 //   for (i = 0; i < limit; i++) {
     29 //     if (0 <= i && i < len) { // this check is fully redundant
     30 //       do_something();
     31 //     } else {
     32 //       throw_out_of_bounds();
     33 //     }
     34 //   }
     35 //   for (i = limit; i < n; i++) {
     36 //     if (0 <= i && i < len) {
     37 //       do_something();
     38 //     } else {
     39 //       throw_out_of_bounds();
     40 //     }
     41 //   }
     42 //===----------------------------------------------------------------------===//
     43 
     44 #include "llvm/ADT/Optional.h"
     45 #include "llvm/Analysis/BranchProbabilityInfo.h"
     46 #include "llvm/Analysis/InstructionSimplify.h"
     47 #include "llvm/Analysis/LoopInfo.h"
     48 #include "llvm/Analysis/LoopPass.h"
     49 #include "llvm/Analysis/ScalarEvolution.h"
     50 #include "llvm/Analysis/ScalarEvolutionExpander.h"
     51 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
     52 #include "llvm/Analysis/ValueTracking.h"
     53 #include "llvm/IR/Dominators.h"
     54 #include "llvm/IR/Function.h"
     55 #include "llvm/IR/IRBuilder.h"
     56 #include "llvm/IR/Instructions.h"
     57 #include "llvm/IR/Module.h"
     58 #include "llvm/IR/PatternMatch.h"
     59 #include "llvm/IR/ValueHandle.h"
     60 #include "llvm/IR/Verifier.h"
     61 #include "llvm/Pass.h"
     62 #include "llvm/Support/Debug.h"
     63 #include "llvm/Support/raw_ostream.h"
     64 #include "llvm/Transforms/Scalar.h"
     65 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     66 #include "llvm/Transforms/Utils/Cloning.h"
     67 #include "llvm/Transforms/Utils/LoopUtils.h"
     68 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
     69 #include "llvm/Transforms/Utils/UnrollLoop.h"
     70 #include <array>
     71 
     72 using namespace llvm;
     73 
     74 static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden,
     75                                         cl::init(64));
     76 
     77 static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden,
     78                                        cl::init(false));
     79 
     80 static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden,
     81                                       cl::init(false));
     82 
     83 static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal",
     84                                           cl::Hidden, cl::init(10));
     85 
     86 #define DEBUG_TYPE "irce"
     87 
     88 namespace {
     89 
     90 /// An inductive range check is conditional branch in a loop with
     91 ///
     92 ///  1. a very cold successor (i.e. the branch jumps to that successor very
     93 ///     rarely)
     94 ///
     95 ///  and
     96 ///
     97 ///  2. a condition that is provably true for some contiguous range of values
     98 ///     taken by the containing loop's induction variable.
     99 ///
    100 class InductiveRangeCheck {
    101   // Classifies a range check
    102   enum RangeCheckKind : unsigned {
    103     // Range check of the form "0 <= I".
    104     RANGE_CHECK_LOWER = 1,
    105 
    106     // Range check of the form "I < L" where L is known positive.
    107     RANGE_CHECK_UPPER = 2,
    108 
    109     // The logical and of the RANGE_CHECK_LOWER and RANGE_CHECK_UPPER
    110     // conditions.
    111     RANGE_CHECK_BOTH = RANGE_CHECK_LOWER | RANGE_CHECK_UPPER,
    112 
    113     // Unrecognized range check condition.
    114     RANGE_CHECK_UNKNOWN = (unsigned)-1
    115   };
    116 
    117   static const char *rangeCheckKindToStr(RangeCheckKind);
    118 
    119   const SCEV *Offset;
    120   const SCEV *Scale;
    121   Value *Length;
    122   BranchInst *Branch;
    123   RangeCheckKind Kind;
    124 
    125   static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
    126                                             ScalarEvolution &SE, Value *&Index,
    127                                             Value *&Length);
    128 
    129   static InductiveRangeCheck::RangeCheckKind
    130   parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition,
    131                   const SCEV *&Index, Value *&UpperLimit);
    132 
    133   InductiveRangeCheck() :
    134     Offset(nullptr), Scale(nullptr), Length(nullptr), Branch(nullptr) { }
    135 
    136 public:
    137   const SCEV *getOffset() const { return Offset; }
    138   const SCEV *getScale() const { return Scale; }
    139   Value *getLength() const { return Length; }
    140 
    141   void print(raw_ostream &OS) const {
    142     OS << "InductiveRangeCheck:\n";
    143     OS << "  Kind: " << rangeCheckKindToStr(Kind) << "\n";
    144     OS << "  Offset: ";
    145     Offset->print(OS);
    146     OS << "  Scale: ";
    147     Scale->print(OS);
    148     OS << "  Length: ";
    149     if (Length)
    150       Length->print(OS);
    151     else
    152       OS << "(null)";
    153     OS << "\n  Branch: ";
    154     getBranch()->print(OS);
    155     OS << "\n";
    156   }
    157 
    158 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
    159   void dump() {
    160     print(dbgs());
    161   }
    162 #endif
    163 
    164   BranchInst *getBranch() const { return Branch; }
    165 
    166   /// Represents an signed integer range [Range.getBegin(), Range.getEnd()).  If
    167   /// R.getEnd() sle R.getBegin(), then R denotes the empty range.
    168 
    169   class Range {
    170     const SCEV *Begin;
    171     const SCEV *End;
    172 
    173   public:
    174     Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) {
    175       assert(Begin->getType() == End->getType() && "ill-typed range!");
    176     }
    177 
    178     Type *getType() const { return Begin->getType(); }
    179     const SCEV *getBegin() const { return Begin; }
    180     const SCEV *getEnd() const { return End; }
    181   };
    182 
    183   typedef SpecificBumpPtrAllocator<InductiveRangeCheck> AllocatorTy;
    184 
    185   /// This is the value the condition of the branch needs to evaluate to for the
    186   /// branch to take the hot successor (see (1) above).
    187   bool getPassingDirection() { return true; }
    188 
    189   /// Computes a range for the induction variable (IndVar) in which the range
    190   /// check is redundant and can be constant-folded away.  The induction
    191   /// variable is not required to be the canonical {0,+,1} induction variable.
    192   Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE,
    193                                             const SCEVAddRecExpr *IndVar,
    194                                             IRBuilder<> &B) const;
    195 
    196   /// Create an inductive range check out of BI if possible, else return
    197   /// nullptr.
    198   static InductiveRangeCheck *create(AllocatorTy &Alloc, BranchInst *BI,
    199                                      Loop *L, ScalarEvolution &SE,
    200                                      BranchProbabilityInfo &BPI);
    201 };
    202 
    203 class InductiveRangeCheckElimination : public LoopPass {
    204   InductiveRangeCheck::AllocatorTy Allocator;
    205 
    206 public:
    207   static char ID;
    208   InductiveRangeCheckElimination() : LoopPass(ID) {
    209     initializeInductiveRangeCheckEliminationPass(
    210         *PassRegistry::getPassRegistry());
    211   }
    212 
    213   void getAnalysisUsage(AnalysisUsage &AU) const override {
    214     AU.addRequired<LoopInfoWrapperPass>();
    215     AU.addRequiredID(LoopSimplifyID);
    216     AU.addRequiredID(LCSSAID);
    217     AU.addRequired<ScalarEvolutionWrapperPass>();
    218     AU.addRequired<BranchProbabilityInfoWrapperPass>();
    219   }
    220 
    221   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
    222 };
    223 
    224 char InductiveRangeCheckElimination::ID = 0;
    225 }
    226 
    227 INITIALIZE_PASS_BEGIN(InductiveRangeCheckElimination, "irce",
    228                       "Inductive range check elimination", false, false)
    229 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
    230 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
    231 INITIALIZE_PASS_DEPENDENCY(LCSSA)
    232 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
    233 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
    234 INITIALIZE_PASS_END(InductiveRangeCheckElimination, "irce",
    235                     "Inductive range check elimination", false, false)
    236 
    237 const char *InductiveRangeCheck::rangeCheckKindToStr(
    238     InductiveRangeCheck::RangeCheckKind RCK) {
    239   switch (RCK) {
    240   case InductiveRangeCheck::RANGE_CHECK_UNKNOWN:
    241     return "RANGE_CHECK_UNKNOWN";
    242 
    243   case InductiveRangeCheck::RANGE_CHECK_UPPER:
    244     return "RANGE_CHECK_UPPER";
    245 
    246   case InductiveRangeCheck::RANGE_CHECK_LOWER:
    247     return "RANGE_CHECK_LOWER";
    248 
    249   case InductiveRangeCheck::RANGE_CHECK_BOTH:
    250     return "RANGE_CHECK_BOTH";
    251   }
    252 
    253   llvm_unreachable("unknown range check type!");
    254 }
    255 
    256 /// Parse a single ICmp instruction, `ICI`, into a range check.  If `ICI`
    257 /// cannot
    258 /// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set
    259 /// `Index` and `Length` to `nullptr`.  Otherwise set `Index` to the value
    260 /// being
    261 /// range checked, and set `Length` to the upper limit `Index` is being range
    262 /// checked with if (and only if) the range check type is stronger or equal to
    263 /// RANGE_CHECK_UPPER.
    264 ///
    265 InductiveRangeCheck::RangeCheckKind
    266 InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
    267                                          ScalarEvolution &SE, Value *&Index,
    268                                          Value *&Length) {
    269 
    270   auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) {
    271     const SCEV *S = SE.getSCEV(V);
    272     if (isa<SCEVCouldNotCompute>(S))
    273       return false;
    274 
    275     return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant &&
    276            SE.isKnownNonNegative(S);
    277   };
    278 
    279   using namespace llvm::PatternMatch;
    280 
    281   ICmpInst::Predicate Pred = ICI->getPredicate();
    282   Value *LHS = ICI->getOperand(0);
    283   Value *RHS = ICI->getOperand(1);
    284 
    285   switch (Pred) {
    286   default:
    287     return RANGE_CHECK_UNKNOWN;
    288 
    289   case ICmpInst::ICMP_SLE:
    290     std::swap(LHS, RHS);
    291   // fallthrough
    292   case ICmpInst::ICMP_SGE:
    293     if (match(RHS, m_ConstantInt<0>())) {
    294       Index = LHS;
    295       return RANGE_CHECK_LOWER;
    296     }
    297     return RANGE_CHECK_UNKNOWN;
    298 
    299   case ICmpInst::ICMP_SLT:
    300     std::swap(LHS, RHS);
    301   // fallthrough
    302   case ICmpInst::ICMP_SGT:
    303     if (match(RHS, m_ConstantInt<-1>())) {
    304       Index = LHS;
    305       return RANGE_CHECK_LOWER;
    306     }
    307 
    308     if (IsNonNegativeAndNotLoopVarying(LHS)) {
    309       Index = RHS;
    310       Length = LHS;
    311       return RANGE_CHECK_UPPER;
    312     }
    313     return RANGE_CHECK_UNKNOWN;
    314 
    315   case ICmpInst::ICMP_ULT:
    316     std::swap(LHS, RHS);
    317   // fallthrough
    318   case ICmpInst::ICMP_UGT:
    319     if (IsNonNegativeAndNotLoopVarying(LHS)) {
    320       Index = RHS;
    321       Length = LHS;
    322       return RANGE_CHECK_BOTH;
    323     }
    324     return RANGE_CHECK_UNKNOWN;
    325   }
    326 
    327   llvm_unreachable("default clause returns!");
    328 }
    329 
    330 /// Parses an arbitrary condition into a range check.  `Length` is set only if
    331 /// the range check is recognized to be `RANGE_CHECK_UPPER` or stronger.
    332 InductiveRangeCheck::RangeCheckKind
    333 InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE,
    334                                      Value *Condition, const SCEV *&Index,
    335                                      Value *&Length) {
    336   using namespace llvm::PatternMatch;
    337 
    338   Value *A = nullptr;
    339   Value *B = nullptr;
    340 
    341   if (match(Condition, m_And(m_Value(A), m_Value(B)))) {
    342     Value *IndexA = nullptr, *IndexB = nullptr;
    343     Value *LengthA = nullptr, *LengthB = nullptr;
    344     ICmpInst *ICmpA = dyn_cast<ICmpInst>(A), *ICmpB = dyn_cast<ICmpInst>(B);
    345 
    346     if (!ICmpA || !ICmpB)
    347       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
    348 
    349     auto RCKindA = parseRangeCheckICmp(L, ICmpA, SE, IndexA, LengthA);
    350     auto RCKindB = parseRangeCheckICmp(L, ICmpB, SE, IndexB, LengthB);
    351 
    352     if (RCKindA == InductiveRangeCheck::RANGE_CHECK_UNKNOWN ||
    353         RCKindB == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
    354       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
    355 
    356     if (IndexA != IndexB)
    357       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
    358 
    359     if (LengthA != nullptr && LengthB != nullptr && LengthA != LengthB)
    360       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
    361 
    362     Index = SE.getSCEV(IndexA);
    363     if (isa<SCEVCouldNotCompute>(Index))
    364       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
    365 
    366     Length = LengthA == nullptr ? LengthB : LengthA;
    367 
    368     return (InductiveRangeCheck::RangeCheckKind)(RCKindA | RCKindB);
    369   }
    370 
    371   if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
    372     Value *IndexVal = nullptr;
    373 
    374     auto RCKind = parseRangeCheckICmp(L, ICI, SE, IndexVal, Length);
    375 
    376     if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
    377       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
    378 
    379     Index = SE.getSCEV(IndexVal);
    380     if (isa<SCEVCouldNotCompute>(Index))
    381       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
    382 
    383     return RCKind;
    384   }
    385 
    386   return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
    387 }
    388 
    389 
    390 InductiveRangeCheck *
    391 InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI,
    392                             Loop *L, ScalarEvolution &SE,
    393                             BranchProbabilityInfo &BPI) {
    394 
    395   if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch())
    396     return nullptr;
    397 
    398   BranchProbability LikelyTaken(15, 16);
    399 
    400   if (BPI.getEdgeProbability(BI->getParent(), (unsigned) 0) < LikelyTaken)
    401     return nullptr;
    402 
    403   Value *Length = nullptr;
    404   const SCEV *IndexSCEV = nullptr;
    405 
    406   auto RCKind = InductiveRangeCheck::parseRangeCheck(L, SE, BI->getCondition(),
    407                                                      IndexSCEV, Length);
    408 
    409   if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
    410     return nullptr;
    411 
    412   assert(IndexSCEV && "contract with SplitRangeCheckCondition!");
    413   assert((!(RCKind & InductiveRangeCheck::RANGE_CHECK_UPPER) || Length) &&
    414          "contract with SplitRangeCheckCondition!");
    415 
    416   const SCEVAddRecExpr *IndexAddRec = dyn_cast<SCEVAddRecExpr>(IndexSCEV);
    417   bool IsAffineIndex =
    418       IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine();
    419 
    420   if (!IsAffineIndex)
    421     return nullptr;
    422 
    423   InductiveRangeCheck *IRC = new (A.Allocate()) InductiveRangeCheck;
    424   IRC->Length = Length;
    425   IRC->Offset = IndexAddRec->getStart();
    426   IRC->Scale = IndexAddRec->getStepRecurrence(SE);
    427   IRC->Branch = BI;
    428   IRC->Kind = RCKind;
    429   return IRC;
    430 }
    431 
    432 namespace {
    433 
    434 // Keeps track of the structure of a loop.  This is similar to llvm::Loop,
    435 // except that it is more lightweight and can track the state of a loop through
    436 // changing and potentially invalid IR.  This structure also formalizes the
    437 // kinds of loops we can deal with -- ones that have a single latch that is also
    438 // an exiting block *and* have a canonical induction variable.
    439 struct LoopStructure {
    440   const char *Tag;
    441 
    442   BasicBlock *Header;
    443   BasicBlock *Latch;
    444 
    445   // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th
    446   // successor is `LatchExit', the exit block of the loop.
    447   BranchInst *LatchBr;
    448   BasicBlock *LatchExit;
    449   unsigned LatchBrExitIdx;
    450 
    451   Value *IndVarNext;
    452   Value *IndVarStart;
    453   Value *LoopExitAt;
    454   bool IndVarIncreasing;
    455 
    456   LoopStructure()
    457       : Tag(""), Header(nullptr), Latch(nullptr), LatchBr(nullptr),
    458         LatchExit(nullptr), LatchBrExitIdx(-1), IndVarNext(nullptr),
    459         IndVarStart(nullptr), LoopExitAt(nullptr), IndVarIncreasing(false) {}
    460 
    461   template <typename M> LoopStructure map(M Map) const {
    462     LoopStructure Result;
    463     Result.Tag = Tag;
    464     Result.Header = cast<BasicBlock>(Map(Header));
    465     Result.Latch = cast<BasicBlock>(Map(Latch));
    466     Result.LatchBr = cast<BranchInst>(Map(LatchBr));
    467     Result.LatchExit = cast<BasicBlock>(Map(LatchExit));
    468     Result.LatchBrExitIdx = LatchBrExitIdx;
    469     Result.IndVarNext = Map(IndVarNext);
    470     Result.IndVarStart = Map(IndVarStart);
    471     Result.LoopExitAt = Map(LoopExitAt);
    472     Result.IndVarIncreasing = IndVarIncreasing;
    473     return Result;
    474   }
    475 
    476   static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &,
    477                                                     BranchProbabilityInfo &BPI,
    478                                                     Loop &,
    479                                                     const char *&);
    480 };
    481 
    482 /// This class is used to constrain loops to run within a given iteration space.
    483 /// The algorithm this class implements is given a Loop and a range [Begin,
    484 /// End).  The algorithm then tries to break out a "main loop" out of the loop
    485 /// it is given in a way that the "main loop" runs with the induction variable
    486 /// in a subset of [Begin, End).  The algorithm emits appropriate pre and post
    487 /// loops to run any remaining iterations.  The pre loop runs any iterations in
    488 /// which the induction variable is < Begin, and the post loop runs any
    489 /// iterations in which the induction variable is >= End.
    490 ///
    491 class LoopConstrainer {
    492   // The representation of a clone of the original loop we started out with.
    493   struct ClonedLoop {
    494     // The cloned blocks
    495     std::vector<BasicBlock *> Blocks;
    496 
    497     // `Map` maps values in the clonee into values in the cloned version
    498     ValueToValueMapTy Map;
    499 
    500     // An instance of `LoopStructure` for the cloned loop
    501     LoopStructure Structure;
    502   };
    503 
    504   // Result of rewriting the range of a loop.  See changeIterationSpaceEnd for
    505   // more details on what these fields mean.
    506   struct RewrittenRangeInfo {
    507     BasicBlock *PseudoExit;
    508     BasicBlock *ExitSelector;
    509     std::vector<PHINode *> PHIValuesAtPseudoExit;
    510     PHINode *IndVarEnd;
    511 
    512     RewrittenRangeInfo()
    513         : PseudoExit(nullptr), ExitSelector(nullptr), IndVarEnd(nullptr) {}
    514   };
    515 
    516   // Calculated subranges we restrict the iteration space of the main loop to.
    517   // See the implementation of `calculateSubRanges' for more details on how
    518   // these fields are computed.  `LowLimit` is None if there is no restriction
    519   // on low end of the restricted iteration space of the main loop.  `HighLimit`
    520   // is None if there is no restriction on high end of the restricted iteration
    521   // space of the main loop.
    522 
    523   struct SubRanges {
    524     Optional<const SCEV *> LowLimit;
    525     Optional<const SCEV *> HighLimit;
    526   };
    527 
    528   // A utility function that does a `replaceUsesOfWith' on the incoming block
    529   // set of a `PHINode' -- replaces instances of `Block' in the `PHINode's
    530   // incoming block list with `ReplaceBy'.
    531   static void replacePHIBlock(PHINode *PN, BasicBlock *Block,
    532                               BasicBlock *ReplaceBy);
    533 
    534   // Compute a safe set of limits for the main loop to run in -- effectively the
    535   // intersection of `Range' and the iteration space of the original loop.
    536   // Return None if unable to compute the set of subranges.
    537   //
    538   Optional<SubRanges> calculateSubRanges() const;
    539 
    540   // Clone `OriginalLoop' and return the result in CLResult.  The IR after
    541   // running `cloneLoop' is well formed except for the PHI nodes in CLResult --
    542   // the PHI nodes say that there is an incoming edge from `OriginalPreheader`
    543   // but there is no such edge.
    544   //
    545   void cloneLoop(ClonedLoop &CLResult, const char *Tag) const;
    546 
    547   // Rewrite the iteration space of the loop denoted by (LS, Preheader). The
    548   // iteration space of the rewritten loop ends at ExitLoopAt.  The start of the
    549   // iteration space is not changed.  `ExitLoopAt' is assumed to be slt
    550   // `OriginalHeaderCount'.
    551   //
    552   // If there are iterations left to execute, control is made to jump to
    553   // `ContinuationBlock', otherwise they take the normal loop exit.  The
    554   // returned `RewrittenRangeInfo' object is populated as follows:
    555   //
    556   //  .PseudoExit is a basic block that unconditionally branches to
    557   //      `ContinuationBlock'.
    558   //
    559   //  .ExitSelector is a basic block that decides, on exit from the loop,
    560   //      whether to branch to the "true" exit or to `PseudoExit'.
    561   //
    562   //  .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value
    563   //      for each PHINode in the loop header on taking the pseudo exit.
    564   //
    565   // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate
    566   // preheader because it is made to branch to the loop header only
    567   // conditionally.
    568   //
    569   RewrittenRangeInfo
    570   changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader,
    571                           Value *ExitLoopAt,
    572                           BasicBlock *ContinuationBlock) const;
    573 
    574   // The loop denoted by `LS' has `OldPreheader' as its preheader.  This
    575   // function creates a new preheader for `LS' and returns it.
    576   //
    577   BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader,
    578                               const char *Tag) const;
    579 
    580   // `ContinuationBlockAndPreheader' was the continuation block for some call to
    581   // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'.
    582   // This function rewrites the PHI nodes in `LS.Header' to start with the
    583   // correct value.
    584   void rewriteIncomingValuesForPHIs(
    585       LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader,
    586       const LoopConstrainer::RewrittenRangeInfo &RRI) const;
    587 
    588   // Even though we do not preserve any passes at this time, we at least need to
    589   // keep the parent loop structure consistent.  The `LPPassManager' seems to
    590   // verify this after running a loop pass.  This function adds the list of
    591   // blocks denoted by BBs to this loops parent loop if required.
    592   void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs);
    593 
    594   // Some global state.
    595   Function &F;
    596   LLVMContext &Ctx;
    597   ScalarEvolution &SE;
    598 
    599   // Information about the original loop we started out with.
    600   Loop &OriginalLoop;
    601   LoopInfo &OriginalLoopInfo;
    602   const SCEV *LatchTakenCount;
    603   BasicBlock *OriginalPreheader;
    604 
    605   // The preheader of the main loop.  This may or may not be different from
    606   // `OriginalPreheader'.
    607   BasicBlock *MainLoopPreheader;
    608 
    609   // The range we need to run the main loop in.
    610   InductiveRangeCheck::Range Range;
    611 
    612   // The structure of the main loop (see comment at the beginning of this class
    613   // for a definition)
    614   LoopStructure MainLoopStructure;
    615 
    616 public:
    617   LoopConstrainer(Loop &L, LoopInfo &LI, const LoopStructure &LS,
    618                   ScalarEvolution &SE, InductiveRangeCheck::Range R)
    619       : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()),
    620         SE(SE), OriginalLoop(L), OriginalLoopInfo(LI), LatchTakenCount(nullptr),
    621         OriginalPreheader(nullptr), MainLoopPreheader(nullptr), Range(R),
    622         MainLoopStructure(LS) {}
    623 
    624   // Entry point for the algorithm.  Returns true on success.
    625   bool run();
    626 };
    627 
    628 }
    629 
    630 void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block,
    631                                       BasicBlock *ReplaceBy) {
    632   for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
    633     if (PN->getIncomingBlock(i) == Block)
    634       PN->setIncomingBlock(i, ReplaceBy);
    635 }
    636 
    637 static bool CanBeSMax(ScalarEvolution &SE, const SCEV *S) {
    638   APInt SMax =
    639       APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth());
    640   return SE.getSignedRange(S).contains(SMax) &&
    641          SE.getUnsignedRange(S).contains(SMax);
    642 }
    643 
    644 static bool CanBeSMin(ScalarEvolution &SE, const SCEV *S) {
    645   APInt SMin =
    646       APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth());
    647   return SE.getSignedRange(S).contains(SMin) &&
    648          SE.getUnsignedRange(S).contains(SMin);
    649 }
    650 
    651 Optional<LoopStructure>
    652 LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BPI,
    653                                   Loop &L, const char *&FailureReason) {
    654   assert(L.isLoopSimplifyForm() && "should follow from addRequired<>");
    655 
    656   BasicBlock *Latch = L.getLoopLatch();
    657   if (!L.isLoopExiting(Latch)) {
    658     FailureReason = "no loop latch";
    659     return None;
    660   }
    661 
    662   BasicBlock *Header = L.getHeader();
    663   BasicBlock *Preheader = L.getLoopPreheader();
    664   if (!Preheader) {
    665     FailureReason = "no preheader";
    666     return None;
    667   }
    668 
    669   BranchInst *LatchBr = dyn_cast<BranchInst>(&*Latch->rbegin());
    670   if (!LatchBr || LatchBr->isUnconditional()) {
    671     FailureReason = "latch terminator not conditional branch";
    672     return None;
    673   }
    674 
    675   unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
    676 
    677   BranchProbability ExitProbability =
    678     BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx);
    679 
    680   if (ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) {
    681     FailureReason = "short running loop, not profitable";
    682     return None;
    683   }
    684 
    685   ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
    686   if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
    687     FailureReason = "latch terminator branch not conditional on integral icmp";
    688     return None;
    689   }
    690 
    691   const SCEV *LatchCount = SE.getExitCount(&L, Latch);
    692   if (isa<SCEVCouldNotCompute>(LatchCount)) {
    693     FailureReason = "could not compute latch count";
    694     return None;
    695   }
    696 
    697   ICmpInst::Predicate Pred = ICI->getPredicate();
    698   Value *LeftValue = ICI->getOperand(0);
    699   const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
    700   IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
    701 
    702   Value *RightValue = ICI->getOperand(1);
    703   const SCEV *RightSCEV = SE.getSCEV(RightValue);
    704 
    705   // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
    706   if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
    707     if (isa<SCEVAddRecExpr>(RightSCEV)) {
    708       std::swap(LeftSCEV, RightSCEV);
    709       std::swap(LeftValue, RightValue);
    710       Pred = ICmpInst::getSwappedPredicate(Pred);
    711     } else {
    712       FailureReason = "no add recurrences in the icmp";
    713       return None;
    714     }
    715   }
    716 
    717   auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
    718     if (AR->getNoWrapFlags(SCEV::FlagNSW))
    719       return true;
    720 
    721     IntegerType *Ty = cast<IntegerType>(AR->getType());
    722     IntegerType *WideTy =
    723         IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
    724 
    725     const SCEVAddRecExpr *ExtendAfterOp =
    726         dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
    727     if (ExtendAfterOp) {
    728       const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
    729       const SCEV *ExtendedStep =
    730           SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
    731 
    732       bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
    733                           ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
    734 
    735       if (NoSignedWrap)
    736         return true;
    737     }
    738 
    739     // We may have proved this when computing the sign extension above.
    740     return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
    741   };
    742 
    743   auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) {
    744     if (!AR->isAffine())
    745       return false;
    746 
    747     // Currently we only work with induction variables that have been proved to
    748     // not wrap.  This restriction can potentially be lifted in the future.
    749 
    750     if (!HasNoSignedWrap(AR))
    751       return false;
    752 
    753     if (const SCEVConstant *StepExpr =
    754             dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) {
    755       ConstantInt *StepCI = StepExpr->getValue();
    756       if (StepCI->isOne() || StepCI->isMinusOne()) {
    757         IsIncreasing = StepCI->isOne();
    758         return true;
    759       }
    760     }
    761 
    762     return false;
    763   };
    764 
    765   // `ICI` is interpreted as taking the backedge if the *next* value of the
    766   // induction variable satisfies some constraint.
    767 
    768   const SCEVAddRecExpr *IndVarNext = cast<SCEVAddRecExpr>(LeftSCEV);
    769   bool IsIncreasing = false;
    770   if (!IsInductionVar(IndVarNext, IsIncreasing)) {
    771     FailureReason = "LHS in icmp not induction variable";
    772     return None;
    773   }
    774 
    775   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
    776   // TODO: generalize the predicates here to also match their unsigned variants.
    777   if (IsIncreasing) {
    778     bool FoundExpectedPred =
    779         (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 1) ||
    780         (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 0);
    781 
    782     if (!FoundExpectedPred) {
    783       FailureReason = "expected icmp slt semantically, found something else";
    784       return None;
    785     }
    786 
    787     if (LatchBrExitIdx == 0) {
    788       if (CanBeSMax(SE, RightSCEV)) {
    789         // TODO: this restriction is easily removable -- we just have to
    790         // remember that the icmp was an slt and not an sle.
    791         FailureReason = "limit may overflow when coercing sle to slt";
    792         return None;
    793       }
    794 
    795       IRBuilder<> B(&*Preheader->rbegin());
    796       RightValue = B.CreateAdd(RightValue, One);
    797     }
    798 
    799   } else {
    800     bool FoundExpectedPred =
    801         (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) ||
    802         (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 0);
    803 
    804     if (!FoundExpectedPred) {
    805       FailureReason = "expected icmp sgt semantically, found something else";
    806       return None;
    807     }
    808 
    809     if (LatchBrExitIdx == 0) {
    810       if (CanBeSMin(SE, RightSCEV)) {
    811         // TODO: this restriction is easily removable -- we just have to
    812         // remember that the icmp was an sgt and not an sge.
    813         FailureReason = "limit may overflow when coercing sge to sgt";
    814         return None;
    815       }
    816 
    817       IRBuilder<> B(&*Preheader->rbegin());
    818       RightValue = B.CreateSub(RightValue, One);
    819     }
    820   }
    821 
    822   const SCEV *StartNext = IndVarNext->getStart();
    823   const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE));
    824   const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
    825 
    826   BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
    827 
    828   assert(SE.getLoopDisposition(LatchCount, &L) ==
    829              ScalarEvolution::LoopInvariant &&
    830          "loop variant exit count doesn't make sense!");
    831 
    832   assert(!L.contains(LatchExit) && "expected an exit block!");
    833   const DataLayout &DL = Preheader->getModule()->getDataLayout();
    834   Value *IndVarStartV =
    835       SCEVExpander(SE, DL, "irce")
    836           .expandCodeFor(IndVarStart, IndVarTy, &*Preheader->rbegin());
    837   IndVarStartV->setName("indvar.start");
    838 
    839   LoopStructure Result;
    840 
    841   Result.Tag = "main";
    842   Result.Header = Header;
    843   Result.Latch = Latch;
    844   Result.LatchBr = LatchBr;
    845   Result.LatchExit = LatchExit;
    846   Result.LatchBrExitIdx = LatchBrExitIdx;
    847   Result.IndVarStart = IndVarStartV;
    848   Result.IndVarNext = LeftValue;
    849   Result.IndVarIncreasing = IsIncreasing;
    850   Result.LoopExitAt = RightValue;
    851 
    852   FailureReason = nullptr;
    853 
    854   return Result;
    855 }
    856 
    857 Optional<LoopConstrainer::SubRanges>
    858 LoopConstrainer::calculateSubRanges() const {
    859   IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType());
    860 
    861   if (Range.getType() != Ty)
    862     return None;
    863 
    864   LoopConstrainer::SubRanges Result;
    865 
    866   // I think we can be more aggressive here and make this nuw / nsw if the
    867   // addition that feeds into the icmp for the latch's terminating branch is nuw
    868   // / nsw.  In any case, a wrapping 2's complement addition is safe.
    869   ConstantInt *One = ConstantInt::get(Ty, 1);
    870   const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart);
    871   const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt);
    872 
    873   bool Increasing = MainLoopStructure.IndVarIncreasing;
    874 
    875   // We compute `Smallest` and `Greatest` such that [Smallest, Greatest) is the
    876   // range of values the induction variable takes.
    877 
    878   const SCEV *Smallest = nullptr, *Greatest = nullptr;
    879 
    880   if (Increasing) {
    881     Smallest = Start;
    882     Greatest = End;
    883   } else {
    884     // These two computations may sign-overflow.  Here is why that is okay:
    885     //
    886     // We know that the induction variable does not sign-overflow on any
    887     // iteration except the last one, and it starts at `Start` and ends at
    888     // `End`, decrementing by one every time.
    889     //
    890     //  * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the
    891     //    induction variable is decreasing we know that that the smallest value
    892     //    the loop body is actually executed with is `INT_SMIN` == `Smallest`.
    893     //
    894     //  * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`.  In
    895     //    that case, `Clamp` will always return `Smallest` and
    896     //    [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`)
    897     //    will be an empty range.  Returning an empty range is always safe.
    898     //
    899 
    900     Smallest = SE.getAddExpr(End, SE.getSCEV(One));
    901     Greatest = SE.getAddExpr(Start, SE.getSCEV(One));
    902   }
    903 
    904   auto Clamp = [this, Smallest, Greatest](const SCEV *S) {
    905     return SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S));
    906   };
    907 
    908   // In some cases we can prove that we don't need a pre or post loop
    909 
    910   bool ProvablyNoPreloop =
    911       SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Smallest);
    912   if (!ProvablyNoPreloop)
    913     Result.LowLimit = Clamp(Range.getBegin());
    914 
    915   bool ProvablyNoPostLoop =
    916       SE.isKnownPredicate(ICmpInst::ICMP_SLE, Greatest, Range.getEnd());
    917   if (!ProvablyNoPostLoop)
    918     Result.HighLimit = Clamp(Range.getEnd());
    919 
    920   return Result;
    921 }
    922 
    923 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
    924                                 const char *Tag) const {
    925   for (BasicBlock *BB : OriginalLoop.getBlocks()) {
    926     BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
    927     Result.Blocks.push_back(Clone);
    928     Result.Map[BB] = Clone;
    929   }
    930 
    931   auto GetClonedValue = [&Result](Value *V) {
    932     assert(V && "null values not in domain!");
    933     auto It = Result.Map.find(V);
    934     if (It == Result.Map.end())
    935       return V;
    936     return static_cast<Value *>(It->second);
    937   };
    938 
    939   Result.Structure = MainLoopStructure.map(GetClonedValue);
    940   Result.Structure.Tag = Tag;
    941 
    942   for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
    943     BasicBlock *ClonedBB = Result.Blocks[i];
    944     BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
    945 
    946     assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
    947 
    948     for (Instruction &I : *ClonedBB)
    949       RemapInstruction(&I, Result.Map,
    950                        RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
    951 
    952     // Exit blocks will now have one more predecessor and their PHI nodes need
    953     // to be edited to reflect that.  No phi nodes need to be introduced because
    954     // the loop is in LCSSA.
    955 
    956     for (auto SBBI = succ_begin(OriginalBB), SBBE = succ_end(OriginalBB);
    957          SBBI != SBBE; ++SBBI) {
    958 
    959       if (OriginalLoop.contains(*SBBI))
    960         continue; // not an exit block
    961 
    962       for (Instruction &I : **SBBI) {
    963         if (!isa<PHINode>(&I))
    964           break;
    965 
    966         PHINode *PN = cast<PHINode>(&I);
    967         Value *OldIncoming = PN->getIncomingValueForBlock(OriginalBB);
    968         PN->addIncoming(GetClonedValue(OldIncoming), ClonedBB);
    969       }
    970     }
    971   }
    972 }
    973 
    974 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
    975     const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
    976     BasicBlock *ContinuationBlock) const {
    977 
    978   // We start with a loop with a single latch:
    979   //
    980   //    +--------------------+
    981   //    |                    |
    982   //    |     preheader      |
    983   //    |                    |
    984   //    +--------+-----------+
    985   //             |      ----------------\
    986   //             |     /                |
    987   //    +--------v----v------+          |
    988   //    |                    |          |
    989   //    |      header        |          |
    990   //    |                    |          |
    991   //    +--------------------+          |
    992   //                                    |
    993   //            .....                   |
    994   //                                    |
    995   //    +--------------------+          |
    996   //    |                    |          |
    997   //    |       latch        >----------/
    998   //    |                    |
    999   //    +-------v------------+
   1000   //            |
   1001   //            |
   1002   //            |   +--------------------+
   1003   //            |   |                    |
   1004   //            +--->   original exit    |
   1005   //                |                    |
   1006   //                +--------------------+
   1007   //
   1008   // We change the control flow to look like
   1009   //
   1010   //
   1011   //    +--------------------+
   1012   //    |                    |
   1013   //    |     preheader      >-------------------------+
   1014   //    |                    |                         |
   1015   //    +--------v-----------+                         |
   1016   //             |    /-------------+                  |
   1017   //             |   /              |                  |
   1018   //    +--------v--v--------+      |                  |
   1019   //    |                    |      |                  |
   1020   //    |      header        |      |   +--------+     |
   1021   //    |                    |      |   |        |     |
   1022   //    +--------------------+      |   |  +-----v-----v-----------+
   1023   //                                |   |  |                       |
   1024   //                                |   |  |     .pseudo.exit      |
   1025   //                                |   |  |                       |
   1026   //                                |   |  +-----------v-----------+
   1027   //                                |   |              |
   1028   //            .....               |   |              |
   1029   //                                |   |     +--------v-------------+
   1030   //    +--------------------+      |   |     |                      |
   1031   //    |                    |      |   |     |   ContinuationBlock  |
   1032   //    |       latch        >------+   |     |                      |
   1033   //    |                    |          |     +----------------------+
   1034   //    +---------v----------+          |
   1035   //              |                     |
   1036   //              |                     |
   1037   //              |     +---------------^-----+
   1038   //              |     |                     |
   1039   //              +----->    .exit.selector   |
   1040   //                    |                     |
   1041   //                    +----------v----------+
   1042   //                               |
   1043   //     +--------------------+    |
   1044   //     |                    |    |
   1045   //     |   original exit    <----+
   1046   //     |                    |
   1047   //     +--------------------+
   1048   //
   1049 
   1050   RewrittenRangeInfo RRI;
   1051 
   1052   auto BBInsertLocation = std::next(Function::iterator(LS.Latch));
   1053   RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
   1054                                         &F, &*BBInsertLocation);
   1055   RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
   1056                                       &*BBInsertLocation);
   1057 
   1058   BranchInst *PreheaderJump = cast<BranchInst>(&*Preheader->rbegin());
   1059   bool Increasing = LS.IndVarIncreasing;
   1060 
   1061   IRBuilder<> B(PreheaderJump);
   1062 
   1063   // EnterLoopCond - is it okay to start executing this `LS'?
   1064   Value *EnterLoopCond = Increasing
   1065                              ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt)
   1066                              : B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt);
   1067 
   1068   B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
   1069   PreheaderJump->eraseFromParent();
   1070 
   1071   LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
   1072   B.SetInsertPoint(LS.LatchBr);
   1073   Value *TakeBackedgeLoopCond =
   1074       Increasing ? B.CreateICmpSLT(LS.IndVarNext, ExitSubloopAt)
   1075                  : B.CreateICmpSGT(LS.IndVarNext, ExitSubloopAt);
   1076   Value *CondForBranch = LS.LatchBrExitIdx == 1
   1077                              ? TakeBackedgeLoopCond
   1078                              : B.CreateNot(TakeBackedgeLoopCond);
   1079 
   1080   LS.LatchBr->setCondition(CondForBranch);
   1081 
   1082   B.SetInsertPoint(RRI.ExitSelector);
   1083 
   1084   // IterationsLeft - are there any more iterations left, given the original
   1085   // upper bound on the induction variable?  If not, we branch to the "real"
   1086   // exit.
   1087   Value *IterationsLeft = Increasing
   1088                               ? B.CreateICmpSLT(LS.IndVarNext, LS.LoopExitAt)
   1089                               : B.CreateICmpSGT(LS.IndVarNext, LS.LoopExitAt);
   1090   B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
   1091 
   1092   BranchInst *BranchToContinuation =
   1093       BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
   1094 
   1095   // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
   1096   // each of the PHI nodes in the loop header.  This feeds into the initial
   1097   // value of the same PHI nodes if/when we continue execution.
   1098   for (Instruction &I : *LS.Header) {
   1099     if (!isa<PHINode>(&I))
   1100       break;
   1101 
   1102     PHINode *PN = cast<PHINode>(&I);
   1103 
   1104     PHINode *NewPHI = PHINode::Create(PN->getType(), 2, PN->getName() + ".copy",
   1105                                       BranchToContinuation);
   1106 
   1107     NewPHI->addIncoming(PN->getIncomingValueForBlock(Preheader), Preheader);
   1108     NewPHI->addIncoming(PN->getIncomingValueForBlock(LS.Latch),
   1109                         RRI.ExitSelector);
   1110     RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
   1111   }
   1112 
   1113   RRI.IndVarEnd = PHINode::Create(LS.IndVarNext->getType(), 2, "indvar.end",
   1114                                   BranchToContinuation);
   1115   RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader);
   1116   RRI.IndVarEnd->addIncoming(LS.IndVarNext, RRI.ExitSelector);
   1117 
   1118   // The latch exit now has a branch from `RRI.ExitSelector' instead of
   1119   // `LS.Latch'.  The PHI nodes need to be updated to reflect that.
   1120   for (Instruction &I : *LS.LatchExit) {
   1121     if (PHINode *PN = dyn_cast<PHINode>(&I))
   1122       replacePHIBlock(PN, LS.Latch, RRI.ExitSelector);
   1123     else
   1124       break;
   1125   }
   1126 
   1127   return RRI;
   1128 }
   1129 
   1130 void LoopConstrainer::rewriteIncomingValuesForPHIs(
   1131     LoopStructure &LS, BasicBlock *ContinuationBlock,
   1132     const LoopConstrainer::RewrittenRangeInfo &RRI) const {
   1133 
   1134   unsigned PHIIndex = 0;
   1135   for (Instruction &I : *LS.Header) {
   1136     if (!isa<PHINode>(&I))
   1137       break;
   1138 
   1139     PHINode *PN = cast<PHINode>(&I);
   1140 
   1141     for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i)
   1142       if (PN->getIncomingBlock(i) == ContinuationBlock)
   1143         PN->setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]);
   1144   }
   1145 
   1146   LS.IndVarStart = RRI.IndVarEnd;
   1147 }
   1148 
   1149 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
   1150                                              BasicBlock *OldPreheader,
   1151                                              const char *Tag) const {
   1152 
   1153   BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
   1154   BranchInst::Create(LS.Header, Preheader);
   1155 
   1156   for (Instruction &I : *LS.Header) {
   1157     if (!isa<PHINode>(&I))
   1158       break;
   1159 
   1160     PHINode *PN = cast<PHINode>(&I);
   1161     for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i)
   1162       replacePHIBlock(PN, OldPreheader, Preheader);
   1163   }
   1164 
   1165   return Preheader;
   1166 }
   1167 
   1168 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
   1169   Loop *ParentLoop = OriginalLoop.getParentLoop();
   1170   if (!ParentLoop)
   1171     return;
   1172 
   1173   for (BasicBlock *BB : BBs)
   1174     ParentLoop->addBasicBlockToLoop(BB, OriginalLoopInfo);
   1175 }
   1176 
   1177 bool LoopConstrainer::run() {
   1178   BasicBlock *Preheader = nullptr;
   1179   LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch);
   1180   Preheader = OriginalLoop.getLoopPreheader();
   1181   assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr &&
   1182          "preconditions!");
   1183 
   1184   OriginalPreheader = Preheader;
   1185   MainLoopPreheader = Preheader;
   1186 
   1187   Optional<SubRanges> MaybeSR = calculateSubRanges();
   1188   if (!MaybeSR.hasValue()) {
   1189     DEBUG(dbgs() << "irce: could not compute subranges\n");
   1190     return false;
   1191   }
   1192 
   1193   SubRanges SR = MaybeSR.getValue();
   1194   bool Increasing = MainLoopStructure.IndVarIncreasing;
   1195   IntegerType *IVTy =
   1196       cast<IntegerType>(MainLoopStructure.IndVarNext->getType());
   1197 
   1198   SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce");
   1199   Instruction *InsertPt = OriginalPreheader->getTerminator();
   1200 
   1201   // It would have been better to make `PreLoop' and `PostLoop'
   1202   // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
   1203   // constructor.
   1204   ClonedLoop PreLoop, PostLoop;
   1205   bool NeedsPreLoop =
   1206       Increasing ? SR.LowLimit.hasValue() : SR.HighLimit.hasValue();
   1207   bool NeedsPostLoop =
   1208       Increasing ? SR.HighLimit.hasValue() : SR.LowLimit.hasValue();
   1209 
   1210   Value *ExitPreLoopAt = nullptr;
   1211   Value *ExitMainLoopAt = nullptr;
   1212   const SCEVConstant *MinusOneS =
   1213       cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
   1214 
   1215   if (NeedsPreLoop) {
   1216     const SCEV *ExitPreLoopAtSCEV = nullptr;
   1217 
   1218     if (Increasing)
   1219       ExitPreLoopAtSCEV = *SR.LowLimit;
   1220     else {
   1221       if (CanBeSMin(SE, *SR.HighLimit)) {
   1222         DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
   1223                      << "preloop exit limit.  HighLimit = " << *(*SR.HighLimit)
   1224                      << "\n");
   1225         return false;
   1226       }
   1227       ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
   1228     }
   1229 
   1230     ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
   1231     ExitPreLoopAt->setName("exit.preloop.at");
   1232   }
   1233 
   1234   if (NeedsPostLoop) {
   1235     const SCEV *ExitMainLoopAtSCEV = nullptr;
   1236 
   1237     if (Increasing)
   1238       ExitMainLoopAtSCEV = *SR.HighLimit;
   1239     else {
   1240       if (CanBeSMin(SE, *SR.LowLimit)) {
   1241         DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
   1242                      << "mainloop exit limit.  LowLimit = " << *(*SR.LowLimit)
   1243                      << "\n");
   1244         return false;
   1245       }
   1246       ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
   1247     }
   1248 
   1249     ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
   1250     ExitMainLoopAt->setName("exit.mainloop.at");
   1251   }
   1252 
   1253   // We clone these ahead of time so that we don't have to deal with changing
   1254   // and temporarily invalid IR as we transform the loops.
   1255   if (NeedsPreLoop)
   1256     cloneLoop(PreLoop, "preloop");
   1257   if (NeedsPostLoop)
   1258     cloneLoop(PostLoop, "postloop");
   1259 
   1260   RewrittenRangeInfo PreLoopRRI;
   1261 
   1262   if (NeedsPreLoop) {
   1263     Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
   1264                                                   PreLoop.Structure.Header);
   1265 
   1266     MainLoopPreheader =
   1267         createPreheader(MainLoopStructure, Preheader, "mainloop");
   1268     PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
   1269                                          ExitPreLoopAt, MainLoopPreheader);
   1270     rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
   1271                                  PreLoopRRI);
   1272   }
   1273 
   1274   BasicBlock *PostLoopPreheader = nullptr;
   1275   RewrittenRangeInfo PostLoopRRI;
   1276 
   1277   if (NeedsPostLoop) {
   1278     PostLoopPreheader =
   1279         createPreheader(PostLoop.Structure, Preheader, "postloop");
   1280     PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
   1281                                           ExitMainLoopAt, PostLoopPreheader);
   1282     rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
   1283                                  PostLoopRRI);
   1284   }
   1285 
   1286   BasicBlock *NewMainLoopPreheader =
   1287       MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
   1288   BasicBlock *NewBlocks[] = {PostLoopPreheader,        PreLoopRRI.PseudoExit,
   1289                              PreLoopRRI.ExitSelector,  PostLoopRRI.PseudoExit,
   1290                              PostLoopRRI.ExitSelector, NewMainLoopPreheader};
   1291 
   1292   // Some of the above may be nullptr, filter them out before passing to
   1293   // addToParentLoopIfNeeded.
   1294   auto NewBlocksEnd =
   1295       std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
   1296 
   1297   addToParentLoopIfNeeded(makeArrayRef(std::begin(NewBlocks), NewBlocksEnd));
   1298   addToParentLoopIfNeeded(PreLoop.Blocks);
   1299   addToParentLoopIfNeeded(PostLoop.Blocks);
   1300 
   1301   return true;
   1302 }
   1303 
   1304 /// Computes and returns a range of values for the induction variable (IndVar)
   1305 /// in which the range check can be safely elided.  If it cannot compute such a
   1306 /// range, returns None.
   1307 Optional<InductiveRangeCheck::Range>
   1308 InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
   1309                                                const SCEVAddRecExpr *IndVar,
   1310                                                IRBuilder<> &) const {
   1311   // IndVar is of the form "A + B * I" (where "I" is the canonical induction
   1312   // variable, that may or may not exist as a real llvm::Value in the loop) and
   1313   // this inductive range check is a range check on the "C + D * I" ("C" is
   1314   // getOffset() and "D" is getScale()).  We rewrite the value being range
   1315   // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA".
   1316   // Currently we support this only for "B" = "D" = { 1 or -1 }, but the code
   1317   // can be generalized as needed.
   1318   //
   1319   // The actual inequalities we solve are of the form
   1320   //
   1321   //   0 <= M + 1 * IndVar < L given L >= 0  (i.e. N == 1)
   1322   //
   1323   // The inequality is satisfied by -M <= IndVar < (L - M) [^1].  All additions
   1324   // and subtractions are twos-complement wrapping and comparisons are signed.
   1325   //
   1326   // Proof:
   1327   //
   1328   //   If there exists IndVar such that -M <= IndVar < (L - M) then it follows
   1329   //   that -M <= (-M + L) [== Eq. 1].  Since L >= 0, if (-M + L) sign-overflows
   1330   //   then (-M + L) < (-M).  Hence by [Eq. 1], (-M + L) could not have
   1331   //   overflown.
   1332   //
   1333   //   This means IndVar = t + (-M) for t in [0, L).  Hence (IndVar + M) = t.
   1334   //   Hence 0 <= (IndVar + M) < L
   1335 
   1336   // [^1]: Note that the solution does _not_ apply if L < 0; consider values M =
   1337   // 127, IndVar = 126 and L = -2 in an i8 world.
   1338 
   1339   if (!IndVar->isAffine())
   1340     return None;
   1341 
   1342   const SCEV *A = IndVar->getStart();
   1343   const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE));
   1344   if (!B)
   1345     return None;
   1346 
   1347   const SCEV *C = getOffset();
   1348   const SCEVConstant *D = dyn_cast<SCEVConstant>(getScale());
   1349   if (D != B)
   1350     return None;
   1351 
   1352   ConstantInt *ConstD = D->getValue();
   1353   if (!(ConstD->isMinusOne() || ConstD->isOne()))
   1354     return None;
   1355 
   1356   const SCEV *M = SE.getMinusSCEV(C, A);
   1357 
   1358   const SCEV *Begin = SE.getNegativeSCEV(M);
   1359   const SCEV *UpperLimit = nullptr;
   1360 
   1361   // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L".
   1362   // We can potentially do much better here.
   1363   if (Value *V = getLength()) {
   1364     UpperLimit = SE.getSCEV(V);
   1365   } else {
   1366     assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!");
   1367     unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth();
   1368     UpperLimit = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
   1369   }
   1370 
   1371   const SCEV *End = SE.getMinusSCEV(UpperLimit, M);
   1372   return InductiveRangeCheck::Range(Begin, End);
   1373 }
   1374 
   1375 static Optional<InductiveRangeCheck::Range>
   1376 IntersectRange(ScalarEvolution &SE,
   1377                const Optional<InductiveRangeCheck::Range> &R1,
   1378                const InductiveRangeCheck::Range &R2, IRBuilder<> &B) {
   1379   if (!R1.hasValue())
   1380     return R2;
   1381   auto &R1Value = R1.getValue();
   1382 
   1383   // TODO: we could widen the smaller range and have this work; but for now we
   1384   // bail out to keep things simple.
   1385   if (R1Value.getType() != R2.getType())
   1386     return None;
   1387 
   1388   const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin());
   1389   const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd());
   1390 
   1391   return InductiveRangeCheck::Range(NewBegin, NewEnd);
   1392 }
   1393 
   1394 bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
   1395   if (L->getBlocks().size() >= LoopSizeCutoff) {
   1396     DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";);
   1397     return false;
   1398   }
   1399 
   1400   BasicBlock *Preheader = L->getLoopPreheader();
   1401   if (!Preheader) {
   1402     DEBUG(dbgs() << "irce: loop has no preheader, leaving\n");
   1403     return false;
   1404   }
   1405 
   1406   LLVMContext &Context = Preheader->getContext();
   1407   InductiveRangeCheck::AllocatorTy IRCAlloc;
   1408   SmallVector<InductiveRangeCheck *, 16> RangeChecks;
   1409   ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
   1410   BranchProbabilityInfo &BPI =
   1411       getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
   1412 
   1413   for (auto BBI : L->getBlocks())
   1414     if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))
   1415       if (InductiveRangeCheck *IRC =
   1416           InductiveRangeCheck::create(IRCAlloc, TBI, L, SE, BPI))
   1417         RangeChecks.push_back(IRC);
   1418 
   1419   if (RangeChecks.empty())
   1420     return false;
   1421 
   1422   auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) {
   1423     OS << "irce: looking at loop "; L->print(OS);
   1424     OS << "irce: loop has " << RangeChecks.size()
   1425        << " inductive range checks: \n";
   1426     for (InductiveRangeCheck *IRC : RangeChecks)
   1427       IRC->print(OS);
   1428   };
   1429 
   1430   DEBUG(PrintRecognizedRangeChecks(dbgs()));
   1431 
   1432   if (PrintRangeChecks)
   1433     PrintRecognizedRangeChecks(errs());
   1434 
   1435   const char *FailureReason = nullptr;
   1436   Optional<LoopStructure> MaybeLoopStructure =
   1437       LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason);
   1438   if (!MaybeLoopStructure.hasValue()) {
   1439     DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason
   1440                  << "\n";);
   1441     return false;
   1442   }
   1443   LoopStructure LS = MaybeLoopStructure.getValue();
   1444   bool Increasing = LS.IndVarIncreasing;
   1445   const SCEV *MinusOne =
   1446       SE.getConstant(LS.IndVarNext->getType(), Increasing ? -1 : 1, true);
   1447   const SCEVAddRecExpr *IndVar =
   1448       cast<SCEVAddRecExpr>(SE.getAddExpr(SE.getSCEV(LS.IndVarNext), MinusOne));
   1449 
   1450   Optional<InductiveRangeCheck::Range> SafeIterRange;
   1451   Instruction *ExprInsertPt = Preheader->getTerminator();
   1452 
   1453   SmallVector<InductiveRangeCheck *, 4> RangeChecksToEliminate;
   1454 
   1455   IRBuilder<> B(ExprInsertPt);
   1456   for (InductiveRangeCheck *IRC : RangeChecks) {
   1457     auto Result = IRC->computeSafeIterationSpace(SE, IndVar, B);
   1458     if (Result.hasValue()) {
   1459       auto MaybeSafeIterRange =
   1460         IntersectRange(SE, SafeIterRange, Result.getValue(), B);
   1461       if (MaybeSafeIterRange.hasValue()) {
   1462         RangeChecksToEliminate.push_back(IRC);
   1463         SafeIterRange = MaybeSafeIterRange.getValue();
   1464       }
   1465     }
   1466   }
   1467 
   1468   if (!SafeIterRange.hasValue())
   1469     return false;
   1470 
   1471   LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LS,
   1472                      SE, SafeIterRange.getValue());
   1473   bool Changed = LC.run();
   1474 
   1475   if (Changed) {
   1476     auto PrintConstrainedLoopInfo = [L]() {
   1477       dbgs() << "irce: in function ";
   1478       dbgs() << L->getHeader()->getParent()->getName() << ": ";
   1479       dbgs() << "constrained ";
   1480       L->print(dbgs());
   1481     };
   1482 
   1483     DEBUG(PrintConstrainedLoopInfo());
   1484 
   1485     if (PrintChangedLoops)
   1486       PrintConstrainedLoopInfo();
   1487 
   1488     // Optimize away the now-redundant range checks.
   1489 
   1490     for (InductiveRangeCheck *IRC : RangeChecksToEliminate) {
   1491       ConstantInt *FoldedRangeCheck = IRC->getPassingDirection()
   1492                                           ? ConstantInt::getTrue(Context)
   1493                                           : ConstantInt::getFalse(Context);
   1494       IRC->getBranch()->setCondition(FoldedRangeCheck);
   1495     }
   1496   }
   1497 
   1498   return Changed;
   1499 }
   1500 
   1501 Pass *llvm::createInductiveRangeCheckEliminationPass() {
   1502   return new InductiveRangeCheckElimination;
   1503 }
   1504