1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Definition of BranchProbability shared by IR and Machine Instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H 15 #define LLVM_SUPPORT_BRANCHPROBABILITY_H 16 17 #include "llvm/Support/DataTypes.h" 18 #include <algorithm> 19 #include <cassert> 20 #include <climits> 21 #include <numeric> 22 23 namespace llvm { 24 25 class raw_ostream; 26 27 // This class represents Branch Probability as a non-negative fraction that is 28 // no greater than 1. It uses a fixed-point-like implementation, in which the 29 // denominator is always a constant value (here we use 1<<31 for maximum 30 // precision). 31 class BranchProbability { 32 // Numerator 33 uint32_t N; 34 35 // Denominator, which is a constant value. 36 static const uint32_t D = 1u << 31; 37 static const uint32_t UnknownN = UINT32_MAX; 38 39 // Construct a BranchProbability with only numerator assuming the denominator 40 // is 1<<31. For internal use only. 41 explicit BranchProbability(uint32_t n) : N(n) {} 42 43 public: 44 BranchProbability() : N(UnknownN) {} 45 BranchProbability(uint32_t Numerator, uint32_t Denominator); 46 47 bool isZero() const { return N == 0; } 48 bool isUnknown() const { return N == UnknownN; } 49 50 static BranchProbability getZero() { return BranchProbability(0); } 51 static BranchProbability getOne() { return BranchProbability(D); } 52 static BranchProbability getUnknown() { return BranchProbability(UnknownN); } 53 // Create a BranchProbability object with the given numerator and 1<<31 54 // as denominator. 55 static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); } 56 // Create a BranchProbability object from 64-bit integers. 57 static BranchProbability getBranchProbability(uint64_t Numerator, 58 uint64_t Denominator); 59 60 // Normalize given probabilties so that the sum of them becomes approximate 61 // one. 62 template <class ProbabilityIter> 63 static void normalizeProbabilities(ProbabilityIter Begin, 64 ProbabilityIter End); 65 66 uint32_t getNumerator() const { return N; } 67 static uint32_t getDenominator() { return D; } 68 69 // Return (1 - Probability). 70 BranchProbability getCompl() const { return BranchProbability(D - N); } 71 72 raw_ostream &print(raw_ostream &OS) const; 73 74 void dump() const; 75 76 /// \brief Scale a large integer. 77 /// 78 /// Scales \c Num. Guarantees full precision. Returns the floor of the 79 /// result. 80 /// 81 /// \return \c Num times \c this. 82 uint64_t scale(uint64_t Num) const; 83 84 /// \brief Scale a large integer by the inverse. 85 /// 86 /// Scales \c Num by the inverse of \c this. Guarantees full precision. 87 /// Returns the floor of the result. 88 /// 89 /// \return \c Num divided by \c this. 90 uint64_t scaleByInverse(uint64_t Num) const; 91 92 BranchProbability &operator+=(BranchProbability RHS) { 93 assert(N != UnknownN && RHS.N != UnknownN && 94 "Unknown probability cannot participate in arithmetics."); 95 // Saturate the result in case of overflow. 96 N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N; 97 return *this; 98 } 99 100 BranchProbability &operator-=(BranchProbability RHS) { 101 assert(N != UnknownN && RHS.N != UnknownN && 102 "Unknown probability cannot participate in arithmetics."); 103 // Saturate the result in case of underflow. 104 N = N < RHS.N ? 0 : N - RHS.N; 105 return *this; 106 } 107 108 BranchProbability &operator*=(BranchProbability RHS) { 109 assert(N != UnknownN && RHS.N != UnknownN && 110 "Unknown probability cannot participate in arithmetics."); 111 N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D; 112 return *this; 113 } 114 115 BranchProbability &operator*=(uint32_t RHS) { 116 assert(N != UnknownN && 117 "Unknown probability cannot participate in arithmetics."); 118 N = (uint64_t(N) * RHS > D) ? D : N * RHS; 119 return *this; 120 } 121 122 BranchProbability &operator/=(uint32_t RHS) { 123 assert(N != UnknownN && 124 "Unknown probability cannot participate in arithmetics."); 125 assert(RHS > 0 && "The divider cannot be zero."); 126 N /= RHS; 127 return *this; 128 } 129 130 BranchProbability operator+(BranchProbability RHS) const { 131 BranchProbability Prob(*this); 132 return Prob += RHS; 133 } 134 135 BranchProbability operator-(BranchProbability RHS) const { 136 BranchProbability Prob(*this); 137 return Prob -= RHS; 138 } 139 140 BranchProbability operator*(BranchProbability RHS) const { 141 BranchProbability Prob(*this); 142 return Prob *= RHS; 143 } 144 145 BranchProbability operator*(uint32_t RHS) const { 146 BranchProbability Prob(*this); 147 return Prob *= RHS; 148 } 149 150 BranchProbability operator/(uint32_t RHS) const { 151 BranchProbability Prob(*this); 152 return Prob /= RHS; 153 } 154 155 bool operator==(BranchProbability RHS) const { return N == RHS.N; } 156 bool operator!=(BranchProbability RHS) const { return !(*this == RHS); } 157 158 bool operator<(BranchProbability RHS) const { 159 assert(N != UnknownN && RHS.N != UnknownN && 160 "Unknown probability cannot participate in comparisons."); 161 return N < RHS.N; 162 } 163 164 bool operator>(BranchProbability RHS) const { 165 assert(N != UnknownN && RHS.N != UnknownN && 166 "Unknown probability cannot participate in comparisons."); 167 return RHS < *this; 168 } 169 170 bool operator<=(BranchProbability RHS) const { 171 assert(N != UnknownN && RHS.N != UnknownN && 172 "Unknown probability cannot participate in comparisons."); 173 return !(RHS < *this); 174 } 175 176 bool operator>=(BranchProbability RHS) const { 177 assert(N != UnknownN && RHS.N != UnknownN && 178 "Unknown probability cannot participate in comparisons."); 179 return !(*this < RHS); 180 } 181 }; 182 183 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) { 184 return Prob.print(OS); 185 } 186 187 template <class ProbabilityIter> 188 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin, 189 ProbabilityIter End) { 190 if (Begin == End) 191 return; 192 193 unsigned UnknownProbCount = 0; 194 uint64_t Sum = std::accumulate(Begin, End, uint64_t(0), 195 [&](uint64_t S, const BranchProbability &BP) { 196 if (!BP.isUnknown()) 197 return S + BP.N; 198 UnknownProbCount++; 199 return S; 200 }); 201 202 if (UnknownProbCount > 0) { 203 BranchProbability ProbForUnknown = BranchProbability::getZero(); 204 // If the sum of all known probabilities is less than one, evenly distribute 205 // the complement of sum to unknown probabilities. Otherwise, set unknown 206 // probabilities to zeros and continue to normalize known probabilities. 207 if (Sum < BranchProbability::getDenominator()) 208 ProbForUnknown = BranchProbability::getRaw( 209 (BranchProbability::getDenominator() - Sum) / UnknownProbCount); 210 211 std::replace_if(Begin, End, 212 [](const BranchProbability &BP) { return BP.isUnknown(); }, 213 ProbForUnknown); 214 215 if (Sum <= BranchProbability::getDenominator()) 216 return; 217 } 218 219 if (Sum == 0) { 220 BranchProbability BP(1, std::distance(Begin, End)); 221 std::fill(Begin, End, BP); 222 return; 223 } 224 225 for (auto I = Begin; I != End; ++I) 226 I->N = (I->N * uint64_t(D) + Sum / 2) / Sum; 227 } 228 229 } 230 231 #endif 232