Home | History | Annotate | Download | only in CodeGen
      1 //===-- JumpInstrTables.cpp: Jump-Instruction Tables ----------------------===//
      2 //
      3 // This file is distributed under the University of Illinois Open Source
      4 // License. See LICENSE.TXT for details.
      5 //
      6 //===----------------------------------------------------------------------===//
      7 ///
      8 /// \file
      9 /// \brief An implementation of jump-instruction tables.
     10 ///
     11 //===----------------------------------------------------------------------===//
     12 
     13 #define DEBUG_TYPE "jt"
     14 
     15 #include "llvm/CodeGen/JumpInstrTables.h"
     16 
     17 #include "llvm/ADT/Statistic.h"
     18 #include "llvm/Analysis/JumpInstrTableInfo.h"
     19 #include "llvm/CodeGen/Passes.h"
     20 #include "llvm/IR/Attributes.h"
     21 #include "llvm/IR/CallSite.h"
     22 #include "llvm/IR/Constants.h"
     23 #include "llvm/IR/DerivedTypes.h"
     24 #include "llvm/IR/Function.h"
     25 #include "llvm/IR/LLVMContext.h"
     26 #include "llvm/IR/Module.h"
     27 #include "llvm/IR/Operator.h"
     28 #include "llvm/IR/Type.h"
     29 #include "llvm/IR/Verifier.h"
     30 #include "llvm/Support/CommandLine.h"
     31 #include "llvm/Support/Debug.h"
     32 #include "llvm/Support/raw_ostream.h"
     33 
     34 #include <vector>
     35 
     36 using namespace llvm;
     37 
     38 char JumpInstrTables::ID = 0;
     39 
     40 INITIALIZE_PASS_BEGIN(JumpInstrTables, "jump-instr-tables",
     41                       "Jump-Instruction Tables", true, true)
     42 INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
     43 INITIALIZE_PASS_END(JumpInstrTables, "jump-instr-tables",
     44                     "Jump-Instruction Tables", true, true)
     45 
     46 STATISTIC(NumJumpTables, "Number of indirect call tables generated");
     47 STATISTIC(NumFuncsInJumpTables, "Number of functions in the jump tables");
     48 
     49 ModulePass *llvm::createJumpInstrTablesPass() {
     50   // The default implementation uses a single table for all functions.
     51   return new JumpInstrTables(JumpTable::Single);
     52 }
     53 
     54 ModulePass *llvm::createJumpInstrTablesPass(JumpTable::JumpTableType JTT) {
     55   return new JumpInstrTables(JTT);
     56 }
     57 
     58 namespace {
     59 static const char jump_func_prefix[] = "__llvm_jump_instr_table_";
     60 static const char jump_section_prefix[] = ".jump.instr.table.text.";
     61 
     62 // Checks to see if a given CallSite is making an indirect call, including
     63 // cases where the indirect call is made through a bitcast.
     64 bool isIndirectCall(CallSite &CS) {
     65   if (CS.getCalledFunction())
     66     return false;
     67 
     68   // Check the value to see if it is merely a bitcast of a function. In
     69   // this case, it will translate to a direct function call in the resulting
     70   // assembly, so we won't treat it as an indirect call here.
     71   const Value *V = CS.getCalledValue();
     72   if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
     73     return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
     74   }
     75 
     76   // Otherwise, since we know it's a call, it must be an indirect call
     77   return true;
     78 }
     79 
     80 // Replaces Functions and GlobalAliases with a different Value.
     81 bool replaceGlobalValueIndirectUse(GlobalValue *GV, Value *V, Use *U) {
     82   User *Us = U->getUser();
     83   if (!Us)
     84     return false;
     85   if (Instruction *I = dyn_cast<Instruction>(Us)) {
     86     CallSite CS(I);
     87 
     88     // Don't do the replacement if this use is a direct call to this function.
     89     // If the use is not the called value, then replace it.
     90     if (CS && (isIndirectCall(CS) || CS.isCallee(U))) {
     91       return false;
     92     }
     93 
     94     U->set(V);
     95   } else if (Constant *C = dyn_cast<Constant>(Us)) {
     96     // Don't replace calls to bitcasts of function symbols, since they get
     97     // translated to direct calls.
     98     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Us)) {
     99       if (CE->getOpcode() == Instruction::BitCast) {
    100         // This bitcast must have exactly one user.
    101         if (CE->user_begin() != CE->user_end()) {
    102           User *ParentUs = *CE->user_begin();
    103           if (CallInst *CI = dyn_cast<CallInst>(ParentUs)) {
    104             CallSite CS(CI);
    105             Use &CEU = *CE->use_begin();
    106             if (CS.isCallee(&CEU)) {
    107               return false;
    108             }
    109           }
    110         }
    111       }
    112     }
    113 
    114     // GlobalAlias doesn't support replaceUsesOfWithOnConstant. And the verifier
    115     // requires alias to point to a defined function. So, GlobalAlias is handled
    116     // as a separate case in runOnModule.
    117     if (!isa<GlobalAlias>(C))
    118       C->replaceUsesOfWithOnConstant(GV, V, U);
    119   } else {
    120     assert(false && "The Use of a Function symbol is neither an instruction nor"
    121                     " a constant");
    122   }
    123 
    124   return true;
    125 }
    126 
    127 // Replaces all replaceable address-taken uses of GV with a pointer to a
    128 // jump-instruction table entry.
    129 void replaceValueWithFunction(GlobalValue *GV, Function *F) {
    130   // Go through all uses of this function and replace the uses of GV with the
    131   // jump-table version of the function. Get the uses as a vector before
    132   // replacing them, since replacing them changes the use list and invalidates
    133   // the iterator otherwise.
    134   for (Value::use_iterator I = GV->use_begin(), E = GV->use_end(); I != E;) {
    135     Use &U = *I++;
    136 
    137     // Replacement of constants replaces all instances in the constant. So, some
    138     // uses might have already been handled by the time we reach them here.
    139     if (U.get() == GV)
    140       replaceGlobalValueIndirectUse(GV, F, &U);
    141   }
    142 
    143   return;
    144 }
    145 } // end anonymous namespace
    146 
    147 JumpInstrTables::JumpInstrTables()
    148     : ModulePass(ID), Metadata(), JITI(nullptr), TableCount(0),
    149       JTType(JumpTable::Single) {
    150   initializeJumpInstrTablesPass(*PassRegistry::getPassRegistry());
    151 }
    152 
    153 JumpInstrTables::JumpInstrTables(JumpTable::JumpTableType JTT)
    154     : ModulePass(ID), Metadata(), JITI(nullptr), TableCount(0), JTType(JTT) {
    155   initializeJumpInstrTablesPass(*PassRegistry::getPassRegistry());
    156 }
    157 
    158 JumpInstrTables::~JumpInstrTables() {}
    159 
    160 void JumpInstrTables::getAnalysisUsage(AnalysisUsage &AU) const {
    161   AU.addRequired<JumpInstrTableInfo>();
    162 }
    163 
    164 Function *JumpInstrTables::insertEntry(Module &M, Function *Target) {
    165   FunctionType *OrigFunTy = Target->getFunctionType();
    166   FunctionType *FunTy = transformType(OrigFunTy);
    167 
    168   JumpMap::iterator it = Metadata.find(FunTy);
    169   if (Metadata.end() == it) {
    170     struct TableMeta Meta;
    171     Meta.TableNum = TableCount;
    172     Meta.Count = 0;
    173     Metadata[FunTy] = Meta;
    174     it = Metadata.find(FunTy);
    175     ++NumJumpTables;
    176     ++TableCount;
    177   }
    178 
    179   it->second.Count++;
    180 
    181   std::string NewName(jump_func_prefix);
    182   NewName += (Twine(it->second.TableNum) + "_" + Twine(it->second.Count)).str();
    183   Function *JumpFun =
    184       Function::Create(OrigFunTy, GlobalValue::ExternalLinkage, NewName, &M);
    185   // The section for this table
    186   JumpFun->setSection((jump_section_prefix + Twine(it->second.TableNum)).str());
    187   JITI->insertEntry(FunTy, Target, JumpFun);
    188 
    189   ++NumFuncsInJumpTables;
    190   return JumpFun;
    191 }
    192 
    193 bool JumpInstrTables::hasTable(FunctionType *FunTy) {
    194   FunctionType *TransTy = transformType(FunTy);
    195   return Metadata.end() != Metadata.find(TransTy);
    196 }
    197 
    198 FunctionType *JumpInstrTables::transformType(FunctionType *FunTy) {
    199   // Returning nullptr forces all types into the same table, since all types map
    200   // to the same type
    201   Type *VoidPtrTy = Type::getInt8PtrTy(FunTy->getContext());
    202 
    203   // Ignore the return type.
    204   Type *RetTy = VoidPtrTy;
    205   bool IsVarArg = FunTy->isVarArg();
    206   std::vector<Type *> ParamTys(FunTy->getNumParams());
    207   FunctionType::param_iterator PI, PE;
    208   int i = 0;
    209 
    210   std::vector<Type *> EmptyParams;
    211   Type *Int32Ty = Type::getInt32Ty(FunTy->getContext());
    212   FunctionType *VoidFnTy = FunctionType::get(
    213       Type::getVoidTy(FunTy->getContext()), EmptyParams, false);
    214   switch (JTType) {
    215   case JumpTable::Single:
    216 
    217     return FunctionType::get(RetTy, EmptyParams, false);
    218   case JumpTable::Arity:
    219     // Transform all types to void* so that all functions with the same arity
    220     // end up in the same table.
    221     for (PI = FunTy->param_begin(), PE = FunTy->param_end(); PI != PE;
    222          PI++, i++) {
    223       ParamTys[i] = VoidPtrTy;
    224     }
    225 
    226     return FunctionType::get(RetTy, ParamTys, IsVarArg);
    227   case JumpTable::Simplified:
    228     // Project all parameters types to one of 3 types: composite, integer, and
    229     // function, matching the three subclasses of Type.
    230     for (PI = FunTy->param_begin(), PE = FunTy->param_end(); PI != PE;
    231          ++PI, ++i) {
    232       assert((isa<IntegerType>(*PI) || isa<FunctionType>(*PI) ||
    233               isa<CompositeType>(*PI)) &&
    234              "This type is not an Integer or a Composite or a Function");
    235       if (isa<CompositeType>(*PI)) {
    236         ParamTys[i] = VoidPtrTy;
    237       } else if (isa<FunctionType>(*PI)) {
    238         ParamTys[i] = VoidFnTy;
    239       } else if (isa<IntegerType>(*PI)) {
    240         ParamTys[i] = Int32Ty;
    241       }
    242     }
    243 
    244     return FunctionType::get(RetTy, ParamTys, IsVarArg);
    245   case JumpTable::Full:
    246     // Don't transform this type at all.
    247     return FunTy;
    248   }
    249 
    250   return nullptr;
    251 }
    252 
    253 bool JumpInstrTables::runOnModule(Module &M) {
    254   // Make sure the module is well-formed, especially with respect to jumptable.
    255   if (verifyModule(M))
    256     return false;
    257 
    258   JITI = &getAnalysis<JumpInstrTableInfo>();
    259 
    260   // Get the set of jumptable-annotated functions.
    261   DenseMap<Function *, Function *> Functions;
    262   for (Function &F : M) {
    263     if (F.hasFnAttribute(Attribute::JumpTable)) {
    264       assert(F.hasUnnamedAddr() &&
    265              "Attribute 'jumptable' requires 'unnamed_addr'");
    266       Functions[&F] = nullptr;
    267     }
    268   }
    269 
    270   // Create the jump-table functions.
    271   for (auto &KV : Functions) {
    272     Function *F = KV.first;
    273     KV.second = insertEntry(M, F);
    274   }
    275 
    276   // GlobalAlias is a special case, because the target of an alias statement
    277   // must be a defined function. So, instead of replacing a given function in
    278   // the alias, we replace all uses of aliases that target jumptable functions.
    279   // Note that there's no need to create these functions, since only aliases
    280   // that target known jumptable functions are replaced, and there's no way to
    281   // put the jumptable annotation on a global alias.
    282   DenseMap<GlobalAlias *, Function *> Aliases;
    283   for (GlobalAlias &GA : M.aliases()) {
    284     Constant *Aliasee = GA.getAliasee();
    285     if (Function *F = dyn_cast<Function>(Aliasee)) {
    286       auto it = Functions.find(F);
    287       if (it != Functions.end()) {
    288         Aliases[&GA] = it->second;
    289       }
    290     }
    291   }
    292 
    293   // Replace each address taken function with its jump-instruction table entry.
    294   for (auto &KV : Functions)
    295     replaceValueWithFunction(KV.first, KV.second);
    296 
    297   for (auto &KV : Aliases)
    298     replaceValueWithFunction(KV.first, KV.second);
    299 
    300   return !Functions.empty();
    301 }
    302