1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 /// \file 11 /// This pass removes performs the following type substitution on all 12 /// non-compute shaders: 13 /// 14 /// v16i8 => i128 15 /// - v16i8 is used for constant memory resource descriptors. This type is 16 /// legal for some compute APIs, and we don't want to declare it as legal 17 /// in the backend, because we want the legalizer to expand all v16i8 18 /// operations. 19 /// v1* => * 20 /// - Having v1* types complicates the legalizer and we can easily replace 21 /// - them with the element type. 22 //===----------------------------------------------------------------------===// 23 24 #include "AMDGPU.h" 25 #include "Utils/AMDGPUBaseInfo.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/InstVisitor.h" 28 29 using namespace llvm; 30 31 namespace { 32 33 class SITypeRewriter : public FunctionPass, 34 public InstVisitor<SITypeRewriter> { 35 36 static char ID; 37 Module *Mod; 38 Type *v16i8; 39 Type *v4i32; 40 41 public: 42 SITypeRewriter() : FunctionPass(ID) { } 43 bool doInitialization(Module &M) override; 44 bool runOnFunction(Function &F) override; 45 const char *getPassName() const override { 46 return "SI Type Rewriter"; 47 } 48 void visitLoadInst(LoadInst &I); 49 void visitCallInst(CallInst &I); 50 void visitBitCast(BitCastInst &I); 51 }; 52 53 } // End anonymous namespace 54 55 char SITypeRewriter::ID = 0; 56 57 bool SITypeRewriter::doInitialization(Module &M) { 58 Mod = &M; 59 v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16); 60 v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4); 61 return false; 62 } 63 64 bool SITypeRewriter::runOnFunction(Function &F) { 65 if (!AMDGPU::isShader(F.getCallingConv())) 66 return false; 67 68 visit(F); 69 visit(F); 70 71 return false; 72 } 73 74 void SITypeRewriter::visitLoadInst(LoadInst &I) { 75 Value *Ptr = I.getPointerOperand(); 76 Type *PtrTy = Ptr->getType(); 77 Type *ElemTy = PtrTy->getPointerElementType(); 78 IRBuilder<> Builder(&I); 79 if (ElemTy == v16i8) { 80 Value *BitCast = Builder.CreateBitCast(Ptr, 81 PointerType::get(v4i32,PtrTy->getPointerAddressSpace())); 82 LoadInst *Load = Builder.CreateLoad(BitCast); 83 SmallVector<std::pair<unsigned, MDNode *>, 8> MD; 84 I.getAllMetadataOtherThanDebugLoc(MD); 85 for (unsigned i = 0, e = MD.size(); i != e; ++i) { 86 Load->setMetadata(MD[i].first, MD[i].second); 87 } 88 Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType()); 89 I.replaceAllUsesWith(BitCastLoad); 90 I.eraseFromParent(); 91 } 92 } 93 94 void SITypeRewriter::visitCallInst(CallInst &I) { 95 IRBuilder<> Builder(&I); 96 97 SmallVector <Value*, 8> Args; 98 SmallVector <Type*, 8> Types; 99 bool NeedToReplace = false; 100 Function *F = I.getCalledFunction(); 101 if (!F) 102 return; 103 104 std::string Name = F->getName(); 105 for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) { 106 Value *Arg = I.getArgOperand(i); 107 if (Arg->getType() == v16i8) { 108 Args.push_back(Builder.CreateBitCast(Arg, v4i32)); 109 Types.push_back(v4i32); 110 NeedToReplace = true; 111 Name = Name + ".v4i32"; 112 } else if (Arg->getType()->isVectorTy() && 113 Arg->getType()->getVectorNumElements() == 1 && 114 Arg->getType()->getVectorElementType() == 115 Type::getInt32Ty(I.getContext())){ 116 Type *ElementTy = Arg->getType()->getVectorElementType(); 117 std::string TypeName = "i32"; 118 InsertElementInst *Def = cast<InsertElementInst>(Arg); 119 Args.push_back(Def->getOperand(1)); 120 Types.push_back(ElementTy); 121 std::string VecTypeName = "v1" + TypeName; 122 Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName); 123 NeedToReplace = true; 124 } else { 125 Args.push_back(Arg); 126 Types.push_back(Arg->getType()); 127 } 128 } 129 130 if (!NeedToReplace) { 131 return; 132 } 133 Function *NewF = Mod->getFunction(Name); 134 if (!NewF) { 135 NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod); 136 NewF->setAttributes(F->getAttributes()); 137 } 138 I.replaceAllUsesWith(Builder.CreateCall(NewF, Args)); 139 I.eraseFromParent(); 140 } 141 142 void SITypeRewriter::visitBitCast(BitCastInst &I) { 143 IRBuilder<> Builder(&I); 144 if (I.getDestTy() != v4i32) { 145 return; 146 } 147 148 if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) { 149 if (Op->getSrcTy() == v4i32) { 150 I.replaceAllUsesWith(Op->getOperand(0)); 151 I.eraseFromParent(); 152 } 153 } 154 } 155 156 FunctionPass *llvm::createSITypeRewriter() { 157 return new SITypeRewriter(); 158 } 159