Home | History | Annotate | Download | only in Support
      1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // Definition of BranchProbability shared by IR and Machine Instructions.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
     15 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
     16 
     17 #include "llvm/Support/DataTypes.h"
     18 #include <algorithm>
     19 #include <cassert>
     20 #include <climits>
     21 #include <numeric>
     22 
     23 namespace llvm {
     24 
     25 class raw_ostream;
     26 
     27 // This class represents Branch Probability as a non-negative fraction that is
     28 // no greater than 1. It uses a fixed-point-like implementation, in which the
     29 // denominator is always a constant value (here we use 1<<31 for maximum
     30 // precision).
     31 class BranchProbability {
     32   // Numerator
     33   uint32_t N;
     34 
     35   // Denominator, which is a constant value.
     36   static const uint32_t D = 1u << 31;
     37   static const uint32_t UnknownN = UINT32_MAX;
     38 
     39   // Construct a BranchProbability with only numerator assuming the denominator
     40   // is 1<<31. For internal use only.
     41   explicit BranchProbability(uint32_t n) : N(n) {}
     42 
     43 public:
     44   BranchProbability() : N(UnknownN) {}
     45   BranchProbability(uint32_t Numerator, uint32_t Denominator);
     46 
     47   bool isZero() const { return N == 0; }
     48   bool isUnknown() const { return N == UnknownN; }
     49 
     50   static BranchProbability getZero() { return BranchProbability(0); }
     51   static BranchProbability getOne() { return BranchProbability(D); }
     52   static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
     53   // Create a BranchProbability object with the given numerator and 1<<31
     54   // as denominator.
     55   static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
     56   // Create a BranchProbability object from 64-bit integers.
     57   static BranchProbability getBranchProbability(uint64_t Numerator,
     58                                                 uint64_t Denominator);
     59 
     60   // Normalize given probabilties so that the sum of them becomes approximate
     61   // one.
     62   template <class ProbabilityIter>
     63   static void normalizeProbabilities(ProbabilityIter Begin,
     64                                      ProbabilityIter End);
     65 
     66   uint32_t getNumerator() const { return N; }
     67   static uint32_t getDenominator() { return D; }
     68 
     69   // Return (1 - Probability).
     70   BranchProbability getCompl() const { return BranchProbability(D - N); }
     71 
     72   raw_ostream &print(raw_ostream &OS) const;
     73 
     74   void dump() const;
     75 
     76   /// \brief Scale a large integer.
     77   ///
     78   /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
     79   /// result.
     80   ///
     81   /// \return \c Num times \c this.
     82   uint64_t scale(uint64_t Num) const;
     83 
     84   /// \brief Scale a large integer by the inverse.
     85   ///
     86   /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
     87   /// Returns the floor of the result.
     88   ///
     89   /// \return \c Num divided by \c this.
     90   uint64_t scaleByInverse(uint64_t Num) const;
     91 
     92   BranchProbability &operator+=(BranchProbability RHS) {
     93     assert(N != UnknownN && RHS.N != UnknownN &&
     94            "Unknown probability cannot participate in arithmetics.");
     95     // Saturate the result in case of overflow.
     96     N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
     97     return *this;
     98   }
     99 
    100   BranchProbability &operator-=(BranchProbability RHS) {
    101     assert(N != UnknownN && RHS.N != UnknownN &&
    102            "Unknown probability cannot participate in arithmetics.");
    103     // Saturate the result in case of underflow.
    104     N = N < RHS.N ? 0 : N - RHS.N;
    105     return *this;
    106   }
    107 
    108   BranchProbability &operator*=(BranchProbability RHS) {
    109     assert(N != UnknownN && RHS.N != UnknownN &&
    110            "Unknown probability cannot participate in arithmetics.");
    111     N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
    112     return *this;
    113   }
    114 
    115   BranchProbability &operator*=(uint32_t RHS) {
    116     assert(N != UnknownN &&
    117            "Unknown probability cannot participate in arithmetics.");
    118     N = (uint64_t(N) * RHS > D) ? D : N * RHS;
    119     return *this;
    120   }
    121 
    122   BranchProbability &operator/=(uint32_t RHS) {
    123     assert(N != UnknownN &&
    124            "Unknown probability cannot participate in arithmetics.");
    125     assert(RHS > 0 && "The divider cannot be zero.");
    126     N /= RHS;
    127     return *this;
    128   }
    129 
    130   BranchProbability operator+(BranchProbability RHS) const {
    131     BranchProbability Prob(*this);
    132     return Prob += RHS;
    133   }
    134 
    135   BranchProbability operator-(BranchProbability RHS) const {
    136     BranchProbability Prob(*this);
    137     return Prob -= RHS;
    138   }
    139 
    140   BranchProbability operator*(BranchProbability RHS) const {
    141     BranchProbability Prob(*this);
    142     return Prob *= RHS;
    143   }
    144 
    145   BranchProbability operator*(uint32_t RHS) const {
    146     BranchProbability Prob(*this);
    147     return Prob *= RHS;
    148   }
    149 
    150   BranchProbability operator/(uint32_t RHS) const {
    151     BranchProbability Prob(*this);
    152     return Prob /= RHS;
    153   }
    154 
    155   bool operator==(BranchProbability RHS) const { return N == RHS.N; }
    156   bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
    157 
    158   bool operator<(BranchProbability RHS) const {
    159     assert(N != UnknownN && RHS.N != UnknownN &&
    160            "Unknown probability cannot participate in comparisons.");
    161     return N < RHS.N;
    162   }
    163 
    164   bool operator>(BranchProbability RHS) const {
    165     assert(N != UnknownN && RHS.N != UnknownN &&
    166            "Unknown probability cannot participate in comparisons.");
    167     return RHS < *this;
    168   }
    169 
    170   bool operator<=(BranchProbability RHS) const {
    171     assert(N != UnknownN && RHS.N != UnknownN &&
    172            "Unknown probability cannot participate in comparisons.");
    173     return !(RHS < *this);
    174   }
    175 
    176   bool operator>=(BranchProbability RHS) const {
    177     assert(N != UnknownN && RHS.N != UnknownN &&
    178            "Unknown probability cannot participate in comparisons.");
    179     return !(*this < RHS);
    180   }
    181 };
    182 
    183 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
    184   return Prob.print(OS);
    185 }
    186 
    187 template <class ProbabilityIter>
    188 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
    189                                                ProbabilityIter End) {
    190   if (Begin == End)
    191     return;
    192 
    193   unsigned UnknownProbCount = 0;
    194   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
    195                                  [&](uint64_t S, const BranchProbability &BP) {
    196                                    if (!BP.isUnknown())
    197                                      return S + BP.N;
    198                                    UnknownProbCount++;
    199                                    return S;
    200                                  });
    201 
    202   if (UnknownProbCount > 0) {
    203     BranchProbability ProbForUnknown = BranchProbability::getZero();
    204     // If the sum of all known probabilities is less than one, evenly distribute
    205     // the complement of sum to unknown probabilities. Otherwise, set unknown
    206     // probabilities to zeros and continue to normalize known probabilities.
    207     if (Sum < BranchProbability::getDenominator())
    208       ProbForUnknown = BranchProbability::getRaw(
    209           (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
    210 
    211     std::replace_if(Begin, End,
    212                     [](const BranchProbability &BP) { return BP.isUnknown(); },
    213                     ProbForUnknown);
    214 
    215     if (Sum <= BranchProbability::getDenominator())
    216       return;
    217   }
    218 
    219   if (Sum == 0) {
    220     BranchProbability BP(1, std::distance(Begin, End));
    221     std::fill(Begin, End, BP);
    222     return;
    223   }
    224 
    225   for (auto I = Begin; I != End; ++I)
    226     I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
    227 }
    228 
    229 }
    230 
    231 #endif
    232