Home | History | Annotate | Download | only in AMDGPU
      1 //===-- AMDGPUOpenCLImageTypeLoweringPass.cpp -----------------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 /// \file
     11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
     12 /// sampler resource ID getter functions.
     13 ///
     14 /// Image attributes (size and format) are expected to be passed to the kernel
     15 /// as kernel arguments immediately following the image argument itself,
     16 /// therefore this pass adds image size and format arguments to the kernel
     17 /// functions in the module. The kernel functions with image arguments are
     18 /// re-created using the new signature. The new arguments are added to the
     19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
     20 /// Note: this pass may invalidate pointers to functions.
     21 ///
     22 /// Resource IDs of read-only images, write-only images and samplers are
     23 /// defined to be their index among the kernel arguments of the same
     24 /// type and access qualifier.
     25 //===----------------------------------------------------------------------===//
     26 
     27 #include "AMDGPU.h"
     28 #include "llvm/ADT/DenseSet.h"
     29 #include "llvm/ADT/STLExtras.h"
     30 #include "llvm/ADT/SmallVector.h"
     31 #include "llvm/Analysis/Passes.h"
     32 #include "llvm/IR/Constants.h"
     33 #include "llvm/IR/Function.h"
     34 #include "llvm/IR/Instructions.h"
     35 #include "llvm/IR/Module.h"
     36 #include "llvm/Transforms/Utils/Cloning.h"
     37 
     38 using namespace llvm;
     39 
     40 namespace {
     41 
     42 StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
     43 StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
     44 StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
     45 StringRef GetSamplerResourceIDFunc = "llvm.OpenCL.sampler.get.resource.id";
     46 
     47 StringRef ImageSizeArgMDType =   "__llvm_image_size";
     48 StringRef ImageFormatArgMDType = "__llvm_image_format";
     49 
     50 StringRef KernelsMDNodeName = "opencl.kernels";
     51 StringRef KernelArgMDNodeNames[] = {
     52   "kernel_arg_addr_space",
     53   "kernel_arg_access_qual",
     54   "kernel_arg_type",
     55   "kernel_arg_base_type",
     56   "kernel_arg_type_qual"};
     57 const unsigned NumKernelArgMDNodes = 5;
     58 
     59 typedef SmallVector<Metadata *, 8> MDVector;
     60 struct KernelArgMD {
     61   MDVector ArgVector[NumKernelArgMDNodes];
     62 };
     63 
     64 } // end anonymous namespace
     65 
     66 static inline bool
     67 IsImageType(StringRef TypeString) {
     68   return TypeString == "image2d_t" || TypeString == "image3d_t";
     69 }
     70 
     71 static inline bool
     72 IsSamplerType(StringRef TypeString) {
     73   return TypeString == "sampler_t";
     74 }
     75 
     76 static Function *
     77 GetFunctionFromMDNode(MDNode *Node) {
     78   if (!Node)
     79     return nullptr;
     80 
     81   size_t NumOps = Node->getNumOperands();
     82   if (NumOps != NumKernelArgMDNodes + 1)
     83     return nullptr;
     84 
     85   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
     86   if (!F)
     87     return nullptr;
     88 
     89   // Sanity checks.
     90   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
     91   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
     92     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
     93     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
     94       return nullptr;
     95     if (!ArgNode->getOperand(0))
     96       return nullptr;
     97 
     98     // FIXME: It should be possible to do image lowering when some metadata
     99     // args missing or not in the expected order.
    100     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
    101     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
    102       return nullptr;
    103   }
    104 
    105   return F;
    106 }
    107 
    108 static StringRef
    109 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
    110   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
    111   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
    112 }
    113 
    114 static StringRef
    115 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
    116   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
    117   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
    118 }
    119 
    120 static MDVector
    121 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
    122   MDVector Res;
    123   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
    124     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
    125     Res.push_back(Node->getOperand(OpIdx));
    126   }
    127   return Res;
    128 }
    129 
    130 static void
    131 PushArgMD(KernelArgMD &MD, const MDVector &V) {
    132   assert(V.size() == NumKernelArgMDNodes);
    133   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
    134     MD.ArgVector[i].push_back(V[i]);
    135   }
    136 }
    137 
    138 namespace {
    139 
    140 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
    141   static char ID;
    142 
    143   LLVMContext *Context;
    144   Type *Int32Type;
    145   Type *ImageSizeType;
    146   Type *ImageFormatType;
    147   SmallVector<Instruction *, 4> InstsToErase;
    148 
    149   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
    150                         Argument &ImageSizeArg,
    151                         Argument &ImageFormatArg) {
    152     bool Modified = false;
    153 
    154     for (auto &Use : ImageArg.uses()) {
    155       auto Inst = dyn_cast<CallInst>(Use.getUser());
    156       if (!Inst) {
    157         continue;
    158       }
    159 
    160       Function *F = Inst->getCalledFunction();
    161       if (!F)
    162         continue;
    163 
    164       Value *Replacement = nullptr;
    165       StringRef Name = F->getName();
    166       if (Name.startswith(GetImageResourceIDFunc)) {
    167         Replacement = ConstantInt::get(Int32Type, ResourceID);
    168       } else if (Name.startswith(GetImageSizeFunc)) {
    169         Replacement = &ImageSizeArg;
    170       } else if (Name.startswith(GetImageFormatFunc)) {
    171         Replacement = &ImageFormatArg;
    172       } else {
    173         continue;
    174       }
    175 
    176       Inst->replaceAllUsesWith(Replacement);
    177       InstsToErase.push_back(Inst);
    178       Modified = true;
    179     }
    180 
    181     return Modified;
    182   }
    183 
    184   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
    185     bool Modified = false;
    186 
    187     for (const auto &Use : SamplerArg.uses()) {
    188       auto Inst = dyn_cast<CallInst>(Use.getUser());
    189       if (!Inst) {
    190         continue;
    191       }
    192 
    193       Function *F = Inst->getCalledFunction();
    194       if (!F)
    195         continue;
    196 
    197       Value *Replacement = nullptr;
    198       StringRef Name = F->getName();
    199       if (Name == GetSamplerResourceIDFunc) {
    200         Replacement = ConstantInt::get(Int32Type, ResourceID);
    201       } else {
    202         continue;
    203       }
    204 
    205       Inst->replaceAllUsesWith(Replacement);
    206       InstsToErase.push_back(Inst);
    207       Modified = true;
    208     }
    209 
    210     return Modified;
    211   }
    212 
    213   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
    214     uint32_t NumReadOnlyImageArgs = 0;
    215     uint32_t NumWriteOnlyImageArgs = 0;
    216     uint32_t NumSamplerArgs = 0;
    217 
    218     bool Modified = false;
    219     InstsToErase.clear();
    220     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
    221       Argument &Arg = *ArgI;
    222       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
    223 
    224       // Handle image types.
    225       if (IsImageType(Type)) {
    226         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
    227         uint32_t ResourceID;
    228         if (AccessQual == "read_only") {
    229           ResourceID = NumReadOnlyImageArgs++;
    230         } else if (AccessQual == "write_only") {
    231           ResourceID = NumWriteOnlyImageArgs++;
    232         } else {
    233           llvm_unreachable("Wrong image access qualifier.");
    234         }
    235 
    236         Argument &SizeArg = *(++ArgI);
    237         Argument &FormatArg = *(++ArgI);
    238         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
    239 
    240       // Handle sampler type.
    241       } else if (IsSamplerType(Type)) {
    242         uint32_t ResourceID = NumSamplerArgs++;
    243         Modified |= replaceSamplerUses(Arg, ResourceID);
    244       }
    245     }
    246     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
    247       InstsToErase[i]->eraseFromParent();
    248     }
    249 
    250     return Modified;
    251   }
    252 
    253   std::tuple<Function *, MDNode *>
    254   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
    255     bool Modified = false;
    256 
    257     FunctionType *FT = F->getFunctionType();
    258     SmallVector<Type *, 8> ArgTypes;
    259 
    260     // Metadata operands for new MDNode.
    261     KernelArgMD NewArgMDs;
    262     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
    263 
    264     // Add implicit arguments to the signature.
    265     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
    266       ArgTypes.push_back(FT->getParamType(i));
    267       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
    268       PushArgMD(NewArgMDs, ArgMD);
    269 
    270       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
    271         continue;
    272 
    273       // Add size implicit argument.
    274       ArgTypes.push_back(ImageSizeType);
    275       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
    276       PushArgMD(NewArgMDs, ArgMD);
    277 
    278       // Add format implicit argument.
    279       ArgTypes.push_back(ImageFormatType);
    280       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
    281       PushArgMD(NewArgMDs, ArgMD);
    282 
    283       Modified = true;
    284     }
    285     if (!Modified) {
    286       return std::make_tuple(nullptr, nullptr);
    287     }
    288 
    289     // Create function with new signature and clone the old body into it.
    290     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
    291     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
    292     ValueToValueMapTy VMap;
    293     auto NewFArgIt = NewF->arg_begin();
    294     for (auto &Arg: F->args()) {
    295       auto ArgName = Arg.getName();
    296       NewFArgIt->setName(ArgName);
    297       VMap[&Arg] = &(*NewFArgIt++);
    298       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
    299         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
    300         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
    301       }
    302     }
    303     SmallVector<ReturnInst*, 8> Returns;
    304     CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
    305 
    306     // Build new MDNode.
    307     SmallVector<llvm::Metadata *, 6> KernelMDArgs;
    308     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
    309     for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
    310       KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
    311     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
    312 
    313     return std::make_tuple(NewF, NewMDNode);
    314   }
    315 
    316   bool transformKernels(Module &M) {
    317     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
    318     if (!KernelsMDNode)
    319       return false;
    320 
    321     bool Modified = false;
    322     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
    323       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
    324       Function *F = GetFunctionFromMDNode(KernelMDNode);
    325       if (!F)
    326         continue;
    327 
    328       Function *NewF;
    329       MDNode *NewMDNode;
    330       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
    331       if (NewF) {
    332         // Replace old function and metadata with new ones.
    333         F->eraseFromParent();
    334         M.getFunctionList().push_back(NewF);
    335         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
    336                               NewF->getAttributes());
    337         KernelsMDNode->setOperand(i, NewMDNode);
    338 
    339         F = NewF;
    340         KernelMDNode = NewMDNode;
    341         Modified = true;
    342       }
    343 
    344       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
    345     }
    346 
    347     return Modified;
    348   }
    349 
    350  public:
    351   AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
    352 
    353   bool runOnModule(Module &M) override {
    354     Context = &M.getContext();
    355     Int32Type = Type::getInt32Ty(M.getContext());
    356     ImageSizeType = ArrayType::get(Int32Type, 3);
    357     ImageFormatType = ArrayType::get(Int32Type, 2);
    358 
    359     return transformKernels(M);
    360   }
    361 
    362   const char *getPassName() const override {
    363     return "AMDGPU OpenCL Image Type Pass";
    364   }
    365 };
    366 
    367 char AMDGPUOpenCLImageTypeLoweringPass::ID = 0;
    368 
    369 } // end anonymous namespace
    370 
    371 ModulePass *llvm::createAMDGPUOpenCLImageTypeLoweringPass() {
    372   return new AMDGPUOpenCLImageTypeLoweringPass();
    373 }
    374