Home | History | Annotate | Download | only in CodeGen
      1 //===-- StackProtector.cpp - Stack Protector Insertion --------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This pass inserts stack protectors into functions which need them. A variable
     11 // with a random value in it is stored onto the stack before the local variables
     12 // are allocated. Upon exiting the block, the stored value is checked. If it's
     13 // changed, then there was some sort of violation and the program aborts.
     14 //
     15 //===----------------------------------------------------------------------===//
     16 
     17 #define DEBUG_TYPE "stack-protector"
     18 #include "llvm/CodeGen/Passes.h"
     19 #include "llvm/ADT/SmallPtrSet.h"
     20 #include "llvm/ADT/Statistic.h"
     21 #include "llvm/ADT/Triple.h"
     22 #include "llvm/Analysis/Dominators.h"
     23 #include "llvm/IR/Attributes.h"
     24 #include "llvm/IR/Constants.h"
     25 #include "llvm/IR/DataLayout.h"
     26 #include "llvm/IR/DerivedTypes.h"
     27 #include "llvm/IR/Function.h"
     28 #include "llvm/IR/Instructions.h"
     29 #include "llvm/IR/Intrinsics.h"
     30 #include "llvm/IR/Module.h"
     31 #include "llvm/Pass.h"
     32 #include "llvm/Support/CommandLine.h"
     33 #include "llvm/Target/TargetLowering.h"
     34 using namespace llvm;
     35 
     36 STATISTIC(NumFunProtected, "Number of functions protected");
     37 STATISTIC(NumAddrTaken, "Number of local variables that have their address"
     38                         " taken.");
     39 
     40 namespace {
     41   class StackProtector : public FunctionPass {
     42     /// TLI - Keep a pointer of a TargetLowering to consult for determining
     43     /// target type sizes.
     44     const TargetLoweringBase *TLI;
     45 
     46     Function *F;
     47     Module *M;
     48 
     49     DominatorTree *DT;
     50 
     51     /// VisitedPHIs - The set of PHI nodes visited when determining
     52     /// if a variable's reference has been taken.  This set
     53     /// is maintained to ensure we don't visit the same PHI node multiple
     54     /// times.
     55     SmallPtrSet<const PHINode*, 16> VisitedPHIs;
     56 
     57     /// InsertStackProtectors - Insert code into the prologue and epilogue of
     58     /// the function.
     59     ///
     60     ///  - The prologue code loads and stores the stack guard onto the stack.
     61     ///  - The epilogue checks the value stored in the prologue against the
     62     ///    original value. It calls __stack_chk_fail if they differ.
     63     bool InsertStackProtectors();
     64 
     65     /// CreateFailBB - Create a basic block to jump to when the stack protector
     66     /// check fails.
     67     BasicBlock *CreateFailBB();
     68 
     69     /// ContainsProtectableArray - Check whether the type either is an array or
     70     /// contains an array of sufficient size so that we need stack protectors
     71     /// for it.
     72     bool ContainsProtectableArray(Type *Ty, bool Strong = false,
     73                                   bool InStruct = false) const;
     74 
     75     /// \brief Check whether a stack allocation has its address taken.
     76     bool HasAddressTaken(const Instruction *AI);
     77 
     78     /// RequiresStackProtector - Check whether or not this function needs a
     79     /// stack protector based upon the stack protector level.
     80     bool RequiresStackProtector();
     81   public:
     82     static char ID;             // Pass identification, replacement for typeid.
     83     StackProtector() : FunctionPass(ID), TLI(0) {
     84       initializeStackProtectorPass(*PassRegistry::getPassRegistry());
     85     }
     86     StackProtector(const TargetLoweringBase *tli)
     87       : FunctionPass(ID), TLI(tli) {
     88       initializeStackProtectorPass(*PassRegistry::getPassRegistry());
     89     }
     90 
     91     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
     92       AU.addPreserved<DominatorTree>();
     93     }
     94 
     95     virtual bool runOnFunction(Function &Fn);
     96   };
     97 } // end anonymous namespace
     98 
     99 char StackProtector::ID = 0;
    100 INITIALIZE_PASS(StackProtector, "stack-protector",
    101                 "Insert stack protectors", false, false)
    102 
    103 FunctionPass *llvm::createStackProtectorPass(const TargetLoweringBase *tli) {
    104   return new StackProtector(tli);
    105 }
    106 
    107 bool StackProtector::runOnFunction(Function &Fn) {
    108   F = &Fn;
    109   M = F->getParent();
    110   DT = getAnalysisIfAvailable<DominatorTree>();
    111 
    112   if (!RequiresStackProtector()) return false;
    113 
    114   ++NumFunProtected;
    115   return InsertStackProtectors();
    116 }
    117 
    118 /// ContainsProtectableArray - Check whether the type either is an array or
    119 /// contains a char array of sufficient size so that we need stack protectors
    120 /// for it.
    121 bool StackProtector::ContainsProtectableArray(Type *Ty, bool Strong,
    122                                               bool InStruct) const {
    123   if (!Ty) return false;
    124   if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
    125     // In strong mode any array, regardless of type and size, triggers a
    126     // protector
    127     if (Strong)
    128       return true;
    129     const TargetMachine &TM = TLI->getTargetMachine();
    130     if (!AT->getElementType()->isIntegerTy(8)) {
    131       Triple Trip(TM.getTargetTriple());
    132 
    133       // If we're on a non-Darwin platform or we're inside of a structure, don't
    134       // add stack protectors unless the array is a character array.
    135       if (InStruct || !Trip.isOSDarwin())
    136           return false;
    137     }
    138 
    139     // If an array has more than SSPBufferSize bytes of allocated space, then we
    140     // emit stack protectors.
    141     if (TM.Options.SSPBufferSize <= TLI->getDataLayout()->getTypeAllocSize(AT))
    142       return true;
    143   }
    144 
    145   const StructType *ST = dyn_cast<StructType>(Ty);
    146   if (!ST) return false;
    147 
    148   for (StructType::element_iterator I = ST->element_begin(),
    149          E = ST->element_end(); I != E; ++I)
    150     if (ContainsProtectableArray(*I, Strong, true))
    151       return true;
    152 
    153   return false;
    154 }
    155 
    156 bool StackProtector::HasAddressTaken(const Instruction *AI) {
    157   for (Value::const_use_iterator UI = AI->use_begin(), UE = AI->use_end();
    158         UI != UE; ++UI) {
    159     const User *U = *UI;
    160     if (const StoreInst *SI = dyn_cast<StoreInst>(U)) {
    161       if (AI == SI->getValueOperand())
    162         return true;
    163     } else if (const PtrToIntInst *SI = dyn_cast<PtrToIntInst>(U)) {
    164       if (AI == SI->getOperand(0))
    165         return true;
    166     } else if (isa<CallInst>(U)) {
    167       return true;
    168     } else if (isa<InvokeInst>(U)) {
    169       return true;
    170     } else if (const SelectInst *SI = dyn_cast<SelectInst>(U)) {
    171       if (HasAddressTaken(SI))
    172         return true;
    173     } else if (const PHINode *PN = dyn_cast<PHINode>(U)) {
    174       // Keep track of what PHI nodes we have already visited to ensure
    175       // they are only visited once.
    176       if (VisitedPHIs.insert(PN))
    177         if (HasAddressTaken(PN))
    178           return true;
    179     } else if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
    180       if (HasAddressTaken(GEP))
    181         return true;
    182     } else if (const BitCastInst *BI = dyn_cast<BitCastInst>(U)) {
    183       if (HasAddressTaken(BI))
    184         return true;
    185     }
    186   }
    187   return false;
    188 }
    189 
    190 /// \brief Check whether or not this function needs a stack protector based
    191 /// upon the stack protector level.
    192 ///
    193 /// We use two heuristics: a standard (ssp) and strong (sspstrong).
    194 /// The standard heuristic which will add a guard variable to functions that
    195 /// call alloca with a either a variable size or a size >= SSPBufferSize,
    196 /// functions with character buffers larger than SSPBufferSize, and functions
    197 /// with aggregates containing character buffers larger than SSPBufferSize. The
    198 /// strong heuristic will add a guard variables to functions that call alloca
    199 /// regardless of size, functions with any buffer regardless of type and size,
    200 /// functions with aggregates that contain any buffer regardless of type and
    201 /// size, and functions that contain stack-based variables that have had their
    202 /// address taken.
    203 bool StackProtector::RequiresStackProtector() {
    204   bool Strong = false;
    205   if (F->getAttributes().hasAttribute(AttributeSet::FunctionIndex,
    206                                       Attribute::StackProtectReq))
    207     return true;
    208   else if (F->getAttributes().hasAttribute(AttributeSet::FunctionIndex,
    209                                            Attribute::StackProtectStrong))
    210     Strong = true;
    211   else if (!F->getAttributes().hasAttribute(AttributeSet::FunctionIndex,
    212                                             Attribute::StackProtect))
    213     return false;
    214 
    215   for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) {
    216     BasicBlock *BB = I;
    217 
    218     for (BasicBlock::iterator
    219            II = BB->begin(), IE = BB->end(); II != IE; ++II) {
    220       if (AllocaInst *AI = dyn_cast<AllocaInst>(II)) {
    221         if (AI->isArrayAllocation()) {
    222           // SSP-Strong: Enable protectors for any call to alloca, regardless
    223           // of size.
    224           if (Strong)
    225             return true;
    226 
    227           if (const ConstantInt *CI =
    228                dyn_cast<ConstantInt>(AI->getArraySize())) {
    229             unsigned BufferSize = TLI->getTargetMachine().Options.SSPBufferSize;
    230             if (CI->getLimitedValue(BufferSize) >= BufferSize)
    231               // A call to alloca with size >= SSPBufferSize requires
    232               // stack protectors.
    233               return true;
    234           } else // A call to alloca with a variable size requires protectors.
    235             return true;
    236         }
    237 
    238         if (ContainsProtectableArray(AI->getAllocatedType(), Strong))
    239           return true;
    240 
    241         if (Strong && HasAddressTaken(AI)) {
    242           ++NumAddrTaken;
    243           return true;
    244         }
    245       }
    246     }
    247   }
    248 
    249   return false;
    250 }
    251 
    252 /// InsertStackProtectors - Insert code into the prologue and epilogue of the
    253 /// function.
    254 ///
    255 ///  - The prologue code loads and stores the stack guard onto the stack.
    256 ///  - The epilogue checks the value stored in the prologue against the original
    257 ///    value. It calls __stack_chk_fail if they differ.
    258 bool StackProtector::InsertStackProtectors() {
    259   BasicBlock *FailBB = 0;       // The basic block to jump to if check fails.
    260   BasicBlock *FailBBDom = 0;    // FailBB's dominator.
    261   AllocaInst *AI = 0;           // Place on stack that stores the stack guard.
    262   Value *StackGuardVar = 0;  // The stack guard variable.
    263 
    264   for (Function::iterator I = F->begin(), E = F->end(); I != E; ) {
    265     BasicBlock *BB = I++;
    266     ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator());
    267     if (!RI) continue;
    268 
    269     if (!FailBB) {
    270       // Insert code into the entry block that stores the __stack_chk_guard
    271       // variable onto the stack:
    272       //
    273       //   entry:
    274       //     StackGuardSlot = alloca i8*
    275       //     StackGuard = load __stack_chk_guard
    276       //     call void @llvm.stackprotect.create(StackGuard, StackGuardSlot)
    277       //
    278       PointerType *PtrTy = Type::getInt8PtrTy(RI->getContext());
    279       unsigned AddressSpace, Offset;
    280       if (TLI->getStackCookieLocation(AddressSpace, Offset)) {
    281         Constant *OffsetVal =
    282           ConstantInt::get(Type::getInt32Ty(RI->getContext()), Offset);
    283 
    284         StackGuardVar = ConstantExpr::getIntToPtr(OffsetVal,
    285                                       PointerType::get(PtrTy, AddressSpace));
    286       } else {
    287         StackGuardVar = M->getOrInsertGlobal("__stack_chk_guard", PtrTy);
    288       }
    289 
    290       BasicBlock &Entry = F->getEntryBlock();
    291       Instruction *InsPt = &Entry.front();
    292 
    293       AI = new AllocaInst(PtrTy, "StackGuardSlot", InsPt);
    294       LoadInst *LI = new LoadInst(StackGuardVar, "StackGuard", false, InsPt);
    295 
    296       Value *Args[] = { LI, AI };
    297       CallInst::
    298         Create(Intrinsic::getDeclaration(M, Intrinsic::stackprotector),
    299                Args, "", InsPt);
    300 
    301       // Create the basic block to jump to when the guard check fails.
    302       FailBB = CreateFailBB();
    303     }
    304 
    305     // For each block with a return instruction, convert this:
    306     //
    307     //   return:
    308     //     ...
    309     //     ret ...
    310     //
    311     // into this:
    312     //
    313     //   return:
    314     //     ...
    315     //     %1 = load __stack_chk_guard
    316     //     %2 = load StackGuardSlot
    317     //     %3 = cmp i1 %1, %2
    318     //     br i1 %3, label %SP_return, label %CallStackCheckFailBlk
    319     //
    320     //   SP_return:
    321     //     ret ...
    322     //
    323     //   CallStackCheckFailBlk:
    324     //     call void @__stack_chk_fail()
    325     //     unreachable
    326 
    327     // Split the basic block before the return instruction.
    328     BasicBlock *NewBB = BB->splitBasicBlock(RI, "SP_return");
    329 
    330     if (DT && DT->isReachableFromEntry(BB)) {
    331       DT->addNewBlock(NewBB, BB);
    332       FailBBDom = FailBBDom ? DT->findNearestCommonDominator(FailBBDom, BB) :BB;
    333     }
    334 
    335     // Remove default branch instruction to the new BB.
    336     BB->getTerminator()->eraseFromParent();
    337 
    338     // Move the newly created basic block to the point right after the old basic
    339     // block so that it's in the "fall through" position.
    340     NewBB->moveAfter(BB);
    341 
    342     // Generate the stack protector instructions in the old basic block.
    343     LoadInst *LI1 = new LoadInst(StackGuardVar, "", false, BB);
    344     LoadInst *LI2 = new LoadInst(AI, "", true, BB);
    345     ICmpInst *Cmp = new ICmpInst(*BB, CmpInst::ICMP_EQ, LI1, LI2, "");
    346     BranchInst::Create(NewBB, FailBB, Cmp, BB);
    347   }
    348 
    349   // Return if we didn't modify any basic blocks. I.e., there are no return
    350   // statements in the function.
    351   if (!FailBB) return false;
    352 
    353   if (DT && FailBBDom)
    354     DT->addNewBlock(FailBB, FailBBDom);
    355 
    356   return true;
    357 }
    358 
    359 /// CreateFailBB - Create a basic block to jump to when the stack protector
    360 /// check fails.
    361 BasicBlock *StackProtector::CreateFailBB() {
    362   BasicBlock *FailBB = BasicBlock::Create(F->getContext(),
    363                                           "CallStackCheckFailBlk", F);
    364   Constant *StackChkFail =
    365     M->getOrInsertFunction("__stack_chk_fail",
    366                            Type::getVoidTy(F->getContext()), NULL);
    367   CallInst::Create(StackChkFail, "", FailBB);
    368   new UnreachableInst(F->getContext(), FailBB);
    369   return FailBB;
    370 }
    371