Home | History | Annotate | Download | only in IPO
      1 //===-- LowerTypeTests.cpp - type metadata lowering pass ------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This pass lowers type metadata and calls to the llvm.type.test intrinsic.
     11 // See http://llvm.org/docs/TypeMetadata.html for more information.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #include "llvm/Transforms/IPO/LowerTypeTests.h"
     16 #include "llvm/Transforms/IPO.h"
     17 #include "llvm/ADT/EquivalenceClasses.h"
     18 #include "llvm/ADT/Statistic.h"
     19 #include "llvm/ADT/Triple.h"
     20 #include "llvm/IR/Constant.h"
     21 #include "llvm/IR/Constants.h"
     22 #include "llvm/IR/Function.h"
     23 #include "llvm/IR/GlobalObject.h"
     24 #include "llvm/IR/GlobalVariable.h"
     25 #include "llvm/IR/IRBuilder.h"
     26 #include "llvm/IR/Instructions.h"
     27 #include "llvm/IR/Intrinsics.h"
     28 #include "llvm/IR/Module.h"
     29 #include "llvm/IR/Operator.h"
     30 #include "llvm/Pass.h"
     31 #include "llvm/Support/Debug.h"
     32 #include "llvm/Support/raw_ostream.h"
     33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     34 
     35 using namespace llvm;
     36 using namespace lowertypetests;
     37 
     38 #define DEBUG_TYPE "lowertypetests"
     39 
     40 STATISTIC(ByteArraySizeBits, "Byte array size in bits");
     41 STATISTIC(ByteArraySizeBytes, "Byte array size in bytes");
     42 STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
     43 STATISTIC(NumTypeTestCallsLowered, "Number of type test calls lowered");
     44 STATISTIC(NumTypeIdDisjointSets, "Number of disjoint sets of type identifiers");
     45 
     46 static cl::opt<bool> AvoidReuse(
     47     "lowertypetests-avoid-reuse",
     48     cl::desc("Try to avoid reuse of byte array addresses using aliases"),
     49     cl::Hidden, cl::init(true));
     50 
     51 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
     52   if (Offset < ByteOffset)
     53     return false;
     54 
     55   if ((Offset - ByteOffset) % (uint64_t(1) << AlignLog2) != 0)
     56     return false;
     57 
     58   uint64_t BitOffset = (Offset - ByteOffset) >> AlignLog2;
     59   if (BitOffset >= BitSize)
     60     return false;
     61 
     62   return Bits.count(BitOffset);
     63 }
     64 
     65 bool BitSetInfo::containsValue(
     66     const DataLayout &DL,
     67     const DenseMap<GlobalObject *, uint64_t> &GlobalLayout, Value *V,
     68     uint64_t COffset) const {
     69   if (auto GV = dyn_cast<GlobalObject>(V)) {
     70     auto I = GlobalLayout.find(GV);
     71     if (I == GlobalLayout.end())
     72       return false;
     73     return containsGlobalOffset(I->second + COffset);
     74   }
     75 
     76   if (auto GEP = dyn_cast<GEPOperator>(V)) {
     77     APInt APOffset(DL.getPointerSizeInBits(0), 0);
     78     bool Result = GEP->accumulateConstantOffset(DL, APOffset);
     79     if (!Result)
     80       return false;
     81     COffset += APOffset.getZExtValue();
     82     return containsValue(DL, GlobalLayout, GEP->getPointerOperand(),
     83                          COffset);
     84   }
     85 
     86   if (auto Op = dyn_cast<Operator>(V)) {
     87     if (Op->getOpcode() == Instruction::BitCast)
     88       return containsValue(DL, GlobalLayout, Op->getOperand(0), COffset);
     89 
     90     if (Op->getOpcode() == Instruction::Select)
     91       return containsValue(DL, GlobalLayout, Op->getOperand(1), COffset) &&
     92              containsValue(DL, GlobalLayout, Op->getOperand(2), COffset);
     93   }
     94 
     95   return false;
     96 }
     97 
     98 void BitSetInfo::print(raw_ostream &OS) const {
     99   OS << "offset " << ByteOffset << " size " << BitSize << " align "
    100      << (1 << AlignLog2);
    101 
    102   if (isAllOnes()) {
    103     OS << " all-ones\n";
    104     return;
    105   }
    106 
    107   OS << " { ";
    108   for (uint64_t B : Bits)
    109     OS << B << ' ';
    110   OS << "}\n";
    111 }
    112 
    113 BitSetInfo BitSetBuilder::build() {
    114   if (Min > Max)
    115     Min = 0;
    116 
    117   // Normalize each offset against the minimum observed offset, and compute
    118   // the bitwise OR of each of the offsets. The number of trailing zeros
    119   // in the mask gives us the log2 of the alignment of all offsets, which
    120   // allows us to compress the bitset by only storing one bit per aligned
    121   // address.
    122   uint64_t Mask = 0;
    123   for (uint64_t &Offset : Offsets) {
    124     Offset -= Min;
    125     Mask |= Offset;
    126   }
    127 
    128   BitSetInfo BSI;
    129   BSI.ByteOffset = Min;
    130 
    131   BSI.AlignLog2 = 0;
    132   if (Mask != 0)
    133     BSI.AlignLog2 = countTrailingZeros(Mask, ZB_Undefined);
    134 
    135   // Build the compressed bitset while normalizing the offsets against the
    136   // computed alignment.
    137   BSI.BitSize = ((Max - Min) >> BSI.AlignLog2) + 1;
    138   for (uint64_t Offset : Offsets) {
    139     Offset >>= BSI.AlignLog2;
    140     BSI.Bits.insert(Offset);
    141   }
    142 
    143   return BSI;
    144 }
    145 
    146 void GlobalLayoutBuilder::addFragment(const std::set<uint64_t> &F) {
    147   // Create a new fragment to hold the layout for F.
    148   Fragments.emplace_back();
    149   std::vector<uint64_t> &Fragment = Fragments.back();
    150   uint64_t FragmentIndex = Fragments.size() - 1;
    151 
    152   for (auto ObjIndex : F) {
    153     uint64_t OldFragmentIndex = FragmentMap[ObjIndex];
    154     if (OldFragmentIndex == 0) {
    155       // We haven't seen this object index before, so just add it to the current
    156       // fragment.
    157       Fragment.push_back(ObjIndex);
    158     } else {
    159       // This index belongs to an existing fragment. Copy the elements of the
    160       // old fragment into this one and clear the old fragment. We don't update
    161       // the fragment map just yet, this ensures that any further references to
    162       // indices from the old fragment in this fragment do not insert any more
    163       // indices.
    164       std::vector<uint64_t> &OldFragment = Fragments[OldFragmentIndex];
    165       Fragment.insert(Fragment.end(), OldFragment.begin(), OldFragment.end());
    166       OldFragment.clear();
    167     }
    168   }
    169 
    170   // Update the fragment map to point our object indices to this fragment.
    171   for (uint64_t ObjIndex : Fragment)
    172     FragmentMap[ObjIndex] = FragmentIndex;
    173 }
    174 
    175 void ByteArrayBuilder::allocate(const std::set<uint64_t> &Bits,
    176                                 uint64_t BitSize, uint64_t &AllocByteOffset,
    177                                 uint8_t &AllocMask) {
    178   // Find the smallest current allocation.
    179   unsigned Bit = 0;
    180   for (unsigned I = 1; I != BitsPerByte; ++I)
    181     if (BitAllocs[I] < BitAllocs[Bit])
    182       Bit = I;
    183 
    184   AllocByteOffset = BitAllocs[Bit];
    185 
    186   // Add our size to it.
    187   unsigned ReqSize = AllocByteOffset + BitSize;
    188   BitAllocs[Bit] = ReqSize;
    189   if (Bytes.size() < ReqSize)
    190     Bytes.resize(ReqSize);
    191 
    192   // Set our bits.
    193   AllocMask = 1 << Bit;
    194   for (uint64_t B : Bits)
    195     Bytes[AllocByteOffset + B] |= AllocMask;
    196 }
    197 
    198 namespace {
    199 
    200 struct ByteArrayInfo {
    201   std::set<uint64_t> Bits;
    202   uint64_t BitSize;
    203   GlobalVariable *ByteArray;
    204   Constant *Mask;
    205 };
    206 
    207 struct LowerTypeTests : public ModulePass {
    208   static char ID;
    209   LowerTypeTests() : ModulePass(ID) {
    210     initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry());
    211   }
    212 
    213   Module *M;
    214 
    215   bool LinkerSubsectionsViaSymbols;
    216   Triple::ArchType Arch;
    217   Triple::ObjectFormatType ObjectFormat;
    218   IntegerType *Int1Ty;
    219   IntegerType *Int8Ty;
    220   IntegerType *Int32Ty;
    221   Type *Int32PtrTy;
    222   IntegerType *Int64Ty;
    223   IntegerType *IntPtrTy;
    224 
    225   // Mapping from type identifiers to the call sites that test them.
    226   DenseMap<Metadata *, std::vector<CallInst *>> TypeTestCallSites;
    227 
    228   std::vector<ByteArrayInfo> ByteArrayInfos;
    229 
    230   BitSetInfo
    231   buildBitSet(Metadata *TypeId,
    232               const DenseMap<GlobalObject *, uint64_t> &GlobalLayout);
    233   ByteArrayInfo *createByteArray(BitSetInfo &BSI);
    234   void allocateByteArrays();
    235   Value *createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, ByteArrayInfo *&BAI,
    236                           Value *BitOffset);
    237   void
    238   lowerTypeTestCalls(ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
    239                      const DenseMap<GlobalObject *, uint64_t> &GlobalLayout);
    240   Value *
    241   lowerBitSetCall(CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI,
    242                   Constant *CombinedGlobal,
    243                   const DenseMap<GlobalObject *, uint64_t> &GlobalLayout);
    244   void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds,
    245                                        ArrayRef<GlobalVariable *> Globals);
    246   unsigned getJumpTableEntrySize();
    247   Type *getJumpTableEntryType();
    248   Constant *createJumpTableEntry(GlobalObject *Src, Function *Dest,
    249                                  unsigned Distance);
    250   void verifyTypeMDNode(GlobalObject *GO, MDNode *Type);
    251   void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds,
    252                                  ArrayRef<Function *> Functions);
    253   void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds,
    254                                    ArrayRef<GlobalObject *> Globals);
    255   bool lower();
    256   bool runOnModule(Module &M) override;
    257 };
    258 
    259 } // anonymous namespace
    260 
    261 INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false,
    262                 false)
    263 char LowerTypeTests::ID = 0;
    264 
    265 ModulePass *llvm::createLowerTypeTestsPass() { return new LowerTypeTests; }
    266 
    267 /// Build a bit set for TypeId using the object layouts in
    268 /// GlobalLayout.
    269 BitSetInfo LowerTypeTests::buildBitSet(
    270     Metadata *TypeId,
    271     const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) {
    272   BitSetBuilder BSB;
    273 
    274   // Compute the byte offset of each address associated with this type
    275   // identifier.
    276   SmallVector<MDNode *, 2> Types;
    277   for (auto &GlobalAndOffset : GlobalLayout) {
    278     Types.clear();
    279     GlobalAndOffset.first->getMetadata(LLVMContext::MD_type, Types);
    280     for (MDNode *Type : Types) {
    281       if (Type->getOperand(1) != TypeId)
    282         continue;
    283       uint64_t Offset =
    284           cast<ConstantInt>(cast<ConstantAsMetadata>(Type->getOperand(0))
    285                                 ->getValue())->getZExtValue();
    286       BSB.addOffset(GlobalAndOffset.second + Offset);
    287     }
    288   }
    289 
    290   return BSB.build();
    291 }
    292 
    293 /// Build a test that bit BitOffset mod sizeof(Bits)*8 is set in
    294 /// Bits. This pattern matches to the bt instruction on x86.
    295 static Value *createMaskedBitTest(IRBuilder<> &B, Value *Bits,
    296                                   Value *BitOffset) {
    297   auto BitsType = cast<IntegerType>(Bits->getType());
    298   unsigned BitWidth = BitsType->getBitWidth();
    299 
    300   BitOffset = B.CreateZExtOrTrunc(BitOffset, BitsType);
    301   Value *BitIndex =
    302       B.CreateAnd(BitOffset, ConstantInt::get(BitsType, BitWidth - 1));
    303   Value *BitMask = B.CreateShl(ConstantInt::get(BitsType, 1), BitIndex);
    304   Value *MaskedBits = B.CreateAnd(Bits, BitMask);
    305   return B.CreateICmpNE(MaskedBits, ConstantInt::get(BitsType, 0));
    306 }
    307 
    308 ByteArrayInfo *LowerTypeTests::createByteArray(BitSetInfo &BSI) {
    309   // Create globals to stand in for byte arrays and masks. These never actually
    310   // get initialized, we RAUW and erase them later in allocateByteArrays() once
    311   // we know the offset and mask to use.
    312   auto ByteArrayGlobal = new GlobalVariable(
    313       *M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr);
    314   auto MaskGlobal = new GlobalVariable(
    315       *M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr);
    316 
    317   ByteArrayInfos.emplace_back();
    318   ByteArrayInfo *BAI = &ByteArrayInfos.back();
    319 
    320   BAI->Bits = BSI.Bits;
    321   BAI->BitSize = BSI.BitSize;
    322   BAI->ByteArray = ByteArrayGlobal;
    323   BAI->Mask = ConstantExpr::getPtrToInt(MaskGlobal, Int8Ty);
    324   return BAI;
    325 }
    326 
    327 void LowerTypeTests::allocateByteArrays() {
    328   std::stable_sort(ByteArrayInfos.begin(), ByteArrayInfos.end(),
    329                    [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) {
    330                      return BAI1.BitSize > BAI2.BitSize;
    331                    });
    332 
    333   std::vector<uint64_t> ByteArrayOffsets(ByteArrayInfos.size());
    334 
    335   ByteArrayBuilder BAB;
    336   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
    337     ByteArrayInfo *BAI = &ByteArrayInfos[I];
    338 
    339     uint8_t Mask;
    340     BAB.allocate(BAI->Bits, BAI->BitSize, ByteArrayOffsets[I], Mask);
    341 
    342     BAI->Mask->replaceAllUsesWith(ConstantInt::get(Int8Ty, Mask));
    343     cast<GlobalVariable>(BAI->Mask->getOperand(0))->eraseFromParent();
    344   }
    345 
    346   Constant *ByteArrayConst = ConstantDataArray::get(M->getContext(), BAB.Bytes);
    347   auto ByteArray =
    348       new GlobalVariable(*M, ByteArrayConst->getType(), /*isConstant=*/true,
    349                          GlobalValue::PrivateLinkage, ByteArrayConst);
    350 
    351   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
    352     ByteArrayInfo *BAI = &ByteArrayInfos[I];
    353 
    354     Constant *Idxs[] = {ConstantInt::get(IntPtrTy, 0),
    355                         ConstantInt::get(IntPtrTy, ByteArrayOffsets[I])};
    356     Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(
    357         ByteArrayConst->getType(), ByteArray, Idxs);
    358 
    359     // Create an alias instead of RAUW'ing the gep directly. On x86 this ensures
    360     // that the pc-relative displacement is folded into the lea instead of the
    361     // test instruction getting another displacement.
    362     if (LinkerSubsectionsViaSymbols) {
    363       BAI->ByteArray->replaceAllUsesWith(GEP);
    364     } else {
    365       GlobalAlias *Alias = GlobalAlias::create(
    366           Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, M);
    367       BAI->ByteArray->replaceAllUsesWith(Alias);
    368     }
    369     BAI->ByteArray->eraseFromParent();
    370   }
    371 
    372   ByteArraySizeBits = BAB.BitAllocs[0] + BAB.BitAllocs[1] + BAB.BitAllocs[2] +
    373                       BAB.BitAllocs[3] + BAB.BitAllocs[4] + BAB.BitAllocs[5] +
    374                       BAB.BitAllocs[6] + BAB.BitAllocs[7];
    375   ByteArraySizeBytes = BAB.Bytes.size();
    376 }
    377 
    378 /// Build a test that bit BitOffset is set in BSI, where
    379 /// BitSetGlobal is a global containing the bits in BSI.
    380 Value *LowerTypeTests::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI,
    381                                         ByteArrayInfo *&BAI, Value *BitOffset) {
    382   if (BSI.BitSize <= 64) {
    383     // If the bit set is sufficiently small, we can avoid a load by bit testing
    384     // a constant.
    385     IntegerType *BitsTy;
    386     if (BSI.BitSize <= 32)
    387       BitsTy = Int32Ty;
    388     else
    389       BitsTy = Int64Ty;
    390 
    391     uint64_t Bits = 0;
    392     for (auto Bit : BSI.Bits)
    393       Bits |= uint64_t(1) << Bit;
    394     Constant *BitsConst = ConstantInt::get(BitsTy, Bits);
    395     return createMaskedBitTest(B, BitsConst, BitOffset);
    396   } else {
    397     if (!BAI) {
    398       ++NumByteArraysCreated;
    399       BAI = createByteArray(BSI);
    400     }
    401 
    402     Constant *ByteArray = BAI->ByteArray;
    403     Type *Ty = BAI->ByteArray->getValueType();
    404     if (!LinkerSubsectionsViaSymbols && AvoidReuse) {
    405       // Each use of the byte array uses a different alias. This makes the
    406       // backend less likely to reuse previously computed byte array addresses,
    407       // improving the security of the CFI mechanism based on this pass.
    408       ByteArray = GlobalAlias::create(BAI->ByteArray->getValueType(), 0,
    409                                       GlobalValue::PrivateLinkage, "bits_use",
    410                                       ByteArray, M);
    411     }
    412 
    413     Value *ByteAddr = B.CreateGEP(Ty, ByteArray, BitOffset);
    414     Value *Byte = B.CreateLoad(ByteAddr);
    415 
    416     Value *ByteAndMask = B.CreateAnd(Byte, BAI->Mask);
    417     return B.CreateICmpNE(ByteAndMask, ConstantInt::get(Int8Ty, 0));
    418   }
    419 }
    420 
    421 /// Lower a llvm.type.test call to its implementation. Returns the value to
    422 /// replace the call with.
    423 Value *LowerTypeTests::lowerBitSetCall(
    424     CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI,
    425     Constant *CombinedGlobalIntAddr,
    426     const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) {
    427   Value *Ptr = CI->getArgOperand(0);
    428   const DataLayout &DL = M->getDataLayout();
    429 
    430   if (BSI.containsValue(DL, GlobalLayout, Ptr))
    431     return ConstantInt::getTrue(M->getContext());
    432 
    433   Constant *OffsetedGlobalAsInt = ConstantExpr::getAdd(
    434       CombinedGlobalIntAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset));
    435 
    436   BasicBlock *InitialBB = CI->getParent();
    437 
    438   IRBuilder<> B(CI);
    439 
    440   Value *PtrAsInt = B.CreatePtrToInt(Ptr, IntPtrTy);
    441 
    442   if (BSI.isSingleOffset())
    443     return B.CreateICmpEQ(PtrAsInt, OffsetedGlobalAsInt);
    444 
    445   Value *PtrOffset = B.CreateSub(PtrAsInt, OffsetedGlobalAsInt);
    446 
    447   Value *BitOffset;
    448   if (BSI.AlignLog2 == 0) {
    449     BitOffset = PtrOffset;
    450   } else {
    451     // We need to check that the offset both falls within our range and is
    452     // suitably aligned. We can check both properties at the same time by
    453     // performing a right rotate by log2(alignment) followed by an integer
    454     // comparison against the bitset size. The rotate will move the lower
    455     // order bits that need to be zero into the higher order bits of the
    456     // result, causing the comparison to fail if they are nonzero. The rotate
    457     // also conveniently gives us a bit offset to use during the load from
    458     // the bitset.
    459     Value *OffsetSHR =
    460         B.CreateLShr(PtrOffset, ConstantInt::get(IntPtrTy, BSI.AlignLog2));
    461     Value *OffsetSHL = B.CreateShl(
    462         PtrOffset,
    463         ConstantInt::get(IntPtrTy, DL.getPointerSizeInBits(0) - BSI.AlignLog2));
    464     BitOffset = B.CreateOr(OffsetSHR, OffsetSHL);
    465   }
    466 
    467   Constant *BitSizeConst = ConstantInt::get(IntPtrTy, BSI.BitSize);
    468   Value *OffsetInRange = B.CreateICmpULT(BitOffset, BitSizeConst);
    469 
    470   // If the bit set is all ones, testing against it is unnecessary.
    471   if (BSI.isAllOnes())
    472     return OffsetInRange;
    473 
    474   TerminatorInst *Term = SplitBlockAndInsertIfThen(OffsetInRange, CI, false);
    475   IRBuilder<> ThenB(Term);
    476 
    477   // Now that we know that the offset is in range and aligned, load the
    478   // appropriate bit from the bitset.
    479   Value *Bit = createBitSetTest(ThenB, BSI, BAI, BitOffset);
    480 
    481   // The value we want is 0 if we came directly from the initial block
    482   // (having failed the range or alignment checks), or the loaded bit if
    483   // we came from the block in which we loaded it.
    484   B.SetInsertPoint(CI);
    485   PHINode *P = B.CreatePHI(Int1Ty, 2);
    486   P->addIncoming(ConstantInt::get(Int1Ty, 0), InitialBB);
    487   P->addIncoming(Bit, ThenB.GetInsertBlock());
    488   return P;
    489 }
    490 
    491 /// Given a disjoint set of type identifiers and globals, lay out the globals,
    492 /// build the bit sets and lower the llvm.type.test calls.
    493 void LowerTypeTests::buildBitSetsFromGlobalVariables(
    494     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalVariable *> Globals) {
    495   // Build a new global with the combined contents of the referenced globals.
    496   // This global is a struct whose even-indexed elements contain the original
    497   // contents of the referenced globals and whose odd-indexed elements contain
    498   // any padding required to align the next element to the next power of 2.
    499   std::vector<Constant *> GlobalInits;
    500   const DataLayout &DL = M->getDataLayout();
    501   for (GlobalVariable *G : Globals) {
    502     GlobalInits.push_back(G->getInitializer());
    503     uint64_t InitSize = DL.getTypeAllocSize(G->getValueType());
    504 
    505     // Compute the amount of padding required.
    506     uint64_t Padding = NextPowerOf2(InitSize - 1) - InitSize;
    507 
    508     // Cap at 128 was found experimentally to have a good data/instruction
    509     // overhead tradeoff.
    510     if (Padding > 128)
    511       Padding = alignTo(InitSize, 128) - InitSize;
    512 
    513     GlobalInits.push_back(
    514         ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding)));
    515   }
    516   if (!GlobalInits.empty())
    517     GlobalInits.pop_back();
    518   Constant *NewInit = ConstantStruct::getAnon(M->getContext(), GlobalInits);
    519   auto *CombinedGlobal =
    520       new GlobalVariable(*M, NewInit->getType(), /*isConstant=*/true,
    521                          GlobalValue::PrivateLinkage, NewInit);
    522 
    523   StructType *NewTy = cast<StructType>(NewInit->getType());
    524   const StructLayout *CombinedGlobalLayout = DL.getStructLayout(NewTy);
    525 
    526   // Compute the offsets of the original globals within the new global.
    527   DenseMap<GlobalObject *, uint64_t> GlobalLayout;
    528   for (unsigned I = 0; I != Globals.size(); ++I)
    529     // Multiply by 2 to account for padding elements.
    530     GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2);
    531 
    532   lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout);
    533 
    534   // Build aliases pointing to offsets into the combined global for each
    535   // global from which we built the combined global, and replace references
    536   // to the original globals with references to the aliases.
    537   for (unsigned I = 0; I != Globals.size(); ++I) {
    538     // Multiply by 2 to account for padding elements.
    539     Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0),
    540                                       ConstantInt::get(Int32Ty, I * 2)};
    541     Constant *CombinedGlobalElemPtr = ConstantExpr::getGetElementPtr(
    542         NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs);
    543     if (LinkerSubsectionsViaSymbols) {
    544       Globals[I]->replaceAllUsesWith(CombinedGlobalElemPtr);
    545     } else {
    546       assert(Globals[I]->getType()->getAddressSpace() == 0);
    547       GlobalAlias *GAlias = GlobalAlias::create(NewTy->getElementType(I * 2), 0,
    548                                                 Globals[I]->getLinkage(), "",
    549                                                 CombinedGlobalElemPtr, M);
    550       GAlias->setVisibility(Globals[I]->getVisibility());
    551       GAlias->takeName(Globals[I]);
    552       Globals[I]->replaceAllUsesWith(GAlias);
    553     }
    554     Globals[I]->eraseFromParent();
    555   }
    556 }
    557 
    558 void LowerTypeTests::lowerTypeTestCalls(
    559     ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
    560     const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) {
    561   Constant *CombinedGlobalIntAddr =
    562       ConstantExpr::getPtrToInt(CombinedGlobalAddr, IntPtrTy);
    563 
    564   // For each type identifier in this disjoint set...
    565   for (Metadata *TypeId : TypeIds) {
    566     // Build the bitset.
    567     BitSetInfo BSI = buildBitSet(TypeId, GlobalLayout);
    568     DEBUG({
    569       if (auto MDS = dyn_cast<MDString>(TypeId))
    570         dbgs() << MDS->getString() << ": ";
    571       else
    572         dbgs() << "<unnamed>: ";
    573       BSI.print(dbgs());
    574     });
    575 
    576     ByteArrayInfo *BAI = nullptr;
    577 
    578     // Lower each call to llvm.type.test for this type identifier.
    579     for (CallInst *CI : TypeTestCallSites[TypeId]) {
    580       ++NumTypeTestCallsLowered;
    581       Value *Lowered =
    582           lowerBitSetCall(CI, BSI, BAI, CombinedGlobalIntAddr, GlobalLayout);
    583       CI->replaceAllUsesWith(Lowered);
    584       CI->eraseFromParent();
    585     }
    586   }
    587 }
    588 
    589 void LowerTypeTests::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) {
    590   if (Type->getNumOperands() != 2)
    591     report_fatal_error(
    592         "All operands of type metadata must have 2 elements");
    593 
    594   if (GO->isThreadLocal())
    595     report_fatal_error("Bit set element may not be thread-local");
    596   if (isa<GlobalVariable>(GO) && GO->hasSection())
    597     report_fatal_error(
    598         "A member of a type identifier may not have an explicit section");
    599 
    600   if (isa<GlobalVariable>(GO) && GO->isDeclarationForLinker())
    601     report_fatal_error(
    602         "A global var member of a type identifier must be a definition");
    603 
    604   auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0));
    605   if (!OffsetConstMD)
    606     report_fatal_error("Type offset must be a constant");
    607   auto OffsetInt = dyn_cast<ConstantInt>(OffsetConstMD->getValue());
    608   if (!OffsetInt)
    609     report_fatal_error("Type offset must be an integer constant");
    610 }
    611 
    612 static const unsigned kX86JumpTableEntrySize = 8;
    613 
    614 unsigned LowerTypeTests::getJumpTableEntrySize() {
    615   if (Arch != Triple::x86 && Arch != Triple::x86_64)
    616     report_fatal_error("Unsupported architecture for jump tables");
    617 
    618   return kX86JumpTableEntrySize;
    619 }
    620 
    621 // Create a constant representing a jump table entry for the target. This
    622 // consists of an instruction sequence containing a relative branch to Dest. The
    623 // constant will be laid out at address Src+(Len*Distance) where Len is the
    624 // target-specific jump table entry size.
    625 Constant *LowerTypeTests::createJumpTableEntry(GlobalObject *Src,
    626                                                Function *Dest,
    627                                                unsigned Distance) {
    628   if (Arch != Triple::x86 && Arch != Triple::x86_64)
    629     report_fatal_error("Unsupported architecture for jump tables");
    630 
    631   const unsigned kJmpPCRel32Code = 0xe9;
    632   const unsigned kInt3Code = 0xcc;
    633 
    634   ConstantInt *Jmp = ConstantInt::get(Int8Ty, kJmpPCRel32Code);
    635 
    636   // Build a constant representing the displacement between the constant's
    637   // address and Dest. This will resolve to a PC32 relocation referring to Dest.
    638   Constant *DestInt = ConstantExpr::getPtrToInt(Dest, IntPtrTy);
    639   Constant *SrcInt = ConstantExpr::getPtrToInt(Src, IntPtrTy);
    640   Constant *Disp = ConstantExpr::getSub(DestInt, SrcInt);
    641   ConstantInt *DispOffset =
    642       ConstantInt::get(IntPtrTy, Distance * kX86JumpTableEntrySize + 5);
    643   Constant *OffsetedDisp = ConstantExpr::getSub(Disp, DispOffset);
    644   OffsetedDisp = ConstantExpr::getTruncOrBitCast(OffsetedDisp, Int32Ty);
    645 
    646   ConstantInt *Int3 = ConstantInt::get(Int8Ty, kInt3Code);
    647 
    648   Constant *Fields[] = {
    649       Jmp, OffsetedDisp, Int3, Int3, Int3,
    650   };
    651   return ConstantStruct::getAnon(Fields, /*Packed=*/true);
    652 }
    653 
    654 Type *LowerTypeTests::getJumpTableEntryType() {
    655   if (Arch != Triple::x86 && Arch != Triple::x86_64)
    656     report_fatal_error("Unsupported architecture for jump tables");
    657 
    658   return StructType::get(M->getContext(),
    659                          {Int8Ty, Int32Ty, Int8Ty, Int8Ty, Int8Ty},
    660                          /*Packed=*/true);
    661 }
    662 
    663 /// Given a disjoint set of type identifiers and functions, build a jump table
    664 /// for the functions, build the bit sets and lower the llvm.type.test calls.
    665 void LowerTypeTests::buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds,
    666                                                ArrayRef<Function *> Functions) {
    667   // Unlike the global bitset builder, the function bitset builder cannot
    668   // re-arrange functions in a particular order and base its calculations on the
    669   // layout of the functions' entry points, as we have no idea how large a
    670   // particular function will end up being (the size could even depend on what
    671   // this pass does!) Instead, we build a jump table, which is a block of code
    672   // consisting of one branch instruction for each of the functions in the bit
    673   // set that branches to the target function, and redirect any taken function
    674   // addresses to the corresponding jump table entry. In the object file's
    675   // symbol table, the symbols for the target functions also refer to the jump
    676   // table entries, so that addresses taken outside the module will pass any
    677   // verification done inside the module.
    678   //
    679   // In more concrete terms, suppose we have three functions f, g, h which are
    680   // of the same type, and a function foo that returns their addresses:
    681   //
    682   // f:
    683   // mov 0, %eax
    684   // ret
    685   //
    686   // g:
    687   // mov 1, %eax
    688   // ret
    689   //
    690   // h:
    691   // mov 2, %eax
    692   // ret
    693   //
    694   // foo:
    695   // mov f, %eax
    696   // mov g, %edx
    697   // mov h, %ecx
    698   // ret
    699   //
    700   // To create a jump table for these functions, we instruct the LLVM code
    701   // generator to output a jump table in the .text section. This is done by
    702   // representing the instructions in the jump table as an LLVM constant and
    703   // placing them in a global variable in the .text section. The end result will
    704   // (conceptually) look like this:
    705   //
    706   // f:
    707   // jmp .Ltmp0 ; 5 bytes
    708   // int3       ; 1 byte
    709   // int3       ; 1 byte
    710   // int3       ; 1 byte
    711   //
    712   // g:
    713   // jmp .Ltmp1 ; 5 bytes
    714   // int3       ; 1 byte
    715   // int3       ; 1 byte
    716   // int3       ; 1 byte
    717   //
    718   // h:
    719   // jmp .Ltmp2 ; 5 bytes
    720   // int3       ; 1 byte
    721   // int3       ; 1 byte
    722   // int3       ; 1 byte
    723   //
    724   // .Ltmp0:
    725   // mov 0, %eax
    726   // ret
    727   //
    728   // .Ltmp1:
    729   // mov 1, %eax
    730   // ret
    731   //
    732   // .Ltmp2:
    733   // mov 2, %eax
    734   // ret
    735   //
    736   // foo:
    737   // mov f, %eax
    738   // mov g, %edx
    739   // mov h, %ecx
    740   // ret
    741   //
    742   // Because the addresses of f, g, h are evenly spaced at a power of 2, in the
    743   // normal case the check can be carried out using the same kind of simple
    744   // arithmetic that we normally use for globals.
    745 
    746   assert(!Functions.empty());
    747 
    748   // Build a simple layout based on the regular layout of jump tables.
    749   DenseMap<GlobalObject *, uint64_t> GlobalLayout;
    750   unsigned EntrySize = getJumpTableEntrySize();
    751   for (unsigned I = 0; I != Functions.size(); ++I)
    752     GlobalLayout[Functions[I]] = I * EntrySize;
    753 
    754   // Create a constant to hold the jump table.
    755   ArrayType *JumpTableType =
    756       ArrayType::get(getJumpTableEntryType(), Functions.size());
    757   auto JumpTable = new GlobalVariable(*M, JumpTableType,
    758                                       /*isConstant=*/true,
    759                                       GlobalValue::PrivateLinkage, nullptr);
    760   JumpTable->setSection(ObjectFormat == Triple::MachO
    761                             ? "__TEXT,__text,regular,pure_instructions"
    762                             : ".text");
    763   lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout);
    764 
    765   // Build aliases pointing to offsets into the jump table, and replace
    766   // references to the original functions with references to the aliases.
    767   for (unsigned I = 0; I != Functions.size(); ++I) {
    768     Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast(
    769         ConstantExpr::getGetElementPtr(
    770             JumpTableType, JumpTable,
    771             ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0),
    772                                  ConstantInt::get(IntPtrTy, I)}),
    773         Functions[I]->getType());
    774     if (LinkerSubsectionsViaSymbols || Functions[I]->isDeclarationForLinker()) {
    775       Functions[I]->replaceAllUsesWith(CombinedGlobalElemPtr);
    776     } else {
    777       assert(Functions[I]->getType()->getAddressSpace() == 0);
    778       GlobalAlias *GAlias = GlobalAlias::create(Functions[I]->getValueType(), 0,
    779                                                 Functions[I]->getLinkage(), "",
    780                                                 CombinedGlobalElemPtr, M);
    781       GAlias->setVisibility(Functions[I]->getVisibility());
    782       GAlias->takeName(Functions[I]);
    783       Functions[I]->replaceAllUsesWith(GAlias);
    784     }
    785     if (!Functions[I]->isDeclarationForLinker())
    786       Functions[I]->setLinkage(GlobalValue::PrivateLinkage);
    787   }
    788 
    789   // Build and set the jump table's initializer.
    790   std::vector<Constant *> JumpTableEntries;
    791   for (unsigned I = 0; I != Functions.size(); ++I)
    792     JumpTableEntries.push_back(
    793         createJumpTableEntry(JumpTable, Functions[I], I));
    794   JumpTable->setInitializer(
    795       ConstantArray::get(JumpTableType, JumpTableEntries));
    796 }
    797 
    798 void LowerTypeTests::buildBitSetsFromDisjointSet(
    799     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalObject *> Globals) {
    800   llvm::DenseMap<Metadata *, uint64_t> TypeIdIndices;
    801   for (unsigned I = 0; I != TypeIds.size(); ++I)
    802     TypeIdIndices[TypeIds[I]] = I;
    803 
    804   // For each type identifier, build a set of indices that refer to members of
    805   // the type identifier.
    806   std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size());
    807   SmallVector<MDNode *, 2> Types;
    808   unsigned GlobalIndex = 0;
    809   for (GlobalObject *GO : Globals) {
    810     Types.clear();
    811     GO->getMetadata(LLVMContext::MD_type, Types);
    812     for (MDNode *Type : Types) {
    813       // Type = { offset, type identifier }
    814       unsigned TypeIdIndex = TypeIdIndices[Type->getOperand(1)];
    815       TypeMembers[TypeIdIndex].insert(GlobalIndex);
    816     }
    817     GlobalIndex++;
    818   }
    819 
    820   // Order the sets of indices by size. The GlobalLayoutBuilder works best
    821   // when given small index sets first.
    822   std::stable_sort(
    823       TypeMembers.begin(), TypeMembers.end(),
    824       [](const std::set<uint64_t> &O1, const std::set<uint64_t> &O2) {
    825         return O1.size() < O2.size();
    826       });
    827 
    828   // Create a GlobalLayoutBuilder and provide it with index sets as layout
    829   // fragments. The GlobalLayoutBuilder tries to lay out members of fragments as
    830   // close together as possible.
    831   GlobalLayoutBuilder GLB(Globals.size());
    832   for (auto &&MemSet : TypeMembers)
    833     GLB.addFragment(MemSet);
    834 
    835   // Build the bitsets from this disjoint set.
    836   if (Globals.empty() || isa<GlobalVariable>(Globals[0])) {
    837     // Build a vector of global variables with the computed layout.
    838     std::vector<GlobalVariable *> OrderedGVs(Globals.size());
    839     auto OGI = OrderedGVs.begin();
    840     for (auto &&F : GLB.Fragments) {
    841       for (auto &&Offset : F) {
    842         auto GV = dyn_cast<GlobalVariable>(Globals[Offset]);
    843         if (!GV)
    844           report_fatal_error("Type identifier may not contain both global "
    845                              "variables and functions");
    846         *OGI++ = GV;
    847       }
    848     }
    849 
    850     buildBitSetsFromGlobalVariables(TypeIds, OrderedGVs);
    851   } else {
    852     // Build a vector of functions with the computed layout.
    853     std::vector<Function *> OrderedFns(Globals.size());
    854     auto OFI = OrderedFns.begin();
    855     for (auto &&F : GLB.Fragments) {
    856       for (auto &&Offset : F) {
    857         auto Fn = dyn_cast<Function>(Globals[Offset]);
    858         if (!Fn)
    859           report_fatal_error("Type identifier may not contain both global "
    860                              "variables and functions");
    861         *OFI++ = Fn;
    862       }
    863     }
    864 
    865     buildBitSetsFromFunctions(TypeIds, OrderedFns);
    866   }
    867 }
    868 
    869 /// Lower all type tests in this module.
    870 bool LowerTypeTests::lower() {
    871   Function *TypeTestFunc =
    872       M->getFunction(Intrinsic::getName(Intrinsic::type_test));
    873   if (!TypeTestFunc || TypeTestFunc->use_empty())
    874     return false;
    875 
    876   // Equivalence class set containing type identifiers and the globals that
    877   // reference them. This is used to partition the set of type identifiers in
    878   // the module into disjoint sets.
    879   typedef EquivalenceClasses<PointerUnion<GlobalObject *, Metadata *>>
    880       GlobalClassesTy;
    881   GlobalClassesTy GlobalClasses;
    882 
    883   // Verify the type metadata and build a mapping from type identifiers to their
    884   // last observed index in the list of globals. This will be used later to
    885   // deterministically order the list of type identifiers.
    886   llvm::DenseMap<Metadata *, unsigned> TypeIdIndices;
    887   unsigned I = 0;
    888   SmallVector<MDNode *, 2> Types;
    889   for (GlobalObject &GO : M->global_objects()) {
    890     Types.clear();
    891     GO.getMetadata(LLVMContext::MD_type, Types);
    892     for (MDNode *Type : Types) {
    893       verifyTypeMDNode(&GO, Type);
    894       TypeIdIndices[cast<MDNode>(Type)->getOperand(1)] = ++I;
    895     }
    896   }
    897 
    898   for (const Use &U : TypeTestFunc->uses()) {
    899     auto CI = cast<CallInst>(U.getUser());
    900 
    901     auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
    902     if (!BitSetMDVal)
    903       report_fatal_error(
    904           "Second argument of llvm.type.test must be metadata");
    905     auto BitSet = BitSetMDVal->getMetadata();
    906 
    907     // Add the call site to the list of call sites for this type identifier. We
    908     // also use TypeTestCallSites to keep track of whether we have seen this
    909     // type identifier before. If we have, we don't need to re-add the
    910     // referenced globals to the equivalence class.
    911     std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, bool>
    912         Ins = TypeTestCallSites.insert(
    913             std::make_pair(BitSet, std::vector<CallInst *>()));
    914     Ins.first->second.push_back(CI);
    915     if (!Ins.second)
    916       continue;
    917 
    918     // Add the type identifier to the equivalence class.
    919     GlobalClassesTy::iterator GCI = GlobalClasses.insert(BitSet);
    920     GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI);
    921 
    922     // Add the referenced globals to the type identifier's equivalence class.
    923     for (GlobalObject &GO : M->global_objects()) {
    924       Types.clear();
    925       GO.getMetadata(LLVMContext::MD_type, Types);
    926       for (MDNode *Type : Types)
    927         if (Type->getOperand(1) == BitSet)
    928           CurSet = GlobalClasses.unionSets(
    929               CurSet, GlobalClasses.findLeader(GlobalClasses.insert(&GO)));
    930     }
    931   }
    932 
    933   if (GlobalClasses.empty())
    934     return false;
    935 
    936   // Build a list of disjoint sets ordered by their maximum global index for
    937   // determinism.
    938   std::vector<std::pair<GlobalClassesTy::iterator, unsigned>> Sets;
    939   for (GlobalClassesTy::iterator I = GlobalClasses.begin(),
    940                                  E = GlobalClasses.end();
    941        I != E; ++I) {
    942     if (!I->isLeader()) continue;
    943     ++NumTypeIdDisjointSets;
    944 
    945     unsigned MaxIndex = 0;
    946     for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I);
    947          MI != GlobalClasses.member_end(); ++MI) {
    948       if ((*MI).is<Metadata *>())
    949         MaxIndex = std::max(MaxIndex, TypeIdIndices[MI->get<Metadata *>()]);
    950     }
    951     Sets.emplace_back(I, MaxIndex);
    952   }
    953   std::sort(Sets.begin(), Sets.end(),
    954             [](const std::pair<GlobalClassesTy::iterator, unsigned> &S1,
    955                const std::pair<GlobalClassesTy::iterator, unsigned> &S2) {
    956               return S1.second < S2.second;
    957             });
    958 
    959   // For each disjoint set we found...
    960   for (const auto &S : Sets) {
    961     // Build the list of type identifiers in this disjoint set.
    962     std::vector<Metadata *> TypeIds;
    963     std::vector<GlobalObject *> Globals;
    964     for (GlobalClassesTy::member_iterator MI =
    965              GlobalClasses.member_begin(S.first);
    966          MI != GlobalClasses.member_end(); ++MI) {
    967       if ((*MI).is<Metadata *>())
    968         TypeIds.push_back(MI->get<Metadata *>());
    969       else
    970         Globals.push_back(MI->get<GlobalObject *>());
    971     }
    972 
    973     // Order type identifiers by global index for determinism. This ordering is
    974     // stable as there is a one-to-one mapping between metadata and indices.
    975     std::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) {
    976       return TypeIdIndices[M1] < TypeIdIndices[M2];
    977     });
    978 
    979     // Build bitsets for this disjoint set.
    980     buildBitSetsFromDisjointSet(TypeIds, Globals);
    981   }
    982 
    983   allocateByteArrays();
    984 
    985   return true;
    986 }
    987 
    988 // Initialization helper shared by the old and the new PM.
    989 static void init(LowerTypeTests *LTT, Module &M) {
    990   LTT->M = &M;
    991   const DataLayout &DL = M.getDataLayout();
    992   Triple TargetTriple(M.getTargetTriple());
    993   LTT->LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX();
    994   LTT->Arch = TargetTriple.getArch();
    995   LTT->ObjectFormat = TargetTriple.getObjectFormat();
    996   LTT->Int1Ty = Type::getInt1Ty(M.getContext());
    997   LTT->Int8Ty = Type::getInt8Ty(M.getContext());
    998   LTT->Int32Ty = Type::getInt32Ty(M.getContext());
    999   LTT->Int32PtrTy = PointerType::getUnqual(LTT->Int32Ty);
   1000   LTT->Int64Ty = Type::getInt64Ty(M.getContext());
   1001   LTT->IntPtrTy = DL.getIntPtrType(M.getContext(), 0);
   1002   LTT->TypeTestCallSites.clear();
   1003 }
   1004 
   1005 bool LowerTypeTests::runOnModule(Module &M) {
   1006   if (skipModule(M))
   1007     return false;
   1008   init(this, M);
   1009   return lower();
   1010 }
   1011 
   1012 PreservedAnalyses LowerTypeTestsPass::run(Module &M,
   1013                                           AnalysisManager<Module> &AM) {
   1014   LowerTypeTests Impl;
   1015   init(&Impl, M);
   1016   bool Changed = Impl.lower();
   1017   if (!Changed)
   1018     return PreservedAnalyses::all();
   1019   return PreservedAnalyses::none();
   1020 }
   1021