Home | History | Annotate | Download | only in Scalar
      1 //===-- StraightLineStrengthReduce.cpp - ------------------------*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This file implements straight-line strength reduction (SLSR). Unlike loop
     11 // strength reduction, this algorithm is designed to reduce arithmetic
     12 // redundancy in straight-line code instead of loops. It has proven to be
     13 // effective in simplifying arithmetic statements derived from an unrolled loop.
     14 // It can also simplify the logic of SeparateConstOffsetFromGEP.
     15 //
     16 // There are many optimizations we can perform in the domain of SLSR. This file
     17 // for now contains only an initial step. Specifically, we look for strength
     18 // reduction candidates in the following forms:
     19 //
     20 // Form 1: B + i * S
     21 // Form 2: (B + i) * S
     22 // Form 3: &B[i * S]
     23 //
     24 // where S is an integer variable, and i is a constant integer. If we found two
     25 // candidates S1 and S2 in the same form and S1 dominates S2, we may rewrite S2
     26 // in a simpler way with respect to S1. For example,
     27 //
     28 // S1: X = B + i * S
     29 // S2: Y = B + i' * S   => X + (i' - i) * S
     30 //
     31 // S1: X = (B + i) * S
     32 // S2: Y = (B + i') * S => X + (i' - i) * S
     33 //
     34 // S1: X = &B[i * S]
     35 // S2: Y = &B[i' * S]   => &X[(i' - i) * S]
     36 //
     37 // Note: (i' - i) * S is folded to the extent possible.
     38 //
     39 // This rewriting is in general a good idea. The code patterns we focus on
     40 // usually come from loop unrolling, so (i' - i) * S is likely the same
     41 // across iterations and can be reused. When that happens, the optimized form
     42 // takes only one add starting from the second iteration.
     43 //
     44 // When such rewriting is possible, we call S1 a "basis" of S2. When S2 has
     45 // multiple bases, we choose to rewrite S2 with respect to its "immediate"
     46 // basis, the basis that is the closest ancestor in the dominator tree.
     47 //
     48 // TODO:
     49 //
     50 // - Floating point arithmetics when fast math is enabled.
     51 //
     52 // - SLSR may decrease ILP at the architecture level. Targets that are very
     53 //   sensitive to ILP may want to disable it. Having SLSR to consider ILP is
     54 //   left as future work.
     55 //
     56 // - When (i' - i) is constant but i and i' are not, we could still perform
     57 //   SLSR.
     58 #include <vector>
     59 
     60 #include "llvm/Analysis/ScalarEvolution.h"
     61 #include "llvm/Analysis/TargetTransformInfo.h"
     62 #include "llvm/Analysis/ValueTracking.h"
     63 #include "llvm/IR/DataLayout.h"
     64 #include "llvm/IR/Dominators.h"
     65 #include "llvm/IR/IRBuilder.h"
     66 #include "llvm/IR/Module.h"
     67 #include "llvm/IR/PatternMatch.h"
     68 #include "llvm/Support/raw_ostream.h"
     69 #include "llvm/Transforms/Scalar.h"
     70 #include "llvm/Transforms/Utils/Local.h"
     71 
     72 using namespace llvm;
     73 using namespace PatternMatch;
     74 
     75 namespace {
     76 
     77 static const unsigned UnknownAddressSpace = ~0u;
     78 
     79 class StraightLineStrengthReduce : public FunctionPass {
     80 public:
     81   // SLSR candidate. Such a candidate must be in one of the forms described in
     82   // the header comments.
     83   struct Candidate : public ilist_node<Candidate> {
     84     enum Kind {
     85       Invalid, // reserved for the default constructor
     86       Add,     // B + i * S
     87       Mul,     // (B + i) * S
     88       GEP,     // &B[..][i * S][..]
     89     };
     90 
     91     Candidate()
     92         : CandidateKind(Invalid), Base(nullptr), Index(nullptr),
     93           Stride(nullptr), Ins(nullptr), Basis(nullptr) {}
     94     Candidate(Kind CT, const SCEV *B, ConstantInt *Idx, Value *S,
     95               Instruction *I)
     96         : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I),
     97           Basis(nullptr) {}
     98     Kind CandidateKind;
     99     const SCEV *Base;
    100     // Note that Index and Stride of a GEP candidate do not necessarily have the
    101     // same integer type. In that case, during rewriting, Stride will be
    102     // sign-extended or truncated to Index's type.
    103     ConstantInt *Index;
    104     Value *Stride;
    105     // The instruction this candidate corresponds to. It helps us to rewrite a
    106     // candidate with respect to its immediate basis. Note that one instruction
    107     // can correspond to multiple candidates depending on how you associate the
    108     // expression. For instance,
    109     //
    110     // (a + 1) * (b + 2)
    111     //
    112     // can be treated as
    113     //
    114     // <Base: a, Index: 1, Stride: b + 2>
    115     //
    116     // or
    117     //
    118     // <Base: b, Index: 2, Stride: a + 1>
    119     Instruction *Ins;
    120     // Points to the immediate basis of this candidate, or nullptr if we cannot
    121     // find any basis for this candidate.
    122     Candidate *Basis;
    123   };
    124 
    125   static char ID;
    126 
    127   StraightLineStrengthReduce()
    128       : FunctionPass(ID), DL(nullptr), DT(nullptr), TTI(nullptr) {
    129     initializeStraightLineStrengthReducePass(*PassRegistry::getPassRegistry());
    130   }
    131 
    132   void getAnalysisUsage(AnalysisUsage &AU) const override {
    133     AU.addRequired<DominatorTreeWrapperPass>();
    134     AU.addRequired<ScalarEvolutionWrapperPass>();
    135     AU.addRequired<TargetTransformInfoWrapperPass>();
    136     // We do not modify the shape of the CFG.
    137     AU.setPreservesCFG();
    138   }
    139 
    140   bool doInitialization(Module &M) override {
    141     DL = &M.getDataLayout();
    142     return false;
    143   }
    144 
    145   bool runOnFunction(Function &F) override;
    146 
    147 private:
    148   // Returns true if Basis is a basis for C, i.e., Basis dominates C and they
    149   // share the same base and stride.
    150   bool isBasisFor(const Candidate &Basis, const Candidate &C);
    151   // Returns whether the candidate can be folded into an addressing mode.
    152   bool isFoldable(const Candidate &C, TargetTransformInfo *TTI,
    153                   const DataLayout *DL);
    154   // Returns true if C is already in a simplest form and not worth being
    155   // rewritten.
    156   bool isSimplestForm(const Candidate &C);
    157   // Checks whether I is in a candidate form. If so, adds all the matching forms
    158   // to Candidates, and tries to find the immediate basis for each of them.
    159   void allocateCandidatesAndFindBasis(Instruction *I);
    160   // Allocate candidates and find bases for Add instructions.
    161   void allocateCandidatesAndFindBasisForAdd(Instruction *I);
    162   // Given I = LHS + RHS, factors RHS into i * S and makes (LHS + i * S) a
    163   // candidate.
    164   void allocateCandidatesAndFindBasisForAdd(Value *LHS, Value *RHS,
    165                                             Instruction *I);
    166   // Allocate candidates and find bases for Mul instructions.
    167   void allocateCandidatesAndFindBasisForMul(Instruction *I);
    168   // Splits LHS into Base + Index and, if succeeds, calls
    169   // allocateCandidatesAndFindBasis.
    170   void allocateCandidatesAndFindBasisForMul(Value *LHS, Value *RHS,
    171                                             Instruction *I);
    172   // Allocate candidates and find bases for GetElementPtr instructions.
    173   void allocateCandidatesAndFindBasisForGEP(GetElementPtrInst *GEP);
    174   // A helper function that scales Idx with ElementSize before invoking
    175   // allocateCandidatesAndFindBasis.
    176   void allocateCandidatesAndFindBasisForGEP(const SCEV *B, ConstantInt *Idx,
    177                                             Value *S, uint64_t ElementSize,
    178                                             Instruction *I);
    179   // Adds the given form <CT, B, Idx, S> to Candidates, and finds its immediate
    180   // basis.
    181   void allocateCandidatesAndFindBasis(Candidate::Kind CT, const SCEV *B,
    182                                       ConstantInt *Idx, Value *S,
    183                                       Instruction *I);
    184   // Rewrites candidate C with respect to Basis.
    185   void rewriteCandidateWithBasis(const Candidate &C, const Candidate &Basis);
    186   // A helper function that factors ArrayIdx to a product of a stride and a
    187   // constant index, and invokes allocateCandidatesAndFindBasis with the
    188   // factorings.
    189   void factorArrayIndex(Value *ArrayIdx, const SCEV *Base, uint64_t ElementSize,
    190                         GetElementPtrInst *GEP);
    191   // Emit code that computes the "bump" from Basis to C. If the candidate is a
    192   // GEP and the bump is not divisible by the element size of the GEP, this
    193   // function sets the BumpWithUglyGEP flag to notify its caller to bump the
    194   // basis using an ugly GEP.
    195   static Value *emitBump(const Candidate &Basis, const Candidate &C,
    196                          IRBuilder<> &Builder, const DataLayout *DL,
    197                          bool &BumpWithUglyGEP);
    198 
    199   const DataLayout *DL;
    200   DominatorTree *DT;
    201   ScalarEvolution *SE;
    202   TargetTransformInfo *TTI;
    203   ilist<Candidate> Candidates;
    204   // Temporarily holds all instructions that are unlinked (but not deleted) by
    205   // rewriteCandidateWithBasis. These instructions will be actually removed
    206   // after all rewriting finishes.
    207   std::vector<Instruction *> UnlinkedInstructions;
    208 };
    209 }  // anonymous namespace
    210 
    211 char StraightLineStrengthReduce::ID = 0;
    212 INITIALIZE_PASS_BEGIN(StraightLineStrengthReduce, "slsr",
    213                       "Straight line strength reduction", false, false)
    214 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
    215 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
    216 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
    217 INITIALIZE_PASS_END(StraightLineStrengthReduce, "slsr",
    218                     "Straight line strength reduction", false, false)
    219 
    220 FunctionPass *llvm::createStraightLineStrengthReducePass() {
    221   return new StraightLineStrengthReduce();
    222 }
    223 
    224 bool StraightLineStrengthReduce::isBasisFor(const Candidate &Basis,
    225                                             const Candidate &C) {
    226   return (Basis.Ins != C.Ins && // skip the same instruction
    227           // They must have the same type too. Basis.Base == C.Base doesn't
    228           // guarantee their types are the same (PR23975).
    229           Basis.Ins->getType() == C.Ins->getType() &&
    230           // Basis must dominate C in order to rewrite C with respect to Basis.
    231           DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) &&
    232           // They share the same base, stride, and candidate kind.
    233           Basis.Base == C.Base && Basis.Stride == C.Stride &&
    234           Basis.CandidateKind == C.CandidateKind);
    235 }
    236 
    237 static bool isGEPFoldable(GetElementPtrInst *GEP,
    238                           const TargetTransformInfo *TTI) {
    239   SmallVector<const Value*, 4> Indices;
    240   for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I)
    241     Indices.push_back(*I);
    242   return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(),
    243                          Indices) == TargetTransformInfo::TCC_Free;
    244 }
    245 
    246 // Returns whether (Base + Index * Stride) can be folded to an addressing mode.
    247 static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride,
    248                           TargetTransformInfo *TTI) {
    249   // Index->getSExtValue() may crash if Index is wider than 64-bit.
    250   return Index->getBitWidth() <= 64 &&
    251          TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true,
    252                                     Index->getSExtValue(), UnknownAddressSpace);
    253 }
    254 
    255 bool StraightLineStrengthReduce::isFoldable(const Candidate &C,
    256                                             TargetTransformInfo *TTI,
    257                                             const DataLayout *DL) {
    258   if (C.CandidateKind == Candidate::Add)
    259     return isAddFoldable(C.Base, C.Index, C.Stride, TTI);
    260   if (C.CandidateKind == Candidate::GEP)
    261     return isGEPFoldable(cast<GetElementPtrInst>(C.Ins), TTI);
    262   return false;
    263 }
    264 
    265 // Returns true if GEP has zero or one non-zero index.
    266 static bool hasOnlyOneNonZeroIndex(GetElementPtrInst *GEP) {
    267   unsigned NumNonZeroIndices = 0;
    268   for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) {
    269     ConstantInt *ConstIdx = dyn_cast<ConstantInt>(*I);
    270     if (ConstIdx == nullptr || !ConstIdx->isZero())
    271       ++NumNonZeroIndices;
    272   }
    273   return NumNonZeroIndices <= 1;
    274 }
    275 
    276 bool StraightLineStrengthReduce::isSimplestForm(const Candidate &C) {
    277   if (C.CandidateKind == Candidate::Add) {
    278     // B + 1 * S or B + (-1) * S
    279     return C.Index->isOne() || C.Index->isMinusOne();
    280   }
    281   if (C.CandidateKind == Candidate::Mul) {
    282     // (B + 0) * S
    283     return C.Index->isZero();
    284   }
    285   if (C.CandidateKind == Candidate::GEP) {
    286     // (char*)B + S or (char*)B - S
    287     return ((C.Index->isOne() || C.Index->isMinusOne()) &&
    288             hasOnlyOneNonZeroIndex(cast<GetElementPtrInst>(C.Ins)));
    289   }
    290   return false;
    291 }
    292 
    293 // TODO: We currently implement an algorithm whose time complexity is linear in
    294 // the number of existing candidates. However, we could do better by using
    295 // ScopedHashTable. Specifically, while traversing the dominator tree, we could
    296 // maintain all the candidates that dominate the basic block being traversed in
    297 // a ScopedHashTable. This hash table is indexed by the base and the stride of
    298 // a candidate. Therefore, finding the immediate basis of a candidate boils down
    299 // to one hash-table look up.
    300 void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
    301     Candidate::Kind CT, const SCEV *B, ConstantInt *Idx, Value *S,
    302     Instruction *I) {
    303   Candidate C(CT, B, Idx, S, I);
    304   // SLSR can complicate an instruction in two cases:
    305   //
    306   // 1. If we can fold I into an addressing mode, computing I is likely free or
    307   // takes only one instruction.
    308   //
    309   // 2. I is already in a simplest form. For example, when
    310   //      X = B + 8 * S
    311   //      Y = B + S,
    312   //    rewriting Y to X - 7 * S is probably a bad idea.
    313   //
    314   // In the above cases, we still add I to the candidate list so that I can be
    315   // the basis of other candidates, but we leave I's basis blank so that I
    316   // won't be rewritten.
    317   if (!isFoldable(C, TTI, DL) && !isSimplestForm(C)) {
    318     // Try to compute the immediate basis of C.
    319     unsigned NumIterations = 0;
    320     // Limit the scan radius to avoid running in quadratice time.
    321     static const unsigned MaxNumIterations = 50;
    322     for (auto Basis = Candidates.rbegin();
    323          Basis != Candidates.rend() && NumIterations < MaxNumIterations;
    324          ++Basis, ++NumIterations) {
    325       if (isBasisFor(*Basis, C)) {
    326         C.Basis = &(*Basis);
    327         break;
    328       }
    329     }
    330   }
    331   // Regardless of whether we find a basis for C, we need to push C to the
    332   // candidate list so that it can be the basis of other candidates.
    333   Candidates.push_back(C);
    334 }
    335 
    336 void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
    337     Instruction *I) {
    338   switch (I->getOpcode()) {
    339   case Instruction::Add:
    340     allocateCandidatesAndFindBasisForAdd(I);
    341     break;
    342   case Instruction::Mul:
    343     allocateCandidatesAndFindBasisForMul(I);
    344     break;
    345   case Instruction::GetElementPtr:
    346     allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(I));
    347     break;
    348   }
    349 }
    350 
    351 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
    352     Instruction *I) {
    353   // Try matching B + i * S.
    354   if (!isa<IntegerType>(I->getType()))
    355     return;
    356 
    357   assert(I->getNumOperands() == 2 && "isn't I an add?");
    358   Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
    359   allocateCandidatesAndFindBasisForAdd(LHS, RHS, I);
    360   if (LHS != RHS)
    361     allocateCandidatesAndFindBasisForAdd(RHS, LHS, I);
    362 }
    363 
    364 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
    365     Value *LHS, Value *RHS, Instruction *I) {
    366   Value *S = nullptr;
    367   ConstantInt *Idx = nullptr;
    368   if (match(RHS, m_Mul(m_Value(S), m_ConstantInt(Idx)))) {
    369     // I = LHS + RHS = LHS + Idx * S
    370     allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I);
    371   } else if (match(RHS, m_Shl(m_Value(S), m_ConstantInt(Idx)))) {
    372     // I = LHS + RHS = LHS + (S << Idx) = LHS + S * (1 << Idx)
    373     APInt One(Idx->getBitWidth(), 1);
    374     Idx = ConstantInt::get(Idx->getContext(), One << Idx->getValue());
    375     allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I);
    376   } else {
    377     // At least, I = LHS + 1 * RHS
    378     ConstantInt *One = ConstantInt::get(cast<IntegerType>(I->getType()), 1);
    379     allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS,
    380                                    I);
    381   }
    382 }
    383 
    384 // Returns true if A matches B + C where C is constant.
    385 static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C) {
    386   return (match(A, m_Add(m_Value(B), m_ConstantInt(C))) ||
    387           match(A, m_Add(m_ConstantInt(C), m_Value(B))));
    388 }
    389 
    390 // Returns true if A matches B | C where C is constant.
    391 static bool matchesOr(Value *A, Value *&B, ConstantInt *&C) {
    392   return (match(A, m_Or(m_Value(B), m_ConstantInt(C))) ||
    393           match(A, m_Or(m_ConstantInt(C), m_Value(B))));
    394 }
    395 
    396 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
    397     Value *LHS, Value *RHS, Instruction *I) {
    398   Value *B = nullptr;
    399   ConstantInt *Idx = nullptr;
    400   if (matchesAdd(LHS, B, Idx)) {
    401     // If LHS is in the form of "Base + Index", then I is in the form of
    402     // "(Base + Index) * RHS".
    403     allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
    404   } else if (matchesOr(LHS, B, Idx) && haveNoCommonBitsSet(B, Idx, *DL)) {
    405     // If LHS is in the form of "Base | Index" and Base and Index have no common
    406     // bits set, then
    407     //   Base | Index = Base + Index
    408     // and I is thus in the form of "(Base + Index) * RHS".
    409     allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
    410   } else {
    411     // Otherwise, at least try the form (LHS + 0) * RHS.
    412     ConstantInt *Zero = ConstantInt::get(cast<IntegerType>(I->getType()), 0);
    413     allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
    414                                    I);
    415   }
    416 }
    417 
    418 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
    419     Instruction *I) {
    420   // Try matching (B + i) * S.
    421   // TODO: we could extend SLSR to float and vector types.
    422   if (!isa<IntegerType>(I->getType()))
    423     return;
    424 
    425   assert(I->getNumOperands() == 2 && "isn't I a mul?");
    426   Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
    427   allocateCandidatesAndFindBasisForMul(LHS, RHS, I);
    428   if (LHS != RHS) {
    429     // Symmetrically, try to split RHS to Base + Index.
    430     allocateCandidatesAndFindBasisForMul(RHS, LHS, I);
    431   }
    432 }
    433 
    434 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
    435     const SCEV *B, ConstantInt *Idx, Value *S, uint64_t ElementSize,
    436     Instruction *I) {
    437   // I = B + sext(Idx *nsw S) * ElementSize
    438   //   = B + (sext(Idx) * sext(S)) * ElementSize
    439   //   = B + (sext(Idx) * ElementSize) * sext(S)
    440   // Casting to IntegerType is safe because we skipped vector GEPs.
    441   IntegerType *IntPtrTy = cast<IntegerType>(DL->getIntPtrType(I->getType()));
    442   ConstantInt *ScaledIdx = ConstantInt::get(
    443       IntPtrTy, Idx->getSExtValue() * (int64_t)ElementSize, true);
    444   allocateCandidatesAndFindBasis(Candidate::GEP, B, ScaledIdx, S, I);
    445 }
    446 
    447 void StraightLineStrengthReduce::factorArrayIndex(Value *ArrayIdx,
    448                                                   const SCEV *Base,
    449                                                   uint64_t ElementSize,
    450                                                   GetElementPtrInst *GEP) {
    451   // At least, ArrayIdx = ArrayIdx *nsw 1.
    452   allocateCandidatesAndFindBasisForGEP(
    453       Base, ConstantInt::get(cast<IntegerType>(ArrayIdx->getType()), 1),
    454       ArrayIdx, ElementSize, GEP);
    455   Value *LHS = nullptr;
    456   ConstantInt *RHS = nullptr;
    457   // One alternative is matching the SCEV of ArrayIdx instead of ArrayIdx
    458   // itself. This would allow us to handle the shl case for free. However,
    459   // matching SCEVs has two issues:
    460   //
    461   // 1. this would complicate rewriting because the rewriting procedure
    462   // would have to translate SCEVs back to IR instructions. This translation
    463   // is difficult when LHS is further evaluated to a composite SCEV.
    464   //
    465   // 2. ScalarEvolution is designed to be control-flow oblivious. It tends
    466   // to strip nsw/nuw flags which are critical for SLSR to trace into
    467   // sext'ed multiplication.
    468   if (match(ArrayIdx, m_NSWMul(m_Value(LHS), m_ConstantInt(RHS)))) {
    469     // SLSR is currently unsafe if i * S may overflow.
    470     // GEP = Base + sext(LHS *nsw RHS) * ElementSize
    471     allocateCandidatesAndFindBasisForGEP(Base, RHS, LHS, ElementSize, GEP);
    472   } else if (match(ArrayIdx, m_NSWShl(m_Value(LHS), m_ConstantInt(RHS)))) {
    473     // GEP = Base + sext(LHS <<nsw RHS) * ElementSize
    474     //     = Base + sext(LHS *nsw (1 << RHS)) * ElementSize
    475     APInt One(RHS->getBitWidth(), 1);
    476     ConstantInt *PowerOf2 =
    477         ConstantInt::get(RHS->getContext(), One << RHS->getValue());
    478     allocateCandidatesAndFindBasisForGEP(Base, PowerOf2, LHS, ElementSize, GEP);
    479   }
    480 }
    481 
    482 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
    483     GetElementPtrInst *GEP) {
    484   // TODO: handle vector GEPs
    485   if (GEP->getType()->isVectorTy())
    486     return;
    487 
    488   SmallVector<const SCEV *, 4> IndexExprs;
    489   for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I)
    490     IndexExprs.push_back(SE->getSCEV(*I));
    491 
    492   gep_type_iterator GTI = gep_type_begin(GEP);
    493   for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I) {
    494     if (!isa<SequentialType>(*GTI++))
    495       continue;
    496 
    497     const SCEV *OrigIndexExpr = IndexExprs[I - 1];
    498     IndexExprs[I - 1] = SE->getZero(OrigIndexExpr->getType());
    499 
    500     // The base of this candidate is GEP's base plus the offsets of all
    501     // indices except this current one.
    502     const SCEV *BaseExpr = SE->getGEPExpr(GEP->getSourceElementType(),
    503                                           SE->getSCEV(GEP->getPointerOperand()),
    504                                           IndexExprs, GEP->isInBounds());
    505     Value *ArrayIdx = GEP->getOperand(I);
    506     uint64_t ElementSize = DL->getTypeAllocSize(*GTI);
    507     if (ArrayIdx->getType()->getIntegerBitWidth() <=
    508         DL->getPointerSizeInBits(GEP->getAddressSpace())) {
    509       // Skip factoring if ArrayIdx is wider than the pointer size, because
    510       // ArrayIdx is implicitly truncated to the pointer size.
    511       factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP);
    512     }
    513     // When ArrayIdx is the sext of a value, we try to factor that value as
    514     // well.  Handling this case is important because array indices are
    515     // typically sign-extended to the pointer size.
    516     Value *TruncatedArrayIdx = nullptr;
    517     if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))) &&
    518         TruncatedArrayIdx->getType()->getIntegerBitWidth() <=
    519             DL->getPointerSizeInBits(GEP->getAddressSpace())) {
    520       // Skip factoring if TruncatedArrayIdx is wider than the pointer size,
    521       // because TruncatedArrayIdx is implicitly truncated to the pointer size.
    522       factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP);
    523     }
    524 
    525     IndexExprs[I - 1] = OrigIndexExpr;
    526   }
    527 }
    528 
    529 // A helper function that unifies the bitwidth of A and B.
    530 static void unifyBitWidth(APInt &A, APInt &B) {
    531   if (A.getBitWidth() < B.getBitWidth())
    532     A = A.sext(B.getBitWidth());
    533   else if (A.getBitWidth() > B.getBitWidth())
    534     B = B.sext(A.getBitWidth());
    535 }
    536 
    537 Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
    538                                             const Candidate &C,
    539                                             IRBuilder<> &Builder,
    540                                             const DataLayout *DL,
    541                                             bool &BumpWithUglyGEP) {
    542   APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue();
    543   unifyBitWidth(Idx, BasisIdx);
    544   APInt IndexOffset = Idx - BasisIdx;
    545 
    546   BumpWithUglyGEP = false;
    547   if (Basis.CandidateKind == Candidate::GEP) {
    548     APInt ElementSize(
    549         IndexOffset.getBitWidth(),
    550         DL->getTypeAllocSize(
    551             cast<GetElementPtrInst>(Basis.Ins)->getResultElementType()));
    552     APInt Q, R;
    553     APInt::sdivrem(IndexOffset, ElementSize, Q, R);
    554     if (R == 0)
    555       IndexOffset = Q;
    556     else
    557       BumpWithUglyGEP = true;
    558   }
    559 
    560   // Compute Bump = C - Basis = (i' - i) * S.
    561   // Common case 1: if (i' - i) is 1, Bump = S.
    562   if (IndexOffset == 1)
    563     return C.Stride;
    564   // Common case 2: if (i' - i) is -1, Bump = -S.
    565   if (IndexOffset.isAllOnesValue())
    566     return Builder.CreateNeg(C.Stride);
    567 
    568   // Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may
    569   // have different bit widths.
    570   IntegerType *DeltaType =
    571       IntegerType::get(Basis.Ins->getContext(), IndexOffset.getBitWidth());
    572   Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType);
    573   if (IndexOffset.isPowerOf2()) {
    574     // If (i' - i) is a power of 2, Bump = sext/trunc(S) << log(i' - i).
    575     ConstantInt *Exponent = ConstantInt::get(DeltaType, IndexOffset.logBase2());
    576     return Builder.CreateShl(ExtendedStride, Exponent);
    577   }
    578   if ((-IndexOffset).isPowerOf2()) {
    579     // If (i - i') is a power of 2, Bump = -sext/trunc(S) << log(i' - i).
    580     ConstantInt *Exponent =
    581         ConstantInt::get(DeltaType, (-IndexOffset).logBase2());
    582     return Builder.CreateNeg(Builder.CreateShl(ExtendedStride, Exponent));
    583   }
    584   Constant *Delta = ConstantInt::get(DeltaType, IndexOffset);
    585   return Builder.CreateMul(ExtendedStride, Delta);
    586 }
    587 
    588 void StraightLineStrengthReduce::rewriteCandidateWithBasis(
    589     const Candidate &C, const Candidate &Basis) {
    590   assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base &&
    591          C.Stride == Basis.Stride);
    592   // We run rewriteCandidateWithBasis on all candidates in a post-order, so the
    593   // basis of a candidate cannot be unlinked before the candidate.
    594   assert(Basis.Ins->getParent() != nullptr && "the basis is unlinked");
    595 
    596   // An instruction can correspond to multiple candidates. Therefore, instead of
    597   // simply deleting an instruction when we rewrite it, we mark its parent as
    598   // nullptr (i.e. unlink it) so that we can skip the candidates whose
    599   // instruction is already rewritten.
    600   if (!C.Ins->getParent())
    601     return;
    602 
    603   IRBuilder<> Builder(C.Ins);
    604   bool BumpWithUglyGEP;
    605   Value *Bump = emitBump(Basis, C, Builder, DL, BumpWithUglyGEP);
    606   Value *Reduced = nullptr; // equivalent to but weaker than C.Ins
    607   switch (C.CandidateKind) {
    608   case Candidate::Add:
    609   case Candidate::Mul:
    610     // C = Basis + Bump
    611     if (BinaryOperator::isNeg(Bump)) {
    612       // If Bump is a neg instruction, emit C = Basis - (-Bump).
    613       Reduced =
    614           Builder.CreateSub(Basis.Ins, BinaryOperator::getNegArgument(Bump));
    615       // We only use the negative argument of Bump, and Bump itself may be
    616       // trivially dead.
    617       RecursivelyDeleteTriviallyDeadInstructions(Bump);
    618     } else {
    619       // It's tempting to preserve nsw on Bump and/or Reduced. However, it's
    620       // usually unsound, e.g.,
    621       //
    622       // X = (-2 +nsw 1) *nsw INT_MAX
    623       // Y = (-2 +nsw 3) *nsw INT_MAX
    624       //   =>
    625       // Y = X + 2 * INT_MAX
    626       //
    627       // Neither + and * in the resultant expression are nsw.
    628       Reduced = Builder.CreateAdd(Basis.Ins, Bump);
    629     }
    630     break;
    631   case Candidate::GEP:
    632     {
    633       Type *IntPtrTy = DL->getIntPtrType(C.Ins->getType());
    634       bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds();
    635       if (BumpWithUglyGEP) {
    636         // C = (char *)Basis + Bump
    637         unsigned AS = Basis.Ins->getType()->getPointerAddressSpace();
    638         Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS);
    639         Reduced = Builder.CreateBitCast(Basis.Ins, CharTy);
    640         if (InBounds)
    641           Reduced =
    642               Builder.CreateInBoundsGEP(Builder.getInt8Ty(), Reduced, Bump);
    643         else
    644           Reduced = Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump);
    645         Reduced = Builder.CreateBitCast(Reduced, C.Ins->getType());
    646       } else {
    647         // C = gep Basis, Bump
    648         // Canonicalize bump to pointer size.
    649         Bump = Builder.CreateSExtOrTrunc(Bump, IntPtrTy);
    650         if (InBounds)
    651           Reduced = Builder.CreateInBoundsGEP(nullptr, Basis.Ins, Bump);
    652         else
    653           Reduced = Builder.CreateGEP(nullptr, Basis.Ins, Bump);
    654       }
    655     }
    656     break;
    657   default:
    658     llvm_unreachable("C.CandidateKind is invalid");
    659   };
    660   Reduced->takeName(C.Ins);
    661   C.Ins->replaceAllUsesWith(Reduced);
    662   // Unlink C.Ins so that we can skip other candidates also corresponding to
    663   // C.Ins. The actual deletion is postponed to the end of runOnFunction.
    664   C.Ins->removeFromParent();
    665   UnlinkedInstructions.push_back(C.Ins);
    666 }
    667 
    668 bool StraightLineStrengthReduce::runOnFunction(Function &F) {
    669   if (skipFunction(F))
    670     return false;
    671 
    672   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
    673   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
    674   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
    675   // Traverse the dominator tree in the depth-first order. This order makes sure
    676   // all bases of a candidate are in Candidates when we process it.
    677   for (auto node = GraphTraits<DominatorTree *>::nodes_begin(DT);
    678        node != GraphTraits<DominatorTree *>::nodes_end(DT); ++node) {
    679     for (auto &I : *node->getBlock())
    680       allocateCandidatesAndFindBasis(&I);
    681   }
    682 
    683   // Rewrite candidates in the reverse depth-first order. This order makes sure
    684   // a candidate being rewritten is not a basis for any other candidate.
    685   while (!Candidates.empty()) {
    686     const Candidate &C = Candidates.back();
    687     if (C.Basis != nullptr) {
    688       rewriteCandidateWithBasis(C, *C.Basis);
    689     }
    690     Candidates.pop_back();
    691   }
    692 
    693   // Delete all unlink instructions.
    694   for (auto *UnlinkedInst : UnlinkedInstructions) {
    695     for (unsigned I = 0, E = UnlinkedInst->getNumOperands(); I != E; ++I) {
    696       Value *Op = UnlinkedInst->getOperand(I);
    697       UnlinkedInst->setOperand(I, nullptr);
    698       RecursivelyDeleteTriviallyDeadInstructions(Op);
    699     }
    700     delete UnlinkedInst;
    701   }
    702   bool Ret = !UnlinkedInstructions.empty();
    703   UnlinkedInstructions.clear();
    704   return Ret;
    705 }
    706