Home | History | Annotate | Download | only in AMDGPU
      1 //===-- AMDGPUPromoteAlloca.cpp - Promote Allocas -------------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This pass eliminates allocas by either converting them into vectors or
     11 // by migrating them to local address space.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #include "AMDGPU.h"
     16 #include "AMDGPUSubtarget.h"
     17 #include "llvm/Analysis/ValueTracking.h"
     18 #include "llvm/IR/IRBuilder.h"
     19 #include "llvm/IR/InstVisitor.h"
     20 #include "llvm/Support/Debug.h"
     21 #include "llvm/Support/raw_ostream.h"
     22 
     23 #define DEBUG_TYPE "amdgpu-promote-alloca"
     24 
     25 using namespace llvm;
     26 
     27 namespace {
     28 
     29 class AMDGPUPromoteAlloca : public FunctionPass,
     30                        public InstVisitor<AMDGPUPromoteAlloca> {
     31 
     32   static char ID;
     33   Module *Mod;
     34   const AMDGPUSubtarget &ST;
     35   int LocalMemAvailable;
     36 
     37 public:
     38   AMDGPUPromoteAlloca(const AMDGPUSubtarget &st) : FunctionPass(ID), ST(st),
     39                                                    LocalMemAvailable(0) { }
     40   bool doInitialization(Module &M) override;
     41   bool runOnFunction(Function &F) override;
     42   const char *getPassName() const override { return "AMDGPU Promote Alloca"; }
     43   void visitAlloca(AllocaInst &I);
     44 };
     45 
     46 } // End anonymous namespace
     47 
     48 char AMDGPUPromoteAlloca::ID = 0;
     49 
     50 bool AMDGPUPromoteAlloca::doInitialization(Module &M) {
     51   Mod = &M;
     52   return false;
     53 }
     54 
     55 bool AMDGPUPromoteAlloca::runOnFunction(Function &F) {
     56 
     57   FunctionType *FTy = F.getFunctionType();
     58 
     59   LocalMemAvailable = ST.getLocalMemorySize();
     60 
     61 
     62   // If the function has any arguments in the local address space, then it's
     63   // possible these arguments require the entire local memory space, so
     64   // we cannot use local memory in the pass.
     65   for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) {
     66     Type *ParamTy = FTy->getParamType(i);
     67     if (ParamTy->isPointerTy() &&
     68         ParamTy->getPointerAddressSpace() == AMDGPUAS::LOCAL_ADDRESS) {
     69       LocalMemAvailable = 0;
     70       DEBUG(dbgs() << "Function has local memory argument.  Promoting to "
     71                       "local memory disabled.\n");
     72       break;
     73     }
     74   }
     75 
     76   if (LocalMemAvailable > 0) {
     77     // Check how much local memory is being used by global objects
     78     for (Module::global_iterator I = Mod->global_begin(),
     79                                  E = Mod->global_end(); I != E; ++I) {
     80       GlobalVariable *GV = &*I;
     81       PointerType *GVTy = GV->getType();
     82       if (GVTy->getAddressSpace() != AMDGPUAS::LOCAL_ADDRESS)
     83         continue;
     84       for (Value::use_iterator U = GV->use_begin(),
     85                                UE = GV->use_end(); U != UE; ++U) {
     86         Instruction *Use = dyn_cast<Instruction>(*U);
     87         if (!Use)
     88           continue;
     89         if (Use->getParent()->getParent() == &F)
     90           LocalMemAvailable -=
     91               Mod->getDataLayout().getTypeAllocSize(GVTy->getElementType());
     92       }
     93     }
     94   }
     95 
     96   LocalMemAvailable = std::max(0, LocalMemAvailable);
     97   DEBUG(dbgs() << LocalMemAvailable << "bytes free in local memory.\n");
     98 
     99   visit(F);
    100 
    101   return false;
    102 }
    103 
    104 static VectorType *arrayTypeToVecType(Type *ArrayTy) {
    105   return VectorType::get(ArrayTy->getArrayElementType(),
    106                          ArrayTy->getArrayNumElements());
    107 }
    108 
    109 static Value *
    110 calculateVectorIndex(Value *Ptr,
    111                      const std::map<GetElementPtrInst *, Value *> &GEPIdx) {
    112   if (isa<AllocaInst>(Ptr))
    113     return Constant::getNullValue(Type::getInt32Ty(Ptr->getContext()));
    114 
    115   GetElementPtrInst *GEP = cast<GetElementPtrInst>(Ptr);
    116 
    117   auto I = GEPIdx.find(GEP);
    118   return I == GEPIdx.end() ? nullptr : I->second;
    119 }
    120 
    121 static Value* GEPToVectorIndex(GetElementPtrInst *GEP) {
    122   // FIXME we only support simple cases
    123   if (GEP->getNumOperands() != 3)
    124     return NULL;
    125 
    126   ConstantInt *I0 = dyn_cast<ConstantInt>(GEP->getOperand(1));
    127   if (!I0 || !I0->isZero())
    128     return NULL;
    129 
    130   return GEP->getOperand(2);
    131 }
    132 
    133 // Not an instruction handled below to turn into a vector.
    134 //
    135 // TODO: Check isTriviallyVectorizable for calls and handle other
    136 // instructions.
    137 static bool canVectorizeInst(Instruction *Inst, User *User) {
    138   switch (Inst->getOpcode()) {
    139   case Instruction::Load:
    140   case Instruction::BitCast:
    141   case Instruction::AddrSpaceCast:
    142     return true;
    143   case Instruction::Store: {
    144     // Must be the stored pointer operand, not a stored value.
    145     StoreInst *SI = cast<StoreInst>(Inst);
    146     return SI->getPointerOperand() == User;
    147   }
    148   default:
    149     return false;
    150   }
    151 }
    152 
    153 static bool tryPromoteAllocaToVector(AllocaInst *Alloca) {
    154   Type *AllocaTy = Alloca->getAllocatedType();
    155 
    156   DEBUG(dbgs() << "Alloca Candidate for vectorization \n");
    157 
    158   // FIXME: There is no reason why we can't support larger arrays, we
    159   // are just being conservative for now.
    160   if (!AllocaTy->isArrayTy() ||
    161       AllocaTy->getArrayElementType()->isVectorTy() ||
    162       AllocaTy->getArrayNumElements() > 4) {
    163 
    164     DEBUG(dbgs() << "  Cannot convert type to vector");
    165     return false;
    166   }
    167 
    168   std::map<GetElementPtrInst*, Value*> GEPVectorIdx;
    169   std::vector<Value*> WorkList;
    170   for (User *AllocaUser : Alloca->users()) {
    171     GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(AllocaUser);
    172     if (!GEP) {
    173       if (!canVectorizeInst(cast<Instruction>(AllocaUser), Alloca))
    174         return false;
    175 
    176       WorkList.push_back(AllocaUser);
    177       continue;
    178     }
    179 
    180     Value *Index = GEPToVectorIndex(GEP);
    181 
    182     // If we can't compute a vector index from this GEP, then we can't
    183     // promote this alloca to vector.
    184     if (!Index) {
    185       DEBUG(dbgs() << "  Cannot compute vector index for GEP " << *GEP << '\n');
    186       return false;
    187     }
    188 
    189     GEPVectorIdx[GEP] = Index;
    190     for (User *GEPUser : AllocaUser->users()) {
    191       if (!canVectorizeInst(cast<Instruction>(GEPUser), AllocaUser))
    192         return false;
    193 
    194       WorkList.push_back(GEPUser);
    195     }
    196   }
    197 
    198   VectorType *VectorTy = arrayTypeToVecType(AllocaTy);
    199 
    200   DEBUG(dbgs() << "  Converting alloca to vector "
    201         << *AllocaTy << " -> " << *VectorTy << '\n');
    202 
    203   for (std::vector<Value*>::iterator I = WorkList.begin(),
    204                                      E = WorkList.end(); I != E; ++I) {
    205     Instruction *Inst = cast<Instruction>(*I);
    206     IRBuilder<> Builder(Inst);
    207     switch (Inst->getOpcode()) {
    208     case Instruction::Load: {
    209       Value *Ptr = Inst->getOperand(0);
    210       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
    211       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
    212       Value *VecValue = Builder.CreateLoad(BitCast);
    213       Value *ExtractElement = Builder.CreateExtractElement(VecValue, Index);
    214       Inst->replaceAllUsesWith(ExtractElement);
    215       Inst->eraseFromParent();
    216       break;
    217     }
    218     case Instruction::Store: {
    219       Value *Ptr = Inst->getOperand(1);
    220       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
    221       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
    222       Value *VecValue = Builder.CreateLoad(BitCast);
    223       Value *NewVecValue = Builder.CreateInsertElement(VecValue,
    224                                                        Inst->getOperand(0),
    225                                                        Index);
    226       Builder.CreateStore(NewVecValue, BitCast);
    227       Inst->eraseFromParent();
    228       break;
    229     }
    230     case Instruction::BitCast:
    231     case Instruction::AddrSpaceCast:
    232       break;
    233 
    234     default:
    235       Inst->dump();
    236       llvm_unreachable("Inconsistency in instructions promotable to vector");
    237     }
    238   }
    239   return true;
    240 }
    241 
    242 static bool collectUsesWithPtrTypes(Value *Val, std::vector<Value*> &WorkList) {
    243   bool Success = true;
    244   for (User *User : Val->users()) {
    245     if(std::find(WorkList.begin(), WorkList.end(), User) != WorkList.end())
    246       continue;
    247     if (CallInst *CI = dyn_cast<CallInst>(User)) {
    248       // TODO: We might be able to handle some cases where the callee is a
    249       // constantexpr bitcast of a function.
    250       if (!CI->getCalledFunction())
    251         return false;
    252 
    253       WorkList.push_back(User);
    254       continue;
    255     }
    256 
    257     // FIXME: Correctly handle ptrtoint instructions.
    258     Instruction *UseInst = dyn_cast<Instruction>(User);
    259     if (UseInst && UseInst->getOpcode() == Instruction::PtrToInt)
    260       return false;
    261 
    262     if (StoreInst *SI = dyn_cast_or_null<StoreInst>(UseInst)) {
    263       // Reject if the stored value is not the pointer operand.
    264       if (SI->getPointerOperand() != Val)
    265         return false;
    266     }
    267 
    268     if (!User->getType()->isPointerTy())
    269       continue;
    270 
    271     WorkList.push_back(User);
    272 
    273     Success &= collectUsesWithPtrTypes(User, WorkList);
    274   }
    275   return Success;
    276 }
    277 
    278 void AMDGPUPromoteAlloca::visitAlloca(AllocaInst &I) {
    279   if (!I.isStaticAlloca())
    280     return;
    281 
    282   IRBuilder<> Builder(&I);
    283 
    284   // First try to replace the alloca with a vector
    285   Type *AllocaTy = I.getAllocatedType();
    286 
    287   DEBUG(dbgs() << "Trying to promote " << I << '\n');
    288 
    289   if (tryPromoteAllocaToVector(&I))
    290     return;
    291 
    292   DEBUG(dbgs() << " alloca is not a candidate for vectorization.\n");
    293 
    294   // FIXME: This is the maximum work group size.  We should try to get
    295   // value from the reqd_work_group_size function attribute if it is
    296   // available.
    297   unsigned WorkGroupSize = 256;
    298   int AllocaSize =
    299       WorkGroupSize * Mod->getDataLayout().getTypeAllocSize(AllocaTy);
    300 
    301   if (AllocaSize > LocalMemAvailable) {
    302     DEBUG(dbgs() << " Not enough local memory to promote alloca.\n");
    303     return;
    304   }
    305 
    306   std::vector<Value*> WorkList;
    307 
    308   if (!collectUsesWithPtrTypes(&I, WorkList)) {
    309     DEBUG(dbgs() << " Do not know how to convert all uses\n");
    310     return;
    311   }
    312 
    313   DEBUG(dbgs() << "Promoting alloca to local memory\n");
    314   LocalMemAvailable -= AllocaSize;
    315 
    316   Type *GVTy = ArrayType::get(I.getAllocatedType(), 256);
    317   GlobalVariable *GV = new GlobalVariable(
    318       *Mod, GVTy, false, GlobalValue::ExternalLinkage, 0, I.getName(), 0,
    319       GlobalVariable::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS);
    320 
    321   FunctionType *FTy = FunctionType::get(
    322       Type::getInt32Ty(Mod->getContext()), false);
    323   AttributeSet AttrSet;
    324   AttrSet.addAttribute(Mod->getContext(), 0, Attribute::ReadNone);
    325 
    326   Value *ReadLocalSizeY = Mod->getOrInsertFunction(
    327       "llvm.r600.read.local.size.y", FTy, AttrSet);
    328   Value *ReadLocalSizeZ = Mod->getOrInsertFunction(
    329       "llvm.r600.read.local.size.z", FTy, AttrSet);
    330   Value *ReadTIDIGX = Mod->getOrInsertFunction(
    331       "llvm.r600.read.tidig.x", FTy, AttrSet);
    332   Value *ReadTIDIGY = Mod->getOrInsertFunction(
    333       "llvm.r600.read.tidig.y", FTy, AttrSet);
    334   Value *ReadTIDIGZ = Mod->getOrInsertFunction(
    335       "llvm.r600.read.tidig.z", FTy, AttrSet);
    336 
    337   Value *TCntY = Builder.CreateCall(ReadLocalSizeY, {});
    338   Value *TCntZ = Builder.CreateCall(ReadLocalSizeZ, {});
    339   Value *TIdX = Builder.CreateCall(ReadTIDIGX, {});
    340   Value *TIdY = Builder.CreateCall(ReadTIDIGY, {});
    341   Value *TIdZ = Builder.CreateCall(ReadTIDIGZ, {});
    342 
    343   Value *Tmp0 = Builder.CreateMul(TCntY, TCntZ);
    344   Tmp0 = Builder.CreateMul(Tmp0, TIdX);
    345   Value *Tmp1 = Builder.CreateMul(TIdY, TCntZ);
    346   Value *TID = Builder.CreateAdd(Tmp0, Tmp1);
    347   TID = Builder.CreateAdd(TID, TIdZ);
    348 
    349   std::vector<Value*> Indices;
    350   Indices.push_back(Constant::getNullValue(Type::getInt32Ty(Mod->getContext())));
    351   Indices.push_back(TID);
    352 
    353   Value *Offset = Builder.CreateGEP(GVTy, GV, Indices);
    354   I.mutateType(Offset->getType());
    355   I.replaceAllUsesWith(Offset);
    356   I.eraseFromParent();
    357 
    358   for (std::vector<Value*>::iterator i = WorkList.begin(),
    359                                      e = WorkList.end(); i != e; ++i) {
    360     Value *V = *i;
    361     CallInst *Call = dyn_cast<CallInst>(V);
    362     if (!Call) {
    363       Type *EltTy = V->getType()->getPointerElementType();
    364       PointerType *NewTy = PointerType::get(EltTy, AMDGPUAS::LOCAL_ADDRESS);
    365 
    366       // The operand's value should be corrected on its own.
    367       if (isa<AddrSpaceCastInst>(V))
    368         continue;
    369 
    370       // FIXME: It doesn't really make sense to try to do this for all
    371       // instructions.
    372       V->mutateType(NewTy);
    373       continue;
    374     }
    375 
    376     IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Call);
    377     if (!Intr) {
    378       std::vector<Type*> ArgTypes;
    379       for (unsigned ArgIdx = 0, ArgEnd = Call->getNumArgOperands();
    380                                 ArgIdx != ArgEnd; ++ArgIdx) {
    381         ArgTypes.push_back(Call->getArgOperand(ArgIdx)->getType());
    382       }
    383       Function *F = Call->getCalledFunction();
    384       FunctionType *NewType = FunctionType::get(Call->getType(), ArgTypes,
    385                                                 F->isVarArg());
    386       Constant *C = Mod->getOrInsertFunction((F->getName() + ".local").str(),
    387                                              NewType, F->getAttributes());
    388       Function *NewF = cast<Function>(C);
    389       Call->setCalledFunction(NewF);
    390       continue;
    391     }
    392 
    393     Builder.SetInsertPoint(Intr);
    394     switch (Intr->getIntrinsicID()) {
    395     case Intrinsic::lifetime_start:
    396     case Intrinsic::lifetime_end:
    397       // These intrinsics are for address space 0 only
    398       Intr->eraseFromParent();
    399       continue;
    400     case Intrinsic::memcpy: {
    401       MemCpyInst *MemCpy = cast<MemCpyInst>(Intr);
    402       Builder.CreateMemCpy(MemCpy->getRawDest(), MemCpy->getRawSource(),
    403                            MemCpy->getLength(), MemCpy->getAlignment(),
    404                            MemCpy->isVolatile());
    405       Intr->eraseFromParent();
    406       continue;
    407     }
    408     case Intrinsic::memset: {
    409       MemSetInst *MemSet = cast<MemSetInst>(Intr);
    410       Builder.CreateMemSet(MemSet->getRawDest(), MemSet->getValue(),
    411                            MemSet->getLength(), MemSet->getAlignment(),
    412                            MemSet->isVolatile());
    413       Intr->eraseFromParent();
    414       continue;
    415     }
    416     default:
    417       Intr->dump();
    418       llvm_unreachable("Don't know how to promote alloca intrinsic use.");
    419     }
    420   }
    421 }
    422 
    423 FunctionPass *llvm::createAMDGPUPromoteAlloca(const AMDGPUSubtarget &ST) {
    424   return new AMDGPUPromoteAlloca(ST);
    425 }
    426