Home | History | Annotate | Download | only in Utils
      1 #define DEBUG_TYPE "lower-expect-intrinsic"
      2 #include "llvm/Constants.h"
      3 #include "llvm/Function.h"
      4 #include "llvm/BasicBlock.h"
      5 #include "llvm/LLVMContext.h"
      6 #include "llvm/Instructions.h"
      7 #include "llvm/Intrinsics.h"
      8 #include "llvm/Metadata.h"
      9 #include "llvm/Pass.h"
     10 #include "llvm/Transforms/Scalar.h"
     11 #include "llvm/Support/CommandLine.h"
     12 #include "llvm/Support/Debug.h"
     13 #include "llvm/ADT/Statistic.h"
     14 #include <vector>
     15 
     16 using namespace llvm;
     17 
     18 STATISTIC(IfHandled, "Number of 'expect' intrinsic intructions handled");
     19 
     20 static cl::opt<uint32_t>
     21 LikelyBranchWeight("likely-branch-weight", cl::Hidden, cl::init(64),
     22                    cl::desc("Weight of the branch likely to be taken (default = 64)"));
     23 static cl::opt<uint32_t>
     24 UnlikelyBranchWeight("unlikely-branch-weight", cl::Hidden, cl::init(4),
     25                    cl::desc("Weight of the branch unlikely to be taken (default = 4)"));
     26 
     27 namespace {
     28 
     29   class LowerExpectIntrinsic : public FunctionPass {
     30 
     31     bool HandleSwitchExpect(SwitchInst *SI);
     32 
     33     bool HandleIfExpect(BranchInst *BI);
     34 
     35   public:
     36     static char ID;
     37     LowerExpectIntrinsic() : FunctionPass(ID) {
     38       initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry());
     39     }
     40 
     41     bool runOnFunction(Function &F);
     42   };
     43 }
     44 
     45 
     46 bool LowerExpectIntrinsic::HandleSwitchExpect(SwitchInst *SI) {
     47   CallInst *CI = dyn_cast<CallInst>(SI->getCondition());
     48   if (!CI)
     49     return false;
     50 
     51   Function *Fn = CI->getCalledFunction();
     52   if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
     53     return false;
     54 
     55   Value *ArgValue = CI->getArgOperand(0);
     56   ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
     57   if (!ExpectedValue)
     58     return false;
     59 
     60   LLVMContext &Context = CI->getContext();
     61   Type *Int32Ty = Type::getInt32Ty(Context);
     62 
     63   unsigned caseNo = SI->findCaseValue(ExpectedValue);
     64   std::vector<Value *> Vec;
     65   unsigned n = SI->getNumCases();
     66   Vec.resize(n + 1); // +1 for MDString
     67 
     68   Vec[0] = MDString::get(Context, "branch_weights");
     69   for (unsigned i = 0; i < n; ++i) {
     70     Vec[i + 1] = ConstantInt::get(Int32Ty, i == caseNo ? LikelyBranchWeight : UnlikelyBranchWeight);
     71   }
     72 
     73   MDNode *WeightsNode = llvm::MDNode::get(Context, Vec);
     74   SI->setMetadata(LLVMContext::MD_prof, WeightsNode);
     75 
     76   SI->setCondition(ArgValue);
     77   return true;
     78 }
     79 
     80 
     81 bool LowerExpectIntrinsic::HandleIfExpect(BranchInst *BI) {
     82   if (BI->isUnconditional())
     83     return false;
     84 
     85   // Handle non-optimized IR code like:
     86   //   %expval = call i64 @llvm.expect.i64.i64(i64 %conv1, i64 1)
     87   //   %tobool = icmp ne i64 %expval, 0
     88   //   br i1 %tobool, label %if.then, label %if.end
     89 
     90   ICmpInst *CmpI = dyn_cast<ICmpInst>(BI->getCondition());
     91   if (!CmpI || CmpI->getPredicate() != CmpInst::ICMP_NE)
     92     return false;
     93 
     94   CallInst *CI = dyn_cast<CallInst>(CmpI->getOperand(0));
     95   if (!CI)
     96     return false;
     97 
     98   Function *Fn = CI->getCalledFunction();
     99   if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect)
    100     return false;
    101 
    102   Value *ArgValue = CI->getArgOperand(0);
    103   ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1));
    104   if (!ExpectedValue)
    105     return false;
    106 
    107   LLVMContext &Context = CI->getContext();
    108   Type *Int32Ty = Type::getInt32Ty(Context);
    109   bool Likely = ExpectedValue->isOne();
    110 
    111   // If expect value is equal to 1 it means that we are more likely to take
    112   // branch 0, in other case more likely is branch 1.
    113   Value *Ops[] = {
    114     MDString::get(Context, "branch_weights"),
    115     ConstantInt::get(Int32Ty, Likely ? LikelyBranchWeight : UnlikelyBranchWeight),
    116     ConstantInt::get(Int32Ty, Likely ? UnlikelyBranchWeight : LikelyBranchWeight)
    117   };
    118 
    119   MDNode *WeightsNode = MDNode::get(Context, Ops);
    120   BI->setMetadata(LLVMContext::MD_prof, WeightsNode);
    121 
    122   CmpI->setOperand(0, ArgValue);
    123   return true;
    124 }
    125 
    126 
    127 bool LowerExpectIntrinsic::runOnFunction(Function &F) {
    128   for (Function::iterator I = F.begin(), E = F.end(); I != E;) {
    129     BasicBlock *BB = I++;
    130 
    131     // Create "block_weights" metadata.
    132     if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) {
    133       if (HandleIfExpect(BI))
    134         IfHandled++;
    135     } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) {
    136       if (HandleSwitchExpect(SI))
    137         IfHandled++;
    138     }
    139 
    140     // remove llvm.expect intrinsics.
    141     for (BasicBlock::iterator BI = BB->begin(), BE = BB->end();
    142          BI != BE; ) {
    143       CallInst *CI = dyn_cast<CallInst>(BI++);
    144       if (!CI)
    145         continue;
    146 
    147       Function *Fn = CI->getCalledFunction();
    148       if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) {
    149         Value *Exp = CI->getArgOperand(0);
    150         CI->replaceAllUsesWith(Exp);
    151         CI->eraseFromParent();
    152       }
    153     }
    154   }
    155 
    156   return false;
    157 }
    158 
    159 
    160 char LowerExpectIntrinsic::ID = 0;
    161 INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect", "Lower 'expect' "
    162                 "Intrinsics", false, false)
    163 
    164 FunctionPass *llvm::createLowerExpectIntrinsicPass() {
    165   return new LowerExpectIntrinsic();
    166 }
    167