1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 /// \file 11 /// This pass removes performs the following type substitution on all 12 /// non-compute shaders: 13 /// 14 /// v16i8 => i128 15 /// - v16i8 is used for constant memory resource descriptors. This type is 16 /// legal for some compute APIs, and we don't want to declare it as legal 17 /// in the backend, because we want the legalizer to expand all v16i8 18 /// operations. 19 /// v1* => * 20 /// - Having v1* types complicates the legalizer and we can easily replace 21 /// - them with the element type. 22 //===----------------------------------------------------------------------===// 23 24 #include "AMDGPU.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/InstVisitor.h" 27 28 using namespace llvm; 29 30 namespace { 31 32 class SITypeRewriter : public FunctionPass, 33 public InstVisitor<SITypeRewriter> { 34 35 static char ID; 36 Module *Mod; 37 Type *v16i8; 38 Type *v4i32; 39 40 public: 41 SITypeRewriter() : FunctionPass(ID) { } 42 bool doInitialization(Module &M) override; 43 bool runOnFunction(Function &F) override; 44 const char *getPassName() const override { 45 return "SI Type Rewriter"; 46 } 47 void visitLoadInst(LoadInst &I); 48 void visitCallInst(CallInst &I); 49 void visitBitCast(BitCastInst &I); 50 }; 51 52 } // End anonymous namespace 53 54 char SITypeRewriter::ID = 0; 55 56 bool SITypeRewriter::doInitialization(Module &M) { 57 Mod = &M; 58 v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16); 59 v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4); 60 return false; 61 } 62 63 bool SITypeRewriter::runOnFunction(Function &F) { 64 AttributeSet Set = F.getAttributes(); 65 Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType"); 66 67 unsigned ShaderType = ShaderType::COMPUTE; 68 if (A.isStringAttribute()) { 69 StringRef Str = A.getValueAsString(); 70 Str.getAsInteger(0, ShaderType); 71 } 72 if (ShaderType == ShaderType::COMPUTE) 73 return false; 74 75 visit(F); 76 visit(F); 77 78 return false; 79 } 80 81 void SITypeRewriter::visitLoadInst(LoadInst &I) { 82 Value *Ptr = I.getPointerOperand(); 83 Type *PtrTy = Ptr->getType(); 84 Type *ElemTy = PtrTy->getPointerElementType(); 85 IRBuilder<> Builder(&I); 86 if (ElemTy == v16i8) { 87 Value *BitCast = Builder.CreateBitCast(Ptr, 88 PointerType::get(v4i32,PtrTy->getPointerAddressSpace())); 89 LoadInst *Load = Builder.CreateLoad(BitCast); 90 SmallVector <std::pair<unsigned, MDNode*>, 8> MD; 91 I.getAllMetadataOtherThanDebugLoc(MD); 92 for (unsigned i = 0, e = MD.size(); i != e; ++i) { 93 Load->setMetadata(MD[i].first, MD[i].second); 94 } 95 Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType()); 96 I.replaceAllUsesWith(BitCastLoad); 97 I.eraseFromParent(); 98 } 99 } 100 101 void SITypeRewriter::visitCallInst(CallInst &I) { 102 IRBuilder<> Builder(&I); 103 104 SmallVector <Value*, 8> Args; 105 SmallVector <Type*, 8> Types; 106 bool NeedToReplace = false; 107 Function *F = I.getCalledFunction(); 108 std::string Name = F->getName().str(); 109 for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) { 110 Value *Arg = I.getArgOperand(i); 111 if (Arg->getType() == v16i8) { 112 Args.push_back(Builder.CreateBitCast(Arg, v4i32)); 113 Types.push_back(v4i32); 114 NeedToReplace = true; 115 Name = Name + ".v4i32"; 116 } else if (Arg->getType()->isVectorTy() && 117 Arg->getType()->getVectorNumElements() == 1 && 118 Arg->getType()->getVectorElementType() == 119 Type::getInt32Ty(I.getContext())){ 120 Type *ElementTy = Arg->getType()->getVectorElementType(); 121 std::string TypeName = "i32"; 122 InsertElementInst *Def = cast<InsertElementInst>(Arg); 123 Args.push_back(Def->getOperand(1)); 124 Types.push_back(ElementTy); 125 std::string VecTypeName = "v1" + TypeName; 126 Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName); 127 NeedToReplace = true; 128 } else { 129 Args.push_back(Arg); 130 Types.push_back(Arg->getType()); 131 } 132 } 133 134 if (!NeedToReplace) { 135 return; 136 } 137 Function *NewF = Mod->getFunction(Name); 138 if (!NewF) { 139 NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod); 140 NewF->setAttributes(F->getAttributes()); 141 } 142 I.replaceAllUsesWith(Builder.CreateCall(NewF, Args)); 143 I.eraseFromParent(); 144 } 145 146 void SITypeRewriter::visitBitCast(BitCastInst &I) { 147 IRBuilder<> Builder(&I); 148 if (I.getDestTy() != v4i32) { 149 return; 150 } 151 152 if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) { 153 if (Op->getSrcTy() == v4i32) { 154 I.replaceAllUsesWith(Op->getOperand(0)); 155 I.eraseFromParent(); 156 } 157 } 158 } 159 160 FunctionPass *llvm::createSITypeRewriter() { 161 return new SITypeRewriter(); 162 } 163