Home | History | Annotate | Download | only in IPO
      1 //===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This pass exports all llvm.bitset's found in the module in the form of a
     11 // __cfi_check function, which can be used to verify cross-DSO call targets.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #include "llvm/Transforms/IPO.h"
     16 #include "llvm/ADT/DenseSet.h"
     17 #include "llvm/ADT/EquivalenceClasses.h"
     18 #include "llvm/ADT/Statistic.h"
     19 #include "llvm/IR/Constant.h"
     20 #include "llvm/IR/Constants.h"
     21 #include "llvm/IR/Function.h"
     22 #include "llvm/IR/GlobalObject.h"
     23 #include "llvm/IR/GlobalVariable.h"
     24 #include "llvm/IR/IRBuilder.h"
     25 #include "llvm/IR/Instructions.h"
     26 #include "llvm/IR/Intrinsics.h"
     27 #include "llvm/IR/MDBuilder.h"
     28 #include "llvm/IR/Module.h"
     29 #include "llvm/IR/Operator.h"
     30 #include "llvm/Pass.h"
     31 #include "llvm/Support/Debug.h"
     32 #include "llvm/Support/raw_ostream.h"
     33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     34 
     35 using namespace llvm;
     36 
     37 #define DEBUG_TYPE "cross-dso-cfi"
     38 
     39 STATISTIC(TypeIds, "Number of unique type identifiers");
     40 
     41 namespace {
     42 
     43 struct CrossDSOCFI : public ModulePass {
     44   static char ID;
     45   CrossDSOCFI() : ModulePass(ID) {
     46     initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry());
     47   }
     48 
     49   Module *M;
     50   MDNode *VeryLikelyWeights;
     51 
     52   ConstantInt *extractBitSetTypeId(MDNode *MD);
     53   void buildCFICheck();
     54 
     55   bool doInitialization(Module &M) override;
     56   bool runOnModule(Module &M) override;
     57 };
     58 
     59 } // anonymous namespace
     60 
     61 INITIALIZE_PASS_BEGIN(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false,
     62                       false)
     63 INITIALIZE_PASS_END(CrossDSOCFI, "cross-dso-cfi", "Cross-DSO CFI", false, false)
     64 char CrossDSOCFI::ID = 0;
     65 
     66 ModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; }
     67 
     68 bool CrossDSOCFI::doInitialization(Module &Mod) {
     69   M = &Mod;
     70   VeryLikelyWeights =
     71       MDBuilder(M->getContext()).createBranchWeights((1U << 20) - 1, 1);
     72 
     73   return false;
     74 }
     75 
     76 /// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode.
     77 ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) {
     78   // This check excludes vtables for classes inside anonymous namespaces.
     79   auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0));
     80   if (!TM)
     81     return nullptr;
     82   auto C = dyn_cast_or_null<ConstantInt>(TM->getValue());
     83   if (!C) return nullptr;
     84   // We are looking for i64 constants.
     85   if (C->getBitWidth() != 64) return nullptr;
     86 
     87   // Sanity check.
     88   auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1));
     89   // Can be null if a function was removed by an optimization.
     90   if (FM) {
     91     auto F = dyn_cast<Function>(FM->getValue());
     92     // But can never be a function declaration.
     93     assert(!F || !F->isDeclaration());
     94     (void)F; // Suppress unused variable warning in the no-asserts build.
     95   }
     96   return C;
     97 }
     98 
     99 /// buildCFICheck - emits __cfi_check for the current module.
    100 void CrossDSOCFI::buildCFICheck() {
    101   // FIXME: verify that __cfi_check ends up near the end of the code section,
    102   // but before the jump slots created in LowerBitSets.
    103   llvm::DenseSet<uint64_t> BitSetIds;
    104   NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets");
    105 
    106   if (BitSetNM)
    107     for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I)
    108       if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I)))
    109         BitSetIds.insert(TypeId->getZExtValue());
    110 
    111   LLVMContext &Ctx = M->getContext();
    112   Constant *C = M->getOrInsertFunction(
    113       "__cfi_check",
    114       FunctionType::get(
    115           Type::getVoidTy(Ctx),
    116           {Type::getInt64Ty(Ctx), PointerType::getUnqual(Type::getInt8Ty(Ctx))},
    117           false));
    118   Function *F = dyn_cast<Function>(C);
    119   F->setAlignment(4096);
    120   auto args = F->arg_begin();
    121   Argument &CallSiteTypeId = *(args++);
    122   CallSiteTypeId.setName("CallSiteTypeId");
    123   Argument &Addr = *(args++);
    124   Addr.setName("Addr");
    125   assert(args == F->arg_end());
    126 
    127   BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
    128 
    129   BasicBlock *TrapBB = BasicBlock::Create(Ctx, "trap", F);
    130   IRBuilder<> IRBTrap(TrapBB);
    131   Function *TrapFn = Intrinsic::getDeclaration(M, Intrinsic::trap);
    132   llvm::CallInst *TrapCall = IRBTrap.CreateCall(TrapFn);
    133   TrapCall->setDoesNotReturn();
    134   TrapCall->setDoesNotThrow();
    135   IRBTrap.CreateUnreachable();
    136 
    137   BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F);
    138   IRBuilder<> IRBExit(ExitBB);
    139   IRBExit.CreateRetVoid();
    140 
    141   IRBuilder<> IRB(BB);
    142   SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size());
    143   for (uint64_t TypeId : BitSetIds) {
    144     ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId);
    145     BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F);
    146     IRBuilder<> IRBTest(TestBB);
    147     Function *BitsetTestFn =
    148         Intrinsic::getDeclaration(M, Intrinsic::bitset_test);
    149 
    150     Value *Test = IRBTest.CreateCall(
    151         BitsetTestFn, {&Addr, MetadataAsValue::get(
    152                                   Ctx, ConstantAsMetadata::get(CaseTypeId))});
    153     BranchInst *BI = IRBTest.CreateCondBr(Test, ExitBB, TrapBB);
    154     BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights);
    155 
    156     SI->addCase(CaseTypeId, TestBB);
    157     ++TypeIds;
    158   }
    159 }
    160 
    161 bool CrossDSOCFI::runOnModule(Module &M) {
    162   if (M.getModuleFlag("Cross-DSO CFI") == nullptr)
    163     return false;
    164   buildCFICheck();
    165   return true;
    166 }
    167