Home | History | Annotate | Download | only in Instrumentation
      1 //===-- PGOInstrumentation.cpp - MST-based PGO Instrumentation ------------===//
      2 //
      3 //                      The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This file implements PGO instrumentation using a minimum spanning tree based
     11 // on the following paper:
     12 //   [1] Donald E. Knuth, Francis R. Stevenson. Optimal measurement of points
     13 //   for program frequency counts. BIT Numerical Mathematics 1973, Volume 13,
     14 //   Issue 3, pp 313-322
     15 // The idea of the algorithm based on the fact that for each node (except for
     16 // the entry and exit), the sum of incoming edge counts equals the sum of
     17 // outgoing edge counts. The count of edge on spanning tree can be derived from
     18 // those edges not on the spanning tree. Knuth proves this method instruments
     19 // the minimum number of edges.
     20 //
     21 // The minimal spanning tree here is actually a maximum weight tree -- on-tree
     22 // edges have higher frequencies (more likely to execute). The idea is to
     23 // instrument those less frequently executed edges to reduce the runtime
     24 // overhead of instrumented binaries.
     25 //
     26 // This file contains two passes:
     27 // (1) Pass PGOInstrumentationGen which instruments the IR to generate edge
     28 // count profile, and generates the instrumentation for indirect call
     29 // profiling.
     30 // (2) Pass PGOInstrumentationUse which reads the edge count profile and
     31 // annotates the branch weights. It also reads the indirect call value
     32 // profiling records and annotate the indirect call instructions.
     33 //
     34 // To get the precise counter information, These two passes need to invoke at
     35 // the same compilation point (so they see the same IR). For pass
     36 // PGOInstrumentationGen, the real work is done in instrumentOneFunc(). For
     37 // pass PGOInstrumentationUse, the real work in done in class PGOUseFunc and
     38 // the profile is opened in module level and passed to each PGOUseFunc instance.
     39 // The shared code for PGOInstrumentationGen and PGOInstrumentationUse is put
     40 // in class FuncPGOInstrumentation.
     41 //
     42 // Class PGOEdge represents a CFG edge and some auxiliary information. Class
     43 // BBInfo contains auxiliary information for each BB. These two classes are used
     44 // in pass PGOInstrumentationGen. Class PGOUseEdge and UseBBInfo are the derived
     45 // class of PGOEdge and BBInfo, respectively. They contains extra data structure
     46 // used in populating profile counters.
     47 // The MST implementation is in Class CFGMST (CFGMST.h).
     48 //
     49 //===----------------------------------------------------------------------===//
     50 
     51 #include "llvm/Transforms/PGOInstrumentation.h"
     52 #include "CFGMST.h"
     53 #include "llvm/ADT/STLExtras.h"
     54 #include "llvm/ADT/Statistic.h"
     55 #include "llvm/ADT/Triple.h"
     56 #include "llvm/Analysis/BlockFrequencyInfo.h"
     57 #include "llvm/Analysis/BranchProbabilityInfo.h"
     58 #include "llvm/Analysis/CFG.h"
     59 #include "llvm/Analysis/IndirectCallSiteVisitor.h"
     60 #include "llvm/IR/CallSite.h"
     61 #include "llvm/IR/DiagnosticInfo.h"
     62 #include "llvm/IR/IRBuilder.h"
     63 #include "llvm/IR/InstIterator.h"
     64 #include "llvm/IR/Instructions.h"
     65 #include "llvm/IR/IntrinsicInst.h"
     66 #include "llvm/IR/MDBuilder.h"
     67 #include "llvm/IR/Module.h"
     68 #include "llvm/Pass.h"
     69 #include "llvm/ProfileData/InstrProfReader.h"
     70 #include "llvm/ProfileData/ProfileCommon.h"
     71 #include "llvm/Support/BranchProbability.h"
     72 #include "llvm/Support/Debug.h"
     73 #include "llvm/Support/JamCRC.h"
     74 #include "llvm/Transforms/Instrumentation.h"
     75 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     76 #include <algorithm>
     77 #include <string>
     78 #include <utility>
     79 #include <vector>
     80 
     81 using namespace llvm;
     82 
     83 #define DEBUG_TYPE "pgo-instrumentation"
     84 
     85 STATISTIC(NumOfPGOInstrument, "Number of edges instrumented.");
     86 STATISTIC(NumOfPGOEdge, "Number of edges.");
     87 STATISTIC(NumOfPGOBB, "Number of basic-blocks.");
     88 STATISTIC(NumOfPGOSplit, "Number of critical edge splits.");
     89 STATISTIC(NumOfPGOFunc, "Number of functions having valid profile counts.");
     90 STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile.");
     91 STATISTIC(NumOfPGOMissing, "Number of functions without profile.");
     92 STATISTIC(NumOfPGOICall, "Number of indirect call value instrumentations.");
     93 
     94 // Command line option to specify the file to read profile from. This is
     95 // mainly used for testing.
     96 static cl::opt<std::string>
     97     PGOTestProfileFile("pgo-test-profile-file", cl::init(""), cl::Hidden,
     98                        cl::value_desc("filename"),
     99                        cl::desc("Specify the path of profile data file. This is"
    100                                 "mainly for test purpose."));
    101 
    102 // Command line option to disable value profiling. The default is false:
    103 // i.e. value profiling is enabled by default. This is for debug purpose.
    104 static cl::opt<bool> DisableValueProfiling("disable-vp", cl::init(false),
    105                                            cl::Hidden,
    106                                            cl::desc("Disable Value Profiling"));
    107 
    108 // Command line option to set the maximum number of VP annotations to write to
    109 // the metadata for a single indirect call callsite.
    110 static cl::opt<unsigned> MaxNumAnnotations(
    111     "icp-max-annotations", cl::init(3), cl::Hidden, cl::ZeroOrMore,
    112     cl::desc("Max number of annotations for a single indirect "
    113              "call callsite"));
    114 
    115 // Command line option to enable/disable the warning about missing profile
    116 // information.
    117 static cl::opt<bool> NoPGOWarnMissing("no-pgo-warn-missing", cl::init(false),
    118                                       cl::Hidden);
    119 
    120 // Command line option to enable/disable the warning about a hash mismatch in
    121 // the profile data.
    122 static cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false),
    123                                        cl::Hidden);
    124 
    125 namespace {
    126 class PGOInstrumentationGenLegacyPass : public ModulePass {
    127 public:
    128   static char ID;
    129 
    130   PGOInstrumentationGenLegacyPass() : ModulePass(ID) {
    131     initializePGOInstrumentationGenLegacyPassPass(
    132         *PassRegistry::getPassRegistry());
    133   }
    134 
    135   const char *getPassName() const override {
    136     return "PGOInstrumentationGenPass";
    137   }
    138 
    139 private:
    140   bool runOnModule(Module &M) override;
    141 
    142   void getAnalysisUsage(AnalysisUsage &AU) const override {
    143     AU.addRequired<BlockFrequencyInfoWrapperPass>();
    144   }
    145 };
    146 
    147 class PGOInstrumentationUseLegacyPass : public ModulePass {
    148 public:
    149   static char ID;
    150 
    151   // Provide the profile filename as the parameter.
    152   PGOInstrumentationUseLegacyPass(std::string Filename = "")
    153       : ModulePass(ID), ProfileFileName(std::move(Filename)) {
    154     if (!PGOTestProfileFile.empty())
    155       ProfileFileName = PGOTestProfileFile;
    156     initializePGOInstrumentationUseLegacyPassPass(
    157         *PassRegistry::getPassRegistry());
    158   }
    159 
    160   const char *getPassName() const override {
    161     return "PGOInstrumentationUsePass";
    162   }
    163 
    164 private:
    165   std::string ProfileFileName;
    166 
    167   bool runOnModule(Module &M) override;
    168   void getAnalysisUsage(AnalysisUsage &AU) const override {
    169     AU.addRequired<BlockFrequencyInfoWrapperPass>();
    170   }
    171 };
    172 } // end anonymous namespace
    173 
    174 char PGOInstrumentationGenLegacyPass::ID = 0;
    175 INITIALIZE_PASS_BEGIN(PGOInstrumentationGenLegacyPass, "pgo-instr-gen",
    176                       "PGO instrumentation.", false, false)
    177 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
    178 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
    179 INITIALIZE_PASS_END(PGOInstrumentationGenLegacyPass, "pgo-instr-gen",
    180                     "PGO instrumentation.", false, false)
    181 
    182 ModulePass *llvm::createPGOInstrumentationGenLegacyPass() {
    183   return new PGOInstrumentationGenLegacyPass();
    184 }
    185 
    186 char PGOInstrumentationUseLegacyPass::ID = 0;
    187 INITIALIZE_PASS_BEGIN(PGOInstrumentationUseLegacyPass, "pgo-instr-use",
    188                       "Read PGO instrumentation profile.", false, false)
    189 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
    190 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
    191 INITIALIZE_PASS_END(PGOInstrumentationUseLegacyPass, "pgo-instr-use",
    192                     "Read PGO instrumentation profile.", false, false)
    193 
    194 ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename) {
    195   return new PGOInstrumentationUseLegacyPass(Filename.str());
    196 }
    197 
    198 namespace {
    199 /// \brief An MST based instrumentation for PGO
    200 ///
    201 /// Implements a Minimum Spanning Tree (MST) based instrumentation for PGO
    202 /// in the function level.
    203 struct PGOEdge {
    204   // This class implements the CFG edges. Note the CFG can be a multi-graph.
    205   // So there might be multiple edges with same SrcBB and DestBB.
    206   const BasicBlock *SrcBB;
    207   const BasicBlock *DestBB;
    208   uint64_t Weight;
    209   bool InMST;
    210   bool Removed;
    211   bool IsCritical;
    212   PGOEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1)
    213       : SrcBB(Src), DestBB(Dest), Weight(W), InMST(false), Removed(false),
    214         IsCritical(false) {}
    215   // Return the information string of an edge.
    216   const std::string infoString() const {
    217     return (Twine(Removed ? "-" : " ") + (InMST ? " " : "*") +
    218             (IsCritical ? "c" : " ") + "  W=" + Twine(Weight)).str();
    219   }
    220 };
    221 
    222 // This class stores the auxiliary information for each BB.
    223 struct BBInfo {
    224   BBInfo *Group;
    225   uint32_t Index;
    226   uint32_t Rank;
    227 
    228   BBInfo(unsigned IX) : Group(this), Index(IX), Rank(0) {}
    229 
    230   // Return the information string of this object.
    231   const std::string infoString() const {
    232     return (Twine("Index=") + Twine(Index)).str();
    233   }
    234 };
    235 
    236 // This class implements the CFG edges. Note the CFG can be a multi-graph.
    237 template <class Edge, class BBInfo> class FuncPGOInstrumentation {
    238 private:
    239   Function &F;
    240   void computeCFGHash();
    241 
    242 public:
    243   std::string FuncName;
    244   GlobalVariable *FuncNameVar;
    245   // CFG hash value for this function.
    246   uint64_t FunctionHash;
    247 
    248   // The Minimum Spanning Tree of function CFG.
    249   CFGMST<Edge, BBInfo> MST;
    250 
    251   // Give an edge, find the BB that will be instrumented.
    252   // Return nullptr if there is no BB to be instrumented.
    253   BasicBlock *getInstrBB(Edge *E);
    254 
    255   // Return the auxiliary BB information.
    256   BBInfo &getBBInfo(const BasicBlock *BB) const { return MST.getBBInfo(BB); }
    257 
    258   // Dump edges and BB information.
    259   void dumpInfo(std::string Str = "") const {
    260     MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName + " Hash: " +
    261                               Twine(FunctionHash) + "\t" + Str);
    262   }
    263 
    264   FuncPGOInstrumentation(Function &Func, bool CreateGlobalVar = false,
    265                          BranchProbabilityInfo *BPI = nullptr,
    266                          BlockFrequencyInfo *BFI = nullptr)
    267       : F(Func), FunctionHash(0), MST(F, BPI, BFI) {
    268     FuncName = getPGOFuncName(F);
    269     computeCFGHash();
    270     DEBUG(dumpInfo("after CFGMST"));
    271 
    272     NumOfPGOBB += MST.BBInfos.size();
    273     for (auto &E : MST.AllEdges) {
    274       if (E->Removed)
    275         continue;
    276       NumOfPGOEdge++;
    277       if (!E->InMST)
    278         NumOfPGOInstrument++;
    279     }
    280 
    281     if (CreateGlobalVar)
    282       FuncNameVar = createPGOFuncNameVar(F, FuncName);
    283   }
    284 };
    285 
    286 // Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index
    287 // value of each BB in the CFG. The higher 32 bits record the number of edges.
    288 template <class Edge, class BBInfo>
    289 void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
    290   std::vector<char> Indexes;
    291   JamCRC JC;
    292   for (auto &BB : F) {
    293     const TerminatorInst *TI = BB.getTerminator();
    294     for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
    295       BasicBlock *Succ = TI->getSuccessor(I);
    296       uint32_t Index = getBBInfo(Succ).Index;
    297       for (int J = 0; J < 4; J++)
    298         Indexes.push_back((char)(Index >> (J * 8)));
    299     }
    300   }
    301   JC.update(Indexes);
    302   FunctionHash = (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC();
    303 }
    304 
    305 // Given a CFG E to be instrumented, find which BB to place the instrumented
    306 // code. The function will split the critical edge if necessary.
    307 template <class Edge, class BBInfo>
    308 BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) {
    309   if (E->InMST || E->Removed)
    310     return nullptr;
    311 
    312   BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB);
    313   BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB);
    314   // For a fake edge, instrument the real BB.
    315   if (SrcBB == nullptr)
    316     return DestBB;
    317   if (DestBB == nullptr)
    318     return SrcBB;
    319 
    320   // Instrument the SrcBB if it has a single successor,
    321   // otherwise, the DestBB if this is not a critical edge.
    322   TerminatorInst *TI = SrcBB->getTerminator();
    323   if (TI->getNumSuccessors() <= 1)
    324     return SrcBB;
    325   if (!E->IsCritical)
    326     return DestBB;
    327 
    328   // For a critical edge, we have to split. Instrument the newly
    329   // created BB.
    330   NumOfPGOSplit++;
    331   DEBUG(dbgs() << "Split critical edge: " << getBBInfo(SrcBB).Index << " --> "
    332                << getBBInfo(DestBB).Index << "\n");
    333   unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB);
    334   BasicBlock *InstrBB = SplitCriticalEdge(TI, SuccNum);
    335   assert(InstrBB && "Critical edge is not split");
    336 
    337   E->Removed = true;
    338   return InstrBB;
    339 }
    340 
    341 // Visit all edge and instrument the edges not in MST, and do value profiling.
    342 // Critical edges will be split.
    343 static void instrumentOneFunc(Function &F, Module *M,
    344                               BranchProbabilityInfo *BPI,
    345                               BlockFrequencyInfo *BFI) {
    346   unsigned NumCounters = 0;
    347   FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, true, BPI, BFI);
    348   for (auto &E : FuncInfo.MST.AllEdges) {
    349     if (!E->InMST && !E->Removed)
    350       NumCounters++;
    351   }
    352 
    353   uint32_t I = 0;
    354   Type *I8PtrTy = Type::getInt8PtrTy(M->getContext());
    355   for (auto &E : FuncInfo.MST.AllEdges) {
    356     BasicBlock *InstrBB = FuncInfo.getInstrBB(E.get());
    357     if (!InstrBB)
    358       continue;
    359 
    360     IRBuilder<> Builder(InstrBB, InstrBB->getFirstInsertionPt());
    361     assert(Builder.GetInsertPoint() != InstrBB->end() &&
    362            "Cannot get the Instrumentation point");
    363     Builder.CreateCall(
    364         Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment),
    365         {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
    366          Builder.getInt64(FuncInfo.FunctionHash), Builder.getInt32(NumCounters),
    367          Builder.getInt32(I++)});
    368   }
    369 
    370   if (DisableValueProfiling)
    371     return;
    372 
    373   unsigned NumIndirectCallSites = 0;
    374   for (auto &I : findIndirectCallSites(F)) {
    375     CallSite CS(I);
    376     Value *Callee = CS.getCalledValue();
    377     DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = "
    378                  << NumIndirectCallSites << "\n");
    379     IRBuilder<> Builder(I);
    380     assert(Builder.GetInsertPoint() != I->getParent()->end() &&
    381            "Cannot get the Instrumentation point");
    382     Builder.CreateCall(
    383         Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
    384         {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
    385          Builder.getInt64(FuncInfo.FunctionHash),
    386          Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()),
    387          Builder.getInt32(llvm::InstrProfValueKind::IPVK_IndirectCallTarget),
    388          Builder.getInt32(NumIndirectCallSites++)});
    389   }
    390   NumOfPGOICall += NumIndirectCallSites;
    391 }
    392 
    393 // This class represents a CFG edge in profile use compilation.
    394 struct PGOUseEdge : public PGOEdge {
    395   bool CountValid;
    396   uint64_t CountValue;
    397   PGOUseEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1)
    398       : PGOEdge(Src, Dest, W), CountValid(false), CountValue(0) {}
    399 
    400   // Set edge count value
    401   void setEdgeCount(uint64_t Value) {
    402     CountValue = Value;
    403     CountValid = true;
    404   }
    405 
    406   // Return the information string for this object.
    407   const std::string infoString() const {
    408     if (!CountValid)
    409       return PGOEdge::infoString();
    410     return (Twine(PGOEdge::infoString()) + "  Count=" + Twine(CountValue))
    411         .str();
    412   }
    413 };
    414 
    415 typedef SmallVector<PGOUseEdge *, 2> DirectEdges;
    416 
    417 // This class stores the auxiliary information for each BB.
    418 struct UseBBInfo : public BBInfo {
    419   uint64_t CountValue;
    420   bool CountValid;
    421   int32_t UnknownCountInEdge;
    422   int32_t UnknownCountOutEdge;
    423   DirectEdges InEdges;
    424   DirectEdges OutEdges;
    425   UseBBInfo(unsigned IX)
    426       : BBInfo(IX), CountValue(0), CountValid(false), UnknownCountInEdge(0),
    427         UnknownCountOutEdge(0) {}
    428   UseBBInfo(unsigned IX, uint64_t C)
    429       : BBInfo(IX), CountValue(C), CountValid(true), UnknownCountInEdge(0),
    430         UnknownCountOutEdge(0) {}
    431 
    432   // Set the profile count value for this BB.
    433   void setBBInfoCount(uint64_t Value) {
    434     CountValue = Value;
    435     CountValid = true;
    436   }
    437 
    438   // Return the information string of this object.
    439   const std::string infoString() const {
    440     if (!CountValid)
    441       return BBInfo::infoString();
    442     return (Twine(BBInfo::infoString()) + "  Count=" + Twine(CountValue)).str();
    443   }
    444 };
    445 
    446 // Sum up the count values for all the edges.
    447 static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) {
    448   uint64_t Total = 0;
    449   for (auto &E : Edges) {
    450     if (E->Removed)
    451       continue;
    452     Total += E->CountValue;
    453   }
    454   return Total;
    455 }
    456 
    457 class PGOUseFunc {
    458 public:
    459   PGOUseFunc(Function &Func, Module *Modu, BranchProbabilityInfo *BPI = nullptr,
    460              BlockFrequencyInfo *BFI = nullptr)
    461       : F(Func), M(Modu), FuncInfo(Func, false, BPI, BFI),
    462         FreqAttr(FFA_Normal) {}
    463 
    464   // Read counts for the instrumented BB from profile.
    465   bool readCounters(IndexedInstrProfReader *PGOReader);
    466 
    467   // Populate the counts for all BBs.
    468   void populateCounters();
    469 
    470   // Set the branch weights based on the count values.
    471   void setBranchWeights();
    472 
    473   // Annotate the indirect call sites.
    474   void annotateIndirectCallSites();
    475 
    476   // The hotness of the function from the profile count.
    477   enum FuncFreqAttr { FFA_Normal, FFA_Cold, FFA_Hot };
    478 
    479   // Return the function hotness from the profile.
    480   FuncFreqAttr getFuncFreqAttr() const { return FreqAttr; }
    481 
    482   // Return the profile record for this function;
    483   InstrProfRecord &getProfileRecord() { return ProfileRecord; }
    484 
    485 private:
    486   Function &F;
    487   Module *M;
    488   // This member stores the shared information with class PGOGenFunc.
    489   FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo;
    490 
    491   // Return the auxiliary BB information.
    492   UseBBInfo &getBBInfo(const BasicBlock *BB) const {
    493     return FuncInfo.getBBInfo(BB);
    494   }
    495 
    496   // The maximum count value in the profile. This is only used in PGO use
    497   // compilation.
    498   uint64_t ProgramMaxCount;
    499 
    500   // ProfileRecord for this function.
    501   InstrProfRecord ProfileRecord;
    502 
    503   // Function hotness info derived from profile.
    504   FuncFreqAttr FreqAttr;
    505 
    506   // Find the Instrumented BB and set the value.
    507   void setInstrumentedCounts(const std::vector<uint64_t> &CountFromProfile);
    508 
    509   // Set the edge counter value for the unknown edge -- there should be only
    510   // one unknown edge.
    511   void setEdgeCount(DirectEdges &Edges, uint64_t Value);
    512 
    513   // Return FuncName string;
    514   const std::string getFuncName() const { return FuncInfo.FuncName; }
    515 
    516   // Set the hot/cold inline hints based on the count values.
    517   // FIXME: This function should be removed once the functionality in
    518   // the inliner is implemented.
    519   void markFunctionAttributes(uint64_t EntryCount, uint64_t MaxCount) {
    520     if (ProgramMaxCount == 0)
    521       return;
    522     // Threshold of the hot functions.
    523     const BranchProbability HotFunctionThreshold(1, 100);
    524     // Threshold of the cold functions.
    525     const BranchProbability ColdFunctionThreshold(2, 10000);
    526     if (EntryCount >= HotFunctionThreshold.scale(ProgramMaxCount))
    527       FreqAttr = FFA_Hot;
    528     else if (MaxCount <= ColdFunctionThreshold.scale(ProgramMaxCount))
    529       FreqAttr = FFA_Cold;
    530   }
    531 };
    532 
    533 // Visit all the edges and assign the count value for the instrumented
    534 // edges and the BB.
    535 void PGOUseFunc::setInstrumentedCounts(
    536     const std::vector<uint64_t> &CountFromProfile) {
    537 
    538   // Use a worklist as we will update the vector during the iteration.
    539   std::vector<PGOUseEdge *> WorkList;
    540   for (auto &E : FuncInfo.MST.AllEdges)
    541     WorkList.push_back(E.get());
    542 
    543   uint32_t I = 0;
    544   for (auto &E : WorkList) {
    545     BasicBlock *InstrBB = FuncInfo.getInstrBB(E);
    546     if (!InstrBB)
    547       continue;
    548     uint64_t CountValue = CountFromProfile[I++];
    549     if (!E->Removed) {
    550       getBBInfo(InstrBB).setBBInfoCount(CountValue);
    551       E->setEdgeCount(CountValue);
    552       continue;
    553     }
    554 
    555     // Need to add two new edges.
    556     BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB);
    557     BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB);
    558     // Add new edge of SrcBB->InstrBB.
    559     PGOUseEdge &NewEdge = FuncInfo.MST.addEdge(SrcBB, InstrBB, 0);
    560     NewEdge.setEdgeCount(CountValue);
    561     // Add new edge of InstrBB->DestBB.
    562     PGOUseEdge &NewEdge1 = FuncInfo.MST.addEdge(InstrBB, DestBB, 0);
    563     NewEdge1.setEdgeCount(CountValue);
    564     NewEdge1.InMST = true;
    565     getBBInfo(InstrBB).setBBInfoCount(CountValue);
    566   }
    567 }
    568 
    569 // Set the count value for the unknown edge. There should be one and only one
    570 // unknown edge in Edges vector.
    571 void PGOUseFunc::setEdgeCount(DirectEdges &Edges, uint64_t Value) {
    572   for (auto &E : Edges) {
    573     if (E->CountValid)
    574       continue;
    575     E->setEdgeCount(Value);
    576 
    577     getBBInfo(E->SrcBB).UnknownCountOutEdge--;
    578     getBBInfo(E->DestBB).UnknownCountInEdge--;
    579     return;
    580   }
    581   llvm_unreachable("Cannot find the unknown count edge");
    582 }
    583 
    584 // Read the profile from ProfileFileName and assign the value to the
    585 // instrumented BB and the edges. This function also updates ProgramMaxCount.
    586 // Return true if the profile are successfully read, and false on errors.
    587 bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) {
    588   auto &Ctx = M->getContext();
    589   Expected<InstrProfRecord> Result =
    590       PGOReader->getInstrProfRecord(FuncInfo.FuncName, FuncInfo.FunctionHash);
    591   if (Error E = Result.takeError()) {
    592     handleAllErrors(std::move(E), [&](const InstrProfError &IPE) {
    593       auto Err = IPE.get();
    594       bool SkipWarning = false;
    595       if (Err == instrprof_error::unknown_function) {
    596         NumOfPGOMissing++;
    597         SkipWarning = NoPGOWarnMissing;
    598       } else if (Err == instrprof_error::hash_mismatch ||
    599                  Err == instrprof_error::malformed) {
    600         NumOfPGOMismatch++;
    601         SkipWarning = NoPGOWarnMismatch;
    602       }
    603 
    604       if (SkipWarning)
    605         return;
    606 
    607       std::string Msg = IPE.message() + std::string(" ") + F.getName().str();
    608       Ctx.diagnose(
    609           DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
    610     });
    611     return false;
    612   }
    613   ProfileRecord = std::move(Result.get());
    614   std::vector<uint64_t> &CountFromProfile = ProfileRecord.Counts;
    615 
    616   NumOfPGOFunc++;
    617   DEBUG(dbgs() << CountFromProfile.size() << " counts\n");
    618   uint64_t ValueSum = 0;
    619   for (unsigned I = 0, S = CountFromProfile.size(); I < S; I++) {
    620     DEBUG(dbgs() << "  " << I << ": " << CountFromProfile[I] << "\n");
    621     ValueSum += CountFromProfile[I];
    622   }
    623 
    624   DEBUG(dbgs() << "SUM =  " << ValueSum << "\n");
    625 
    626   getBBInfo(nullptr).UnknownCountOutEdge = 2;
    627   getBBInfo(nullptr).UnknownCountInEdge = 2;
    628 
    629   setInstrumentedCounts(CountFromProfile);
    630   ProgramMaxCount = PGOReader->getMaximumFunctionCount();
    631   return true;
    632 }
    633 
    634 // Populate the counters from instrumented BBs to all BBs.
    635 // In the end of this operation, all BBs should have a valid count value.
    636 void PGOUseFunc::populateCounters() {
    637   // First set up Count variable for all BBs.
    638   for (auto &E : FuncInfo.MST.AllEdges) {
    639     if (E->Removed)
    640       continue;
    641 
    642     const BasicBlock *SrcBB = E->SrcBB;
    643     const BasicBlock *DestBB = E->DestBB;
    644     UseBBInfo &SrcInfo = getBBInfo(SrcBB);
    645     UseBBInfo &DestInfo = getBBInfo(DestBB);
    646     SrcInfo.OutEdges.push_back(E.get());
    647     DestInfo.InEdges.push_back(E.get());
    648     SrcInfo.UnknownCountOutEdge++;
    649     DestInfo.UnknownCountInEdge++;
    650 
    651     if (!E->CountValid)
    652       continue;
    653     DestInfo.UnknownCountInEdge--;
    654     SrcInfo.UnknownCountOutEdge--;
    655   }
    656 
    657   bool Changes = true;
    658   unsigned NumPasses = 0;
    659   while (Changes) {
    660     NumPasses++;
    661     Changes = false;
    662 
    663     // For efficient traversal, it's better to start from the end as most
    664     // of the instrumented edges are at the end.
    665     for (auto &BB : reverse(F)) {
    666       UseBBInfo &Count = getBBInfo(&BB);
    667       if (!Count.CountValid) {
    668         if (Count.UnknownCountOutEdge == 0) {
    669           Count.CountValue = sumEdgeCount(Count.OutEdges);
    670           Count.CountValid = true;
    671           Changes = true;
    672         } else if (Count.UnknownCountInEdge == 0) {
    673           Count.CountValue = sumEdgeCount(Count.InEdges);
    674           Count.CountValid = true;
    675           Changes = true;
    676         }
    677       }
    678       if (Count.CountValid) {
    679         if (Count.UnknownCountOutEdge == 1) {
    680           uint64_t Total = Count.CountValue - sumEdgeCount(Count.OutEdges);
    681           setEdgeCount(Count.OutEdges, Total);
    682           Changes = true;
    683         }
    684         if (Count.UnknownCountInEdge == 1) {
    685           uint64_t Total = Count.CountValue - sumEdgeCount(Count.InEdges);
    686           setEdgeCount(Count.InEdges, Total);
    687           Changes = true;
    688         }
    689       }
    690     }
    691   }
    692 
    693   DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n");
    694 #ifndef NDEBUG
    695   // Assert every BB has a valid counter.
    696   for (auto &BB : F)
    697     assert(getBBInfo(&BB).CountValid && "BB count is not valid");
    698 #endif
    699   uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
    700   F.setEntryCount(FuncEntryCount);
    701   uint64_t FuncMaxCount = FuncEntryCount;
    702   for (auto &BB : F)
    703     FuncMaxCount = std::max(FuncMaxCount, getBBInfo(&BB).CountValue);
    704   markFunctionAttributes(FuncEntryCount, FuncMaxCount);
    705 
    706   DEBUG(FuncInfo.dumpInfo("after reading profile."));
    707 }
    708 
    709 // Assign the scaled count values to the BB with multiple out edges.
    710 void PGOUseFunc::setBranchWeights() {
    711   // Generate MD_prof metadata for every branch instruction.
    712   DEBUG(dbgs() << "\nSetting branch weights.\n");
    713   MDBuilder MDB(M->getContext());
    714   for (auto &BB : F) {
    715     TerminatorInst *TI = BB.getTerminator();
    716     if (TI->getNumSuccessors() < 2)
    717       continue;
    718     if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI))
    719       continue;
    720     if (getBBInfo(&BB).CountValue == 0)
    721       continue;
    722 
    723     // We have a non-zero Branch BB.
    724     const UseBBInfo &BBCountInfo = getBBInfo(&BB);
    725     unsigned Size = BBCountInfo.OutEdges.size();
    726     SmallVector<unsigned, 2> EdgeCounts(Size, 0);
    727     uint64_t MaxCount = 0;
    728     for (unsigned s = 0; s < Size; s++) {
    729       const PGOUseEdge *E = BBCountInfo.OutEdges[s];
    730       const BasicBlock *SrcBB = E->SrcBB;
    731       const BasicBlock *DestBB = E->DestBB;
    732       if (DestBB == nullptr)
    733         continue;
    734       unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB);
    735       uint64_t EdgeCount = E->CountValue;
    736       if (EdgeCount > MaxCount)
    737         MaxCount = EdgeCount;
    738       EdgeCounts[SuccNum] = EdgeCount;
    739     }
    740     assert(MaxCount > 0 && "Bad max count");
    741     uint64_t Scale = calculateCountScale(MaxCount);
    742     SmallVector<unsigned, 4> Weights;
    743     for (const auto &ECI : EdgeCounts)
    744       Weights.push_back(scaleBranchCount(ECI, Scale));
    745 
    746     TI->setMetadata(llvm::LLVMContext::MD_prof,
    747                     MDB.createBranchWeights(Weights));
    748     DEBUG(dbgs() << "Weight is: ";
    749           for (const auto &W : Weights) { dbgs() << W << " "; }
    750           dbgs() << "\n";);
    751   }
    752 }
    753 
    754 // Traverse all the indirect callsites and annotate the instructions.
    755 void PGOUseFunc::annotateIndirectCallSites() {
    756   if (DisableValueProfiling)
    757     return;
    758 
    759   // Create the PGOFuncName meta data.
    760   createPGOFuncNameMetadata(F, FuncInfo.FuncName);
    761 
    762   unsigned IndirectCallSiteIndex = 0;
    763   auto IndirectCallSites = findIndirectCallSites(F);
    764   unsigned NumValueSites =
    765       ProfileRecord.getNumValueSites(IPVK_IndirectCallTarget);
    766   if (NumValueSites != IndirectCallSites.size()) {
    767     std::string Msg =
    768         std::string("Inconsistent number of indirect call sites: ") +
    769         F.getName().str();
    770     auto &Ctx = M->getContext();
    771     Ctx.diagnose(
    772         DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
    773     return;
    774   }
    775 
    776   for (auto &I : IndirectCallSites) {
    777     DEBUG(dbgs() << "Read one indirect call instrumentation: Index="
    778                  << IndirectCallSiteIndex << " out of " << NumValueSites
    779                  << "\n");
    780     annotateValueSite(*M, *I, ProfileRecord, IPVK_IndirectCallTarget,
    781                       IndirectCallSiteIndex, MaxNumAnnotations);
    782     IndirectCallSiteIndex++;
    783   }
    784 }
    785 } // end anonymous namespace
    786 
    787 // Create a COMDAT variable IR_LEVEL_PROF_VARNAME to make the runtime
    788 // aware this is an ir_level profile so it can set the version flag.
    789 static void createIRLevelProfileFlagVariable(Module &M) {
    790   Type *IntTy64 = Type::getInt64Ty(M.getContext());
    791   uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
    792   auto IRLevelVersionVariable = new GlobalVariable(
    793       M, IntTy64, true, GlobalVariable::ExternalLinkage,
    794       Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)),
    795       INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR));
    796   IRLevelVersionVariable->setVisibility(GlobalValue::DefaultVisibility);
    797   Triple TT(M.getTargetTriple());
    798   if (!TT.supportsCOMDAT())
    799     IRLevelVersionVariable->setLinkage(GlobalValue::WeakAnyLinkage);
    800   else
    801     IRLevelVersionVariable->setComdat(M.getOrInsertComdat(
    802         StringRef(INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR))));
    803 }
    804 
    805 static bool InstrumentAllFunctions(
    806     Module &M, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
    807     function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
    808   createIRLevelProfileFlagVariable(M);
    809   for (auto &F : M) {
    810     if (F.isDeclaration())
    811       continue;
    812     auto *BPI = LookupBPI(F);
    813     auto *BFI = LookupBFI(F);
    814     instrumentOneFunc(F, &M, BPI, BFI);
    815   }
    816   return true;
    817 }
    818 
    819 bool PGOInstrumentationGenLegacyPass::runOnModule(Module &M) {
    820   if (skipModule(M))
    821     return false;
    822 
    823   auto LookupBPI = [this](Function &F) {
    824     return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI();
    825   };
    826   auto LookupBFI = [this](Function &F) {
    827     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
    828   };
    829   return InstrumentAllFunctions(M, LookupBPI, LookupBFI);
    830 }
    831 
    832 PreservedAnalyses PGOInstrumentationGen::run(Module &M,
    833                                              AnalysisManager<Module> &AM) {
    834 
    835   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
    836   auto LookupBPI = [&FAM](Function &F) {
    837     return &FAM.getResult<BranchProbabilityAnalysis>(F);
    838   };
    839 
    840   auto LookupBFI = [&FAM](Function &F) {
    841     return &FAM.getResult<BlockFrequencyAnalysis>(F);
    842   };
    843 
    844   if (!InstrumentAllFunctions(M, LookupBPI, LookupBFI))
    845     return PreservedAnalyses::all();
    846 
    847   return PreservedAnalyses::none();
    848 }
    849 
    850 static bool annotateAllFunctions(
    851     Module &M, StringRef ProfileFileName,
    852     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
    853     function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
    854   DEBUG(dbgs() << "Read in profile counters: ");
    855   auto &Ctx = M.getContext();
    856   // Read the counter array from file.
    857   auto ReaderOrErr = IndexedInstrProfReader::create(ProfileFileName);
    858   if (Error E = ReaderOrErr.takeError()) {
    859     handleAllErrors(std::move(E), [&](const ErrorInfoBase &EI) {
    860       Ctx.diagnose(
    861           DiagnosticInfoPGOProfile(ProfileFileName.data(), EI.message()));
    862     });
    863     return false;
    864   }
    865 
    866   std::unique_ptr<IndexedInstrProfReader> PGOReader =
    867       std::move(ReaderOrErr.get());
    868   if (!PGOReader) {
    869     Ctx.diagnose(DiagnosticInfoPGOProfile(ProfileFileName.data(),
    870                                           StringRef("Cannot get PGOReader")));
    871     return false;
    872   }
    873   // TODO: might need to change the warning once the clang option is finalized.
    874   if (!PGOReader->isIRLevelProfile()) {
    875     Ctx.diagnose(DiagnosticInfoPGOProfile(
    876         ProfileFileName.data(), "Not an IR level instrumentation profile"));
    877     return false;
    878   }
    879 
    880   std::vector<Function *> HotFunctions;
    881   std::vector<Function *> ColdFunctions;
    882   for (auto &F : M) {
    883     if (F.isDeclaration())
    884       continue;
    885     auto *BPI = LookupBPI(F);
    886     auto *BFI = LookupBFI(F);
    887     PGOUseFunc Func(F, &M, BPI, BFI);
    888     if (!Func.readCounters(PGOReader.get()))
    889       continue;
    890     Func.populateCounters();
    891     Func.setBranchWeights();
    892     Func.annotateIndirectCallSites();
    893     PGOUseFunc::FuncFreqAttr FreqAttr = Func.getFuncFreqAttr();
    894     if (FreqAttr == PGOUseFunc::FFA_Cold)
    895       ColdFunctions.push_back(&F);
    896     else if (FreqAttr == PGOUseFunc::FFA_Hot)
    897       HotFunctions.push_back(&F);
    898   }
    899   M.setProfileSummary(PGOReader->getSummary().getMD(M.getContext()));
    900   // Set function hotness attribute from the profile.
    901   // We have to apply these attributes at the end because their presence
    902   // can affect the BranchProbabilityInfo of any callers, resulting in an
    903   // inconsistent MST between prof-gen and prof-use.
    904   for (auto &F : HotFunctions) {
    905     F->addFnAttr(llvm::Attribute::InlineHint);
    906     DEBUG(dbgs() << "Set inline attribute to function: " << F->getName()
    907                  << "\n");
    908   }
    909   for (auto &F : ColdFunctions) {
    910     F->addFnAttr(llvm::Attribute::Cold);
    911     DEBUG(dbgs() << "Set cold attribute to function: " << F->getName() << "\n");
    912   }
    913 
    914   return true;
    915 }
    916 
    917 PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename)
    918     : ProfileFileName(std::move(Filename)) {
    919   if (!PGOTestProfileFile.empty())
    920     ProfileFileName = PGOTestProfileFile;
    921 }
    922 
    923 PreservedAnalyses PGOInstrumentationUse::run(Module &M,
    924                                              AnalysisManager<Module> &AM) {
    925 
    926   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
    927   auto LookupBPI = [&FAM](Function &F) {
    928     return &FAM.getResult<BranchProbabilityAnalysis>(F);
    929   };
    930 
    931   auto LookupBFI = [&FAM](Function &F) {
    932     return &FAM.getResult<BlockFrequencyAnalysis>(F);
    933   };
    934 
    935   if (!annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI))
    936     return PreservedAnalyses::all();
    937 
    938   return PreservedAnalyses::none();
    939 }
    940 
    941 bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) {
    942   if (skipModule(M))
    943     return false;
    944 
    945   auto LookupBPI = [this](Function &F) {
    946     return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI();
    947   };
    948   auto LookupBFI = [this](Function &F) {
    949     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
    950   };
    951 
    952   return annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI);
    953 }
    954