Home | History | Annotate | Download | only in R600
      1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 /// \file
     11 /// This pass removes performs the following type substitution on all
     12 /// non-compute shaders:
     13 ///
     14 /// v16i8 => i128
     15 ///   - v16i8 is used for constant memory resource descriptors.  This type is
     16 ///      legal for some compute APIs, and we don't want to declare it as legal
     17 ///      in the backend, because we want the legalizer to expand all v16i8
     18 ///      operations.
     19 /// v1* => *
     20 ///   - Having v1* types complicates the legalizer and we can easily replace
     21 ///   - them with the element type.
     22 //===----------------------------------------------------------------------===//
     23 
     24 #include "AMDGPU.h"
     25 #include "llvm/IR/IRBuilder.h"
     26 #include "llvm/IR/InstVisitor.h"
     27 
     28 using namespace llvm;
     29 
     30 namespace {
     31 
     32 class SITypeRewriter : public FunctionPass,
     33                        public InstVisitor<SITypeRewriter> {
     34 
     35   static char ID;
     36   Module *Mod;
     37   Type *v16i8;
     38   Type *v4i32;
     39 
     40 public:
     41   SITypeRewriter() : FunctionPass(ID) { }
     42   bool doInitialization(Module &M) override;
     43   bool runOnFunction(Function &F) override;
     44   const char *getPassName() const override {
     45     return "SI Type Rewriter";
     46   }
     47   void visitLoadInst(LoadInst &I);
     48   void visitCallInst(CallInst &I);
     49   void visitBitCast(BitCastInst &I);
     50 };
     51 
     52 } // End anonymous namespace
     53 
     54 char SITypeRewriter::ID = 0;
     55 
     56 bool SITypeRewriter::doInitialization(Module &M) {
     57   Mod = &M;
     58   v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
     59   v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
     60   return false;
     61 }
     62 
     63 bool SITypeRewriter::runOnFunction(Function &F) {
     64   AttributeSet Set = F.getAttributes();
     65   Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType");
     66 
     67   unsigned ShaderType = ShaderType::COMPUTE;
     68   if (A.isStringAttribute()) {
     69     StringRef Str = A.getValueAsString();
     70     Str.getAsInteger(0, ShaderType);
     71   }
     72   if (ShaderType == ShaderType::COMPUTE)
     73     return false;
     74 
     75   visit(F);
     76   visit(F);
     77 
     78   return false;
     79 }
     80 
     81 void SITypeRewriter::visitLoadInst(LoadInst &I) {
     82   Value *Ptr = I.getPointerOperand();
     83   Type *PtrTy = Ptr->getType();
     84   Type *ElemTy = PtrTy->getPointerElementType();
     85   IRBuilder<> Builder(&I);
     86   if (ElemTy == v16i8)  {
     87     Value *BitCast = Builder.CreateBitCast(Ptr,
     88         PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
     89     LoadInst *Load = Builder.CreateLoad(BitCast);
     90     SmallVector <std::pair<unsigned, MDNode*>, 8> MD;
     91     I.getAllMetadataOtherThanDebugLoc(MD);
     92     for (unsigned i = 0, e = MD.size(); i != e; ++i) {
     93       Load->setMetadata(MD[i].first, MD[i].second);
     94     }
     95     Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
     96     I.replaceAllUsesWith(BitCastLoad);
     97     I.eraseFromParent();
     98   }
     99 }
    100 
    101 void SITypeRewriter::visitCallInst(CallInst &I) {
    102   IRBuilder<> Builder(&I);
    103 
    104   SmallVector <Value*, 8> Args;
    105   SmallVector <Type*, 8> Types;
    106   bool NeedToReplace = false;
    107   Function *F = I.getCalledFunction();
    108   std::string Name = F->getName().str();
    109   for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
    110     Value *Arg = I.getArgOperand(i);
    111     if (Arg->getType() == v16i8) {
    112       Args.push_back(Builder.CreateBitCast(Arg, v4i32));
    113       Types.push_back(v4i32);
    114       NeedToReplace = true;
    115       Name = Name + ".v4i32";
    116     } else if (Arg->getType()->isVectorTy() &&
    117                Arg->getType()->getVectorNumElements() == 1 &&
    118                Arg->getType()->getVectorElementType() ==
    119                                               Type::getInt32Ty(I.getContext())){
    120       Type *ElementTy = Arg->getType()->getVectorElementType();
    121       std::string TypeName = "i32";
    122       InsertElementInst *Def = cast<InsertElementInst>(Arg);
    123       Args.push_back(Def->getOperand(1));
    124       Types.push_back(ElementTy);
    125       std::string VecTypeName = "v1" + TypeName;
    126       Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
    127       NeedToReplace = true;
    128     } else {
    129       Args.push_back(Arg);
    130       Types.push_back(Arg->getType());
    131     }
    132   }
    133 
    134   if (!NeedToReplace) {
    135     return;
    136   }
    137   Function *NewF = Mod->getFunction(Name);
    138   if (!NewF) {
    139     NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
    140     NewF->setAttributes(F->getAttributes());
    141   }
    142   I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
    143   I.eraseFromParent();
    144 }
    145 
    146 void SITypeRewriter::visitBitCast(BitCastInst &I) {
    147   IRBuilder<> Builder(&I);
    148   if (I.getDestTy() != v4i32) {
    149     return;
    150   }
    151 
    152   if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
    153     if (Op->getSrcTy() == v4i32) {
    154       I.replaceAllUsesWith(Op->getOperand(0));
    155       I.eraseFromParent();
    156     }
    157   }
    158 }
    159 
    160 FunctionPass *llvm::createSITypeRewriter() {
    161   return new SITypeRewriter();
    162 }
    163