Home | History | Annotate | Download | only in AMDGPU
      1 //===- AMDGPURewriteOutArgumentsPass.cpp - Create struct returns ----------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 /// \file This pass attempts to replace out argument usage with a return of a
     11 /// struct.
     12 ///
     13 /// We can support returning a lot of values directly in registers, but
     14 /// idiomatic C code frequently uses a pointer argument to return a second value
     15 /// rather than returning a struct by value. GPU stack access is also quite
     16 /// painful, so we want to avoid that if possible. Passing a stack object
     17 /// pointer to a function also requires an additional address expansion code
     18 /// sequence to convert the pointer to be relative to the kernel's scratch wave
     19 /// offset register since the callee doesn't know what stack frame the incoming
     20 /// pointer is relative to.
     21 ///
     22 /// The goal is to try rewriting code that looks like this:
     23 ///
     24 ///  int foo(int a, int b, int* out) {
     25 ///     *out = bar();
     26 ///     return a + b;
     27 /// }
     28 ///
     29 /// into something like this:
     30 ///
     31 ///  std::pair<int, int> foo(int a, int b) {
     32 ///     return std::make_pair(a + b, bar());
     33 /// }
     34 ///
     35 /// Typically the incoming pointer is a simple alloca for a temporary variable
     36 /// to use the API, which if replaced with a struct return will be easily SROA'd
     37 /// out when the stub function we create is inlined
     38 ///
     39 /// This pass introduces the struct return, but leaves the unused pointer
     40 /// arguments and introduces a new stub function calling the struct returning
     41 /// body. DeadArgumentElimination should be run after this to clean these up.
     42 //
     43 //===----------------------------------------------------------------------===//
     44 
     45 #include "AMDGPU.h"
     46 #include "Utils/AMDGPUBaseInfo.h"
     47 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
     48 #include "llvm/ADT/DenseMap.h"
     49 #include "llvm/ADT/STLExtras.h"
     50 #include "llvm/ADT/SmallSet.h"
     51 #include "llvm/ADT/SmallVector.h"
     52 #include "llvm/ADT/Statistic.h"
     53 #include "llvm/Analysis/MemoryLocation.h"
     54 #include "llvm/IR/Argument.h"
     55 #include "llvm/IR/Attributes.h"
     56 #include "llvm/IR/BasicBlock.h"
     57 #include "llvm/IR/Constants.h"
     58 #include "llvm/IR/DataLayout.h"
     59 #include "llvm/IR/DerivedTypes.h"
     60 #include "llvm/IR/Function.h"
     61 #include "llvm/IR/IRBuilder.h"
     62 #include "llvm/IR/Instructions.h"
     63 #include "llvm/IR/Module.h"
     64 #include "llvm/IR/Type.h"
     65 #include "llvm/IR/Use.h"
     66 #include "llvm/IR/User.h"
     67 #include "llvm/IR/Value.h"
     68 #include "llvm/Pass.h"
     69 #include "llvm/Support/Casting.h"
     70 #include "llvm/Support/CommandLine.h"
     71 #include "llvm/Support/Debug.h"
     72 #include "llvm/Support/raw_ostream.h"
     73 #include <cassert>
     74 #include <utility>
     75 
     76 #define DEBUG_TYPE "amdgpu-rewrite-out-arguments"
     77 
     78 using namespace llvm;
     79 
     80 static cl::opt<bool> AnyAddressSpace(
     81   "amdgpu-any-address-space-out-arguments",
     82   cl::desc("Replace pointer out arguments with "
     83            "struct returns for non-private address space"),
     84   cl::Hidden,
     85   cl::init(false));
     86 
     87 static cl::opt<unsigned> MaxNumRetRegs(
     88   "amdgpu-max-return-arg-num-regs",
     89   cl::desc("Approximately limit number of return registers for replacing out arguments"),
     90   cl::Hidden,
     91   cl::init(16));
     92 
     93 STATISTIC(NumOutArgumentsReplaced,
     94           "Number out arguments moved to struct return values");
     95 STATISTIC(NumOutArgumentFunctionsReplaced,
     96           "Number of functions with out arguments moved to struct return values");
     97 
     98 namespace {
     99 
    100 class AMDGPURewriteOutArguments : public FunctionPass {
    101 private:
    102   const DataLayout *DL = nullptr;
    103   MemoryDependenceResults *MDA = nullptr;
    104 
    105   bool checkArgumentUses(Value &Arg) const;
    106   bool isOutArgumentCandidate(Argument &Arg) const;
    107 
    108 #ifndef NDEBUG
    109   bool isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const;
    110 #endif
    111 
    112 public:
    113   static char ID;
    114 
    115   AMDGPURewriteOutArguments() : FunctionPass(ID) {}
    116 
    117   void getAnalysisUsage(AnalysisUsage &AU) const override {
    118     AU.addRequired<MemoryDependenceWrapperPass>();
    119     FunctionPass::getAnalysisUsage(AU);
    120   }
    121 
    122   bool doInitialization(Module &M) override;
    123   bool runOnFunction(Function &F) override;
    124 };
    125 
    126 } // end anonymous namespace
    127 
    128 INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments, DEBUG_TYPE,
    129                       "AMDGPU Rewrite Out Arguments", false, false)
    130 INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
    131 INITIALIZE_PASS_END(AMDGPURewriteOutArguments, DEBUG_TYPE,
    132                     "AMDGPU Rewrite Out Arguments", false, false)
    133 
    134 char AMDGPURewriteOutArguments::ID = 0;
    135 
    136 bool AMDGPURewriteOutArguments::checkArgumentUses(Value &Arg) const {
    137   const int MaxUses = 10;
    138   int UseCount = 0;
    139 
    140   for (Use &U : Arg.uses()) {
    141     StoreInst *SI = dyn_cast<StoreInst>(U.getUser());
    142     if (UseCount > MaxUses)
    143       return false;
    144 
    145     if (!SI) {
    146       auto *BCI = dyn_cast<BitCastInst>(U.getUser());
    147       if (!BCI || !BCI->hasOneUse())
    148         return false;
    149 
    150       // We don't handle multiple stores currently, so stores to aggregate
    151       // pointers aren't worth the trouble since they are canonically split up.
    152       Type *DestEltTy = BCI->getType()->getPointerElementType();
    153       if (DestEltTy->isAggregateType())
    154         return false;
    155 
    156       // We could handle these if we had a convenient way to bitcast between
    157       // them.
    158       Type *SrcEltTy = Arg.getType()->getPointerElementType();
    159       if (SrcEltTy->isArrayTy())
    160         return false;
    161 
    162       // Special case handle structs with single members. It is useful to handle
    163       // some casts between structs and non-structs, but we can't bitcast
    164       // directly between them.  directly bitcast between them.  Blender uses
    165       // some casts that look like { <3 x float> }* to <4 x float>*
    166       if ((SrcEltTy->isStructTy() && (SrcEltTy->getNumContainedTypes() != 1)))
    167         return false;
    168 
    169       // Clang emits OpenCL 3-vector type accesses with a bitcast to the
    170       // equivalent 4-element vector and accesses that, and we're looking for
    171       // this pointer cast.
    172       if (DL->getTypeAllocSize(SrcEltTy) != DL->getTypeAllocSize(DestEltTy))
    173         return false;
    174 
    175       return checkArgumentUses(*BCI);
    176     }
    177 
    178     if (!SI->isSimple() ||
    179         U.getOperandNo() != StoreInst::getPointerOperandIndex())
    180       return false;
    181 
    182     ++UseCount;
    183   }
    184 
    185   // Skip unused arguments.
    186   return UseCount > 0;
    187 }
    188 
    189 bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const {
    190   const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
    191   PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
    192 
    193   // TODO: It might be useful for any out arguments, not just privates.
    194   if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
    195                  !AnyAddressSpace) ||
    196       Arg.hasByValAttr() || Arg.hasStructRetAttr() ||
    197       DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) {
    198     return false;
    199   }
    200 
    201   return checkArgumentUses(Arg);
    202 }
    203 
    204 bool AMDGPURewriteOutArguments::doInitialization(Module &M) {
    205   DL = &M.getDataLayout();
    206   return false;
    207 }
    208 
    209 #ifndef NDEBUG
    210 bool AMDGPURewriteOutArguments::isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const {
    211   VectorType *VT0 = dyn_cast<VectorType>(Ty0);
    212   VectorType *VT1 = dyn_cast<VectorType>(Ty1);
    213   if (!VT0 || !VT1)
    214     return false;
    215 
    216   if (VT0->getNumElements() != 3 ||
    217       VT1->getNumElements() != 4)
    218     return false;
    219 
    220   return DL->getTypeSizeInBits(VT0->getElementType()) ==
    221          DL->getTypeSizeInBits(VT1->getElementType());
    222 }
    223 #endif
    224 
    225 bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {
    226   if (skipFunction(F))
    227     return false;
    228 
    229   // TODO: Could probably handle variadic functions.
    230   if (F.isVarArg() || F.hasStructRetAttr() ||
    231       AMDGPU::isEntryFunctionCC(F.getCallingConv()))
    232     return false;
    233 
    234   MDA = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
    235 
    236   unsigned ReturnNumRegs = 0;
    237   SmallSet<int, 4> OutArgIndexes;
    238   SmallVector<Type *, 4> ReturnTypes;
    239   Type *RetTy = F.getReturnType();
    240   if (!RetTy->isVoidTy()) {
    241     ReturnNumRegs = DL->getTypeStoreSize(RetTy) / 4;
    242 
    243     if (ReturnNumRegs >= MaxNumRetRegs)
    244       return false;
    245 
    246     ReturnTypes.push_back(RetTy);
    247   }
    248 
    249   SmallVector<Argument *, 4> OutArgs;
    250   for (Argument &Arg : F.args()) {
    251     if (isOutArgumentCandidate(Arg)) {
    252       LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
    253                         << " in function " << F.getName() << '\n');
    254       OutArgs.push_back(&Arg);
    255     }
    256   }
    257 
    258   if (OutArgs.empty())
    259     return false;
    260 
    261   using ReplacementVec = SmallVector<std::pair<Argument *, Value *>, 4>;
    262 
    263   DenseMap<ReturnInst *, ReplacementVec> Replacements;
    264 
    265   SmallVector<ReturnInst *, 4> Returns;
    266   for (BasicBlock &BB : F) {
    267     if (ReturnInst *RI = dyn_cast<ReturnInst>(&BB.back()))
    268       Returns.push_back(RI);
    269   }
    270 
    271   if (Returns.empty())
    272     return false;
    273 
    274   bool Changing;
    275 
    276   do {
    277     Changing = false;
    278 
    279     // Keep retrying if we are able to successfully eliminate an argument. This
    280     // helps with cases with multiple arguments which may alias, such as in a
    281     // sincos implemntation. If we have 2 stores to arguments, on the first
    282     // attempt the MDA query will succeed for the second store but not the
    283     // first. On the second iteration we've removed that out clobbering argument
    284     // (by effectively moving it into another function) and will find the second
    285     // argument is OK to move.
    286     for (Argument *OutArg : OutArgs) {
    287       bool ThisReplaceable = true;
    288       SmallVector<std::pair<ReturnInst *, StoreInst *>, 4> ReplaceableStores;
    289 
    290       Type *ArgTy = OutArg->getType()->getPointerElementType();
    291 
    292       // Skip this argument if converting it will push us over the register
    293       // count to return limit.
    294 
    295       // TODO: This is an approximation. When legalized this could be more. We
    296       // can ask TLI for exactly how many.
    297       unsigned ArgNumRegs = DL->getTypeStoreSize(ArgTy) / 4;
    298       if (ArgNumRegs + ReturnNumRegs > MaxNumRetRegs)
    299         continue;
    300 
    301       // An argument is convertible only if all exit blocks are able to replace
    302       // it.
    303       for (ReturnInst *RI : Returns) {
    304         BasicBlock *BB = RI->getParent();
    305 
    306         MemDepResult Q = MDA->getPointerDependencyFrom(MemoryLocation(OutArg),
    307                                                        true, BB->end(), BB, RI);
    308         StoreInst *SI = nullptr;
    309         if (Q.isDef())
    310           SI = dyn_cast<StoreInst>(Q.getInst());
    311 
    312         if (SI) {
    313           LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI << '\n');
    314           ReplaceableStores.emplace_back(RI, SI);
    315         } else {
    316           ThisReplaceable = false;
    317           break;
    318         }
    319       }
    320 
    321       if (!ThisReplaceable)
    322         continue; // Try the next argument candidate.
    323 
    324       for (std::pair<ReturnInst *, StoreInst *> Store : ReplaceableStores) {
    325         Value *ReplVal = Store.second->getValueOperand();
    326 
    327         auto &ValVec = Replacements[Store.first];
    328         if (llvm::find_if(ValVec,
    329               [OutArg](const std::pair<Argument *, Value *> &Entry) {
    330                  return Entry.first == OutArg;}) != ValVec.end()) {
    331           LLVM_DEBUG(dbgs()
    332                      << "Saw multiple out arg stores" << *OutArg << '\n');
    333           // It is possible to see stores to the same argument multiple times,
    334           // but we expect these would have been optimized out already.
    335           ThisReplaceable = false;
    336           break;
    337         }
    338 
    339         ValVec.emplace_back(OutArg, ReplVal);
    340         Store.second->eraseFromParent();
    341       }
    342 
    343       if (ThisReplaceable) {
    344         ReturnTypes.push_back(ArgTy);
    345         OutArgIndexes.insert(OutArg->getArgNo());
    346         ++NumOutArgumentsReplaced;
    347         Changing = true;
    348       }
    349     }
    350   } while (Changing);
    351 
    352   if (Replacements.empty())
    353     return false;
    354 
    355   LLVMContext &Ctx = F.getParent()->getContext();
    356   StructType *NewRetTy = StructType::create(Ctx, ReturnTypes, F.getName());
    357 
    358   FunctionType *NewFuncTy = FunctionType::get(NewRetTy,
    359                                               F.getFunctionType()->params(),
    360                                               F.isVarArg());
    361 
    362   LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy << '\n');
    363 
    364   Function *NewFunc = Function::Create(NewFuncTy, Function::PrivateLinkage,
    365                                        F.getName() + ".body");
    366   F.getParent()->getFunctionList().insert(F.getIterator(), NewFunc);
    367   NewFunc->copyAttributesFrom(&F);
    368   NewFunc->setComdat(F.getComdat());
    369 
    370   // We want to preserve the function and param attributes, but need to strip
    371   // off any return attributes, e.g. zeroext doesn't make sense with a struct.
    372   NewFunc->stealArgumentListFrom(F);
    373 
    374   AttrBuilder RetAttrs;
    375   RetAttrs.addAttribute(Attribute::SExt);
    376   RetAttrs.addAttribute(Attribute::ZExt);
    377   RetAttrs.addAttribute(Attribute::NoAlias);
    378   NewFunc->removeAttributes(AttributeList::ReturnIndex, RetAttrs);
    379   // TODO: How to preserve metadata?
    380 
    381   // Move the body of the function into the new rewritten function, and replace
    382   // this function with a stub.
    383   NewFunc->getBasicBlockList().splice(NewFunc->begin(), F.getBasicBlockList());
    384 
    385   for (std::pair<ReturnInst *, ReplacementVec> &Replacement : Replacements) {
    386     ReturnInst *RI = Replacement.first;
    387     IRBuilder<> B(RI);
    388     B.SetCurrentDebugLocation(RI->getDebugLoc());
    389 
    390     int RetIdx = 0;
    391     Value *NewRetVal = UndefValue::get(NewRetTy);
    392 
    393     Value *RetVal = RI->getReturnValue();
    394     if (RetVal)
    395       NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);
    396 
    397     for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second) {
    398       Argument *Arg = ReturnPoint.first;
    399       Value *Val = ReturnPoint.second;
    400       Type *EltTy = Arg->getType()->getPointerElementType();
    401       if (Val->getType() != EltTy) {
    402         Type *EffectiveEltTy = EltTy;
    403         if (StructType *CT = dyn_cast<StructType>(EltTy)) {
    404           assert(CT->getNumContainedTypes() == 1);
    405           EffectiveEltTy = CT->getContainedType(0);
    406         }
    407 
    408         if (DL->getTypeSizeInBits(EffectiveEltTy) !=
    409             DL->getTypeSizeInBits(Val->getType())) {
    410           assert(isVec3ToVec4Shuffle(EffectiveEltTy, Val->getType()));
    411           Val = B.CreateShuffleVector(Val, UndefValue::get(Val->getType()),
    412                                       { 0, 1, 2 });
    413         }
    414 
    415         Val = B.CreateBitCast(Val, EffectiveEltTy);
    416 
    417         // Re-create single element composite.
    418         if (EltTy != EffectiveEltTy)
    419           Val = B.CreateInsertValue(UndefValue::get(EltTy), Val, 0);
    420       }
    421 
    422       NewRetVal = B.CreateInsertValue(NewRetVal, Val, RetIdx++);
    423     }
    424 
    425     if (RetVal)
    426       RI->setOperand(0, NewRetVal);
    427     else {
    428       B.CreateRet(NewRetVal);
    429       RI->eraseFromParent();
    430     }
    431   }
    432 
    433   SmallVector<Value *, 16> StubCallArgs;
    434   for (Argument &Arg : F.args()) {
    435     if (OutArgIndexes.count(Arg.getArgNo())) {
    436       // It's easier to preserve the type of the argument list. We rely on
    437       // DeadArgumentElimination to take care of these.
    438       StubCallArgs.push_back(UndefValue::get(Arg.getType()));
    439     } else {
    440       StubCallArgs.push_back(&Arg);
    441     }
    442   }
    443 
    444   BasicBlock *StubBB = BasicBlock::Create(Ctx, "", &F);
    445   IRBuilder<> B(StubBB);
    446   CallInst *StubCall = B.CreateCall(NewFunc, StubCallArgs);
    447 
    448   int RetIdx = RetTy->isVoidTy() ? 0 : 1;
    449   for (Argument &Arg : F.args()) {
    450     if (!OutArgIndexes.count(Arg.getArgNo()))
    451       continue;
    452 
    453     PointerType *ArgType = cast<PointerType>(Arg.getType());
    454 
    455     auto *EltTy = ArgType->getElementType();
    456     unsigned Align = Arg.getParamAlignment();
    457     if (Align == 0)
    458       Align = DL->getABITypeAlignment(EltTy);
    459 
    460     Value *Val = B.CreateExtractValue(StubCall, RetIdx++);
    461     Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace());
    462 
    463     // We can peek through bitcasts, so the type may not match.
    464     Value *PtrVal = B.CreateBitCast(&Arg, PtrTy);
    465 
    466     B.CreateAlignedStore(Val, PtrVal, Align);
    467   }
    468 
    469   if (!RetTy->isVoidTy()) {
    470     B.CreateRet(B.CreateExtractValue(StubCall, 0));
    471   } else {
    472     B.CreateRetVoid();
    473   }
    474 
    475   // The function is now a stub we want to inline.
    476   F.addFnAttr(Attribute::AlwaysInline);
    477 
    478   ++NumOutArgumentFunctionsReplaced;
    479   return true;
    480 }
    481 
    482 FunctionPass *llvm::createAMDGPURewriteOutArgumentsPass() {
    483   return new AMDGPURewriteOutArguments();
    484 }
    485