Home | History | Annotate | Download | only in IPO
      1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // A pass wrapper around the ExtractLoop() scalar transformation to extract each
     11 // top-level loop into its own new function. If the loop is the ONLY loop in a
     12 // given function, it is not touched. This is a pass most useful for debugging
     13 // via bugpoint.
     14 //
     15 //===----------------------------------------------------------------------===//
     16 
     17 #include "llvm/Transforms/IPO.h"
     18 #include "llvm/ADT/Statistic.h"
     19 #include "llvm/Analysis/LoopPass.h"
     20 #include "llvm/IR/Dominators.h"
     21 #include "llvm/IR/Instructions.h"
     22 #include "llvm/IR/Module.h"
     23 #include "llvm/Pass.h"
     24 #include "llvm/Support/CommandLine.h"
     25 #include "llvm/Transforms/Scalar.h"
     26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     27 #include "llvm/Transforms/Utils/CodeExtractor.h"
     28 #include <fstream>
     29 #include <set>
     30 using namespace llvm;
     31 
     32 #define DEBUG_TYPE "loop-extract"
     33 
     34 STATISTIC(NumExtracted, "Number of loops extracted");
     35 
     36 namespace {
     37   struct LoopExtractor : public LoopPass {
     38     static char ID; // Pass identification, replacement for typeid
     39     unsigned NumLoops;
     40 
     41     explicit LoopExtractor(unsigned numLoops = ~0)
     42       : LoopPass(ID), NumLoops(numLoops) {
     43         initializeLoopExtractorPass(*PassRegistry::getPassRegistry());
     44       }
     45 
     46     bool runOnLoop(Loop *L, LPPassManager &) override;
     47 
     48     void getAnalysisUsage(AnalysisUsage &AU) const override {
     49       AU.addRequiredID(BreakCriticalEdgesID);
     50       AU.addRequiredID(LoopSimplifyID);
     51       AU.addRequired<DominatorTreeWrapperPass>();
     52       AU.addRequired<LoopInfoWrapperPass>();
     53     }
     54   };
     55 }
     56 
     57 char LoopExtractor::ID = 0;
     58 INITIALIZE_PASS_BEGIN(LoopExtractor, "loop-extract",
     59                       "Extract loops into new functions", false, false)
     60 INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges)
     61 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
     62 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
     63 INITIALIZE_PASS_END(LoopExtractor, "loop-extract",
     64                     "Extract loops into new functions", false, false)
     65 
     66 namespace {
     67   /// SingleLoopExtractor - For bugpoint.
     68   struct SingleLoopExtractor : public LoopExtractor {
     69     static char ID; // Pass identification, replacement for typeid
     70     SingleLoopExtractor() : LoopExtractor(1) {}
     71   };
     72 } // End anonymous namespace
     73 
     74 char SingleLoopExtractor::ID = 0;
     75 INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single",
     76                 "Extract at most one loop into a new function", false, false)
     77 
     78 // createLoopExtractorPass - This pass extracts all natural loops from the
     79 // program into a function if it can.
     80 //
     81 Pass *llvm::createLoopExtractorPass() { return new LoopExtractor(); }
     82 
     83 bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &) {
     84   if (skipLoop(L))
     85     return false;
     86 
     87   // Only visit top-level loops.
     88   if (L->getParentLoop())
     89     return false;
     90 
     91   // If LoopSimplify form is not available, stay out of trouble.
     92   if (!L->isLoopSimplifyForm())
     93     return false;
     94 
     95   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
     96   LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
     97   bool Changed = false;
     98 
     99   // If there is more than one top-level loop in this function, extract all of
    100   // the loops. Otherwise there is exactly one top-level loop; in this case if
    101   // this function is more than a minimal wrapper around the loop, extract
    102   // the loop.
    103   bool ShouldExtractLoop = false;
    104 
    105   // Extract the loop if the entry block doesn't branch to the loop header.
    106   TerminatorInst *EntryTI =
    107     L->getHeader()->getParent()->getEntryBlock().getTerminator();
    108   if (!isa<BranchInst>(EntryTI) ||
    109       !cast<BranchInst>(EntryTI)->isUnconditional() ||
    110       EntryTI->getSuccessor(0) != L->getHeader()) {
    111     ShouldExtractLoop = true;
    112   } else {
    113     // Check to see if any exits from the loop are more than just return
    114     // blocks.
    115     SmallVector<BasicBlock*, 8> ExitBlocks;
    116     L->getExitBlocks(ExitBlocks);
    117     for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i)
    118       if (!isa<ReturnInst>(ExitBlocks[i]->getTerminator())) {
    119         ShouldExtractLoop = true;
    120         break;
    121       }
    122   }
    123 
    124   if (ShouldExtractLoop) {
    125     // We must omit EH pads. EH pads must accompany the invoke
    126     // instruction. But this would result in a loop in the extracted
    127     // function. An infinite cycle occurs when it tries to extract that loop as
    128     // well.
    129     SmallVector<BasicBlock*, 8> ExitBlocks;
    130     L->getExitBlocks(ExitBlocks);
    131     for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i)
    132       if (ExitBlocks[i]->isEHPad()) {
    133         ShouldExtractLoop = false;
    134         break;
    135       }
    136   }
    137 
    138   if (ShouldExtractLoop) {
    139     if (NumLoops == 0) return Changed;
    140     --NumLoops;
    141     CodeExtractor Extractor(DT, *L);
    142     if (Extractor.extractCodeRegion() != nullptr) {
    143       Changed = true;
    144       // After extraction, the loop is replaced by a function call, so
    145       // we shouldn't try to run any more loop passes on it.
    146       LI.markAsRemoved(L);
    147     }
    148     ++NumExtracted;
    149   }
    150 
    151   return Changed;
    152 }
    153 
    154 // createSingleLoopExtractorPass - This pass extracts one natural loop from the
    155 // program into a function if it can.  This is used by bugpoint.
    156 //
    157 Pass *llvm::createSingleLoopExtractorPass() {
    158   return new SingleLoopExtractor();
    159 }
    160 
    161 
    162 // BlockFile - A file which contains a list of blocks that should not be
    163 // extracted.
    164 static cl::opt<std::string>
    165 BlockFile("extract-blocks-file", cl::value_desc("filename"),
    166           cl::desc("A file containing list of basic blocks to not extract"),
    167           cl::Hidden);
    168 
    169 namespace {
    170   /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks
    171   /// from the module into their own functions except for those specified by the
    172   /// BlocksToNotExtract list.
    173   class BlockExtractorPass : public ModulePass {
    174     void LoadFile(const char *Filename);
    175     void SplitLandingPadPreds(Function *F);
    176 
    177     std::vector<BasicBlock*> BlocksToNotExtract;
    178     std::vector<std::pair<std::string, std::string> > BlocksToNotExtractByName;
    179   public:
    180     static char ID; // Pass identification, replacement for typeid
    181     BlockExtractorPass() : ModulePass(ID) {
    182       if (!BlockFile.empty())
    183         LoadFile(BlockFile.c_str());
    184     }
    185 
    186     bool runOnModule(Module &M) override;
    187   };
    188 }
    189 
    190 char BlockExtractorPass::ID = 0;
    191 INITIALIZE_PASS(BlockExtractorPass, "extract-blocks",
    192                 "Extract Basic Blocks From Module (for bugpoint use)",
    193                 false, false)
    194 
    195 // createBlockExtractorPass - This pass extracts all blocks (except those
    196 // specified in the argument list) from the functions in the module.
    197 //
    198 ModulePass *llvm::createBlockExtractorPass() {
    199   return new BlockExtractorPass();
    200 }
    201 
    202 void BlockExtractorPass::LoadFile(const char *Filename) {
    203   // Load the BlockFile...
    204   std::ifstream In(Filename);
    205   if (!In.good()) {
    206     errs() << "WARNING: BlockExtractor couldn't load file '" << Filename
    207            << "'!\n";
    208     return;
    209   }
    210   while (In) {
    211     std::string FunctionName, BlockName;
    212     In >> FunctionName;
    213     In >> BlockName;
    214     if (!BlockName.empty())
    215       BlocksToNotExtractByName.push_back(
    216           std::make_pair(FunctionName, BlockName));
    217   }
    218 }
    219 
    220 /// SplitLandingPadPreds - The landing pad needs to be extracted with the invoke
    221 /// instruction. The critical edge breaker will refuse to break critical edges
    222 /// to a landing pad. So do them here. After this method runs, all landing pads
    223 /// should have only one predecessor.
    224 void BlockExtractorPass::SplitLandingPadPreds(Function *F) {
    225   for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) {
    226     InvokeInst *II = dyn_cast<InvokeInst>(I);
    227     if (!II) continue;
    228     BasicBlock *Parent = II->getParent();
    229     BasicBlock *LPad = II->getUnwindDest();
    230 
    231     // Look through the landing pad's predecessors. If one of them ends in an
    232     // 'invoke', then we want to split the landing pad.
    233     bool Split = false;
    234     for (pred_iterator
    235            PI = pred_begin(LPad), PE = pred_end(LPad); PI != PE; ++PI) {
    236       BasicBlock *BB = *PI;
    237       if (BB->isLandingPad() && BB != Parent &&
    238           isa<InvokeInst>(Parent->getTerminator())) {
    239         Split = true;
    240         break;
    241       }
    242     }
    243 
    244     if (!Split) continue;
    245 
    246     SmallVector<BasicBlock*, 2> NewBBs;
    247     SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs);
    248   }
    249 }
    250 
    251 bool BlockExtractorPass::runOnModule(Module &M) {
    252   if (skipModule(M))
    253     return false;
    254 
    255   std::set<BasicBlock*> TranslatedBlocksToNotExtract;
    256   for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) {
    257     BasicBlock *BB = BlocksToNotExtract[i];
    258     Function *F = BB->getParent();
    259 
    260     // Map the corresponding function in this module.
    261     Function *MF = M.getFunction(F->getName());
    262     assert(MF->getFunctionType() == F->getFunctionType() && "Wrong function?");
    263 
    264     // Figure out which index the basic block is in its function.
    265     Function::iterator BBI = MF->begin();
    266     std::advance(BBI, std::distance(F->begin(), Function::iterator(BB)));
    267     TranslatedBlocksToNotExtract.insert(&*BBI);
    268   }
    269 
    270   while (!BlocksToNotExtractByName.empty()) {
    271     // There's no way to find BBs by name without looking at every BB inside
    272     // every Function. Fortunately, this is always empty except when used by
    273     // bugpoint in which case correctness is more important than performance.
    274 
    275     std::string &FuncName  = BlocksToNotExtractByName.back().first;
    276     std::string &BlockName = BlocksToNotExtractByName.back().second;
    277 
    278     for (Function &F : M) {
    279       if (F.getName() != FuncName) continue;
    280 
    281       for (BasicBlock &BB : F) {
    282         if (BB.getName() != BlockName) continue;
    283 
    284         TranslatedBlocksToNotExtract.insert(&BB);
    285       }
    286     }
    287 
    288     BlocksToNotExtractByName.pop_back();
    289   }
    290 
    291   // Now that we know which blocks to not extract, figure out which ones we WANT
    292   // to extract.
    293   std::vector<BasicBlock*> BlocksToExtract;
    294   for (Function &F : M) {
    295     SplitLandingPadPreds(&F);
    296     for (BasicBlock &BB : F)
    297       if (!TranslatedBlocksToNotExtract.count(&BB))
    298         BlocksToExtract.push_back(&BB);
    299   }
    300 
    301   for (BasicBlock *BlockToExtract : BlocksToExtract) {
    302     SmallVector<BasicBlock*, 2> BlocksToExtractVec;
    303     BlocksToExtractVec.push_back(BlockToExtract);
    304     if (const InvokeInst *II =
    305             dyn_cast<InvokeInst>(BlockToExtract->getTerminator()))
    306       BlocksToExtractVec.push_back(II->getUnwindDest());
    307     CodeExtractor(BlocksToExtractVec).extractCodeRegion();
    308   }
    309 
    310   return !BlocksToExtract.empty();
    311 }
    312