1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 /// \file 11 /// This pass resolves calls to OpenCL image attribute, image resource ID and 12 /// sampler resource ID getter functions. 13 /// 14 /// Image attributes (size and format) are expected to be passed to the kernel 15 /// as kernel arguments immediately following the image argument itself, 16 /// therefore this pass adds image size and format arguments to the kernel 17 /// functions in the module. The kernel functions with image arguments are 18 /// re-created using the new signature. The new arguments are added to the 19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format". 20 /// Note: this pass may invalidate pointers to functions. 21 /// 22 /// Resource IDs of read-only images, write-only images and samplers are 23 /// defined to be their index among the kernel arguments of the same 24 /// type and access qualifier. 25 // 26 //===----------------------------------------------------------------------===// 27 28 #include "AMDGPU.h" 29 #include "llvm/ADT/SmallVector.h" 30 #include "llvm/ADT/StringRef.h" 31 #include "llvm/ADT/Twine.h" 32 #include "llvm/IR/Argument.h" 33 #include "llvm/IR/DerivedTypes.h" 34 #include "llvm/IR/Constants.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/Instruction.h" 37 #include "llvm/IR/Instructions.h" 38 #include "llvm/IR/Metadata.h" 39 #include "llvm/IR/Module.h" 40 #include "llvm/IR/Type.h" 41 #include "llvm/IR/Use.h" 42 #include "llvm/IR/User.h" 43 #include "llvm/Pass.h" 44 #include "llvm/Support/Casting.h" 45 #include "llvm/Support/ErrorHandling.h" 46 #include "llvm/Transforms/Utils/Cloning.h" 47 #include "llvm/Transforms/Utils/ValueMapper.h" 48 #include <cassert> 49 #include <cstddef> 50 #include <cstdint> 51 #include <tuple> 52 53 using namespace llvm; 54 55 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size"; 56 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format"; 57 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id"; 58 static StringRef GetSamplerResourceIDFunc = 59 "llvm.OpenCL.sampler.get.resource.id"; 60 61 static StringRef ImageSizeArgMDType = "__llvm_image_size"; 62 static StringRef ImageFormatArgMDType = "__llvm_image_format"; 63 64 static StringRef KernelsMDNodeName = "opencl.kernels"; 65 static StringRef KernelArgMDNodeNames[] = { 66 "kernel_arg_addr_space", 67 "kernel_arg_access_qual", 68 "kernel_arg_type", 69 "kernel_arg_base_type", 70 "kernel_arg_type_qual"}; 71 static const unsigned NumKernelArgMDNodes = 5; 72 73 namespace { 74 75 using MDVector = SmallVector<Metadata *, 8>; 76 struct KernelArgMD { 77 MDVector ArgVector[NumKernelArgMDNodes]; 78 }; 79 80 } // end anonymous namespace 81 82 static inline bool 83 IsImageType(StringRef TypeString) { 84 return TypeString == "image2d_t" || TypeString == "image3d_t"; 85 } 86 87 static inline bool 88 IsSamplerType(StringRef TypeString) { 89 return TypeString == "sampler_t"; 90 } 91 92 static Function * 93 GetFunctionFromMDNode(MDNode *Node) { 94 if (!Node) 95 return nullptr; 96 97 size_t NumOps = Node->getNumOperands(); 98 if (NumOps != NumKernelArgMDNodes + 1) 99 return nullptr; 100 101 auto F = mdconst::dyn_extract<Function>(Node->getOperand(0)); 102 if (!F) 103 return nullptr; 104 105 // Sanity checks. 106 size_t ExpectNumArgNodeOps = F->arg_size() + 1; 107 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) { 108 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1)); 109 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps) 110 return nullptr; 111 if (!ArgNode->getOperand(0)) 112 return nullptr; 113 114 // FIXME: It should be possible to do image lowering when some metadata 115 // args missing or not in the expected order. 116 MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0)); 117 if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i]) 118 return nullptr; 119 } 120 121 return F; 122 } 123 124 static StringRef 125 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 126 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2)); 127 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString(); 128 } 129 130 static StringRef 131 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) { 132 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3)); 133 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString(); 134 } 135 136 static MDVector 137 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) { 138 MDVector Res; 139 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 140 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1)); 141 Res.push_back(Node->getOperand(OpIdx)); 142 } 143 return Res; 144 } 145 146 static void 147 PushArgMD(KernelArgMD &MD, const MDVector &V) { 148 assert(V.size() == NumKernelArgMDNodes); 149 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) { 150 MD.ArgVector[i].push_back(V[i]); 151 } 152 } 153 154 namespace { 155 156 class R600OpenCLImageTypeLoweringPass : public ModulePass { 157 static char ID; 158 159 LLVMContext *Context; 160 Type *Int32Type; 161 Type *ImageSizeType; 162 Type *ImageFormatType; 163 SmallVector<Instruction *, 4> InstsToErase; 164 165 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID, 166 Argument &ImageSizeArg, 167 Argument &ImageFormatArg) { 168 bool Modified = false; 169 170 for (auto &Use : ImageArg.uses()) { 171 auto Inst = dyn_cast<CallInst>(Use.getUser()); 172 if (!Inst) { 173 continue; 174 } 175 176 Function *F = Inst->getCalledFunction(); 177 if (!F) 178 continue; 179 180 Value *Replacement = nullptr; 181 StringRef Name = F->getName(); 182 if (Name.startswith(GetImageResourceIDFunc)) { 183 Replacement = ConstantInt::get(Int32Type, ResourceID); 184 } else if (Name.startswith(GetImageSizeFunc)) { 185 Replacement = &ImageSizeArg; 186 } else if (Name.startswith(GetImageFormatFunc)) { 187 Replacement = &ImageFormatArg; 188 } else { 189 continue; 190 } 191 192 Inst->replaceAllUsesWith(Replacement); 193 InstsToErase.push_back(Inst); 194 Modified = true; 195 } 196 197 return Modified; 198 } 199 200 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) { 201 bool Modified = false; 202 203 for (const auto &Use : SamplerArg.uses()) { 204 auto Inst = dyn_cast<CallInst>(Use.getUser()); 205 if (!Inst) { 206 continue; 207 } 208 209 Function *F = Inst->getCalledFunction(); 210 if (!F) 211 continue; 212 213 Value *Replacement = nullptr; 214 StringRef Name = F->getName(); 215 if (Name == GetSamplerResourceIDFunc) { 216 Replacement = ConstantInt::get(Int32Type, ResourceID); 217 } else { 218 continue; 219 } 220 221 Inst->replaceAllUsesWith(Replacement); 222 InstsToErase.push_back(Inst); 223 Modified = true; 224 } 225 226 return Modified; 227 } 228 229 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) { 230 uint32_t NumReadOnlyImageArgs = 0; 231 uint32_t NumWriteOnlyImageArgs = 0; 232 uint32_t NumSamplerArgs = 0; 233 234 bool Modified = false; 235 InstsToErase.clear(); 236 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) { 237 Argument &Arg = *ArgI; 238 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo()); 239 240 // Handle image types. 241 if (IsImageType(Type)) { 242 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo()); 243 uint32_t ResourceID; 244 if (AccessQual == "read_only") { 245 ResourceID = NumReadOnlyImageArgs++; 246 } else if (AccessQual == "write_only") { 247 ResourceID = NumWriteOnlyImageArgs++; 248 } else { 249 llvm_unreachable("Wrong image access qualifier."); 250 } 251 252 Argument &SizeArg = *(++ArgI); 253 Argument &FormatArg = *(++ArgI); 254 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg); 255 256 // Handle sampler type. 257 } else if (IsSamplerType(Type)) { 258 uint32_t ResourceID = NumSamplerArgs++; 259 Modified |= replaceSamplerUses(Arg, ResourceID); 260 } 261 } 262 for (unsigned i = 0; i < InstsToErase.size(); ++i) { 263 InstsToErase[i]->eraseFromParent(); 264 } 265 266 return Modified; 267 } 268 269 std::tuple<Function *, MDNode *> 270 addImplicitArgs(Function *F, MDNode *KernelMDNode) { 271 bool Modified = false; 272 273 FunctionType *FT = F->getFunctionType(); 274 SmallVector<Type *, 8> ArgTypes; 275 276 // Metadata operands for new MDNode. 277 KernelArgMD NewArgMDs; 278 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0)); 279 280 // Add implicit arguments to the signature. 281 for (unsigned i = 0; i < FT->getNumParams(); ++i) { 282 ArgTypes.push_back(FT->getParamType(i)); 283 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1); 284 PushArgMD(NewArgMDs, ArgMD); 285 286 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i))) 287 continue; 288 289 // Add size implicit argument. 290 ArgTypes.push_back(ImageSizeType); 291 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType); 292 PushArgMD(NewArgMDs, ArgMD); 293 294 // Add format implicit argument. 295 ArgTypes.push_back(ImageFormatType); 296 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType); 297 PushArgMD(NewArgMDs, ArgMD); 298 299 Modified = true; 300 } 301 if (!Modified) { 302 return std::make_tuple(nullptr, nullptr); 303 } 304 305 // Create function with new signature and clone the old body into it. 306 auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false); 307 auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName()); 308 ValueToValueMapTy VMap; 309 auto NewFArgIt = NewF->arg_begin(); 310 for (auto &Arg: F->args()) { 311 auto ArgName = Arg.getName(); 312 NewFArgIt->setName(ArgName); 313 VMap[&Arg] = &(*NewFArgIt++); 314 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) { 315 (NewFArgIt++)->setName(Twine("__size_") + ArgName); 316 (NewFArgIt++)->setName(Twine("__format_") + ArgName); 317 } 318 } 319 SmallVector<ReturnInst*, 8> Returns; 320 CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns); 321 322 // Build new MDNode. 323 SmallVector<Metadata *, 6> KernelMDArgs; 324 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF)); 325 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) 326 KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i])); 327 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs); 328 329 return std::make_tuple(NewF, NewMDNode); 330 } 331 332 bool transformKernels(Module &M) { 333 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName); 334 if (!KernelsMDNode) 335 return false; 336 337 bool Modified = false; 338 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) { 339 MDNode *KernelMDNode = KernelsMDNode->getOperand(i); 340 Function *F = GetFunctionFromMDNode(KernelMDNode); 341 if (!F) 342 continue; 343 344 Function *NewF; 345 MDNode *NewMDNode; 346 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode); 347 if (NewF) { 348 // Replace old function and metadata with new ones. 349 F->eraseFromParent(); 350 M.getFunctionList().push_back(NewF); 351 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(), 352 NewF->getAttributes()); 353 KernelsMDNode->setOperand(i, NewMDNode); 354 355 F = NewF; 356 KernelMDNode = NewMDNode; 357 Modified = true; 358 } 359 360 Modified |= replaceImageAndSamplerUses(F, KernelMDNode); 361 } 362 363 return Modified; 364 } 365 366 public: 367 R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {} 368 369 bool runOnModule(Module &M) override { 370 Context = &M.getContext(); 371 Int32Type = Type::getInt32Ty(M.getContext()); 372 ImageSizeType = ArrayType::get(Int32Type, 3); 373 ImageFormatType = ArrayType::get(Int32Type, 2); 374 375 return transformKernels(M); 376 } 377 378 StringRef getPassName() const override { 379 return "R600 OpenCL Image Type Pass"; 380 } 381 }; 382 383 } // end anonymous namespace 384 385 char R600OpenCLImageTypeLoweringPass::ID = 0; 386 387 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() { 388 return new R600OpenCLImageTypeLoweringPass(); 389 } 390