Home | History | Annotate | Download | only in AMDGPU
      1 //===-- AMDGPUOpenCLImageTypeLoweringPass.cpp -----------------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 /// \file
     11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
     12 /// sampler resource ID getter functions.
     13 ///
     14 /// Image attributes (size and format) are expected to be passed to the kernel
     15 /// as kernel arguments immediately following the image argument itself,
     16 /// therefore this pass adds image size and format arguments to the kernel
     17 /// functions in the module. The kernel functions with image arguments are
     18 /// re-created using the new signature. The new arguments are added to the
     19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
     20 /// Note: this pass may invalidate pointers to functions.
     21 ///
     22 /// Resource IDs of read-only images, write-only images and samplers are
     23 /// defined to be their index among the kernel arguments of the same
     24 /// type and access qualifier.
     25 //===----------------------------------------------------------------------===//
     26 
     27 #include "AMDGPU.h"
     28 #include "llvm/ADT/STLExtras.h"
     29 #include "llvm/ADT/SmallVector.h"
     30 #include "llvm/Analysis/Passes.h"
     31 #include "llvm/IR/Constants.h"
     32 #include "llvm/IR/Function.h"
     33 #include "llvm/IR/Instructions.h"
     34 #include "llvm/IR/Module.h"
     35 #include "llvm/Transforms/Utils/Cloning.h"
     36 
     37 using namespace llvm;
     38 
     39 namespace {
     40 
     41 StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
     42 StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
     43 StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
     44 StringRef GetSamplerResourceIDFunc = "llvm.OpenCL.sampler.get.resource.id";
     45 
     46 StringRef ImageSizeArgMDType =   "__llvm_image_size";
     47 StringRef ImageFormatArgMDType = "__llvm_image_format";
     48 
     49 StringRef KernelsMDNodeName = "opencl.kernels";
     50 StringRef KernelArgMDNodeNames[] = {
     51   "kernel_arg_addr_space",
     52   "kernel_arg_access_qual",
     53   "kernel_arg_type",
     54   "kernel_arg_base_type",
     55   "kernel_arg_type_qual"};
     56 const unsigned NumKernelArgMDNodes = 5;
     57 
     58 typedef SmallVector<Metadata *, 8> MDVector;
     59 struct KernelArgMD {
     60   MDVector ArgVector[NumKernelArgMDNodes];
     61 };
     62 
     63 } // end anonymous namespace
     64 
     65 static inline bool
     66 IsImageType(StringRef TypeString) {
     67   return TypeString == "image2d_t" || TypeString == "image3d_t";
     68 }
     69 
     70 static inline bool
     71 IsSamplerType(StringRef TypeString) {
     72   return TypeString == "sampler_t";
     73 }
     74 
     75 static Function *
     76 GetFunctionFromMDNode(MDNode *Node) {
     77   if (!Node)
     78     return nullptr;
     79 
     80   size_t NumOps = Node->getNumOperands();
     81   if (NumOps != NumKernelArgMDNodes + 1)
     82     return nullptr;
     83 
     84   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
     85   if (!F)
     86     return nullptr;
     87 
     88   // Sanity checks.
     89   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
     90   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
     91     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
     92     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
     93       return nullptr;
     94     if (!ArgNode->getOperand(0))
     95       return nullptr;
     96 
     97     // FIXME: It should be possible to do image lowering when some metadata
     98     // args missing or not in the expected order.
     99     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
    100     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
    101       return nullptr;
    102   }
    103 
    104   return F;
    105 }
    106 
    107 static StringRef
    108 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
    109   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
    110   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
    111 }
    112 
    113 static StringRef
    114 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
    115   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
    116   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
    117 }
    118 
    119 static MDVector
    120 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
    121   MDVector Res;
    122   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
    123     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
    124     Res.push_back(Node->getOperand(OpIdx));
    125   }
    126   return Res;
    127 }
    128 
    129 static void
    130 PushArgMD(KernelArgMD &MD, const MDVector &V) {
    131   assert(V.size() == NumKernelArgMDNodes);
    132   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
    133     MD.ArgVector[i].push_back(V[i]);
    134   }
    135 }
    136 
    137 namespace {
    138 
    139 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
    140   static char ID;
    141 
    142   LLVMContext *Context;
    143   Type *Int32Type;
    144   Type *ImageSizeType;
    145   Type *ImageFormatType;
    146   SmallVector<Instruction *, 4> InstsToErase;
    147 
    148   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
    149                         Argument &ImageSizeArg,
    150                         Argument &ImageFormatArg) {
    151     bool Modified = false;
    152 
    153     for (auto &Use : ImageArg.uses()) {
    154       auto Inst = dyn_cast<CallInst>(Use.getUser());
    155       if (!Inst) {
    156         continue;
    157       }
    158 
    159       Function *F = Inst->getCalledFunction();
    160       if (!F)
    161         continue;
    162 
    163       Value *Replacement = nullptr;
    164       StringRef Name = F->getName();
    165       if (Name.startswith(GetImageResourceIDFunc)) {
    166         Replacement = ConstantInt::get(Int32Type, ResourceID);
    167       } else if (Name.startswith(GetImageSizeFunc)) {
    168         Replacement = &ImageSizeArg;
    169       } else if (Name.startswith(GetImageFormatFunc)) {
    170         Replacement = &ImageFormatArg;
    171       } else {
    172         continue;
    173       }
    174 
    175       Inst->replaceAllUsesWith(Replacement);
    176       InstsToErase.push_back(Inst);
    177       Modified = true;
    178     }
    179 
    180     return Modified;
    181   }
    182 
    183   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
    184     bool Modified = false;
    185 
    186     for (const auto &Use : SamplerArg.uses()) {
    187       auto Inst = dyn_cast<CallInst>(Use.getUser());
    188       if (!Inst) {
    189         continue;
    190       }
    191 
    192       Function *F = Inst->getCalledFunction();
    193       if (!F)
    194         continue;
    195 
    196       Value *Replacement = nullptr;
    197       StringRef Name = F->getName();
    198       if (Name == GetSamplerResourceIDFunc) {
    199         Replacement = ConstantInt::get(Int32Type, ResourceID);
    200       } else {
    201         continue;
    202       }
    203 
    204       Inst->replaceAllUsesWith(Replacement);
    205       InstsToErase.push_back(Inst);
    206       Modified = true;
    207     }
    208 
    209     return Modified;
    210   }
    211 
    212   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
    213     uint32_t NumReadOnlyImageArgs = 0;
    214     uint32_t NumWriteOnlyImageArgs = 0;
    215     uint32_t NumSamplerArgs = 0;
    216 
    217     bool Modified = false;
    218     InstsToErase.clear();
    219     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
    220       Argument &Arg = *ArgI;
    221       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
    222 
    223       // Handle image types.
    224       if (IsImageType(Type)) {
    225         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
    226         uint32_t ResourceID;
    227         if (AccessQual == "read_only") {
    228           ResourceID = NumReadOnlyImageArgs++;
    229         } else if (AccessQual == "write_only") {
    230           ResourceID = NumWriteOnlyImageArgs++;
    231         } else {
    232           llvm_unreachable("Wrong image access qualifier.");
    233         }
    234 
    235         Argument &SizeArg = *(++ArgI);
    236         Argument &FormatArg = *(++ArgI);
    237         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
    238 
    239       // Handle sampler type.
    240       } else if (IsSamplerType(Type)) {
    241         uint32_t ResourceID = NumSamplerArgs++;
    242         Modified |= replaceSamplerUses(Arg, ResourceID);
    243       }
    244     }
    245     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
    246       InstsToErase[i]->eraseFromParent();
    247     }
    248 
    249     return Modified;
    250   }
    251 
    252   std::tuple<Function *, MDNode *>
    253   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
    254     bool Modified = false;
    255 
    256     FunctionType *FT = F->getFunctionType();
    257     SmallVector<Type *, 8> ArgTypes;
    258 
    259     // Metadata operands for new MDNode.
    260     KernelArgMD NewArgMDs;
    261     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
    262 
    263     // Add implicit arguments to the signature.
    264     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
    265       ArgTypes.push_back(FT->getParamType(i));
    266       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
    267       PushArgMD(NewArgMDs, ArgMD);
    268 
    269       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
    270         continue;
    271 
    272       // Add size implicit argument.
    273       ArgTypes.push_back(ImageSizeType);
    274       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
    275       PushArgMD(NewArgMDs, ArgMD);
    276 
    277       // Add format implicit argument.
    278       ArgTypes.push_back(ImageFormatType);
    279       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
    280       PushArgMD(NewArgMDs, ArgMD);
    281 
    282       Modified = true;
    283     }
    284     if (!Modified) {
    285       return std::make_tuple(nullptr, nullptr);
    286     }
    287 
    288     // Create function with new signature and clone the old body into it.
    289     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
    290     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
    291     ValueToValueMapTy VMap;
    292     auto NewFArgIt = NewF->arg_begin();
    293     for (auto &Arg: F->args()) {
    294       auto ArgName = Arg.getName();
    295       NewFArgIt->setName(ArgName);
    296       VMap[&Arg] = &(*NewFArgIt++);
    297       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
    298         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
    299         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
    300       }
    301     }
    302     SmallVector<ReturnInst*, 8> Returns;
    303     CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
    304 
    305     // Build new MDNode.
    306     SmallVector<llvm::Metadata *, 6> KernelMDArgs;
    307     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
    308     for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
    309       KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
    310     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
    311 
    312     return std::make_tuple(NewF, NewMDNode);
    313   }
    314 
    315   bool transformKernels(Module &M) {
    316     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
    317     if (!KernelsMDNode)
    318       return false;
    319 
    320     bool Modified = false;
    321     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
    322       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
    323       Function *F = GetFunctionFromMDNode(KernelMDNode);
    324       if (!F)
    325         continue;
    326 
    327       Function *NewF;
    328       MDNode *NewMDNode;
    329       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
    330       if (NewF) {
    331         // Replace old function and metadata with new ones.
    332         F->eraseFromParent();
    333         M.getFunctionList().push_back(NewF);
    334         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
    335                               NewF->getAttributes());
    336         KernelsMDNode->setOperand(i, NewMDNode);
    337 
    338         F = NewF;
    339         KernelMDNode = NewMDNode;
    340         Modified = true;
    341       }
    342 
    343       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
    344     }
    345 
    346     return Modified;
    347   }
    348 
    349  public:
    350   AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
    351 
    352   bool runOnModule(Module &M) override {
    353     Context = &M.getContext();
    354     Int32Type = Type::getInt32Ty(M.getContext());
    355     ImageSizeType = ArrayType::get(Int32Type, 3);
    356     ImageFormatType = ArrayType::get(Int32Type, 2);
    357 
    358     return transformKernels(M);
    359   }
    360 
    361   const char *getPassName() const override {
    362     return "AMDGPU OpenCL Image Type Pass";
    363   }
    364 };
    365 
    366 char AMDGPUOpenCLImageTypeLoweringPass::ID = 0;
    367 
    368 } // end anonymous namespace
    369 
    370 ModulePass *llvm::createAMDGPUOpenCLImageTypeLoweringPass() {
    371   return new AMDGPUOpenCLImageTypeLoweringPass();
    372 }
    373