Home | History | Annotate | Download | only in CodeGen
      1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // Instrumentation-based profile-guided optimization
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "CodeGenPGO.h"
     15 #include "CodeGenFunction.h"
     16 #include "clang/AST/RecursiveASTVisitor.h"
     17 #include "clang/AST/StmtVisitor.h"
     18 #include "llvm/IR/MDBuilder.h"
     19 #include "llvm/ProfileData/InstrProfReader.h"
     20 #include "llvm/Support/Endian.h"
     21 #include "llvm/Support/FileSystem.h"
     22 #include "llvm/Support/MD5.h"
     23 
     24 using namespace clang;
     25 using namespace CodeGen;
     26 
     27 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
     28   RawFuncName = Fn->getName();
     29 
     30   // Function names may be prefixed with a binary '1' to indicate
     31   // that the backend should not modify the symbols due to any platform
     32   // naming convention. Do not include that '1' in the PGO profile name.
     33   if (RawFuncName[0] == '\1')
     34     RawFuncName = RawFuncName.substr(1);
     35 
     36   if (!Fn->hasLocalLinkage()) {
     37     PrefixedFuncName.reset(new std::string(RawFuncName));
     38     return;
     39   }
     40 
     41   // For local symbols, prepend the main file name to distinguish them.
     42   // Do not include the full path in the file name since there's no guarantee
     43   // that it will stay the same, e.g., if the files are checked out from
     44   // version control in different locations.
     45   PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName));
     46   if (PrefixedFuncName->empty())
     47     PrefixedFuncName->assign("<unknown>");
     48   PrefixedFuncName->append(":");
     49   PrefixedFuncName->append(RawFuncName);
     50 }
     51 
     52 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) {
     53   return CGM.getModule().getFunction("__llvm_profile_register_functions");
     54 }
     55 
     56 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) {
     57   // Don't do this for Darwin.  compiler-rt uses linker magic.
     58   if (CGM.getTarget().getTriple().isOSDarwin())
     59     return nullptr;
     60 
     61   // Only need to insert this once per module.
     62   if (llvm::Function *RegisterF = getRegisterFunc(CGM))
     63     return &RegisterF->getEntryBlock();
     64 
     65   // Construct the function.
     66   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
     67   auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false);
     68   auto *RegisterF = llvm::Function::Create(RegisterFTy,
     69                                            llvm::GlobalValue::InternalLinkage,
     70                                            "__llvm_profile_register_functions",
     71                                            &CGM.getModule());
     72   RegisterF->setUnnamedAddr(true);
     73   if (CGM.getCodeGenOpts().DisableRedZone)
     74     RegisterF->addFnAttr(llvm::Attribute::NoRedZone);
     75 
     76   // Construct and return the entry block.
     77   auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF);
     78   CGBuilderTy Builder(BB);
     79   Builder.CreateRetVoid();
     80   return BB;
     81 }
     82 
     83 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) {
     84   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
     85   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
     86   auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false);
     87   return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function",
     88                                              RuntimeRegisterTy);
     89 }
     90 
     91 static bool isMachO(const CodeGenModule &CGM) {
     92   return CGM.getTarget().getTriple().isOSBinFormatMachO();
     93 }
     94 
     95 static StringRef getCountersSection(const CodeGenModule &CGM) {
     96   return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts";
     97 }
     98 
     99 static StringRef getNameSection(const CodeGenModule &CGM) {
    100   return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names";
    101 }
    102 
    103 static StringRef getDataSection(const CodeGenModule &CGM) {
    104   return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data";
    105 }
    106 
    107 llvm::GlobalVariable *CodeGenPGO::buildDataVar() {
    108   // Create name variable.
    109   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
    110   auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(),
    111                                                      false);
    112   auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(),
    113                                         true, VarLinkage, VarName,
    114                                         getFuncVarName("name"));
    115   Name->setSection(getNameSection(CGM));
    116   Name->setAlignment(1);
    117 
    118   // Create data variable.
    119   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
    120   auto *Int64Ty = llvm::Type::getInt64Ty(Ctx);
    121   auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
    122   auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
    123   llvm::Type *DataTypes[] = {
    124     Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy
    125   };
    126   auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes));
    127   llvm::Constant *DataVals[] = {
    128     llvm::ConstantInt::get(Int32Ty, getFuncName().size()),
    129     llvm::ConstantInt::get(Int32Ty, NumRegionCounters),
    130     llvm::ConstantInt::get(Int64Ty, FunctionHash),
    131     llvm::ConstantExpr::getBitCast(Name, Int8PtrTy),
    132     llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy)
    133   };
    134   auto *Data =
    135     new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage,
    136                              llvm::ConstantStruct::get(DataTy, DataVals),
    137                              getFuncVarName("data"));
    138 
    139   // All the data should be packed into an array in its own section.
    140   Data->setSection(getDataSection(CGM));
    141   Data->setAlignment(8);
    142 
    143   // Hide all these symbols so that we correctly get a copy for each
    144   // executable.  The profile format expects names and counters to be
    145   // contiguous, so references into shared objects would be invalid.
    146   if (!llvm::GlobalValue::isLocalLinkage(VarLinkage)) {
    147     Name->setVisibility(llvm::GlobalValue::HiddenVisibility);
    148     Data->setVisibility(llvm::GlobalValue::HiddenVisibility);
    149     RegionCounters->setVisibility(llvm::GlobalValue::HiddenVisibility);
    150   }
    151 
    152   // Make sure the data doesn't get deleted.
    153   CGM.addUsedGlobal(Data);
    154   return Data;
    155 }
    156 
    157 void CodeGenPGO::emitInstrumentationData() {
    158   if (!RegionCounters)
    159     return;
    160 
    161   // Build the data.
    162   auto *Data = buildDataVar();
    163 
    164   // Register the data.
    165   auto *RegisterBB = getOrInsertRegisterBB(CGM);
    166   if (!RegisterBB)
    167     return;
    168   CGBuilderTy Builder(RegisterBB->getTerminator());
    169   auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
    170   Builder.CreateCall(getOrInsertRuntimeRegister(CGM),
    171                      Builder.CreateBitCast(Data, VoidPtrTy));
    172 }
    173 
    174 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
    175   if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
    176     return nullptr;
    177 
    178   assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr &&
    179          "profile initialization already emitted");
    180 
    181   // Get the function to call at initialization.
    182   llvm::Constant *RegisterF = getRegisterFunc(CGM);
    183   if (!RegisterF)
    184     return nullptr;
    185 
    186   // Create the initialization function.
    187   auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
    188   auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false),
    189                                    llvm::GlobalValue::InternalLinkage,
    190                                    "__llvm_profile_init", &CGM.getModule());
    191   F->setUnnamedAddr(true);
    192   F->addFnAttr(llvm::Attribute::NoInline);
    193   if (CGM.getCodeGenOpts().DisableRedZone)
    194     F->addFnAttr(llvm::Attribute::NoRedZone);
    195 
    196   // Add the basic block and the necessary calls.
    197   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F));
    198   Builder.CreateCall(RegisterF);
    199   Builder.CreateRetVoid();
    200 
    201   return F;
    202 }
    203 
    204 namespace {
    205 /// \brief Stable hasher for PGO region counters.
    206 ///
    207 /// PGOHash produces a stable hash of a given function's control flow.
    208 ///
    209 /// Changing the output of this hash will invalidate all previously generated
    210 /// profiles -- i.e., don't do it.
    211 ///
    212 /// \note  When this hash does eventually change (years?), we still need to
    213 /// support old hashes.  We'll need to pull in the version number from the
    214 /// profile data format and use the matching hash function.
    215 class PGOHash {
    216   uint64_t Working;
    217   unsigned Count;
    218   llvm::MD5 MD5;
    219 
    220   static const int NumBitsPerType = 6;
    221   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
    222   static const unsigned TooBig = 1u << NumBitsPerType;
    223 
    224 public:
    225   /// \brief Hash values for AST nodes.
    226   ///
    227   /// Distinct values for AST nodes that have region counters attached.
    228   ///
    229   /// These values must be stable.  All new members must be added at the end,
    230   /// and no members should be removed.  Changing the enumeration value for an
    231   /// AST node will affect the hash of every function that contains that node.
    232   enum HashType : unsigned char {
    233     None = 0,
    234     LabelStmt = 1,
    235     WhileStmt,
    236     DoStmt,
    237     ForStmt,
    238     CXXForRangeStmt,
    239     ObjCForCollectionStmt,
    240     SwitchStmt,
    241     CaseStmt,
    242     DefaultStmt,
    243     IfStmt,
    244     CXXTryStmt,
    245     CXXCatchStmt,
    246     ConditionalOperator,
    247     BinaryOperatorLAnd,
    248     BinaryOperatorLOr,
    249     BinaryConditionalOperator,
    250 
    251     // Keep this last.  It's for the static assert that follows.
    252     LastHashType
    253   };
    254   static_assert(LastHashType <= TooBig, "Too many types in HashType");
    255 
    256   // TODO: When this format changes, take in a version number here, and use the
    257   // old hash calculation for file formats that used the old hash.
    258   PGOHash() : Working(0), Count(0) {}
    259   void combine(HashType Type);
    260   uint64_t finalize();
    261 };
    262 const int PGOHash::NumBitsPerType;
    263 const unsigned PGOHash::NumTypesPerWord;
    264 const unsigned PGOHash::TooBig;
    265 
    266   /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
    267   struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
    268     /// The next counter value to assign.
    269     unsigned NextCounter;
    270     /// The function hash.
    271     PGOHash Hash;
    272     /// The map of statements to counters.
    273     llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
    274 
    275     MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
    276         : NextCounter(0), CounterMap(CounterMap) {}
    277 
    278     // Blocks and lambdas are handled as separate functions, so we need not
    279     // traverse them in the parent context.
    280     bool TraverseBlockExpr(BlockExpr *BE) { return true; }
    281     bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
    282     bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
    283 
    284     bool VisitDecl(const Decl *D) {
    285       switch (D->getKind()) {
    286       default:
    287         break;
    288       case Decl::Function:
    289       case Decl::CXXMethod:
    290       case Decl::CXXConstructor:
    291       case Decl::CXXDestructor:
    292       case Decl::CXXConversion:
    293       case Decl::ObjCMethod:
    294       case Decl::Block:
    295       case Decl::Captured:
    296         CounterMap[D->getBody()] = NextCounter++;
    297         break;
    298       }
    299       return true;
    300     }
    301 
    302     bool VisitStmt(const Stmt *S) {
    303       auto Type = getHashType(S);
    304       if (Type == PGOHash::None)
    305         return true;
    306 
    307       CounterMap[S] = NextCounter++;
    308       Hash.combine(Type);
    309       return true;
    310     }
    311     PGOHash::HashType getHashType(const Stmt *S) {
    312       switch (S->getStmtClass()) {
    313       default:
    314         break;
    315       case Stmt::LabelStmtClass:
    316         return PGOHash::LabelStmt;
    317       case Stmt::WhileStmtClass:
    318         return PGOHash::WhileStmt;
    319       case Stmt::DoStmtClass:
    320         return PGOHash::DoStmt;
    321       case Stmt::ForStmtClass:
    322         return PGOHash::ForStmt;
    323       case Stmt::CXXForRangeStmtClass:
    324         return PGOHash::CXXForRangeStmt;
    325       case Stmt::ObjCForCollectionStmtClass:
    326         return PGOHash::ObjCForCollectionStmt;
    327       case Stmt::SwitchStmtClass:
    328         return PGOHash::SwitchStmt;
    329       case Stmt::CaseStmtClass:
    330         return PGOHash::CaseStmt;
    331       case Stmt::DefaultStmtClass:
    332         return PGOHash::DefaultStmt;
    333       case Stmt::IfStmtClass:
    334         return PGOHash::IfStmt;
    335       case Stmt::CXXTryStmtClass:
    336         return PGOHash::CXXTryStmt;
    337       case Stmt::CXXCatchStmtClass:
    338         return PGOHash::CXXCatchStmt;
    339       case Stmt::ConditionalOperatorClass:
    340         return PGOHash::ConditionalOperator;
    341       case Stmt::BinaryConditionalOperatorClass:
    342         return PGOHash::BinaryConditionalOperator;
    343       case Stmt::BinaryOperatorClass: {
    344         const BinaryOperator *BO = cast<BinaryOperator>(S);
    345         if (BO->getOpcode() == BO_LAnd)
    346           return PGOHash::BinaryOperatorLAnd;
    347         if (BO->getOpcode() == BO_LOr)
    348           return PGOHash::BinaryOperatorLOr;
    349         break;
    350       }
    351       }
    352       return PGOHash::None;
    353     }
    354   };
    355 
    356   /// A StmtVisitor that propagates the raw counts through the AST and
    357   /// records the count at statements where the value may change.
    358   struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
    359     /// PGO state.
    360     CodeGenPGO &PGO;
    361 
    362     /// A flag that is set when the current count should be recorded on the
    363     /// next statement, such as at the exit of a loop.
    364     bool RecordNextStmtCount;
    365 
    366     /// The map of statements to count values.
    367     llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
    368 
    369     /// BreakContinueStack - Keep counts of breaks and continues inside loops.
    370     struct BreakContinue {
    371       uint64_t BreakCount;
    372       uint64_t ContinueCount;
    373       BreakContinue() : BreakCount(0), ContinueCount(0) {}
    374     };
    375     SmallVector<BreakContinue, 8> BreakContinueStack;
    376 
    377     ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
    378                         CodeGenPGO &PGO)
    379         : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
    380 
    381     void RecordStmtCount(const Stmt *S) {
    382       if (RecordNextStmtCount) {
    383         CountMap[S] = PGO.getCurrentRegionCount();
    384         RecordNextStmtCount = false;
    385       }
    386     }
    387 
    388     void VisitStmt(const Stmt *S) {
    389       RecordStmtCount(S);
    390       for (Stmt::const_child_range I = S->children(); I; ++I) {
    391         if (*I)
    392          this->Visit(*I);
    393       }
    394     }
    395 
    396     void VisitFunctionDecl(const FunctionDecl *D) {
    397       // Counter tracks entry to the function body.
    398       RegionCounter Cnt(PGO, D->getBody());
    399       Cnt.beginRegion();
    400       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
    401       Visit(D->getBody());
    402     }
    403 
    404     // Skip lambda expressions. We visit these as FunctionDecls when we're
    405     // generating them and aren't interested in the body when generating a
    406     // parent context.
    407     void VisitLambdaExpr(const LambdaExpr *LE) {}
    408 
    409     void VisitCapturedDecl(const CapturedDecl *D) {
    410       // Counter tracks entry to the capture body.
    411       RegionCounter Cnt(PGO, D->getBody());
    412       Cnt.beginRegion();
    413       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
    414       Visit(D->getBody());
    415     }
    416 
    417     void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
    418       // Counter tracks entry to the method body.
    419       RegionCounter Cnt(PGO, D->getBody());
    420       Cnt.beginRegion();
    421       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
    422       Visit(D->getBody());
    423     }
    424 
    425     void VisitBlockDecl(const BlockDecl *D) {
    426       // Counter tracks entry to the block body.
    427       RegionCounter Cnt(PGO, D->getBody());
    428       Cnt.beginRegion();
    429       CountMap[D->getBody()] = PGO.getCurrentRegionCount();
    430       Visit(D->getBody());
    431     }
    432 
    433     void VisitReturnStmt(const ReturnStmt *S) {
    434       RecordStmtCount(S);
    435       if (S->getRetValue())
    436         Visit(S->getRetValue());
    437       PGO.setCurrentRegionUnreachable();
    438       RecordNextStmtCount = true;
    439     }
    440 
    441     void VisitGotoStmt(const GotoStmt *S) {
    442       RecordStmtCount(S);
    443       PGO.setCurrentRegionUnreachable();
    444       RecordNextStmtCount = true;
    445     }
    446 
    447     void VisitLabelStmt(const LabelStmt *S) {
    448       RecordNextStmtCount = false;
    449       // Counter tracks the block following the label.
    450       RegionCounter Cnt(PGO, S);
    451       Cnt.beginRegion();
    452       CountMap[S] = PGO.getCurrentRegionCount();
    453       Visit(S->getSubStmt());
    454     }
    455 
    456     void VisitBreakStmt(const BreakStmt *S) {
    457       RecordStmtCount(S);
    458       assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
    459       BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
    460       PGO.setCurrentRegionUnreachable();
    461       RecordNextStmtCount = true;
    462     }
    463 
    464     void VisitContinueStmt(const ContinueStmt *S) {
    465       RecordStmtCount(S);
    466       assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
    467       BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
    468       PGO.setCurrentRegionUnreachable();
    469       RecordNextStmtCount = true;
    470     }
    471 
    472     void VisitWhileStmt(const WhileStmt *S) {
    473       RecordStmtCount(S);
    474       // Counter tracks the body of the loop.
    475       RegionCounter Cnt(PGO, S);
    476       BreakContinueStack.push_back(BreakContinue());
    477       // Visit the body region first so the break/continue adjustments can be
    478       // included when visiting the condition.
    479       Cnt.beginRegion();
    480       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
    481       Visit(S->getBody());
    482       Cnt.adjustForControlFlow();
    483 
    484       // ...then go back and propagate counts through the condition. The count
    485       // at the start of the condition is the sum of the incoming edges,
    486       // the backedge from the end of the loop body, and the edges from
    487       // continue statements.
    488       BreakContinue BC = BreakContinueStack.pop_back_val();
    489       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
    490                                 Cnt.getAdjustedCount() + BC.ContinueCount);
    491       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
    492       Visit(S->getCond());
    493       Cnt.adjustForControlFlow();
    494       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    495       RecordNextStmtCount = true;
    496     }
    497 
    498     void VisitDoStmt(const DoStmt *S) {
    499       RecordStmtCount(S);
    500       // Counter tracks the body of the loop.
    501       RegionCounter Cnt(PGO, S);
    502       BreakContinueStack.push_back(BreakContinue());
    503       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
    504       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
    505       Visit(S->getBody());
    506       Cnt.adjustForControlFlow();
    507 
    508       BreakContinue BC = BreakContinueStack.pop_back_val();
    509       // The count at the start of the condition is equal to the count at the
    510       // end of the body. The adjusted count does not include either the
    511       // fall-through count coming into the loop or the continue count, so add
    512       // both of those separately. This is coincidentally the same equation as
    513       // with while loops but for different reasons.
    514       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
    515                                 Cnt.getAdjustedCount() + BC.ContinueCount);
    516       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
    517       Visit(S->getCond());
    518       Cnt.adjustForControlFlow();
    519       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    520       RecordNextStmtCount = true;
    521     }
    522 
    523     void VisitForStmt(const ForStmt *S) {
    524       RecordStmtCount(S);
    525       if (S->getInit())
    526         Visit(S->getInit());
    527       // Counter tracks the body of the loop.
    528       RegionCounter Cnt(PGO, S);
    529       BreakContinueStack.push_back(BreakContinue());
    530       // Visit the body region first. (This is basically the same as a while
    531       // loop; see further comments in VisitWhileStmt.)
    532       Cnt.beginRegion();
    533       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
    534       Visit(S->getBody());
    535       Cnt.adjustForControlFlow();
    536 
    537       // The increment is essentially part of the body but it needs to include
    538       // the count for all the continue statements.
    539       if (S->getInc()) {
    540         Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
    541                                   BreakContinueStack.back().ContinueCount);
    542         CountMap[S->getInc()] = PGO.getCurrentRegionCount();
    543         Visit(S->getInc());
    544         Cnt.adjustForControlFlow();
    545       }
    546 
    547       BreakContinue BC = BreakContinueStack.pop_back_val();
    548 
    549       // ...then go back and propagate counts through the condition.
    550       if (S->getCond()) {
    551         Cnt.setCurrentRegionCount(Cnt.getParentCount() +
    552                                   Cnt.getAdjustedCount() +
    553                                   BC.ContinueCount);
    554         CountMap[S->getCond()] = PGO.getCurrentRegionCount();
    555         Visit(S->getCond());
    556         Cnt.adjustForControlFlow();
    557       }
    558       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    559       RecordNextStmtCount = true;
    560     }
    561 
    562     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
    563       RecordStmtCount(S);
    564       Visit(S->getRangeStmt());
    565       Visit(S->getBeginEndStmt());
    566       // Counter tracks the body of the loop.
    567       RegionCounter Cnt(PGO, S);
    568       BreakContinueStack.push_back(BreakContinue());
    569       // Visit the body region first. (This is basically the same as a while
    570       // loop; see further comments in VisitWhileStmt.)
    571       Cnt.beginRegion();
    572       CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
    573       Visit(S->getLoopVarStmt());
    574       Visit(S->getBody());
    575       Cnt.adjustForControlFlow();
    576 
    577       // The increment is essentially part of the body but it needs to include
    578       // the count for all the continue statements.
    579       Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
    580                                 BreakContinueStack.back().ContinueCount);
    581       CountMap[S->getInc()] = PGO.getCurrentRegionCount();
    582       Visit(S->getInc());
    583       Cnt.adjustForControlFlow();
    584 
    585       BreakContinue BC = BreakContinueStack.pop_back_val();
    586 
    587       // ...then go back and propagate counts through the condition.
    588       Cnt.setCurrentRegionCount(Cnt.getParentCount() +
    589                                 Cnt.getAdjustedCount() +
    590                                 BC.ContinueCount);
    591       CountMap[S->getCond()] = PGO.getCurrentRegionCount();
    592       Visit(S->getCond());
    593       Cnt.adjustForControlFlow();
    594       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    595       RecordNextStmtCount = true;
    596     }
    597 
    598     void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
    599       RecordStmtCount(S);
    600       Visit(S->getElement());
    601       // Counter tracks the body of the loop.
    602       RegionCounter Cnt(PGO, S);
    603       BreakContinueStack.push_back(BreakContinue());
    604       Cnt.beginRegion();
    605       CountMap[S->getBody()] = PGO.getCurrentRegionCount();
    606       Visit(S->getBody());
    607       BreakContinue BC = BreakContinueStack.pop_back_val();
    608       Cnt.adjustForControlFlow();
    609       Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
    610       RecordNextStmtCount = true;
    611     }
    612 
    613     void VisitSwitchStmt(const SwitchStmt *S) {
    614       RecordStmtCount(S);
    615       Visit(S->getCond());
    616       PGO.setCurrentRegionUnreachable();
    617       BreakContinueStack.push_back(BreakContinue());
    618       Visit(S->getBody());
    619       // If the switch is inside a loop, add the continue counts.
    620       BreakContinue BC = BreakContinueStack.pop_back_val();
    621       if (!BreakContinueStack.empty())
    622         BreakContinueStack.back().ContinueCount += BC.ContinueCount;
    623       // Counter tracks the exit block of the switch.
    624       RegionCounter ExitCnt(PGO, S);
    625       ExitCnt.beginRegion();
    626       RecordNextStmtCount = true;
    627     }
    628 
    629     void VisitCaseStmt(const CaseStmt *S) {
    630       RecordNextStmtCount = false;
    631       // Counter for this particular case. This counts only jumps from the
    632       // switch header and does not include fallthrough from the case before
    633       // this one.
    634       RegionCounter Cnt(PGO, S);
    635       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
    636       CountMap[S] = Cnt.getCount();
    637       RecordNextStmtCount = true;
    638       Visit(S->getSubStmt());
    639     }
    640 
    641     void VisitDefaultStmt(const DefaultStmt *S) {
    642       RecordNextStmtCount = false;
    643       // Counter for this default case. This does not include fallthrough from
    644       // the previous case.
    645       RegionCounter Cnt(PGO, S);
    646       Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
    647       CountMap[S] = Cnt.getCount();
    648       RecordNextStmtCount = true;
    649       Visit(S->getSubStmt());
    650     }
    651 
    652     void VisitIfStmt(const IfStmt *S) {
    653       RecordStmtCount(S);
    654       // Counter tracks the "then" part of an if statement. The count for
    655       // the "else" part, if it exists, will be calculated from this counter.
    656       RegionCounter Cnt(PGO, S);
    657       Visit(S->getCond());
    658 
    659       Cnt.beginRegion();
    660       CountMap[S->getThen()] = PGO.getCurrentRegionCount();
    661       Visit(S->getThen());
    662       Cnt.adjustForControlFlow();
    663 
    664       if (S->getElse()) {
    665         Cnt.beginElseRegion();
    666         CountMap[S->getElse()] = PGO.getCurrentRegionCount();
    667         Visit(S->getElse());
    668         Cnt.adjustForControlFlow();
    669       }
    670       Cnt.applyAdjustmentsToRegion(0);
    671       RecordNextStmtCount = true;
    672     }
    673 
    674     void VisitCXXTryStmt(const CXXTryStmt *S) {
    675       RecordStmtCount(S);
    676       Visit(S->getTryBlock());
    677       for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
    678         Visit(S->getHandler(I));
    679       // Counter tracks the continuation block of the try statement.
    680       RegionCounter Cnt(PGO, S);
    681       Cnt.beginRegion();
    682       RecordNextStmtCount = true;
    683     }
    684 
    685     void VisitCXXCatchStmt(const CXXCatchStmt *S) {
    686       RecordNextStmtCount = false;
    687       // Counter tracks the catch statement's handler block.
    688       RegionCounter Cnt(PGO, S);
    689       Cnt.beginRegion();
    690       CountMap[S] = PGO.getCurrentRegionCount();
    691       Visit(S->getHandlerBlock());
    692     }
    693 
    694     void VisitAbstractConditionalOperator(
    695         const AbstractConditionalOperator *E) {
    696       RecordStmtCount(E);
    697       // Counter tracks the "true" part of a conditional operator. The
    698       // count in the "false" part will be calculated from this counter.
    699       RegionCounter Cnt(PGO, E);
    700       Visit(E->getCond());
    701 
    702       Cnt.beginRegion();
    703       CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
    704       Visit(E->getTrueExpr());
    705       Cnt.adjustForControlFlow();
    706 
    707       Cnt.beginElseRegion();
    708       CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
    709       Visit(E->getFalseExpr());
    710       Cnt.adjustForControlFlow();
    711 
    712       Cnt.applyAdjustmentsToRegion(0);
    713       RecordNextStmtCount = true;
    714     }
    715 
    716     void VisitBinLAnd(const BinaryOperator *E) {
    717       RecordStmtCount(E);
    718       // Counter tracks the right hand side of a logical and operator.
    719       RegionCounter Cnt(PGO, E);
    720       Visit(E->getLHS());
    721       Cnt.beginRegion();
    722       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
    723       Visit(E->getRHS());
    724       Cnt.adjustForControlFlow();
    725       Cnt.applyAdjustmentsToRegion(0);
    726       RecordNextStmtCount = true;
    727     }
    728 
    729     void VisitBinLOr(const BinaryOperator *E) {
    730       RecordStmtCount(E);
    731       // Counter tracks the right hand side of a logical or operator.
    732       RegionCounter Cnt(PGO, E);
    733       Visit(E->getLHS());
    734       Cnt.beginRegion();
    735       CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
    736       Visit(E->getRHS());
    737       Cnt.adjustForControlFlow();
    738       Cnt.applyAdjustmentsToRegion(0);
    739       RecordNextStmtCount = true;
    740     }
    741   };
    742 }
    743 
    744 void PGOHash::combine(HashType Type) {
    745   // Check that we never combine 0 and only have six bits.
    746   assert(Type && "Hash is invalid: unexpected type 0");
    747   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
    748 
    749   // Pass through MD5 if enough work has built up.
    750   if (Count && Count % NumTypesPerWord == 0) {
    751     using namespace llvm::support;
    752     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
    753     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
    754     Working = 0;
    755   }
    756 
    757   // Accumulate the current type.
    758   ++Count;
    759   Working = Working << NumBitsPerType | Type;
    760 }
    761 
    762 uint64_t PGOHash::finalize() {
    763   // Use Working as the hash directly if we never used MD5.
    764   if (Count <= NumTypesPerWord)
    765     // No need to byte swap here, since none of the math was endian-dependent.
    766     // This number will be byte-swapped as required on endianness transitions,
    767     // so we will see the same value on the other side.
    768     return Working;
    769 
    770   // Check for remaining work in Working.
    771   if (Working)
    772     MD5.update(Working);
    773 
    774   // Finalize the MD5 and return the hash.
    775   llvm::MD5::MD5Result Result;
    776   MD5.final(Result);
    777   using namespace llvm::support;
    778   return endian::read<uint64_t, little, unaligned>(Result);
    779 }
    780 
    781 static void emitRuntimeHook(CodeGenModule &CGM) {
    782   const char *const RuntimeVarName = "__llvm_profile_runtime";
    783   const char *const RuntimeUserName = "__llvm_profile_runtime_user";
    784   if (CGM.getModule().getGlobalVariable(RuntimeVarName))
    785     return;
    786 
    787   // Declare the runtime hook.
    788   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
    789   auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
    790   auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false,
    791                                        llvm::GlobalValue::ExternalLinkage,
    792                                        nullptr, RuntimeVarName);
    793 
    794   // Make a function that uses it.
    795   auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false),
    796                                       llvm::GlobalValue::LinkOnceODRLinkage,
    797                                       RuntimeUserName, &CGM.getModule());
    798   User->addFnAttr(llvm::Attribute::NoInline);
    799   if (CGM.getCodeGenOpts().DisableRedZone)
    800     User->addFnAttr(llvm::Attribute::NoRedZone);
    801   CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User));
    802   auto *Load = Builder.CreateLoad(Var);
    803   Builder.CreateRet(Load);
    804 
    805   // Create a use of the function.  Now the definition of the runtime variable
    806   // should get pulled in, along with any static initializears.
    807   CGM.addUsedGlobal(User);
    808 }
    809 
    810 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
    811   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
    812   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
    813   if (!InstrumentRegions && !PGOReader)
    814     return;
    815   if (D->isImplicit())
    816     return;
    817   setFuncName(Fn);
    818 
    819   // Set the linkage for variables based on the function linkage.  Usually, we
    820   // want to match it, but available_externally and extern_weak both have the
    821   // wrong semantics.
    822   VarLinkage = Fn->getLinkage();
    823   switch (VarLinkage) {
    824   case llvm::GlobalValue::ExternalWeakLinkage:
    825     VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage;
    826     break;
    827   case llvm::GlobalValue::AvailableExternallyLinkage:
    828     VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage;
    829     break;
    830   default:
    831     break;
    832   }
    833 
    834   mapRegionCounters(D);
    835   if (InstrumentRegions) {
    836     emitRuntimeHook(CGM);
    837     emitCounterVariables();
    838   }
    839   if (PGOReader) {
    840     SourceManager &SM = CGM.getContext().getSourceManager();
    841     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
    842     computeRegionCounts(D);
    843     applyFunctionAttributes(PGOReader, Fn);
    844   }
    845 }
    846 
    847 void CodeGenPGO::mapRegionCounters(const Decl *D) {
    848   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
    849   MapRegionCounters Walker(*RegionCounterMap);
    850   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
    851     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
    852   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
    853     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
    854   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
    855     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
    856   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
    857     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
    858   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
    859   NumRegionCounters = Walker.NextCounter;
    860   FunctionHash = Walker.Hash.finalize();
    861 }
    862 
    863 void CodeGenPGO::computeRegionCounts(const Decl *D) {
    864   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
    865   ComputeRegionCounts Walker(*StmtCountMap, *this);
    866   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
    867     Walker.VisitFunctionDecl(FD);
    868   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
    869     Walker.VisitObjCMethodDecl(MD);
    870   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
    871     Walker.VisitBlockDecl(BD);
    872   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
    873     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
    874 }
    875 
    876 void
    877 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
    878                                     llvm::Function *Fn) {
    879   if (!haveRegionCounts())
    880     return;
    881 
    882   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
    883   uint64_t FunctionCount = getRegionCount(0);
    884   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
    885     // Turn on InlineHint attribute for hot functions.
    886     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
    887     Fn->addFnAttr(llvm::Attribute::InlineHint);
    888   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
    889     // Turn on Cold attribute for cold functions.
    890     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
    891     Fn->addFnAttr(llvm::Attribute::Cold);
    892 }
    893 
    894 void CodeGenPGO::emitCounterVariables() {
    895   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
    896   llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
    897                                                     NumRegionCounters);
    898   RegionCounters =
    899     new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage,
    900                              llvm::Constant::getNullValue(CounterTy),
    901                              getFuncVarName("counters"));
    902   RegionCounters->setAlignment(8);
    903   RegionCounters->setSection(getCountersSection(CGM));
    904 }
    905 
    906 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
    907   if (!RegionCounters)
    908     return;
    909   llvm::Value *Addr =
    910     Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
    911   llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
    912   Count = Builder.CreateAdd(Count, Builder.getInt64(1));
    913   Builder.CreateStore(Count, Addr);
    914 }
    915 
    916 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
    917                                   bool IsInMainFile) {
    918   CGM.getPGOStats().addVisited(IsInMainFile);
    919   RegionCounts.reset(new std::vector<uint64_t>);
    920   uint64_t Hash;
    921   if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) {
    922     CGM.getPGOStats().addMissing(IsInMainFile);
    923     RegionCounts.reset();
    924   } else if (Hash != FunctionHash ||
    925              RegionCounts->size() != NumRegionCounters) {
    926     CGM.getPGOStats().addMismatched(IsInMainFile);
    927     RegionCounts.reset();
    928   }
    929 }
    930 
    931 void CodeGenPGO::destroyRegionCounters() {
    932   RegionCounterMap.reset();
    933   StmtCountMap.reset();
    934   RegionCounts.reset();
    935   RegionCounters = nullptr;
    936 }
    937 
    938 /// \brief Calculate what to divide by to scale weights.
    939 ///
    940 /// Given the maximum weight, calculate a divisor that will scale all the
    941 /// weights to strictly less than UINT32_MAX.
    942 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
    943   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
    944 }
    945 
    946 /// \brief Scale an individual branch weight (and add 1).
    947 ///
    948 /// Scale a 64-bit weight down to 32-bits using \c Scale.
    949 ///
    950 /// According to Laplace's Rule of Succession, it is better to compute the
    951 /// weight based on the count plus 1, so universally add 1 to the value.
    952 ///
    953 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
    954 /// greater than \c Weight.
    955 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
    956   assert(Scale && "scale by 0?");
    957   uint64_t Scaled = Weight / Scale + 1;
    958   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
    959   return Scaled;
    960 }
    961 
    962 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
    963                                               uint64_t FalseCount) {
    964   // Check for empty weights.
    965   if (!TrueCount && !FalseCount)
    966     return nullptr;
    967 
    968   // Calculate how to scale down to 32-bits.
    969   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
    970 
    971   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
    972   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
    973                                       scaleBranchWeight(FalseCount, Scale));
    974 }
    975 
    976 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
    977   // We need at least two elements to create meaningful weights.
    978   if (Weights.size() < 2)
    979     return nullptr;
    980 
    981   // Check for empty weights.
    982   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
    983   if (MaxWeight == 0)
    984     return nullptr;
    985 
    986   // Calculate how to scale down to 32-bits.
    987   uint64_t Scale = calculateWeightScale(MaxWeight);
    988 
    989   SmallVector<uint32_t, 16> ScaledWeights;
    990   ScaledWeights.reserve(Weights.size());
    991   for (uint64_t W : Weights)
    992     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
    993 
    994   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
    995   return MDHelper.createBranchWeights(ScaledWeights);
    996 }
    997 
    998 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
    999                                             RegionCounter &Cnt) {
   1000   if (!haveRegionCounts())
   1001     return nullptr;
   1002   uint64_t LoopCount = Cnt.getCount();
   1003   uint64_t CondCount = 0;
   1004   bool Found = getStmtCount(Cond, CondCount);
   1005   assert(Found && "missing expected loop condition count");
   1006   (void)Found;
   1007   if (CondCount == 0)
   1008     return nullptr;
   1009   return createBranchWeights(LoopCount,
   1010                              std::max(CondCount, LoopCount) - LoopCount);
   1011 }
   1012