Home | History | Annotate | Download | only in CodeGen
      1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // Instrumentation-based profile-guided optimization
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "CodeGenPGO.h"
     15 #include "CodeGenFunction.h"
     16 #include "CoverageMappingGen.h"
     17 #include "clang/AST/RecursiveASTVisitor.h"
     18 #include "clang/AST/StmtVisitor.h"
     19 #include "llvm/IR/Intrinsics.h"
     20 #include "llvm/IR/MDBuilder.h"
     21 #include "llvm/ProfileData/InstrProfReader.h"
     22 #include "llvm/Support/Endian.h"
     23 #include "llvm/Support/FileSystem.h"
     24 #include "llvm/Support/MD5.h"
     25 
     26 using namespace clang;
     27 using namespace CodeGen;
     28 
     29 void CodeGenPGO::setFuncName(StringRef Name,
     30                              llvm::GlobalValue::LinkageTypes Linkage) {
     31   StringRef RawFuncName = Name;
     32 
     33   // Function names may be prefixed with a binary '1' to indicate
     34   // that the backend should not modify the symbols due to any platform
     35   // naming convention. Do not include that '1' in the PGO profile name.
     36   if (RawFuncName[0] == '\1')
     37     RawFuncName = RawFuncName.substr(1);
     38 
     39   FuncName = RawFuncName;
     40   if (llvm::GlobalValue::isLocalLinkage(Linkage)) {
     41     // For local symbols, prepend the main file name to distinguish them.
     42     // Do not include the full path in the file name since there's no guarantee
     43     // that it will stay the same, e.g., if the files are checked out from
     44     // version control in different locations.
     45     if (CGM.getCodeGenOpts().MainFileName.empty())
     46       FuncName = FuncName.insert(0, "<unknown>:");
     47     else
     48       FuncName = FuncName.insert(0, CGM.getCodeGenOpts().MainFileName + ":");
     49   }
     50 
     51   // If we're generating a profile, create a variable for the name.
     52   if (CGM.getCodeGenOpts().ProfileInstrGenerate)
     53     createFuncNameVar(Linkage);
     54 }
     55 
     56 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
     57   setFuncName(Fn->getName(), Fn->getLinkage());
     58 }
     59 
     60 void CodeGenPGO::createFuncNameVar(llvm::GlobalValue::LinkageTypes Linkage) {
     61   // We generally want to match the function's linkage, but available_externally
     62   // and extern_weak both have the wrong semantics, and anything that doesn't
     63   // need to link across compilation units doesn't need to be visible at all.
     64   if (Linkage == llvm::GlobalValue::ExternalWeakLinkage)
     65     Linkage = llvm::GlobalValue::LinkOnceAnyLinkage;
     66   else if (Linkage == llvm::GlobalValue::AvailableExternallyLinkage)
     67     Linkage = llvm::GlobalValue::LinkOnceODRLinkage;
     68   else if (Linkage == llvm::GlobalValue::InternalLinkage ||
     69            Linkage == llvm::GlobalValue::ExternalLinkage)
     70     Linkage = llvm::GlobalValue::PrivateLinkage;
     71 
     72   auto *Value =
     73       llvm::ConstantDataArray::getString(CGM.getLLVMContext(), FuncName, false);
     74   FuncNameVar =
     75       new llvm::GlobalVariable(CGM.getModule(), Value->getType(), true, Linkage,
     76                                Value, "__llvm_profile_name_" + FuncName);
     77 
     78   // Hide the symbol so that we correctly get a copy for each executable.
     79   if (!llvm::GlobalValue::isLocalLinkage(FuncNameVar->getLinkage()))
     80     FuncNameVar->setVisibility(llvm::GlobalValue::HiddenVisibility);
     81 }
     82 
     83 namespace {
     84 /// \brief Stable hasher for PGO region counters.
     85 ///
     86 /// PGOHash produces a stable hash of a given function's control flow.
     87 ///
     88 /// Changing the output of this hash will invalidate all previously generated
     89 /// profiles -- i.e., don't do it.
     90 ///
     91 /// \note  When this hash does eventually change (years?), we still need to
     92 /// support old hashes.  We'll need to pull in the version number from the
     93 /// profile data format and use the matching hash function.
     94 class PGOHash {
     95   uint64_t Working;
     96   unsigned Count;
     97   llvm::MD5 MD5;
     98 
     99   static const int NumBitsPerType = 6;
    100   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
    101   static const unsigned TooBig = 1u << NumBitsPerType;
    102 
    103 public:
    104   /// \brief Hash values for AST nodes.
    105   ///
    106   /// Distinct values for AST nodes that have region counters attached.
    107   ///
    108   /// These values must be stable.  All new members must be added at the end,
    109   /// and no members should be removed.  Changing the enumeration value for an
    110   /// AST node will affect the hash of every function that contains that node.
    111   enum HashType : unsigned char {
    112     None = 0,
    113     LabelStmt = 1,
    114     WhileStmt,
    115     DoStmt,
    116     ForStmt,
    117     CXXForRangeStmt,
    118     ObjCForCollectionStmt,
    119     SwitchStmt,
    120     CaseStmt,
    121     DefaultStmt,
    122     IfStmt,
    123     CXXTryStmt,
    124     CXXCatchStmt,
    125     ConditionalOperator,
    126     BinaryOperatorLAnd,
    127     BinaryOperatorLOr,
    128     BinaryConditionalOperator,
    129 
    130     // Keep this last.  It's for the static assert that follows.
    131     LastHashType
    132   };
    133   static_assert(LastHashType <= TooBig, "Too many types in HashType");
    134 
    135   // TODO: When this format changes, take in a version number here, and use the
    136   // old hash calculation for file formats that used the old hash.
    137   PGOHash() : Working(0), Count(0) {}
    138   void combine(HashType Type);
    139   uint64_t finalize();
    140 };
    141 const int PGOHash::NumBitsPerType;
    142 const unsigned PGOHash::NumTypesPerWord;
    143 const unsigned PGOHash::TooBig;
    144 
    145 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
    146 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
    147   /// The next counter value to assign.
    148   unsigned NextCounter;
    149   /// The function hash.
    150   PGOHash Hash;
    151   /// The map of statements to counters.
    152   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
    153 
    154   MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
    155       : NextCounter(0), CounterMap(CounterMap) {}
    156 
    157   // Blocks and lambdas are handled as separate functions, so we need not
    158   // traverse them in the parent context.
    159   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
    160   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
    161   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
    162 
    163   bool VisitDecl(const Decl *D) {
    164     switch (D->getKind()) {
    165     default:
    166       break;
    167     case Decl::Function:
    168     case Decl::CXXMethod:
    169     case Decl::CXXConstructor:
    170     case Decl::CXXDestructor:
    171     case Decl::CXXConversion:
    172     case Decl::ObjCMethod:
    173     case Decl::Block:
    174     case Decl::Captured:
    175       CounterMap[D->getBody()] = NextCounter++;
    176       break;
    177     }
    178     return true;
    179   }
    180 
    181   bool VisitStmt(const Stmt *S) {
    182     auto Type = getHashType(S);
    183     if (Type == PGOHash::None)
    184       return true;
    185 
    186     CounterMap[S] = NextCounter++;
    187     Hash.combine(Type);
    188     return true;
    189   }
    190   PGOHash::HashType getHashType(const Stmt *S) {
    191     switch (S->getStmtClass()) {
    192     default:
    193       break;
    194     case Stmt::LabelStmtClass:
    195       return PGOHash::LabelStmt;
    196     case Stmt::WhileStmtClass:
    197       return PGOHash::WhileStmt;
    198     case Stmt::DoStmtClass:
    199       return PGOHash::DoStmt;
    200     case Stmt::ForStmtClass:
    201       return PGOHash::ForStmt;
    202     case Stmt::CXXForRangeStmtClass:
    203       return PGOHash::CXXForRangeStmt;
    204     case Stmt::ObjCForCollectionStmtClass:
    205       return PGOHash::ObjCForCollectionStmt;
    206     case Stmt::SwitchStmtClass:
    207       return PGOHash::SwitchStmt;
    208     case Stmt::CaseStmtClass:
    209       return PGOHash::CaseStmt;
    210     case Stmt::DefaultStmtClass:
    211       return PGOHash::DefaultStmt;
    212     case Stmt::IfStmtClass:
    213       return PGOHash::IfStmt;
    214     case Stmt::CXXTryStmtClass:
    215       return PGOHash::CXXTryStmt;
    216     case Stmt::CXXCatchStmtClass:
    217       return PGOHash::CXXCatchStmt;
    218     case Stmt::ConditionalOperatorClass:
    219       return PGOHash::ConditionalOperator;
    220     case Stmt::BinaryConditionalOperatorClass:
    221       return PGOHash::BinaryConditionalOperator;
    222     case Stmt::BinaryOperatorClass: {
    223       const BinaryOperator *BO = cast<BinaryOperator>(S);
    224       if (BO->getOpcode() == BO_LAnd)
    225         return PGOHash::BinaryOperatorLAnd;
    226       if (BO->getOpcode() == BO_LOr)
    227         return PGOHash::BinaryOperatorLOr;
    228       break;
    229     }
    230     }
    231     return PGOHash::None;
    232   }
    233 };
    234 
    235 /// A StmtVisitor that propagates the raw counts through the AST and
    236 /// records the count at statements where the value may change.
    237 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
    238   /// PGO state.
    239   CodeGenPGO &PGO;
    240 
    241   /// A flag that is set when the current count should be recorded on the
    242   /// next statement, such as at the exit of a loop.
    243   bool RecordNextStmtCount;
    244 
    245   /// The map of statements to count values.
    246   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
    247 
    248   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
    249   struct BreakContinue {
    250     uint64_t BreakCount;
    251     uint64_t ContinueCount;
    252     BreakContinue() : BreakCount(0), ContinueCount(0) {}
    253   };
    254   SmallVector<BreakContinue, 8> BreakContinueStack;
    255 
    256   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
    257                       CodeGenPGO &PGO)
    258       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
    259 
    260   void RecordStmtCount(const Stmt *S) {
    261     if (RecordNextStmtCount) {
    262       CountMap[S] = PGO.getCurrentRegionCount();
    263       RecordNextStmtCount = false;
    264     }
    265   }
    266 
    267   void VisitStmt(const Stmt *S) {
    268     RecordStmtCount(S);
    269     for (Stmt::const_child_range I = S->children(); I; ++I) {
    270       if (*I)
    271         this->Visit(*I);
    272     }
    273   }
    274 
    275   void VisitFunctionDecl(const FunctionDecl *D) {
    276     // Counter tracks entry to the function body.
    277     RegionCounter Cnt(PGO, D->getBody());
    278     Cnt.beginRegion();
    279     CountMap[D->getBody()] = PGO.getCurrentRegionCount();
    280     Visit(D->getBody());
    281   }
    282 
    283   // Skip lambda expressions. We visit these as FunctionDecls when we're
    284   // generating them and aren't interested in the body when generating a
    285   // parent context.
    286   void VisitLambdaExpr(const LambdaExpr *LE) {}
    287 
    288   void VisitCapturedDecl(const CapturedDecl *D) {
    289     // Counter tracks entry to the capture body.
    290     RegionCounter Cnt(PGO, D->getBody());
    291     Cnt.beginRegion();
    292     CountMap[D->getBody()] = PGO.getCurrentRegionCount();
    293     Visit(D->getBody());
    294   }
    295 
    296   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
    297     // Counter tracks entry to the method body.
    298     RegionCounter Cnt(PGO, D->getBody());
    299     Cnt.beginRegion();
    300     CountMap[D->getBody()] = PGO.getCurrentRegionCount();
    301     Visit(D->getBody());
    302   }
    303 
    304   void VisitBlockDecl(const BlockDecl *D) {
    305     // Counter tracks entry to the block body.
    306     RegionCounter Cnt(PGO, D->getBody());
    307     Cnt.beginRegion();
    308     CountMap[D->getBody()] = PGO.getCurrentRegionCount();
    309     Visit(D->getBody());
    310   }
    311 
    312   void VisitReturnStmt(const ReturnStmt *S) {
    313     RecordStmtCount(S);
    314     if (S->getRetValue())
    315       Visit(S->getRetValue());
    316     PGO.setCurrentRegionUnreachable();
    317     RecordNextStmtCount = true;
    318   }
    319 
    320   void VisitGotoStmt(const GotoStmt *S) {
    321     RecordStmtCount(S);
    322     PGO.setCurrentRegionUnreachable();
    323     RecordNextStmtCount = true;
    324   }
    325 
    326   void VisitLabelStmt(const LabelStmt *S) {
    327     RecordNextStmtCount = false;
    328     // Counter tracks the block following the label.
    329     RegionCounter Cnt(PGO, S);
    330     Cnt.beginRegion();
    331     CountMap[S] = PGO.getCurrentRegionCount();
    332     Visit(S->getSubStmt());
    333   }
    334 
    335   void VisitBreakStmt(const BreakStmt *S) {
    336     RecordStmtCount(S);
    337     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
    338     BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
    339     PGO.setCurrentRegionUnreachable();
    340     RecordNextStmtCount = true;
    341   }
    342 
    343   void VisitContinueStmt(const ContinueStmt *S) {
    344     RecordStmtCount(S);
    345     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
    346     BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
    347     PGO.setCurrentRegionUnreachable();
    348     RecordNextStmtCount = true;
    349   }
    350 
    351   void VisitWhileStmt(const WhileStmt *S) {
    352     RecordStmtCount(S);
    353     // Counter tracks the body of the loop.
    354     RegionCounter Cnt(PGO, S);
    355     BreakContinueStack.push_back(BreakContinue());
    356     // Visit the body region first so the break/continue adjustments can be
    357     // included when visiting the condition.
    358     Cnt.beginRegion();
    359     CountMap[S->getBody()] = PGO.getCurrentRegionCount();
    360     Visit(S->getBody());
    361     Cnt.adjustForControlFlow();
    362 
    363     // ...then go back and propagate counts through the condition. The count
    364     // at the start of the condition is the sum of the incoming edges,
    365     // the backedge from the end of the loop body, and the edges from
    366     // continue statements.
    367     BreakContinue BC = BreakContinueStack.pop_back_val();
    368     Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() +
    369                               BC.ContinueCount);
    370     CountMap[S->getCond()] = PGO.getCurrentRegionCount();
    371     Visit(S->getCond());
    372     Cnt.adjustForControlFlow();
    373     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    374     RecordNextStmtCount = true;
    375   }
    376 
    377   void VisitDoStmt(const DoStmt *S) {
    378     RecordStmtCount(S);
    379     // Counter tracks the body of the loop.
    380     RegionCounter Cnt(PGO, S);
    381     BreakContinueStack.push_back(BreakContinue());
    382     Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
    383     CountMap[S->getBody()] = PGO.getCurrentRegionCount();
    384     Visit(S->getBody());
    385     Cnt.adjustForControlFlow();
    386 
    387     BreakContinue BC = BreakContinueStack.pop_back_val();
    388     // The count at the start of the condition is equal to the count at the
    389     // end of the body. The adjusted count does not include either the
    390     // fall-through count coming into the loop or the continue count, so add
    391     // both of those separately. This is coincidentally the same equation as
    392     // with while loops but for different reasons.
    393     Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() +
    394                               BC.ContinueCount);
    395     CountMap[S->getCond()] = PGO.getCurrentRegionCount();
    396     Visit(S->getCond());
    397     Cnt.adjustForControlFlow();
    398     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    399     RecordNextStmtCount = true;
    400   }
    401 
    402   void VisitForStmt(const ForStmt *S) {
    403     RecordStmtCount(S);
    404     if (S->getInit())
    405       Visit(S->getInit());
    406     // Counter tracks the body of the loop.
    407     RegionCounter Cnt(PGO, S);
    408     BreakContinueStack.push_back(BreakContinue());
    409     // Visit the body region first. (This is basically the same as a while
    410     // loop; see further comments in VisitWhileStmt.)
    411     Cnt.beginRegion();
    412     CountMap[S->getBody()] = PGO.getCurrentRegionCount();
    413     Visit(S->getBody());
    414     Cnt.adjustForControlFlow();
    415 
    416     // The increment is essentially part of the body but it needs to include
    417     // the count for all the continue statements.
    418     if (S->getInc()) {
    419       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
    420                                 BreakContinueStack.back().ContinueCount);
    421       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
    422       Visit(S->getInc());
    423       Cnt.adjustForControlFlow();
    424     }
    425 
    426     BreakContinue BC = BreakContinueStack.pop_back_val();
    427 
    428     // ...then go back and propagate counts through the condition.
    429     if (S->getCond()) {
    430       Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() +
    431                                 BC.ContinueCount);
    432       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
    433       Visit(S->getCond());
    434       Cnt.adjustForControlFlow();
    435     }
    436     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    437     RecordNextStmtCount = true;
    438   }
    439 
    440   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
    441     RecordStmtCount(S);
    442     Visit(S->getRangeStmt());
    443     Visit(S->getBeginEndStmt());
    444     // Counter tracks the body of the loop.
    445     RegionCounter Cnt(PGO, S);
    446     BreakContinueStack.push_back(BreakContinue());
    447     // Visit the body region first. (This is basically the same as a while
    448     // loop; see further comments in VisitWhileStmt.)
    449     Cnt.beginRegion();
    450     CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
    451     Visit(S->getLoopVarStmt());
    452     Visit(S->getBody());
    453     Cnt.adjustForControlFlow();
    454 
    455     // The increment is essentially part of the body but it needs to include
    456     // the count for all the continue statements.
    457     Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
    458                               BreakContinueStack.back().ContinueCount);
    459     CountMap[S->getInc()] = PGO.getCurrentRegionCount();
    460     Visit(S->getInc());
    461     Cnt.adjustForControlFlow();
    462 
    463     BreakContinue BC = BreakContinueStack.pop_back_val();
    464 
    465     // ...then go back and propagate counts through the condition.
    466     Cnt.setCurrentRegionCount(Cnt.getParentCount() + Cnt.getAdjustedCount() +
    467                               BC.ContinueCount);
    468     CountMap[S->getCond()] = PGO.getCurrentRegionCount();
    469     Visit(S->getCond());
    470     Cnt.adjustForControlFlow();
    471     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    472     RecordNextStmtCount = true;
    473   }
    474 
    475   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
    476     RecordStmtCount(S);
    477     Visit(S->getElement());
    478     // Counter tracks the body of the loop.
    479     RegionCounter Cnt(PGO, S);
    480     BreakContinueStack.push_back(BreakContinue());
    481     Cnt.beginRegion();
    482     CountMap[S->getBody()] = PGO.getCurrentRegionCount();
    483     Visit(S->getBody());
    484     BreakContinue BC = BreakContinueStack.pop_back_val();
    485     Cnt.adjustForControlFlow();
    486     Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    487     RecordNextStmtCount = true;
    488   }
    489 
    490   void VisitSwitchStmt(const SwitchStmt *S) {
    491     RecordStmtCount(S);
    492     Visit(S->getCond());
    493     PGO.setCurrentRegionUnreachable();
    494     BreakContinueStack.push_back(BreakContinue());
    495     Visit(S->getBody());
    496     // If the switch is inside a loop, add the continue counts.
    497     BreakContinue BC = BreakContinueStack.pop_back_val();
    498     if (!BreakContinueStack.empty())
    499       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
    500     // Counter tracks the exit block of the switch.
    501     RegionCounter ExitCnt(PGO, S);
    502     ExitCnt.beginRegion();
    503     RecordNextStmtCount = true;
    504   }
    505 
    506   void VisitCaseStmt(const CaseStmt *S) {
    507     RecordNextStmtCount = false;
    508     // Counter for this particular case. This counts only jumps from the
    509     // switch header and does not include fallthrough from the case before
    510     // this one.
    511     RegionCounter Cnt(PGO, S);
    512     Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
    513     CountMap[S] = Cnt.getCount();
    514     RecordNextStmtCount = true;
    515     Visit(S->getSubStmt());
    516   }
    517 
    518   void VisitDefaultStmt(const DefaultStmt *S) {
    519     RecordNextStmtCount = false;
    520     // Counter for this default case. This does not include fallthrough from
    521     // the previous case.
    522     RegionCounter Cnt(PGO, S);
    523     Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
    524     CountMap[S] = Cnt.getCount();
    525     RecordNextStmtCount = true;
    526     Visit(S->getSubStmt());
    527   }
    528 
    529   void VisitIfStmt(const IfStmt *S) {
    530     RecordStmtCount(S);
    531     // Counter tracks the "then" part of an if statement. The count for
    532     // the "else" part, if it exists, will be calculated from this counter.
    533     RegionCounter Cnt(PGO, S);
    534     Visit(S->getCond());
    535 
    536     Cnt.beginRegion();
    537     CountMap[S->getThen()] = PGO.getCurrentRegionCount();
    538     Visit(S->getThen());
    539     Cnt.adjustForControlFlow();
    540 
    541     if (S->getElse()) {
    542       Cnt.beginElseRegion();
    543       CountMap[S->getElse()] = PGO.getCurrentRegionCount();
    544       Visit(S->getElse());
    545       Cnt.adjustForControlFlow();
    546     }
    547     Cnt.applyAdjustmentsToRegion(0);
    548     RecordNextStmtCount = true;
    549   }
    550 
    551   void VisitCXXTryStmt(const CXXTryStmt *S) {
    552     RecordStmtCount(S);
    553     Visit(S->getTryBlock());
    554     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
    555       Visit(S->getHandler(I));
    556     // Counter tracks the continuation block of the try statement.
    557     RegionCounter Cnt(PGO, S);
    558     Cnt.beginRegion();
    559     RecordNextStmtCount = true;
    560   }
    561 
    562   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
    563     RecordNextStmtCount = false;
    564     // Counter tracks the catch statement's handler block.
    565     RegionCounter Cnt(PGO, S);
    566     Cnt.beginRegion();
    567     CountMap[S] = PGO.getCurrentRegionCount();
    568     Visit(S->getHandlerBlock());
    569   }
    570 
    571   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
    572     RecordStmtCount(E);
    573     // Counter tracks the "true" part of a conditional operator. The
    574     // count in the "false" part will be calculated from this counter.
    575     RegionCounter Cnt(PGO, E);
    576     Visit(E->getCond());
    577 
    578     Cnt.beginRegion();
    579     CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
    580     Visit(E->getTrueExpr());
    581     Cnt.adjustForControlFlow();
    582 
    583     Cnt.beginElseRegion();
    584     CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
    585     Visit(E->getFalseExpr());
    586     Cnt.adjustForControlFlow();
    587 
    588     Cnt.applyAdjustmentsToRegion(0);
    589     RecordNextStmtCount = true;
    590   }
    591 
    592   void VisitBinLAnd(const BinaryOperator *E) {
    593     RecordStmtCount(E);
    594     // Counter tracks the right hand side of a logical and operator.
    595     RegionCounter Cnt(PGO, E);
    596     Visit(E->getLHS());
    597     Cnt.beginRegion();
    598     CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
    599     Visit(E->getRHS());
    600     Cnt.adjustForControlFlow();
    601     Cnt.applyAdjustmentsToRegion(0);
    602     RecordNextStmtCount = true;
    603   }
    604 
    605   void VisitBinLOr(const BinaryOperator *E) {
    606     RecordStmtCount(E);
    607     // Counter tracks the right hand side of a logical or operator.
    608     RegionCounter Cnt(PGO, E);
    609     Visit(E->getLHS());
    610     Cnt.beginRegion();
    611     CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
    612     Visit(E->getRHS());
    613     Cnt.adjustForControlFlow();
    614     Cnt.applyAdjustmentsToRegion(0);
    615     RecordNextStmtCount = true;
    616   }
    617 };
    618 }
    619 
    620 void PGOHash::combine(HashType Type) {
    621   // Check that we never combine 0 and only have six bits.
    622   assert(Type && "Hash is invalid: unexpected type 0");
    623   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
    624 
    625   // Pass through MD5 if enough work has built up.
    626   if (Count && Count % NumTypesPerWord == 0) {
    627     using namespace llvm::support;
    628     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
    629     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
    630     Working = 0;
    631   }
    632 
    633   // Accumulate the current type.
    634   ++Count;
    635   Working = Working << NumBitsPerType | Type;
    636 }
    637 
    638 uint64_t PGOHash::finalize() {
    639   // Use Working as the hash directly if we never used MD5.
    640   if (Count <= NumTypesPerWord)
    641     // No need to byte swap here, since none of the math was endian-dependent.
    642     // This number will be byte-swapped as required on endianness transitions,
    643     // so we will see the same value on the other side.
    644     return Working;
    645 
    646   // Check for remaining work in Working.
    647   if (Working)
    648     MD5.update(Working);
    649 
    650   // Finalize the MD5 and return the hash.
    651   llvm::MD5::MD5Result Result;
    652   MD5.final(Result);
    653   using namespace llvm::support;
    654   return endian::read<uint64_t, little, unaligned>(Result);
    655 }
    656 
    657 void CodeGenPGO::checkGlobalDecl(GlobalDecl GD) {
    658   // Make sure we only emit coverage mapping for one constructor/destructor.
    659   // Clang emits several functions for the constructor and the destructor of
    660   // a class. Every function is instrumented, but we only want to provide
    661   // coverage for one of them. Because of that we only emit the coverage mapping
    662   // for the base constructor/destructor.
    663   if ((isa<CXXConstructorDecl>(GD.getDecl()) &&
    664        GD.getCtorType() != Ctor_Base) ||
    665       (isa<CXXDestructorDecl>(GD.getDecl()) &&
    666        GD.getDtorType() != Dtor_Base)) {
    667     SkipCoverageMapping = true;
    668   }
    669 }
    670 
    671 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
    672   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
    673   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
    674   if (!InstrumentRegions && !PGOReader)
    675     return;
    676   if (D->isImplicit())
    677     return;
    678   CGM.ClearUnusedCoverageMapping(D);
    679   setFuncName(Fn);
    680 
    681   mapRegionCounters(D);
    682   if (CGM.getCodeGenOpts().CoverageMapping)
    683     emitCounterRegionMapping(D);
    684   if (PGOReader) {
    685     SourceManager &SM = CGM.getContext().getSourceManager();
    686     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
    687     computeRegionCounts(D);
    688     applyFunctionAttributes(PGOReader, Fn);
    689   }
    690 }
    691 
    692 void CodeGenPGO::mapRegionCounters(const Decl *D) {
    693   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
    694   MapRegionCounters Walker(*RegionCounterMap);
    695   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
    696     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
    697   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
    698     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
    699   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
    700     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
    701   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
    702     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
    703   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
    704   NumRegionCounters = Walker.NextCounter;
    705   FunctionHash = Walker.Hash.finalize();
    706 }
    707 
    708 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
    709   if (SkipCoverageMapping)
    710     return;
    711   // Don't map the functions inside the system headers
    712   auto Loc = D->getBody()->getLocStart();
    713   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
    714     return;
    715 
    716   std::string CoverageMapping;
    717   llvm::raw_string_ostream OS(CoverageMapping);
    718   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
    719                                 CGM.getContext().getSourceManager(),
    720                                 CGM.getLangOpts(), RegionCounterMap.get());
    721   MappingGen.emitCounterMapping(D, OS);
    722   OS.flush();
    723 
    724   if (CoverageMapping.empty())
    725     return;
    726 
    727   CGM.getCoverageMapping()->addFunctionMappingRecord(
    728       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
    729 }
    730 
    731 void
    732 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef FuncName,
    733                                     llvm::GlobalValue::LinkageTypes Linkage) {
    734   if (SkipCoverageMapping)
    735     return;
    736   // Don't map the functions inside the system headers
    737   auto Loc = D->getBody()->getLocStart();
    738   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
    739     return;
    740 
    741   std::string CoverageMapping;
    742   llvm::raw_string_ostream OS(CoverageMapping);
    743   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
    744                                 CGM.getContext().getSourceManager(),
    745                                 CGM.getLangOpts());
    746   MappingGen.emitEmptyMapping(D, OS);
    747   OS.flush();
    748 
    749   if (CoverageMapping.empty())
    750     return;
    751 
    752   setFuncName(FuncName, Linkage);
    753   CGM.getCoverageMapping()->addFunctionMappingRecord(
    754       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
    755 }
    756 
    757 void CodeGenPGO::computeRegionCounts(const Decl *D) {
    758   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
    759   ComputeRegionCounts Walker(*StmtCountMap, *this);
    760   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
    761     Walker.VisitFunctionDecl(FD);
    762   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
    763     Walker.VisitObjCMethodDecl(MD);
    764   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
    765     Walker.VisitBlockDecl(BD);
    766   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
    767     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
    768 }
    769 
    770 void
    771 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
    772                                     llvm::Function *Fn) {
    773   if (!haveRegionCounts())
    774     return;
    775 
    776   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
    777   uint64_t FunctionCount = getRegionCount(0);
    778   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
    779     // Turn on InlineHint attribute for hot functions.
    780     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
    781     Fn->addFnAttr(llvm::Attribute::InlineHint);
    782   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
    783     // Turn on Cold attribute for cold functions.
    784     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
    785     Fn->addFnAttr(llvm::Attribute::Cold);
    786 }
    787 
    788 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
    789   if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
    790     return;
    791   if (!Builder.GetInsertPoint())
    792     return;
    793   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
    794   Builder.CreateCall4(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
    795                       llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
    796                       Builder.getInt64(FunctionHash),
    797                       Builder.getInt32(NumRegionCounters),
    798                       Builder.getInt32(Counter));
    799 }
    800 
    801 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
    802                                   bool IsInMainFile) {
    803   CGM.getPGOStats().addVisited(IsInMainFile);
    804   RegionCounts.clear();
    805   if (std::error_code EC =
    806           PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
    807     if (EC == llvm::instrprof_error::unknown_function)
    808       CGM.getPGOStats().addMissing(IsInMainFile);
    809     else if (EC == llvm::instrprof_error::hash_mismatch)
    810       CGM.getPGOStats().addMismatched(IsInMainFile);
    811     else if (EC == llvm::instrprof_error::malformed)
    812       // TODO: Consider a more specific warning for this case.
    813       CGM.getPGOStats().addMismatched(IsInMainFile);
    814     RegionCounts.clear();
    815   }
    816 }
    817 
    818 /// \brief Calculate what to divide by to scale weights.
    819 ///
    820 /// Given the maximum weight, calculate a divisor that will scale all the
    821 /// weights to strictly less than UINT32_MAX.
    822 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
    823   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
    824 }
    825 
    826 /// \brief Scale an individual branch weight (and add 1).
    827 ///
    828 /// Scale a 64-bit weight down to 32-bits using \c Scale.
    829 ///
    830 /// According to Laplace's Rule of Succession, it is better to compute the
    831 /// weight based on the count plus 1, so universally add 1 to the value.
    832 ///
    833 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
    834 /// greater than \c Weight.
    835 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
    836   assert(Scale && "scale by 0?");
    837   uint64_t Scaled = Weight / Scale + 1;
    838   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
    839   return Scaled;
    840 }
    841 
    842 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
    843                                               uint64_t FalseCount) {
    844   // Check for empty weights.
    845   if (!TrueCount && !FalseCount)
    846     return nullptr;
    847 
    848   // Calculate how to scale down to 32-bits.
    849   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
    850 
    851   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
    852   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
    853                                       scaleBranchWeight(FalseCount, Scale));
    854 }
    855 
    856 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
    857   // We need at least two elements to create meaningful weights.
    858   if (Weights.size() < 2)
    859     return nullptr;
    860 
    861   // Check for empty weights.
    862   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
    863   if (MaxWeight == 0)
    864     return nullptr;
    865 
    866   // Calculate how to scale down to 32-bits.
    867   uint64_t Scale = calculateWeightScale(MaxWeight);
    868 
    869   SmallVector<uint32_t, 16> ScaledWeights;
    870   ScaledWeights.reserve(Weights.size());
    871   for (uint64_t W : Weights)
    872     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
    873 
    874   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
    875   return MDHelper.createBranchWeights(ScaledWeights);
    876 }
    877 
    878 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
    879                                             RegionCounter &Cnt) {
    880   if (!haveRegionCounts())
    881     return nullptr;
    882   uint64_t LoopCount = Cnt.getCount();
    883   Optional<uint64_t> CondCount = getStmtCount(Cond);
    884   assert(CondCount.hasValue() && "missing expected loop condition count");
    885   if (*CondCount == 0)
    886     return nullptr;
    887   return createBranchWeights(LoopCount,
    888                              std::max(*CondCount, LoopCount) - LoopCount);
    889 }
    890