Home | History | Annotate | Download | only in Support
      1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // Definition of BranchProbability shared by IR and Machine Instructions.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
     15 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
     16 
     17 #include "llvm/Support/DataTypes.h"
     18 #include <algorithm>
     19 #include <cassert>
     20 #include <climits>
     21 #include <numeric>
     22 
     23 namespace llvm {
     24 
     25 class raw_ostream;
     26 
     27 // This class represents Branch Probability as a non-negative fraction that is
     28 // no greater than 1. It uses a fixed-point-like implementation, in which the
     29 // denominator is always a constant value (here we use 1<<31 for maximum
     30 // precision).
     31 class BranchProbability {
     32   // Numerator
     33   uint32_t N;
     34 
     35   // Denominator, which is a constant value.
     36   static const uint32_t D = 1u << 31;
     37   static const uint32_t UnknownN = UINT32_MAX;
     38 
     39   // Construct a BranchProbability with only numerator assuming the denominator
     40   // is 1<<31. For internal use only.
     41   explicit BranchProbability(uint32_t n) : N(n) {}
     42 
     43 public:
     44   BranchProbability() : N(UnknownN) {}
     45   BranchProbability(uint32_t Numerator, uint32_t Denominator);
     46 
     47   bool isZero() const { return N == 0; }
     48   bool isUnknown() const { return N == UnknownN; }
     49 
     50   static BranchProbability getZero() { return BranchProbability(0); }
     51   static BranchProbability getOne() { return BranchProbability(D); }
     52   static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
     53   // Create a BranchProbability object with the given numerator and 1<<31
     54   // as denominator.
     55   static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
     56   // Create a BranchProbability object from 64-bit integers.
     57   static BranchProbability getBranchProbability(uint64_t Numerator,
     58                                                 uint64_t Denominator);
     59 
     60   // Normalize given probabilties so that the sum of them becomes approximate
     61   // one.
     62   template <class ProbabilityIter>
     63   static void normalizeProbabilities(ProbabilityIter Begin,
     64                                      ProbabilityIter End);
     65 
     66   // Normalize a list of weights by scaling them down so that the sum of them
     67   // doesn't exceed UINT32_MAX.
     68   template <class WeightListIter>
     69   static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End);
     70 
     71   uint32_t getNumerator() const { return N; }
     72   static uint32_t getDenominator() { return D; }
     73 
     74   // Return (1 - Probability).
     75   BranchProbability getCompl() const { return BranchProbability(D - N); }
     76 
     77   raw_ostream &print(raw_ostream &OS) const;
     78 
     79   void dump() const;
     80 
     81   /// \brief Scale a large integer.
     82   ///
     83   /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
     84   /// result.
     85   ///
     86   /// \return \c Num times \c this.
     87   uint64_t scale(uint64_t Num) const;
     88 
     89   /// \brief Scale a large integer by the inverse.
     90   ///
     91   /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
     92   /// Returns the floor of the result.
     93   ///
     94   /// \return \c Num divided by \c this.
     95   uint64_t scaleByInverse(uint64_t Num) const;
     96 
     97   BranchProbability &operator+=(BranchProbability RHS) {
     98     assert(N != UnknownN && RHS.N != UnknownN &&
     99            "Unknown probability cannot participate in arithmetics.");
    100     // Saturate the result in case of overflow.
    101     N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
    102     return *this;
    103   }
    104 
    105   BranchProbability &operator-=(BranchProbability RHS) {
    106     assert(N != UnknownN && RHS.N != UnknownN &&
    107            "Unknown probability cannot participate in arithmetics.");
    108     // Saturate the result in case of underflow.
    109     N = N < RHS.N ? 0 : N - RHS.N;
    110     return *this;
    111   }
    112 
    113   BranchProbability &operator*=(BranchProbability RHS) {
    114     assert(N != UnknownN && RHS.N != UnknownN &&
    115            "Unknown probability cannot participate in arithmetics.");
    116     N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
    117     return *this;
    118   }
    119 
    120   BranchProbability &operator/=(uint32_t RHS) {
    121     assert(N != UnknownN &&
    122            "Unknown probability cannot participate in arithmetics.");
    123     assert(RHS > 0 && "The divider cannot be zero.");
    124     N /= RHS;
    125     return *this;
    126   }
    127 
    128   BranchProbability operator+(BranchProbability RHS) const {
    129     BranchProbability Prob(*this);
    130     return Prob += RHS;
    131   }
    132 
    133   BranchProbability operator-(BranchProbability RHS) const {
    134     BranchProbability Prob(*this);
    135     return Prob -= RHS;
    136   }
    137 
    138   BranchProbability operator*(BranchProbability RHS) const {
    139     BranchProbability Prob(*this);
    140     return Prob *= RHS;
    141   }
    142 
    143   BranchProbability operator/(uint32_t RHS) const {
    144     BranchProbability Prob(*this);
    145     return Prob /= RHS;
    146   }
    147 
    148   bool operator==(BranchProbability RHS) const { return N == RHS.N; }
    149   bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
    150 
    151   bool operator<(BranchProbability RHS) const {
    152     assert(N != UnknownN && RHS.N != UnknownN &&
    153            "Unknown probability cannot participate in comparisons.");
    154     return N < RHS.N;
    155   }
    156 
    157   bool operator>(BranchProbability RHS) const {
    158     assert(N != UnknownN && RHS.N != UnknownN &&
    159            "Unknown probability cannot participate in comparisons.");
    160     return RHS < *this;
    161   }
    162 
    163   bool operator<=(BranchProbability RHS) const {
    164     assert(N != UnknownN && RHS.N != UnknownN &&
    165            "Unknown probability cannot participate in comparisons.");
    166     return !(RHS < *this);
    167   }
    168 
    169   bool operator>=(BranchProbability RHS) const {
    170     assert(N != UnknownN && RHS.N != UnknownN &&
    171            "Unknown probability cannot participate in comparisons.");
    172     return !(*this < RHS);
    173   }
    174 };
    175 
    176 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
    177   return Prob.print(OS);
    178 }
    179 
    180 template <class ProbabilityIter>
    181 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
    182                                                ProbabilityIter End) {
    183   if (Begin == End)
    184     return;
    185 
    186   unsigned UnknownProbCount = 0;
    187   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
    188                                  [&](uint64_t S, const BranchProbability &BP) {
    189                                    if (!BP.isUnknown())
    190                                      return S + BP.N;
    191                                    UnknownProbCount++;
    192                                    return S;
    193                                  });
    194 
    195   if (UnknownProbCount > 0) {
    196     BranchProbability ProbForUnknown = BranchProbability::getZero();
    197     // If the sum of all known probabilities is less than one, evenly distribute
    198     // the complement of sum to unknown probabilities. Otherwise, set unknown
    199     // probabilities to zeros and continue to normalize known probabilities.
    200     if (Sum < BranchProbability::getDenominator())
    201       ProbForUnknown = BranchProbability::getRaw(
    202           (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
    203 
    204     std::replace_if(Begin, End,
    205                     [](const BranchProbability &BP) { return BP.isUnknown(); },
    206                     ProbForUnknown);
    207 
    208     if (Sum <= BranchProbability::getDenominator())
    209       return;
    210   }
    211 
    212   if (Sum == 0) {
    213     BranchProbability BP(1, std::distance(Begin, End));
    214     std::fill(Begin, End, BP);
    215     return;
    216   }
    217 
    218   for (auto I = Begin; I != End; ++I)
    219     I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
    220 }
    221 
    222 template <class WeightListIter>
    223 void BranchProbability::normalizeEdgeWeights(WeightListIter Begin,
    224                                              WeightListIter End) {
    225   // First we compute the sum with 64-bits of precision.
    226   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0));
    227 
    228   if (Sum > UINT32_MAX) {
    229     // Compute the scale necessary to cause the weights to fit, and re-sum with
    230     // that scale applied.
    231     assert(Sum / UINT32_MAX < UINT32_MAX &&
    232            "The sum of weights exceeds UINT32_MAX^2!");
    233     uint32_t Scale = Sum / UINT32_MAX + 1;
    234     for (auto I = Begin; I != End; ++I)
    235       *I /= Scale;
    236     Sum = std::accumulate(Begin, End, uint64_t(0));
    237   }
    238 
    239   // Eliminate zero weights.
    240   auto ZeroWeightNum = std::count(Begin, End, 0u);
    241   if (ZeroWeightNum > 0) {
    242     // If all weights are zeros, replace them by 1.
    243     if (Sum == 0)
    244       std::fill(Begin, End, 1u);
    245     else {
    246       // We are converting zeros into ones, and here we need to make sure that
    247       // after this the sum won't exceed UINT32_MAX.
    248       if (Sum + ZeroWeightNum > UINT32_MAX) {
    249         for (auto I = Begin; I != End; ++I)
    250           *I /= 2;
    251         ZeroWeightNum = std::count(Begin, End, 0u);
    252         Sum = std::accumulate(Begin, End, uint64_t(0));
    253       }
    254       // Scale up non-zero weights and turn zero weights into ones.
    255       uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum;
    256       assert(ScalingFactor >= 1);
    257       if (ScalingFactor > 1)
    258         for (auto I = Begin; I != End; ++I)
    259           *I *= ScalingFactor;
    260       std::replace(Begin, End, 0u, 1u);
    261     }
    262   }
    263 }
    264 
    265 }
    266 
    267 #endif
    268