Home | History | Annotate | Download | only in Scalar
      1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
      2 //                  Set Load/Store Alignments From Assumptions
      3 //
      4 //                     The LLVM Compiler Infrastructure
      5 //
      6 // This file is distributed under the University of Illinois Open Source
      7 // License. See LICENSE.TXT for details.
      8 //
      9 //===----------------------------------------------------------------------===//
     10 //
     11 // This file implements a ScalarEvolution-based transformation to set
     12 // the alignments of load, stores and memory intrinsics based on the truth
     13 // expressions of assume intrinsics. The primary motivation is to handle
     14 // complex alignment assumptions that apply to vector loads and stores that
     15 // appear after vectorization and unrolling.
     16 //
     17 //===----------------------------------------------------------------------===//
     18 
     19 #define AA_NAME "alignment-from-assumptions"
     20 #define DEBUG_TYPE AA_NAME
     21 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
     22 #include "llvm/Transforms/Scalar.h"
     23 #include "llvm/ADT/SmallPtrSet.h"
     24 #include "llvm/ADT/Statistic.h"
     25 #include "llvm/Analysis/AliasAnalysis.h"
     26 #include "llvm/Analysis/GlobalsModRef.h"
     27 #include "llvm/Analysis/AssumptionCache.h"
     28 #include "llvm/Analysis/LoopInfo.h"
     29 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
     30 #include "llvm/Analysis/ValueTracking.h"
     31 #include "llvm/IR/Constant.h"
     32 #include "llvm/IR/Dominators.h"
     33 #include "llvm/IR/Instruction.h"
     34 #include "llvm/IR/Intrinsics.h"
     35 #include "llvm/IR/Module.h"
     36 #include "llvm/Support/Debug.h"
     37 #include "llvm/Support/raw_ostream.h"
     38 using namespace llvm;
     39 
     40 STATISTIC(NumLoadAlignChanged,
     41   "Number of loads changed by alignment assumptions");
     42 STATISTIC(NumStoreAlignChanged,
     43   "Number of stores changed by alignment assumptions");
     44 STATISTIC(NumMemIntAlignChanged,
     45   "Number of memory intrinsics changed by alignment assumptions");
     46 
     47 namespace {
     48 struct AlignmentFromAssumptions : public FunctionPass {
     49   static char ID; // Pass identification, replacement for typeid
     50   AlignmentFromAssumptions() : FunctionPass(ID) {
     51     initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
     52   }
     53 
     54   bool runOnFunction(Function &F) override;
     55 
     56   void getAnalysisUsage(AnalysisUsage &AU) const override {
     57     AU.addRequired<AssumptionCacheTracker>();
     58     AU.addRequired<ScalarEvolutionWrapperPass>();
     59     AU.addRequired<DominatorTreeWrapperPass>();
     60 
     61     AU.setPreservesCFG();
     62     AU.addPreserved<AAResultsWrapperPass>();
     63     AU.addPreserved<GlobalsAAWrapperPass>();
     64     AU.addPreserved<LoopInfoWrapperPass>();
     65     AU.addPreserved<DominatorTreeWrapperPass>();
     66     AU.addPreserved<ScalarEvolutionWrapperPass>();
     67   }
     68 
     69   AlignmentFromAssumptionsPass Impl;
     70 };
     71 }
     72 
     73 char AlignmentFromAssumptions::ID = 0;
     74 static const char aip_name[] = "Alignment from assumptions";
     75 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
     76                       aip_name, false, false)
     77 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
     78 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
     79 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
     80 INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
     81                     aip_name, false, false)
     82 
     83 FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
     84   return new AlignmentFromAssumptions();
     85 }
     86 
     87 // Given an expression for the (constant) alignment, AlignSCEV, and an
     88 // expression for the displacement between a pointer and the aligned address,
     89 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
     90 // to a constant. Using SCEV to compute alignment handles the case where
     91 // DiffSCEV is a recurrence with constant start such that the aligned offset
     92 // is constant. e.g. {16,+,32} % 32 -> 16.
     93 static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV,
     94                                     const SCEV *AlignSCEV,
     95                                     ScalarEvolution *SE) {
     96   // DiffUnits = Diff % int64_t(Alignment)
     97   const SCEV *DiffAlignDiv = SE->getUDivExpr(DiffSCEV, AlignSCEV);
     98   const SCEV *DiffAlign = SE->getMulExpr(DiffAlignDiv, AlignSCEV);
     99   const SCEV *DiffUnitsSCEV = SE->getMinusSCEV(DiffAlign, DiffSCEV);
    100 
    101   DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " <<
    102                   *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
    103 
    104   if (const SCEVConstant *ConstDUSCEV =
    105       dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
    106     int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
    107 
    108     // If the displacement is an exact multiple of the alignment, then the
    109     // displaced pointer has the same alignment as the aligned pointer, so
    110     // return the alignment value.
    111     if (!DiffUnits)
    112       return (unsigned)
    113         cast<SCEVConstant>(AlignSCEV)->getValue()->getSExtValue();
    114 
    115     // If the displacement is not an exact multiple, but the remainder is a
    116     // constant, then return this remainder (but only if it is a power of 2).
    117     uint64_t DiffUnitsAbs = std::abs(DiffUnits);
    118     if (isPowerOf2_64(DiffUnitsAbs))
    119       return (unsigned) DiffUnitsAbs;
    120   }
    121 
    122   return 0;
    123 }
    124 
    125 // There is an address given by an offset OffSCEV from AASCEV which has an
    126 // alignment AlignSCEV. Use that information, if possible, to compute a new
    127 // alignment for Ptr.
    128 static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
    129                                 const SCEV *OffSCEV, Value *Ptr,
    130                                 ScalarEvolution *SE) {
    131   const SCEV *PtrSCEV = SE->getSCEV(Ptr);
    132   const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
    133 
    134   // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
    135   // sign-extended OffSCEV to i64, so make sure they agree again.
    136   DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
    137 
    138   // What we really want to know is the overall offset to the aligned
    139   // address. This address is displaced by the provided offset.
    140   DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV);
    141 
    142   DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " <<
    143                   *AlignSCEV << " and offset " << *OffSCEV <<
    144                   " using diff " << *DiffSCEV << "\n");
    145 
    146   unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE);
    147   DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n");
    148 
    149   if (NewAlignment) {
    150     return NewAlignment;
    151   } else if (const SCEVAddRecExpr *DiffARSCEV =
    152              dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
    153     // The relative offset to the alignment assumption did not yield a constant,
    154     // but we should try harder: if we assume that a is 32-byte aligned, then in
    155     // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
    156     // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
    157     // As a result, the new alignment will not be a constant, but can still
    158     // be improved over the default (of 4) to 16.
    159 
    160     const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
    161     const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
    162 
    163     DEBUG(dbgs() << "\ttrying start/inc alignment using start " <<
    164                     *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
    165 
    166     // Now compute the new alignment using the displacement to the value in the
    167     // first iteration, and also the alignment using the per-iteration delta.
    168     // If these are the same, then use that answer. Otherwise, use the smaller
    169     // one, but only if it divides the larger one.
    170     NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
    171     unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
    172 
    173     DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n");
    174     DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n");
    175 
    176     if (!NewAlignment || !NewIncAlignment) {
    177       return 0;
    178     } else if (NewAlignment > NewIncAlignment) {
    179       if (NewAlignment % NewIncAlignment == 0) {
    180         DEBUG(dbgs() << "\tnew start/inc alignment: " <<
    181                         NewIncAlignment << "\n");
    182         return NewIncAlignment;
    183       }
    184     } else if (NewIncAlignment > NewAlignment) {
    185       if (NewIncAlignment % NewAlignment == 0) {
    186         DEBUG(dbgs() << "\tnew start/inc alignment: " <<
    187                         NewAlignment << "\n");
    188         return NewAlignment;
    189       }
    190     } else if (NewIncAlignment == NewAlignment) {
    191       DEBUG(dbgs() << "\tnew start/inc alignment: " <<
    192                       NewAlignment << "\n");
    193       return NewAlignment;
    194     }
    195   }
    196 
    197   return 0;
    198 }
    199 
    200 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
    201                                                         Value *&AAPtr,
    202                                                         const SCEV *&AlignSCEV,
    203                                                         const SCEV *&OffSCEV) {
    204   // An alignment assume must be a statement about the least-significant
    205   // bits of the pointer being zero, possibly with some offset.
    206   ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0));
    207   if (!ICI)
    208     return false;
    209 
    210   // This must be an expression of the form: x & m == 0.
    211   if (ICI->getPredicate() != ICmpInst::ICMP_EQ)
    212     return false;
    213 
    214   // Swap things around so that the RHS is 0.
    215   Value *CmpLHS = ICI->getOperand(0);
    216   Value *CmpRHS = ICI->getOperand(1);
    217   const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS);
    218   const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS);
    219   if (CmpLHSSCEV->isZero())
    220     std::swap(CmpLHS, CmpRHS);
    221   else if (!CmpRHSSCEV->isZero())
    222     return false;
    223 
    224   BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS);
    225   if (!CmpBO || CmpBO->getOpcode() != Instruction::And)
    226     return false;
    227 
    228   // Swap things around so that the right operand of the and is a constant
    229   // (the mask); we cannot deal with variable masks.
    230   Value *AndLHS = CmpBO->getOperand(0);
    231   Value *AndRHS = CmpBO->getOperand(1);
    232   const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS);
    233   const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS);
    234   if (isa<SCEVConstant>(AndLHSSCEV)) {
    235     std::swap(AndLHS, AndRHS);
    236     std::swap(AndLHSSCEV, AndRHSSCEV);
    237   }
    238 
    239   const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV);
    240   if (!MaskSCEV)
    241     return false;
    242 
    243   // The mask must have some trailing ones (otherwise the condition is
    244   // trivial and tells us nothing about the alignment of the left operand).
    245   unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes();
    246   if (!TrailingOnes)
    247     return false;
    248 
    249   // Cap the alignment at the maximum with which LLVM can deal (and make sure
    250   // we don't overflow the shift).
    251   uint64_t Alignment;
    252   TrailingOnes = std::min(TrailingOnes,
    253     unsigned(sizeof(unsigned) * CHAR_BIT - 1));
    254   Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment);
    255 
    256   Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext());
    257   AlignSCEV = SE->getConstant(Int64Ty, Alignment);
    258 
    259   // The LHS might be a ptrtoint instruction, or it might be the pointer
    260   // with an offset.
    261   AAPtr = nullptr;
    262   OffSCEV = nullptr;
    263   if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) {
    264     AAPtr = PToI->getPointerOperand();
    265     OffSCEV = SE->getZero(Int64Ty);
    266   } else if (const SCEVAddExpr* AndLHSAddSCEV =
    267              dyn_cast<SCEVAddExpr>(AndLHSSCEV)) {
    268     // Try to find the ptrtoint; subtract it and the rest is the offset.
    269     for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(),
    270          JE = AndLHSAddSCEV->op_end(); J != JE; ++J)
    271       if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J))
    272         if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) {
    273           AAPtr = PToI->getPointerOperand();
    274           OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J);
    275           break;
    276         }
    277   }
    278 
    279   if (!AAPtr)
    280     return false;
    281 
    282   // Sign extend the offset to 64 bits (so that it is like all of the other
    283   // expressions).
    284   unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits();
    285   if (OffSCEVBits < 64)
    286     OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty);
    287   else if (OffSCEVBits > 64)
    288     return false;
    289 
    290   AAPtr = AAPtr->stripPointerCasts();
    291   return true;
    292 }
    293 
    294 bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
    295   Value *AAPtr;
    296   const SCEV *AlignSCEV, *OffSCEV;
    297   if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV))
    298     return false;
    299 
    300   const SCEV *AASCEV = SE->getSCEV(AAPtr);
    301 
    302   // Apply the assumption to all other users of the specified pointer.
    303   SmallPtrSet<Instruction *, 32> Visited;
    304   SmallVector<Instruction*, 16> WorkList;
    305   for (User *J : AAPtr->users()) {
    306     if (J == ACall)
    307       continue;
    308 
    309     if (Instruction *K = dyn_cast<Instruction>(J))
    310       if (isValidAssumeForContext(ACall, K, DT))
    311         WorkList.push_back(K);
    312   }
    313 
    314   while (!WorkList.empty()) {
    315     Instruction *J = WorkList.pop_back_val();
    316 
    317     if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
    318       unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
    319         LI->getPointerOperand(), SE);
    320 
    321       if (NewAlignment > LI->getAlignment()) {
    322         LI->setAlignment(NewAlignment);
    323         ++NumLoadAlignChanged;
    324       }
    325     } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
    326       unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
    327         SI->getPointerOperand(), SE);
    328 
    329       if (NewAlignment > SI->getAlignment()) {
    330         SI->setAlignment(NewAlignment);
    331         ++NumStoreAlignChanged;
    332       }
    333     } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
    334       unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
    335         MI->getDest(), SE);
    336 
    337       // For memory transfers, we need a common alignment for both the
    338       // source and destination. If we have a new alignment for this
    339       // instruction, but only for one operand, save it. If we reach the
    340       // other operand through another assumption later, then we may
    341       // change the alignment at that point.
    342       if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
    343         unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
    344           MTI->getSource(), SE);
    345 
    346         DenseMap<MemTransferInst *, unsigned>::iterator DI =
    347           NewDestAlignments.find(MTI);
    348         unsigned AltDestAlignment = (DI == NewDestAlignments.end()) ?
    349                                     0 : DI->second;
    350 
    351         DenseMap<MemTransferInst *, unsigned>::iterator SI =
    352           NewSrcAlignments.find(MTI);
    353         unsigned AltSrcAlignment = (SI == NewSrcAlignments.end()) ?
    354                                    0 : SI->second;
    355 
    356         DEBUG(dbgs() << "\tmem trans: " << NewDestAlignment << " " <<
    357                         AltDestAlignment << " " << NewSrcAlignment <<
    358                         " " << AltSrcAlignment << "\n");
    359 
    360         // Of these four alignments, pick the largest possible...
    361         unsigned NewAlignment = 0;
    362         if (NewDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment))
    363           NewAlignment = std::max(NewAlignment, NewDestAlignment);
    364         if (AltDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment))
    365           NewAlignment = std::max(NewAlignment, AltDestAlignment);
    366         if (NewSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment))
    367           NewAlignment = std::max(NewAlignment, NewSrcAlignment);
    368         if (AltSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment))
    369           NewAlignment = std::max(NewAlignment, AltSrcAlignment);
    370 
    371         if (NewAlignment > MI->getAlignment()) {
    372           MI->setAlignment(ConstantInt::get(Type::getInt32Ty(
    373             MI->getParent()->getContext()), NewAlignment));
    374           ++NumMemIntAlignChanged;
    375         }
    376 
    377         NewDestAlignments.insert(std::make_pair(MTI, NewDestAlignment));
    378         NewSrcAlignments.insert(std::make_pair(MTI, NewSrcAlignment));
    379       } else if (NewDestAlignment > MI->getAlignment()) {
    380         assert((!isa<MemIntrinsic>(MI) || isa<MemSetInst>(MI)) &&
    381                "Unknown memory intrinsic");
    382 
    383         MI->setAlignment(ConstantInt::get(Type::getInt32Ty(
    384           MI->getParent()->getContext()), NewDestAlignment));
    385         ++NumMemIntAlignChanged;
    386       }
    387     }
    388 
    389     // Now that we've updated that use of the pointer, look for other uses of
    390     // the pointer to update.
    391     Visited.insert(J);
    392     for (User *UJ : J->users()) {
    393       Instruction *K = cast<Instruction>(UJ);
    394       if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT))
    395         WorkList.push_back(K);
    396     }
    397   }
    398 
    399   return true;
    400 }
    401 
    402 bool AlignmentFromAssumptions::runOnFunction(Function &F) {
    403   if (skipFunction(F))
    404     return false;
    405 
    406   auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
    407   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
    408   DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
    409 
    410   return Impl.runImpl(F, AC, SE, DT);
    411 }
    412 
    413 bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
    414                                            ScalarEvolution *SE_,
    415                                            DominatorTree *DT_) {
    416   SE = SE_;
    417   DT = DT_;
    418 
    419   NewDestAlignments.clear();
    420   NewSrcAlignments.clear();
    421 
    422   bool Changed = false;
    423   for (auto &AssumeVH : AC.assumptions())
    424     if (AssumeVH)
    425       Changed |= processAssumption(cast<CallInst>(AssumeVH));
    426 
    427   return Changed;
    428 }
    429 
    430 PreservedAnalyses
    431 AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {
    432 
    433   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
    434   ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
    435   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
    436   bool Changed = runImpl(F, AC, &SE, &DT);
    437   if (!Changed)
    438     return PreservedAnalyses::all();
    439   PreservedAnalyses PA;
    440   PA.preserve<AAManager>();
    441   PA.preserve<ScalarEvolutionAnalysis>();
    442   PA.preserve<GlobalsAA>();
    443   PA.preserve<LoopAnalysis>();
    444   PA.preserve<DominatorTreeAnalysis>();
    445   return PA;
    446 }
    447