1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Instrumentation-based profile-guided optimization 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "CodeGenPGO.h" 15 #include "CodeGenFunction.h" 16 #include "CoverageMappingGen.h" 17 #include "clang/AST/RecursiveASTVisitor.h" 18 #include "clang/AST/StmtVisitor.h" 19 #include "llvm/IR/Intrinsics.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/Support/Endian.h" 22 #include "llvm/Support/FileSystem.h" 23 #include "llvm/Support/MD5.h" 24 25 static llvm::cl::opt<bool> EnableValueProfiling( 26 "enable-value-profiling", llvm::cl::ZeroOrMore, 27 llvm::cl::desc("Enable value profiling"), llvm::cl::init(false)); 28 29 using namespace clang; 30 using namespace CodeGen; 31 32 void CodeGenPGO::setFuncName(StringRef Name, 33 llvm::GlobalValue::LinkageTypes Linkage) { 34 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 35 FuncName = llvm::getPGOFuncName( 36 Name, Linkage, CGM.getCodeGenOpts().MainFileName, 37 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version); 38 39 // If we're generating a profile, create a variable for the name. 40 if (CGM.getCodeGenOpts().hasProfileClangInstr()) 41 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName); 42 } 43 44 void CodeGenPGO::setFuncName(llvm::Function *Fn) { 45 setFuncName(Fn->getName(), Fn->getLinkage()); 46 // Create PGOFuncName meta data. 47 llvm::createPGOFuncNameMetadata(*Fn, FuncName); 48 } 49 50 namespace { 51 /// \brief Stable hasher for PGO region counters. 52 /// 53 /// PGOHash produces a stable hash of a given function's control flow. 54 /// 55 /// Changing the output of this hash will invalidate all previously generated 56 /// profiles -- i.e., don't do it. 57 /// 58 /// \note When this hash does eventually change (years?), we still need to 59 /// support old hashes. We'll need to pull in the version number from the 60 /// profile data format and use the matching hash function. 61 class PGOHash { 62 uint64_t Working; 63 unsigned Count; 64 llvm::MD5 MD5; 65 66 static const int NumBitsPerType = 6; 67 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType; 68 static const unsigned TooBig = 1u << NumBitsPerType; 69 70 public: 71 /// \brief Hash values for AST nodes. 72 /// 73 /// Distinct values for AST nodes that have region counters attached. 74 /// 75 /// These values must be stable. All new members must be added at the end, 76 /// and no members should be removed. Changing the enumeration value for an 77 /// AST node will affect the hash of every function that contains that node. 78 enum HashType : unsigned char { 79 None = 0, 80 LabelStmt = 1, 81 WhileStmt, 82 DoStmt, 83 ForStmt, 84 CXXForRangeStmt, 85 ObjCForCollectionStmt, 86 SwitchStmt, 87 CaseStmt, 88 DefaultStmt, 89 IfStmt, 90 CXXTryStmt, 91 CXXCatchStmt, 92 ConditionalOperator, 93 BinaryOperatorLAnd, 94 BinaryOperatorLOr, 95 BinaryConditionalOperator, 96 97 // Keep this last. It's for the static assert that follows. 98 LastHashType 99 }; 100 static_assert(LastHashType <= TooBig, "Too many types in HashType"); 101 102 // TODO: When this format changes, take in a version number here, and use the 103 // old hash calculation for file formats that used the old hash. 104 PGOHash() : Working(0), Count(0) {} 105 void combine(HashType Type); 106 uint64_t finalize(); 107 }; 108 const int PGOHash::NumBitsPerType; 109 const unsigned PGOHash::NumTypesPerWord; 110 const unsigned PGOHash::TooBig; 111 112 /// A RecursiveASTVisitor that fills a map of statements to PGO counters. 113 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> { 114 /// The next counter value to assign. 115 unsigned NextCounter; 116 /// The function hash. 117 PGOHash Hash; 118 /// The map of statements to counters. 119 llvm::DenseMap<const Stmt *, unsigned> &CounterMap; 120 121 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap) 122 : NextCounter(0), CounterMap(CounterMap) {} 123 124 // Blocks and lambdas are handled as separate functions, so we need not 125 // traverse them in the parent context. 126 bool TraverseBlockExpr(BlockExpr *BE) { return true; } 127 bool TraverseLambdaBody(LambdaExpr *LE) { return true; } 128 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; } 129 130 bool VisitDecl(const Decl *D) { 131 switch (D->getKind()) { 132 default: 133 break; 134 case Decl::Function: 135 case Decl::CXXMethod: 136 case Decl::CXXConstructor: 137 case Decl::CXXDestructor: 138 case Decl::CXXConversion: 139 case Decl::ObjCMethod: 140 case Decl::Block: 141 case Decl::Captured: 142 CounterMap[D->getBody()] = NextCounter++; 143 break; 144 } 145 return true; 146 } 147 148 bool VisitStmt(const Stmt *S) { 149 auto Type = getHashType(S); 150 if (Type == PGOHash::None) 151 return true; 152 153 CounterMap[S] = NextCounter++; 154 Hash.combine(Type); 155 return true; 156 } 157 PGOHash::HashType getHashType(const Stmt *S) { 158 switch (S->getStmtClass()) { 159 default: 160 break; 161 case Stmt::LabelStmtClass: 162 return PGOHash::LabelStmt; 163 case Stmt::WhileStmtClass: 164 return PGOHash::WhileStmt; 165 case Stmt::DoStmtClass: 166 return PGOHash::DoStmt; 167 case Stmt::ForStmtClass: 168 return PGOHash::ForStmt; 169 case Stmt::CXXForRangeStmtClass: 170 return PGOHash::CXXForRangeStmt; 171 case Stmt::ObjCForCollectionStmtClass: 172 return PGOHash::ObjCForCollectionStmt; 173 case Stmt::SwitchStmtClass: 174 return PGOHash::SwitchStmt; 175 case Stmt::CaseStmtClass: 176 return PGOHash::CaseStmt; 177 case Stmt::DefaultStmtClass: 178 return PGOHash::DefaultStmt; 179 case Stmt::IfStmtClass: 180 return PGOHash::IfStmt; 181 case Stmt::CXXTryStmtClass: 182 return PGOHash::CXXTryStmt; 183 case Stmt::CXXCatchStmtClass: 184 return PGOHash::CXXCatchStmt; 185 case Stmt::ConditionalOperatorClass: 186 return PGOHash::ConditionalOperator; 187 case Stmt::BinaryConditionalOperatorClass: 188 return PGOHash::BinaryConditionalOperator; 189 case Stmt::BinaryOperatorClass: { 190 const BinaryOperator *BO = cast<BinaryOperator>(S); 191 if (BO->getOpcode() == BO_LAnd) 192 return PGOHash::BinaryOperatorLAnd; 193 if (BO->getOpcode() == BO_LOr) 194 return PGOHash::BinaryOperatorLOr; 195 break; 196 } 197 } 198 return PGOHash::None; 199 } 200 }; 201 202 /// A StmtVisitor that propagates the raw counts through the AST and 203 /// records the count at statements where the value may change. 204 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> { 205 /// PGO state. 206 CodeGenPGO &PGO; 207 208 /// A flag that is set when the current count should be recorded on the 209 /// next statement, such as at the exit of a loop. 210 bool RecordNextStmtCount; 211 212 /// The count at the current location in the traversal. 213 uint64_t CurrentCount; 214 215 /// The map of statements to count values. 216 llvm::DenseMap<const Stmt *, uint64_t> &CountMap; 217 218 /// BreakContinueStack - Keep counts of breaks and continues inside loops. 219 struct BreakContinue { 220 uint64_t BreakCount; 221 uint64_t ContinueCount; 222 BreakContinue() : BreakCount(0), ContinueCount(0) {} 223 }; 224 SmallVector<BreakContinue, 8> BreakContinueStack; 225 226 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap, 227 CodeGenPGO &PGO) 228 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {} 229 230 void RecordStmtCount(const Stmt *S) { 231 if (RecordNextStmtCount) { 232 CountMap[S] = CurrentCount; 233 RecordNextStmtCount = false; 234 } 235 } 236 237 /// Set and return the current count. 238 uint64_t setCount(uint64_t Count) { 239 CurrentCount = Count; 240 return Count; 241 } 242 243 void VisitStmt(const Stmt *S) { 244 RecordStmtCount(S); 245 for (const Stmt *Child : S->children()) 246 if (Child) 247 this->Visit(Child); 248 } 249 250 void VisitFunctionDecl(const FunctionDecl *D) { 251 // Counter tracks entry to the function body. 252 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 253 CountMap[D->getBody()] = BodyCount; 254 Visit(D->getBody()); 255 } 256 257 // Skip lambda expressions. We visit these as FunctionDecls when we're 258 // generating them and aren't interested in the body when generating a 259 // parent context. 260 void VisitLambdaExpr(const LambdaExpr *LE) {} 261 262 void VisitCapturedDecl(const CapturedDecl *D) { 263 // Counter tracks entry to the capture body. 264 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 265 CountMap[D->getBody()] = BodyCount; 266 Visit(D->getBody()); 267 } 268 269 void VisitObjCMethodDecl(const ObjCMethodDecl *D) { 270 // Counter tracks entry to the method body. 271 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 272 CountMap[D->getBody()] = BodyCount; 273 Visit(D->getBody()); 274 } 275 276 void VisitBlockDecl(const BlockDecl *D) { 277 // Counter tracks entry to the block body. 278 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody())); 279 CountMap[D->getBody()] = BodyCount; 280 Visit(D->getBody()); 281 } 282 283 void VisitReturnStmt(const ReturnStmt *S) { 284 RecordStmtCount(S); 285 if (S->getRetValue()) 286 Visit(S->getRetValue()); 287 CurrentCount = 0; 288 RecordNextStmtCount = true; 289 } 290 291 void VisitCXXThrowExpr(const CXXThrowExpr *E) { 292 RecordStmtCount(E); 293 if (E->getSubExpr()) 294 Visit(E->getSubExpr()); 295 CurrentCount = 0; 296 RecordNextStmtCount = true; 297 } 298 299 void VisitGotoStmt(const GotoStmt *S) { 300 RecordStmtCount(S); 301 CurrentCount = 0; 302 RecordNextStmtCount = true; 303 } 304 305 void VisitLabelStmt(const LabelStmt *S) { 306 RecordNextStmtCount = false; 307 // Counter tracks the block following the label. 308 uint64_t BlockCount = setCount(PGO.getRegionCount(S)); 309 CountMap[S] = BlockCount; 310 Visit(S->getSubStmt()); 311 } 312 313 void VisitBreakStmt(const BreakStmt *S) { 314 RecordStmtCount(S); 315 assert(!BreakContinueStack.empty() && "break not in a loop or switch!"); 316 BreakContinueStack.back().BreakCount += CurrentCount; 317 CurrentCount = 0; 318 RecordNextStmtCount = true; 319 } 320 321 void VisitContinueStmt(const ContinueStmt *S) { 322 RecordStmtCount(S); 323 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!"); 324 BreakContinueStack.back().ContinueCount += CurrentCount; 325 CurrentCount = 0; 326 RecordNextStmtCount = true; 327 } 328 329 void VisitWhileStmt(const WhileStmt *S) { 330 RecordStmtCount(S); 331 uint64_t ParentCount = CurrentCount; 332 333 BreakContinueStack.push_back(BreakContinue()); 334 // Visit the body region first so the break/continue adjustments can be 335 // included when visiting the condition. 336 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 337 CountMap[S->getBody()] = CurrentCount; 338 Visit(S->getBody()); 339 uint64_t BackedgeCount = CurrentCount; 340 341 // ...then go back and propagate counts through the condition. The count 342 // at the start of the condition is the sum of the incoming edges, 343 // the backedge from the end of the loop body, and the edges from 344 // continue statements. 345 BreakContinue BC = BreakContinueStack.pop_back_val(); 346 uint64_t CondCount = 347 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 348 CountMap[S->getCond()] = CondCount; 349 Visit(S->getCond()); 350 setCount(BC.BreakCount + CondCount - BodyCount); 351 RecordNextStmtCount = true; 352 } 353 354 void VisitDoStmt(const DoStmt *S) { 355 RecordStmtCount(S); 356 uint64_t LoopCount = PGO.getRegionCount(S); 357 358 BreakContinueStack.push_back(BreakContinue()); 359 // The count doesn't include the fallthrough from the parent scope. Add it. 360 uint64_t BodyCount = setCount(LoopCount + CurrentCount); 361 CountMap[S->getBody()] = BodyCount; 362 Visit(S->getBody()); 363 uint64_t BackedgeCount = CurrentCount; 364 365 BreakContinue BC = BreakContinueStack.pop_back_val(); 366 // The count at the start of the condition is equal to the count at the 367 // end of the body, plus any continues. 368 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount); 369 CountMap[S->getCond()] = CondCount; 370 Visit(S->getCond()); 371 setCount(BC.BreakCount + CondCount - LoopCount); 372 RecordNextStmtCount = true; 373 } 374 375 void VisitForStmt(const ForStmt *S) { 376 RecordStmtCount(S); 377 if (S->getInit()) 378 Visit(S->getInit()); 379 380 uint64_t ParentCount = CurrentCount; 381 382 BreakContinueStack.push_back(BreakContinue()); 383 // Visit the body region first. (This is basically the same as a while 384 // loop; see further comments in VisitWhileStmt.) 385 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 386 CountMap[S->getBody()] = BodyCount; 387 Visit(S->getBody()); 388 uint64_t BackedgeCount = CurrentCount; 389 BreakContinue BC = BreakContinueStack.pop_back_val(); 390 391 // The increment is essentially part of the body but it needs to include 392 // the count for all the continue statements. 393 if (S->getInc()) { 394 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 395 CountMap[S->getInc()] = IncCount; 396 Visit(S->getInc()); 397 } 398 399 // ...then go back and propagate counts through the condition. 400 uint64_t CondCount = 401 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 402 if (S->getCond()) { 403 CountMap[S->getCond()] = CondCount; 404 Visit(S->getCond()); 405 } 406 setCount(BC.BreakCount + CondCount - BodyCount); 407 RecordNextStmtCount = true; 408 } 409 410 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { 411 RecordStmtCount(S); 412 Visit(S->getLoopVarStmt()); 413 Visit(S->getRangeStmt()); 414 Visit(S->getBeginStmt()); 415 Visit(S->getEndStmt()); 416 417 uint64_t ParentCount = CurrentCount; 418 BreakContinueStack.push_back(BreakContinue()); 419 // Visit the body region first. (This is basically the same as a while 420 // loop; see further comments in VisitWhileStmt.) 421 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 422 CountMap[S->getBody()] = BodyCount; 423 Visit(S->getBody()); 424 uint64_t BackedgeCount = CurrentCount; 425 BreakContinue BC = BreakContinueStack.pop_back_val(); 426 427 // The increment is essentially part of the body but it needs to include 428 // the count for all the continue statements. 429 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount); 430 CountMap[S->getInc()] = IncCount; 431 Visit(S->getInc()); 432 433 // ...then go back and propagate counts through the condition. 434 uint64_t CondCount = 435 setCount(ParentCount + BackedgeCount + BC.ContinueCount); 436 CountMap[S->getCond()] = CondCount; 437 Visit(S->getCond()); 438 setCount(BC.BreakCount + CondCount - BodyCount); 439 RecordNextStmtCount = true; 440 } 441 442 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { 443 RecordStmtCount(S); 444 Visit(S->getElement()); 445 uint64_t ParentCount = CurrentCount; 446 BreakContinueStack.push_back(BreakContinue()); 447 // Counter tracks the body of the loop. 448 uint64_t BodyCount = setCount(PGO.getRegionCount(S)); 449 CountMap[S->getBody()] = BodyCount; 450 Visit(S->getBody()); 451 uint64_t BackedgeCount = CurrentCount; 452 BreakContinue BC = BreakContinueStack.pop_back_val(); 453 454 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount - 455 BodyCount); 456 RecordNextStmtCount = true; 457 } 458 459 void VisitSwitchStmt(const SwitchStmt *S) { 460 RecordStmtCount(S); 461 Visit(S->getCond()); 462 CurrentCount = 0; 463 BreakContinueStack.push_back(BreakContinue()); 464 Visit(S->getBody()); 465 // If the switch is inside a loop, add the continue counts. 466 BreakContinue BC = BreakContinueStack.pop_back_val(); 467 if (!BreakContinueStack.empty()) 468 BreakContinueStack.back().ContinueCount += BC.ContinueCount; 469 // Counter tracks the exit block of the switch. 470 setCount(PGO.getRegionCount(S)); 471 RecordNextStmtCount = true; 472 } 473 474 void VisitSwitchCase(const SwitchCase *S) { 475 RecordNextStmtCount = false; 476 // Counter for this particular case. This counts only jumps from the 477 // switch header and does not include fallthrough from the case before 478 // this one. 479 uint64_t CaseCount = PGO.getRegionCount(S); 480 setCount(CurrentCount + CaseCount); 481 // We need the count without fallthrough in the mapping, so it's more useful 482 // for branch probabilities. 483 CountMap[S] = CaseCount; 484 RecordNextStmtCount = true; 485 Visit(S->getSubStmt()); 486 } 487 488 void VisitIfStmt(const IfStmt *S) { 489 RecordStmtCount(S); 490 uint64_t ParentCount = CurrentCount; 491 Visit(S->getCond()); 492 493 // Counter tracks the "then" part of an if statement. The count for 494 // the "else" part, if it exists, will be calculated from this counter. 495 uint64_t ThenCount = setCount(PGO.getRegionCount(S)); 496 CountMap[S->getThen()] = ThenCount; 497 Visit(S->getThen()); 498 uint64_t OutCount = CurrentCount; 499 500 uint64_t ElseCount = ParentCount - ThenCount; 501 if (S->getElse()) { 502 setCount(ElseCount); 503 CountMap[S->getElse()] = ElseCount; 504 Visit(S->getElse()); 505 OutCount += CurrentCount; 506 } else 507 OutCount += ElseCount; 508 setCount(OutCount); 509 RecordNextStmtCount = true; 510 } 511 512 void VisitCXXTryStmt(const CXXTryStmt *S) { 513 RecordStmtCount(S); 514 Visit(S->getTryBlock()); 515 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) 516 Visit(S->getHandler(I)); 517 // Counter tracks the continuation block of the try statement. 518 setCount(PGO.getRegionCount(S)); 519 RecordNextStmtCount = true; 520 } 521 522 void VisitCXXCatchStmt(const CXXCatchStmt *S) { 523 RecordNextStmtCount = false; 524 // Counter tracks the catch statement's handler block. 525 uint64_t CatchCount = setCount(PGO.getRegionCount(S)); 526 CountMap[S] = CatchCount; 527 Visit(S->getHandlerBlock()); 528 } 529 530 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) { 531 RecordStmtCount(E); 532 uint64_t ParentCount = CurrentCount; 533 Visit(E->getCond()); 534 535 // Counter tracks the "true" part of a conditional operator. The 536 // count in the "false" part will be calculated from this counter. 537 uint64_t TrueCount = setCount(PGO.getRegionCount(E)); 538 CountMap[E->getTrueExpr()] = TrueCount; 539 Visit(E->getTrueExpr()); 540 uint64_t OutCount = CurrentCount; 541 542 uint64_t FalseCount = setCount(ParentCount - TrueCount); 543 CountMap[E->getFalseExpr()] = FalseCount; 544 Visit(E->getFalseExpr()); 545 OutCount += CurrentCount; 546 547 setCount(OutCount); 548 RecordNextStmtCount = true; 549 } 550 551 void VisitBinLAnd(const BinaryOperator *E) { 552 RecordStmtCount(E); 553 uint64_t ParentCount = CurrentCount; 554 Visit(E->getLHS()); 555 // Counter tracks the right hand side of a logical and operator. 556 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 557 CountMap[E->getRHS()] = RHSCount; 558 Visit(E->getRHS()); 559 setCount(ParentCount + RHSCount - CurrentCount); 560 RecordNextStmtCount = true; 561 } 562 563 void VisitBinLOr(const BinaryOperator *E) { 564 RecordStmtCount(E); 565 uint64_t ParentCount = CurrentCount; 566 Visit(E->getLHS()); 567 // Counter tracks the right hand side of a logical or operator. 568 uint64_t RHSCount = setCount(PGO.getRegionCount(E)); 569 CountMap[E->getRHS()] = RHSCount; 570 Visit(E->getRHS()); 571 setCount(ParentCount + RHSCount - CurrentCount); 572 RecordNextStmtCount = true; 573 } 574 }; 575 } // end anonymous namespace 576 577 void PGOHash::combine(HashType Type) { 578 // Check that we never combine 0 and only have six bits. 579 assert(Type && "Hash is invalid: unexpected type 0"); 580 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types"); 581 582 // Pass through MD5 if enough work has built up. 583 if (Count && Count % NumTypesPerWord == 0) { 584 using namespace llvm::support; 585 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working); 586 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped))); 587 Working = 0; 588 } 589 590 // Accumulate the current type. 591 ++Count; 592 Working = Working << NumBitsPerType | Type; 593 } 594 595 uint64_t PGOHash::finalize() { 596 // Use Working as the hash directly if we never used MD5. 597 if (Count <= NumTypesPerWord) 598 // No need to byte swap here, since none of the math was endian-dependent. 599 // This number will be byte-swapped as required on endianness transitions, 600 // so we will see the same value on the other side. 601 return Working; 602 603 // Check for remaining work in Working. 604 if (Working) 605 MD5.update(Working); 606 607 // Finalize the MD5 and return the hash. 608 llvm::MD5::MD5Result Result; 609 MD5.final(Result); 610 using namespace llvm::support; 611 return endian::read<uint64_t, little, unaligned>(Result); 612 } 613 614 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) { 615 const Decl *D = GD.getDecl(); 616 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr(); 617 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 618 if (!InstrumentRegions && !PGOReader) 619 return; 620 if (D->isImplicit()) 621 return; 622 // Constructors and destructors may be represented by several functions in IR. 623 // If so, instrument only base variant, others are implemented by delegation 624 // to the base one, it would be counted twice otherwise. 625 if (CGM.getTarget().getCXXABI().hasConstructorVariants() && 626 ((isa<CXXConstructorDecl>(GD.getDecl()) && 627 GD.getCtorType() != Ctor_Base) || 628 (isa<CXXDestructorDecl>(GD.getDecl()) && 629 GD.getDtorType() != Dtor_Base))) { 630 return; 631 } 632 CGM.ClearUnusedCoverageMapping(D); 633 setFuncName(Fn); 634 635 mapRegionCounters(D); 636 if (CGM.getCodeGenOpts().CoverageMapping) 637 emitCounterRegionMapping(D); 638 if (PGOReader) { 639 SourceManager &SM = CGM.getContext().getSourceManager(); 640 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation())); 641 computeRegionCounts(D); 642 applyFunctionAttributes(PGOReader, Fn); 643 } 644 } 645 646 void CodeGenPGO::mapRegionCounters(const Decl *D) { 647 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>); 648 MapRegionCounters Walker(*RegionCounterMap); 649 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 650 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD)); 651 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 652 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD)); 653 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 654 Walker.TraverseDecl(const_cast<BlockDecl *>(BD)); 655 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 656 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD)); 657 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl"); 658 NumRegionCounters = Walker.NextCounter; 659 FunctionHash = Walker.Hash.finalize(); 660 } 661 662 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) { 663 if (SkipCoverageMapping) 664 return true; 665 666 // Don't map the functions in system headers. 667 const auto &SM = CGM.getContext().getSourceManager(); 668 auto Loc = D->getBody()->getLocStart(); 669 return SM.isInSystemHeader(Loc); 670 } 671 672 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) { 673 if (skipRegionMappingForDecl(D)) 674 return; 675 676 std::string CoverageMapping; 677 llvm::raw_string_ostream OS(CoverageMapping); 678 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 679 CGM.getContext().getSourceManager(), 680 CGM.getLangOpts(), RegionCounterMap.get()); 681 MappingGen.emitCounterMapping(D, OS); 682 OS.flush(); 683 684 if (CoverageMapping.empty()) 685 return; 686 687 CGM.getCoverageMapping()->addFunctionMappingRecord( 688 FuncNameVar, FuncName, FunctionHash, CoverageMapping); 689 } 690 691 void 692 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name, 693 llvm::GlobalValue::LinkageTypes Linkage) { 694 if (skipRegionMappingForDecl(D)) 695 return; 696 697 std::string CoverageMapping; 698 llvm::raw_string_ostream OS(CoverageMapping); 699 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(), 700 CGM.getContext().getSourceManager(), 701 CGM.getLangOpts()); 702 MappingGen.emitEmptyMapping(D, OS); 703 OS.flush(); 704 705 if (CoverageMapping.empty()) 706 return; 707 708 setFuncName(Name, Linkage); 709 CGM.getCoverageMapping()->addFunctionMappingRecord( 710 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false); 711 } 712 713 void CodeGenPGO::computeRegionCounts(const Decl *D) { 714 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>); 715 ComputeRegionCounts Walker(*StmtCountMap, *this); 716 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D)) 717 Walker.VisitFunctionDecl(FD); 718 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D)) 719 Walker.VisitObjCMethodDecl(MD); 720 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D)) 721 Walker.VisitBlockDecl(BD); 722 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D)) 723 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD)); 724 } 725 726 void 727 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader, 728 llvm::Function *Fn) { 729 if (!haveRegionCounts()) 730 return; 731 732 uint64_t FunctionCount = getRegionCount(nullptr); 733 Fn->setEntryCount(FunctionCount); 734 } 735 736 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) { 737 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap) 738 return; 739 if (!Builder.GetInsertBlock()) 740 return; 741 742 unsigned Counter = (*RegionCounterMap)[S]; 743 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext()); 744 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment), 745 {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), 746 Builder.getInt64(FunctionHash), 747 Builder.getInt32(NumRegionCounters), 748 Builder.getInt32(Counter)}); 749 } 750 751 // This method either inserts a call to the profile run-time during 752 // instrumentation or puts profile data into metadata for PGO use. 753 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind, 754 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) { 755 756 if (!EnableValueProfiling) 757 return; 758 759 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock()) 760 return; 761 762 if (isa<llvm::Constant>(ValuePtr)) 763 return; 764 765 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr(); 766 if (InstrumentValueSites && RegionCounterMap) { 767 auto BuilderInsertPoint = Builder.saveIP(); 768 Builder.SetInsertPoint(ValueSite); 769 llvm::Value *Args[5] = { 770 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()), 771 Builder.getInt64(FunctionHash), 772 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()), 773 Builder.getInt32(ValueKind), 774 Builder.getInt32(NumValueSites[ValueKind]++) 775 }; 776 Builder.CreateCall( 777 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args); 778 Builder.restoreIP(BuilderInsertPoint); 779 return; 780 } 781 782 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader(); 783 if (PGOReader && haveRegionCounts()) { 784 // We record the top most called three functions at each call site. 785 // Profile metadata contains "VP" string identifying this metadata 786 // as value profiling data, then a uint32_t value for the value profiling 787 // kind, a uint64_t value for the total number of times the call is 788 // executed, followed by the function hash and execution count (uint64_t) 789 // pairs for each function. 790 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind)) 791 return; 792 793 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord, 794 (llvm::InstrProfValueKind)ValueKind, 795 NumValueSites[ValueKind]); 796 797 NumValueSites[ValueKind]++; 798 } 799 } 800 801 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader, 802 bool IsInMainFile) { 803 CGM.getPGOStats().addVisited(IsInMainFile); 804 RegionCounts.clear(); 805 llvm::Expected<llvm::InstrProfRecord> RecordExpected = 806 PGOReader->getInstrProfRecord(FuncName, FunctionHash); 807 if (auto E = RecordExpected.takeError()) { 808 auto IPE = llvm::InstrProfError::take(std::move(E)); 809 if (IPE == llvm::instrprof_error::unknown_function) 810 CGM.getPGOStats().addMissing(IsInMainFile); 811 else if (IPE == llvm::instrprof_error::hash_mismatch) 812 CGM.getPGOStats().addMismatched(IsInMainFile); 813 else if (IPE == llvm::instrprof_error::malformed) 814 // TODO: Consider a more specific warning for this case. 815 CGM.getPGOStats().addMismatched(IsInMainFile); 816 return; 817 } 818 ProfRecord = 819 llvm::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get())); 820 RegionCounts = ProfRecord->Counts; 821 } 822 823 /// \brief Calculate what to divide by to scale weights. 824 /// 825 /// Given the maximum weight, calculate a divisor that will scale all the 826 /// weights to strictly less than UINT32_MAX. 827 static uint64_t calculateWeightScale(uint64_t MaxWeight) { 828 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1; 829 } 830 831 /// \brief Scale an individual branch weight (and add 1). 832 /// 833 /// Scale a 64-bit weight down to 32-bits using \c Scale. 834 /// 835 /// According to Laplace's Rule of Succession, it is better to compute the 836 /// weight based on the count plus 1, so universally add 1 to the value. 837 /// 838 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no 839 /// greater than \c Weight. 840 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) { 841 assert(Scale && "scale by 0?"); 842 uint64_t Scaled = Weight / Scale + 1; 843 assert(Scaled <= UINT32_MAX && "overflow 32-bits"); 844 return Scaled; 845 } 846 847 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount, 848 uint64_t FalseCount) { 849 // Check for empty weights. 850 if (!TrueCount && !FalseCount) 851 return nullptr; 852 853 // Calculate how to scale down to 32-bits. 854 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount)); 855 856 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 857 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale), 858 scaleBranchWeight(FalseCount, Scale)); 859 } 860 861 llvm::MDNode * 862 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) { 863 // We need at least two elements to create meaningful weights. 864 if (Weights.size() < 2) 865 return nullptr; 866 867 // Check for empty weights. 868 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end()); 869 if (MaxWeight == 0) 870 return nullptr; 871 872 // Calculate how to scale down to 32-bits. 873 uint64_t Scale = calculateWeightScale(MaxWeight); 874 875 SmallVector<uint32_t, 16> ScaledWeights; 876 ScaledWeights.reserve(Weights.size()); 877 for (uint64_t W : Weights) 878 ScaledWeights.push_back(scaleBranchWeight(W, Scale)); 879 880 llvm::MDBuilder MDHelper(CGM.getLLVMContext()); 881 return MDHelper.createBranchWeights(ScaledWeights); 882 } 883 884 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond, 885 uint64_t LoopCount) { 886 if (!PGO.haveRegionCounts()) 887 return nullptr; 888 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond); 889 assert(CondCount.hasValue() && "missing expected loop condition count"); 890 if (*CondCount == 0) 891 return nullptr; 892 return createProfileWeights(LoopCount, 893 std::max(*CondCount, LoopCount) - LoopCount); 894 } 895