Home | History | Annotate | Download | only in AMDGPU
      1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 /// \file
     11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
     12 /// sampler resource ID getter functions.
     13 ///
     14 /// Image attributes (size and format) are expected to be passed to the kernel
     15 /// as kernel arguments immediately following the image argument itself,
     16 /// therefore this pass adds image size and format arguments to the kernel
     17 /// functions in the module. The kernel functions with image arguments are
     18 /// re-created using the new signature. The new arguments are added to the
     19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
     20 /// Note: this pass may invalidate pointers to functions.
     21 ///
     22 /// Resource IDs of read-only images, write-only images and samplers are
     23 /// defined to be their index among the kernel arguments of the same
     24 /// type and access qualifier.
     25 //
     26 //===----------------------------------------------------------------------===//
     27 
     28 #include "AMDGPU.h"
     29 #include "llvm/ADT/SmallVector.h"
     30 #include "llvm/ADT/StringRef.h"
     31 #include "llvm/ADT/Twine.h"
     32 #include "llvm/IR/Argument.h"
     33 #include "llvm/IR/DerivedTypes.h"
     34 #include "llvm/IR/Constants.h"
     35 #include "llvm/IR/Function.h"
     36 #include "llvm/IR/Instruction.h"
     37 #include "llvm/IR/Instructions.h"
     38 #include "llvm/IR/Metadata.h"
     39 #include "llvm/IR/Module.h"
     40 #include "llvm/IR/Type.h"
     41 #include "llvm/IR/Use.h"
     42 #include "llvm/IR/User.h"
     43 #include "llvm/Pass.h"
     44 #include "llvm/Support/Casting.h"
     45 #include "llvm/Support/ErrorHandling.h"
     46 #include "llvm/Transforms/Utils/Cloning.h"
     47 #include "llvm/Transforms/Utils/ValueMapper.h"
     48 #include <cassert>
     49 #include <cstddef>
     50 #include <cstdint>
     51 #include <tuple>
     52 
     53 using namespace llvm;
     54 
     55 static StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
     56 static StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
     57 static StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
     58 static StringRef GetSamplerResourceIDFunc =
     59     "llvm.OpenCL.sampler.get.resource.id";
     60 
     61 static StringRef ImageSizeArgMDType =   "__llvm_image_size";
     62 static StringRef ImageFormatArgMDType = "__llvm_image_format";
     63 
     64 static StringRef KernelsMDNodeName = "opencl.kernels";
     65 static StringRef KernelArgMDNodeNames[] = {
     66   "kernel_arg_addr_space",
     67   "kernel_arg_access_qual",
     68   "kernel_arg_type",
     69   "kernel_arg_base_type",
     70   "kernel_arg_type_qual"};
     71 static const unsigned NumKernelArgMDNodes = 5;
     72 
     73 namespace {
     74 
     75 using MDVector = SmallVector<Metadata *, 8>;
     76 struct KernelArgMD {
     77   MDVector ArgVector[NumKernelArgMDNodes];
     78 };
     79 
     80 } // end anonymous namespace
     81 
     82 static inline bool
     83 IsImageType(StringRef TypeString) {
     84   return TypeString == "image2d_t" || TypeString == "image3d_t";
     85 }
     86 
     87 static inline bool
     88 IsSamplerType(StringRef TypeString) {
     89   return TypeString == "sampler_t";
     90 }
     91 
     92 static Function *
     93 GetFunctionFromMDNode(MDNode *Node) {
     94   if (!Node)
     95     return nullptr;
     96 
     97   size_t NumOps = Node->getNumOperands();
     98   if (NumOps != NumKernelArgMDNodes + 1)
     99     return nullptr;
    100 
    101   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
    102   if (!F)
    103     return nullptr;
    104 
    105   // Sanity checks.
    106   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
    107   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
    108     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
    109     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
    110       return nullptr;
    111     if (!ArgNode->getOperand(0))
    112       return nullptr;
    113 
    114     // FIXME: It should be possible to do image lowering when some metadata
    115     // args missing or not in the expected order.
    116     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
    117     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
    118       return nullptr;
    119   }
    120 
    121   return F;
    122 }
    123 
    124 static StringRef
    125 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
    126   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
    127   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
    128 }
    129 
    130 static StringRef
    131 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
    132   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
    133   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
    134 }
    135 
    136 static MDVector
    137 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
    138   MDVector Res;
    139   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
    140     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
    141     Res.push_back(Node->getOperand(OpIdx));
    142   }
    143   return Res;
    144 }
    145 
    146 static void
    147 PushArgMD(KernelArgMD &MD, const MDVector &V) {
    148   assert(V.size() == NumKernelArgMDNodes);
    149   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
    150     MD.ArgVector[i].push_back(V[i]);
    151   }
    152 }
    153 
    154 namespace {
    155 
    156 class R600OpenCLImageTypeLoweringPass : public ModulePass {
    157   static char ID;
    158 
    159   LLVMContext *Context;
    160   Type *Int32Type;
    161   Type *ImageSizeType;
    162   Type *ImageFormatType;
    163   SmallVector<Instruction *, 4> InstsToErase;
    164 
    165   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
    166                         Argument &ImageSizeArg,
    167                         Argument &ImageFormatArg) {
    168     bool Modified = false;
    169 
    170     for (auto &Use : ImageArg.uses()) {
    171       auto Inst = dyn_cast<CallInst>(Use.getUser());
    172       if (!Inst) {
    173         continue;
    174       }
    175 
    176       Function *F = Inst->getCalledFunction();
    177       if (!F)
    178         continue;
    179 
    180       Value *Replacement = nullptr;
    181       StringRef Name = F->getName();
    182       if (Name.startswith(GetImageResourceIDFunc)) {
    183         Replacement = ConstantInt::get(Int32Type, ResourceID);
    184       } else if (Name.startswith(GetImageSizeFunc)) {
    185         Replacement = &ImageSizeArg;
    186       } else if (Name.startswith(GetImageFormatFunc)) {
    187         Replacement = &ImageFormatArg;
    188       } else {
    189         continue;
    190       }
    191 
    192       Inst->replaceAllUsesWith(Replacement);
    193       InstsToErase.push_back(Inst);
    194       Modified = true;
    195     }
    196 
    197     return Modified;
    198   }
    199 
    200   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
    201     bool Modified = false;
    202 
    203     for (const auto &Use : SamplerArg.uses()) {
    204       auto Inst = dyn_cast<CallInst>(Use.getUser());
    205       if (!Inst) {
    206         continue;
    207       }
    208 
    209       Function *F = Inst->getCalledFunction();
    210       if (!F)
    211         continue;
    212 
    213       Value *Replacement = nullptr;
    214       StringRef Name = F->getName();
    215       if (Name == GetSamplerResourceIDFunc) {
    216         Replacement = ConstantInt::get(Int32Type, ResourceID);
    217       } else {
    218         continue;
    219       }
    220 
    221       Inst->replaceAllUsesWith(Replacement);
    222       InstsToErase.push_back(Inst);
    223       Modified = true;
    224     }
    225 
    226     return Modified;
    227   }
    228 
    229   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
    230     uint32_t NumReadOnlyImageArgs = 0;
    231     uint32_t NumWriteOnlyImageArgs = 0;
    232     uint32_t NumSamplerArgs = 0;
    233 
    234     bool Modified = false;
    235     InstsToErase.clear();
    236     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
    237       Argument &Arg = *ArgI;
    238       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
    239 
    240       // Handle image types.
    241       if (IsImageType(Type)) {
    242         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
    243         uint32_t ResourceID;
    244         if (AccessQual == "read_only") {
    245           ResourceID = NumReadOnlyImageArgs++;
    246         } else if (AccessQual == "write_only") {
    247           ResourceID = NumWriteOnlyImageArgs++;
    248         } else {
    249           llvm_unreachable("Wrong image access qualifier.");
    250         }
    251 
    252         Argument &SizeArg = *(++ArgI);
    253         Argument &FormatArg = *(++ArgI);
    254         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
    255 
    256       // Handle sampler type.
    257       } else if (IsSamplerType(Type)) {
    258         uint32_t ResourceID = NumSamplerArgs++;
    259         Modified |= replaceSamplerUses(Arg, ResourceID);
    260       }
    261     }
    262     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
    263       InstsToErase[i]->eraseFromParent();
    264     }
    265 
    266     return Modified;
    267   }
    268 
    269   std::tuple<Function *, MDNode *>
    270   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
    271     bool Modified = false;
    272 
    273     FunctionType *FT = F->getFunctionType();
    274     SmallVector<Type *, 8> ArgTypes;
    275 
    276     // Metadata operands for new MDNode.
    277     KernelArgMD NewArgMDs;
    278     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
    279 
    280     // Add implicit arguments to the signature.
    281     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
    282       ArgTypes.push_back(FT->getParamType(i));
    283       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
    284       PushArgMD(NewArgMDs, ArgMD);
    285 
    286       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
    287         continue;
    288 
    289       // Add size implicit argument.
    290       ArgTypes.push_back(ImageSizeType);
    291       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
    292       PushArgMD(NewArgMDs, ArgMD);
    293 
    294       // Add format implicit argument.
    295       ArgTypes.push_back(ImageFormatType);
    296       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
    297       PushArgMD(NewArgMDs, ArgMD);
    298 
    299       Modified = true;
    300     }
    301     if (!Modified) {
    302       return std::make_tuple(nullptr, nullptr);
    303     }
    304 
    305     // Create function with new signature and clone the old body into it.
    306     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
    307     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
    308     ValueToValueMapTy VMap;
    309     auto NewFArgIt = NewF->arg_begin();
    310     for (auto &Arg: F->args()) {
    311       auto ArgName = Arg.getName();
    312       NewFArgIt->setName(ArgName);
    313       VMap[&Arg] = &(*NewFArgIt++);
    314       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
    315         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
    316         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
    317       }
    318     }
    319     SmallVector<ReturnInst*, 8> Returns;
    320     CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
    321 
    322     // Build new MDNode.
    323     SmallVector<Metadata *, 6> KernelMDArgs;
    324     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
    325     for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
    326       KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
    327     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
    328 
    329     return std::make_tuple(NewF, NewMDNode);
    330   }
    331 
    332   bool transformKernels(Module &M) {
    333     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
    334     if (!KernelsMDNode)
    335       return false;
    336 
    337     bool Modified = false;
    338     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
    339       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
    340       Function *F = GetFunctionFromMDNode(KernelMDNode);
    341       if (!F)
    342         continue;
    343 
    344       Function *NewF;
    345       MDNode *NewMDNode;
    346       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
    347       if (NewF) {
    348         // Replace old function and metadata with new ones.
    349         F->eraseFromParent();
    350         M.getFunctionList().push_back(NewF);
    351         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
    352                               NewF->getAttributes());
    353         KernelsMDNode->setOperand(i, NewMDNode);
    354 
    355         F = NewF;
    356         KernelMDNode = NewMDNode;
    357         Modified = true;
    358       }
    359 
    360       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
    361     }
    362 
    363     return Modified;
    364   }
    365 
    366 public:
    367   R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
    368 
    369   bool runOnModule(Module &M) override {
    370     Context = &M.getContext();
    371     Int32Type = Type::getInt32Ty(M.getContext());
    372     ImageSizeType = ArrayType::get(Int32Type, 3);
    373     ImageFormatType = ArrayType::get(Int32Type, 2);
    374 
    375     return transformKernels(M);
    376   }
    377 
    378   StringRef getPassName() const override {
    379     return "R600 OpenCL Image Type Pass";
    380   }
    381 };
    382 
    383 } // end anonymous namespace
    384 
    385 char R600OpenCLImageTypeLoweringPass::ID = 0;
    386 
    387 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
    388   return new R600OpenCLImageTypeLoweringPass();
    389 }
    390