Home | History | Annotate | Download | only in Scalar
      1 //===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 /// \file
     10 /// This transformation combines adjacent loads.
     11 ///
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "llvm/Transforms/Scalar.h"
     15 #include "llvm/ADT/DenseMap.h"
     16 #include "llvm/ADT/Statistic.h"
     17 #include "llvm/Analysis/AliasAnalysis.h"
     18 #include "llvm/Analysis/AliasSetTracker.h"
     19 #include "llvm/Analysis/GlobalsModRef.h"
     20 #include "llvm/Analysis/TargetFolder.h"
     21 #include "llvm/IR/DataLayout.h"
     22 #include "llvm/IR/Function.h"
     23 #include "llvm/IR/IRBuilder.h"
     24 #include "llvm/IR/Instructions.h"
     25 #include "llvm/IR/Module.h"
     26 #include "llvm/Pass.h"
     27 #include "llvm/Support/Debug.h"
     28 #include "llvm/Support/MathExtras.h"
     29 #include "llvm/Support/raw_ostream.h"
     30 
     31 using namespace llvm;
     32 
     33 #define DEBUG_TYPE "load-combine"
     34 
     35 STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining");
     36 STATISTIC(NumLoadsCombined, "Number of loads combined");
     37 
     38 #define LDCOMBINE_NAME "Combine Adjacent Loads"
     39 
     40 namespace {
     41 struct PointerOffsetPair {
     42   Value *Pointer;
     43   APInt Offset;
     44 };
     45 
     46 struct LoadPOPPair {
     47   LoadPOPPair() = default;
     48   LoadPOPPair(LoadInst *L, PointerOffsetPair P, unsigned O)
     49       : Load(L), POP(P), InsertOrder(O) {}
     50   LoadInst *Load;
     51   PointerOffsetPair POP;
     52   /// \brief The new load needs to be created before the first load in IR order.
     53   unsigned InsertOrder;
     54 };
     55 
     56 class LoadCombine : public BasicBlockPass {
     57   LLVMContext *C;
     58   AliasAnalysis *AA;
     59 
     60 public:
     61   LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) {
     62     initializeLoadCombinePass(*PassRegistry::getPassRegistry());
     63   }
     64 
     65   using llvm::Pass::doInitialization;
     66   bool doInitialization(Function &) override;
     67   bool runOnBasicBlock(BasicBlock &BB) override;
     68   void getAnalysisUsage(AnalysisUsage &AU) const override {
     69     AU.setPreservesCFG();
     70     AU.addRequired<AAResultsWrapperPass>();
     71     AU.addPreserved<GlobalsAAWrapperPass>();
     72   }
     73 
     74   const char *getPassName() const override { return LDCOMBINE_NAME; }
     75   static char ID;
     76 
     77   typedef IRBuilder<TargetFolder> BuilderTy;
     78 
     79 private:
     80   BuilderTy *Builder;
     81 
     82   PointerOffsetPair getPointerOffsetPair(LoadInst &);
     83   bool combineLoads(DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &);
     84   bool aggregateLoads(SmallVectorImpl<LoadPOPPair> &);
     85   bool combineLoads(SmallVectorImpl<LoadPOPPair> &);
     86 };
     87 }
     88 
     89 bool LoadCombine::doInitialization(Function &F) {
     90   DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n");
     91   C = &F.getContext();
     92   return true;
     93 }
     94 
     95 PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) {
     96   auto &DL = LI.getModule()->getDataLayout();
     97 
     98   PointerOffsetPair POP;
     99   POP.Pointer = LI.getPointerOperand();
    100   unsigned BitWidth = DL.getPointerSizeInBits(LI.getPointerAddressSpace());
    101   POP.Offset = APInt(BitWidth, 0);
    102 
    103   while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) {
    104     if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) {
    105       APInt LastOffset = POP.Offset;
    106       if (!GEP->accumulateConstantOffset(DL, POP.Offset)) {
    107         // Can't handle GEPs with variable indices.
    108         POP.Offset = LastOffset;
    109         return POP;
    110       }
    111       POP.Pointer = GEP->getPointerOperand();
    112     } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer)) {
    113       POP.Pointer = BC->getOperand(0);
    114     }
    115   }
    116   return POP;
    117 }
    118 
    119 bool LoadCombine::combineLoads(
    120     DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &LoadMap) {
    121   bool Combined = false;
    122   for (auto &Loads : LoadMap) {
    123     if (Loads.second.size() < 2)
    124       continue;
    125     std::sort(Loads.second.begin(), Loads.second.end(),
    126               [](const LoadPOPPair &A, const LoadPOPPair &B) {
    127                 return A.POP.Offset.slt(B.POP.Offset);
    128               });
    129     if (aggregateLoads(Loads.second))
    130       Combined = true;
    131   }
    132   return Combined;
    133 }
    134 
    135 /// \brief Try to aggregate loads from a sorted list of loads to be combined.
    136 ///
    137 /// It is guaranteed that no writes occur between any of the loads. All loads
    138 /// have the same base pointer. There are at least two loads.
    139 bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
    140   assert(Loads.size() >= 2 && "Insufficient loads!");
    141   LoadInst *BaseLoad = nullptr;
    142   SmallVector<LoadPOPPair, 8> AggregateLoads;
    143   bool Combined = false;
    144   bool ValidPrevOffset = false;
    145   APInt PrevOffset;
    146   uint64_t PrevSize = 0;
    147   for (auto &L : Loads) {
    148     if (ValidPrevOffset == false) {
    149       BaseLoad = L.Load;
    150       PrevOffset = L.POP.Offset;
    151       PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
    152           L.Load->getType());
    153       AggregateLoads.push_back(L);
    154       ValidPrevOffset = true;
    155       continue;
    156     }
    157     if (L.Load->getAlignment() > BaseLoad->getAlignment())
    158       continue;
    159     APInt PrevEnd = PrevOffset + PrevSize;
    160     if (L.POP.Offset.sgt(PrevEnd)) {
    161       // No other load will be combinable
    162       if (combineLoads(AggregateLoads))
    163         Combined = true;
    164       AggregateLoads.clear();
    165       ValidPrevOffset = false;
    166       continue;
    167     }
    168     if (L.POP.Offset != PrevEnd)
    169       // This load is offset less than the size of the last load.
    170       // FIXME: We may want to handle this case.
    171       continue;
    172     PrevOffset = L.POP.Offset;
    173     PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
    174         L.Load->getType());
    175     AggregateLoads.push_back(L);
    176   }
    177   if (combineLoads(AggregateLoads))
    178     Combined = true;
    179   return Combined;
    180 }
    181 
    182 /// \brief Given a list of combinable load. Combine the maximum number of them.
    183 bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
    184   // Remove loads from the end while the size is not a power of 2.
    185   unsigned TotalSize = 0;
    186   for (const auto &L : Loads)
    187     TotalSize += L.Load->getType()->getPrimitiveSizeInBits();
    188   while (TotalSize != 0 && !isPowerOf2_32(TotalSize))
    189     TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits();
    190   if (Loads.size() < 2)
    191     return false;
    192 
    193   DEBUG({
    194     dbgs() << "***** Combining Loads ******\n";
    195     for (const auto &L : Loads) {
    196       dbgs() << L.POP.Offset << ": " << *L.Load << "\n";
    197     }
    198   });
    199 
    200   // Find first load. This is where we put the new load.
    201   LoadPOPPair FirstLP;
    202   FirstLP.InsertOrder = -1u;
    203   for (const auto &L : Loads)
    204     if (L.InsertOrder < FirstLP.InsertOrder)
    205       FirstLP = L;
    206 
    207   unsigned AddressSpace =
    208       FirstLP.POP.Pointer->getType()->getPointerAddressSpace();
    209 
    210   Builder->SetInsertPoint(FirstLP.Load);
    211   Value *Ptr = Builder->CreateConstGEP1_64(
    212       Builder->CreatePointerCast(Loads[0].POP.Pointer,
    213                                  Builder->getInt8PtrTy(AddressSpace)),
    214       Loads[0].POP.Offset.getSExtValue());
    215   LoadInst *NewLoad = new LoadInst(
    216       Builder->CreatePointerCast(
    217           Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize),
    218                                 Ptr->getType()->getPointerAddressSpace())),
    219       Twine(Loads[0].Load->getName()) + ".combined", false,
    220       Loads[0].Load->getAlignment(), FirstLP.Load);
    221 
    222   for (const auto &L : Loads) {
    223     Builder->SetInsertPoint(L.Load);
    224     Value *V = Builder->CreateExtractInteger(
    225         L.Load->getModule()->getDataLayout(), NewLoad,
    226         cast<IntegerType>(L.Load->getType()),
    227         (L.POP.Offset - Loads[0].POP.Offset).getZExtValue(), "combine.extract");
    228     L.Load->replaceAllUsesWith(V);
    229   }
    230 
    231   NumLoadsCombined = NumLoadsCombined + Loads.size();
    232   return true;
    233 }
    234 
    235 bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {
    236   if (skipBasicBlock(BB))
    237     return false;
    238 
    239   AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
    240 
    241   IRBuilder<TargetFolder> TheBuilder(
    242       BB.getContext(), TargetFolder(BB.getModule()->getDataLayout()));
    243   Builder = &TheBuilder;
    244 
    245   DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap;
    246   AliasSetTracker AST(*AA);
    247 
    248   bool Combined = false;
    249   unsigned Index = 0;
    250   for (auto &I : BB) {
    251     if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) {
    252       if (combineLoads(LoadMap))
    253         Combined = true;
    254       LoadMap.clear();
    255       AST.clear();
    256       continue;
    257     }
    258     LoadInst *LI = dyn_cast<LoadInst>(&I);
    259     if (!LI)
    260       continue;
    261     ++NumLoadsAnalyzed;
    262     if (!LI->isSimple() || !LI->getType()->isIntegerTy())
    263       continue;
    264     auto POP = getPointerOffsetPair(*LI);
    265     if (!POP.Pointer)
    266       continue;
    267     LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++));
    268     AST.add(LI);
    269   }
    270   if (combineLoads(LoadMap))
    271     Combined = true;
    272   return Combined;
    273 }
    274 
    275 char LoadCombine::ID = 0;
    276 
    277 BasicBlockPass *llvm::createLoadCombinePass() {
    278   return new LoadCombine();
    279 }
    280 
    281 INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false)
    282 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
    283 INITIALIZE_PASS_END(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false)
    284