1 //===- SPIRVLowerOCLBlocks.cpp - Lower OpenCL blocks ------------*- C++ -*-===// 2 // 3 // The LLVM/SPIR-V Translator 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved. 9 // 10 // Permission is hereby granted, free of charge, to any person obtaining a 11 // copy of this software and associated documentation files (the "Software"), 12 // to deal with the Software without restriction, including without limitation 13 // the rights to use, copy, modify, merge, publish, distribute, sublicense, 14 // and/or sell copies of the Software, and to permit persons to whom the 15 // Software is furnished to do so, subject to the following conditions: 16 // 17 // Redistributions of source code must retain the above copyright notice, 18 // this list of conditions and the following disclaimers. 19 // Redistributions in binary form must reproduce the above copyright notice, 20 // this list of conditions and the following disclaimers in the documentation 21 // and/or other materials provided with the distribution. 22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its 23 // contributors may be used to endorse or promote products derived from this 24 // Software without specific prior written permission. 25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH 31 // THE SOFTWARE. 32 // 33 //===----------------------------------------------------------------------===// 34 /// \file 35 /// 36 /// This file implements lowering of OpenCL blocks to functions. 37 /// 38 //===----------------------------------------------------------------------===// 39 40 #ifndef OCLLOWERBLOCKS_H_ 41 #define OCLLOWERBLOCKS_H_ 42 43 #include "SPIRVInternal.h" 44 #include "OCLUtil.h" 45 46 #include "llvm/ADT/DenseMap.h" 47 #include "llvm/ADT/SetVector.h" 48 #include "llvm/ADT/StringSwitch.h" 49 #include "llvm/ADT/Triple.h" 50 #include "llvm/Analysis/AliasAnalysis.h" 51 #include "llvm/Analysis/AssumptionCache.h" 52 #include "llvm/Analysis/CallGraph.h" 53 #include "llvm/IR/Verifier.h" 54 #include "llvm/Bitcode/ReaderWriter.h" 55 #include "llvm/IR/Constants.h" 56 #include "llvm/IR/DerivedTypes.h" 57 #include "llvm/IR/Function.h" 58 #include "llvm/IR/InstrTypes.h" 59 #include "llvm/IR/Instructions.h" 60 #include "llvm/IR/Module.h" 61 #include "llvm/IR/Operator.h" 62 #include "llvm/Pass.h" 63 #include "llvm/PassSupport.h" 64 #include "llvm/Support/Casting.h" 65 #include "llvm/Support/Debug.h" 66 #include "llvm/Support/raw_ostream.h" 67 #include "llvm/Support/ToolOutputFile.h" 68 #include "llvm/Transforms/Utils/Cloning.h" 69 70 #include <iostream> 71 #include <list> 72 #include <memory> 73 #include <set> 74 #include <sstream> 75 #include <vector> 76 77 #define DEBUG_TYPE "spvblocks" 78 79 using namespace llvm; 80 using namespace SPIRV; 81 using namespace OCLUtil; 82 83 namespace SPIRV{ 84 85 /// Lower SPIR2 blocks to function calls. 86 /// 87 /// SPIR2 representation of blocks: 88 /// 89 /// block = spir_block_bind(bitcast(block_func), context_len, context_align, 90 /// context) 91 /// block_func_ptr = bitcast(spir_get_block_invoke(block)) 92 /// context_ptr = spir_get_block_context(block) 93 /// ret = block_func_ptr(context_ptr, args) 94 /// 95 /// Propagates block_func to each spir_get_block_invoke through def-use chain of 96 /// spir_block_bind, so that 97 /// ret = block_func(context, args) 98 class SPIRVLowerOCLBlocks: public ModulePass { 99 public: 100 SPIRVLowerOCLBlocks():ModulePass(ID), M(nullptr){ 101 initializeSPIRVLowerOCLBlocksPass(*PassRegistry::getPassRegistry()); 102 } 103 104 virtual void getAnalysisUsage(AnalysisUsage &AU) const { 105 AU.addRequired<CallGraphWrapperPass>(); 106 //AU.addRequired<AliasAnalysis>(); 107 AU.addRequired<AssumptionCacheTracker>(); 108 } 109 110 virtual bool runOnModule(Module &Module) { 111 M = &Module; 112 lowerBlockBind(); 113 lowerGetBlockInvoke(); 114 lowerGetBlockContext(); 115 erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE)); 116 erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT)); 117 erase(M->getFunction(SPIR_INTRINSIC_BLOCK_BIND)); 118 DEBUG(dbgs() << "------- After OCLLowerBlocks ------------\n" << 119 *M << '\n'); 120 return true; 121 } 122 123 static char ID; 124 private: 125 const static int MaxIter = 1000; 126 Module *M; 127 128 bool 129 lowerBlockBind() { 130 auto F = M->getFunction(SPIR_INTRINSIC_BLOCK_BIND); 131 if (!F) 132 return false; 133 int Iter = MaxIter; 134 while(lowerBlockBind(F) && Iter > 0){ 135 Iter--; 136 DEBUG(dbgs() << "-------------- after iteration " << MaxIter - Iter << 137 " --------------\n" << *M << '\n'); 138 } 139 assert(Iter > 0 && "Too many iterations"); 140 return true; 141 } 142 143 bool 144 eraseUselessFunctions() { 145 bool changed = false; 146 for (auto I = M->begin(), E = M->end(); I != E;) { 147 Function *F = static_cast<Function*>(I++); 148 if (!GlobalValue::isInternalLinkage(F->getLinkage()) && 149 !F->isDeclaration()) 150 continue; 151 152 dumpUsers(F, "[eraseUselessFunctions] "); 153 for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) { 154 auto U = *UI++; 155 if (auto CE = dyn_cast<ConstantExpr>(U)){ 156 if (CE->use_empty()) { 157 CE->dropAllReferences(); 158 changed = true; 159 } 160 } 161 } 162 if (F->use_empty()) { 163 erase(F); 164 changed = true; 165 } 166 } 167 return changed; 168 } 169 170 void 171 lowerGetBlockInvoke() { 172 if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE)) { 173 for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) { 174 auto CI = dyn_cast<CallInst>(*UI++); 175 assert(CI && "Invalid usage of spir_get_block_invoke"); 176 lowerGetBlockInvoke(CI); 177 } 178 } 179 } 180 181 void 182 lowerGetBlockContext() { 183 if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT)) { 184 for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) { 185 auto CI = dyn_cast<CallInst>(*UI++); 186 assert(CI && "Invalid usage of spir_get_block_context"); 187 lowerGetBlockContext(CI); 188 } 189 } 190 } 191 /// Lower calls of spir_block_bind. 192 /// Return true if the Module is changed. 193 bool 194 lowerBlockBind(Function *BlockBindFunc) { 195 bool changed = false; 196 for (auto I = BlockBindFunc->user_begin(), E = BlockBindFunc->user_end(); 197 I != E;) { 198 DEBUG(dbgs() << "[lowerBlockBind] " << **I << '\n'); 199 // Handle spir_block_bind(bitcast(block_func), context_len, 200 // context_align, context) 201 auto CallBlkBind = cast<CallInst>(*I++); 202 Function *InvF = nullptr; 203 Value *Ctx = nullptr; 204 Value *CtxLen = nullptr; 205 Value *CtxAlign = nullptr; 206 getBlockInvokeFuncAndContext(CallBlkBind, &InvF, &Ctx, &CtxLen, 207 &CtxAlign); 208 for (auto II = CallBlkBind->user_begin(), EE = CallBlkBind->user_end(); 209 II != EE;) { 210 auto BlkUser = *II++; 211 SPIRVDBG(dbgs() << " Block user: " << *BlkUser << '\n'); 212 if (auto Ret = dyn_cast<ReturnInst>(BlkUser)) { 213 bool Inlined = false; 214 changed |= lowerReturnBlock(Ret, CallBlkBind, Inlined); 215 if (Inlined) 216 return true; 217 } else if (auto CI = dyn_cast<CallInst>(BlkUser)){ 218 auto CallBindF = CI->getCalledFunction(); 219 auto Name = CallBindF->getName(); 220 std::string DemangledName; 221 if (Name == SPIR_INTRINSIC_GET_BLOCK_INVOKE) { 222 assert(CI->getArgOperand(0) == CallBlkBind); 223 changed |= lowerGetBlockInvoke(CI, cast<Function>(InvF)); 224 } else if (Name == SPIR_INTRINSIC_GET_BLOCK_CONTEXT) { 225 assert(CI->getArgOperand(0) == CallBlkBind); 226 // Handle context_ptr = spir_get_block_context(block) 227 lowerGetBlockContext(CI, Ctx); 228 changed = true; 229 } else if (oclIsBuiltin(Name, &DemangledName)) { 230 lowerBlockBuiltin(CI, InvF, Ctx, CtxLen, CtxAlign, DemangledName); 231 changed = true; 232 } else 233 llvm_unreachable("Invalid block user"); 234 } 235 } 236 erase(CallBlkBind); 237 } 238 changed |= eraseUselessFunctions(); 239 return changed; 240 } 241 242 void 243 lowerGetBlockContext(CallInst *CallGetBlkCtx, Value *Ctx = nullptr) { 244 if (!Ctx) 245 getBlockInvokeFuncAndContext(CallGetBlkCtx->getArgOperand(0), nullptr, 246 &Ctx); 247 CallGetBlkCtx->replaceAllUsesWith(Ctx); 248 DEBUG(dbgs() << " [lowerGetBlockContext] " << *CallGetBlkCtx << " => " << 249 *Ctx << "\n\n"); 250 erase(CallGetBlkCtx); 251 } 252 253 bool 254 lowerGetBlockInvoke(CallInst *CallGetBlkInvoke, 255 Function *InvokeF = nullptr) { 256 bool changed = false; 257 for (auto UI = CallGetBlkInvoke->user_begin(), 258 UE = CallGetBlkInvoke->user_end(); 259 UI != UE;) { 260 // Handle block_func_ptr = bitcast(spir_get_block_invoke(block)) 261 auto CallInv = cast<Instruction>(*UI++); 262 auto Cast = dyn_cast<BitCastInst>(CallInv); 263 if (Cast) 264 CallInv = dyn_cast<Instruction>(*CallInv->user_begin()); 265 DEBUG(dbgs() << "[lowerGetBlockInvoke] " << *CallInv); 266 // Handle ret = block_func_ptr(context_ptr, args) 267 auto CI = cast<CallInst>(CallInv); 268 auto F = CI->getCalledValue(); 269 if (InvokeF == nullptr) { 270 getBlockInvokeFuncAndContext(CallGetBlkInvoke->getArgOperand(0), 271 &InvokeF, nullptr); 272 assert(InvokeF); 273 } 274 assert(F->getType() == InvokeF->getType()); 275 CI->replaceUsesOfWith(F, InvokeF); 276 DEBUG(dbgs() << " => " << *CI << "\n\n"); 277 erase(Cast); 278 changed = true; 279 } 280 erase(CallGetBlkInvoke); 281 return changed; 282 } 283 284 void 285 lowerBlockBuiltin(CallInst *CI, Function *InvF, Value *Ctx, Value *CtxLen, 286 Value *CtxAlign, const std::string& DemangledName) { 287 mutateCallInstSPIRV (M, CI, [=](CallInst *CI, std::vector<Value *> &Args) { 288 size_t I = 0; 289 size_t E = Args.size(); 290 for (; I != E; ++I) { 291 if (isPointerToOpaqueStructType(Args[I]->getType(), 292 SPIR_TYPE_NAME_BLOCK_T)) { 293 break; 294 } 295 } 296 assert (I < E); 297 Args[I] = castToVoidFuncPtr(InvF); 298 if (I + 1 == E) { 299 Args.push_back(Ctx); 300 Args.push_back(CtxLen); 301 Args.push_back(CtxAlign); 302 } else { 303 Args.insert(Args.begin() + I + 1, CtxAlign); 304 Args.insert(Args.begin() + I + 1, CtxLen); 305 Args.insert(Args.begin() + I + 1, Ctx); 306 } 307 if (DemangledName == kOCLBuiltinName::EnqueueKernel) { 308 // Insert event arguments if there are not. 309 if (!isa<IntegerType>(Args[3]->getType())) { 310 Args.insert(Args.begin() + 3, getInt32(M, 0)); 311 Args.insert(Args.begin() + 4, getOCLNullClkEventPtr()); 312 } 313 if (!isOCLClkEventPtrType(Args[5]->getType())) 314 Args.insert(Args.begin() + 5, getOCLNullClkEventPtr()); 315 } 316 return getSPIRVFuncName(OCLSPIRVBuiltinMap::map(DemangledName)); 317 }); 318 } 319 /// Transform return of a block. 320 /// The function returning a block is inlined since the context cannot be 321 /// passed to another function. 322 /// Returns true of module is changed. 323 bool 324 lowerReturnBlock(ReturnInst *Ret, Value *CallBlkBind, bool &Inlined) { 325 auto F = Ret->getParent()->getParent(); 326 auto changed = false; 327 for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) { 328 auto U = *UI++; 329 dumpUsers(U); 330 auto Inst = dyn_cast<Instruction>(U); 331 if (Inst && Inst->use_empty()) { 332 erase(Inst); 333 changed = true; 334 continue; 335 } 336 auto CI = dyn_cast<CallInst>(U); 337 if(!CI || CI->getCalledFunction() != F) 338 continue; 339 340 DEBUG(dbgs() << "[lowerReturnBlock] inline " << F->getName() << '\n'); 341 auto CG = &getAnalysis<CallGraphWrapperPass>().getCallGraph(); 342 auto ACT = &getAnalysis<AssumptionCacheTracker>(); 343 //auto AA = &getAnalysis<AliasAnalysis>(); 344 //InlineFunctionInfo IFI(CG, M->getDataLayout(), AA, ACT); 345 InlineFunctionInfo IFI(CG, ACT); 346 InlineFunction(CI, IFI); 347 Inlined = true; 348 } 349 return changed || Inlined; 350 } 351 352 void 353 getBlockInvokeFuncAndContext(Value *Blk, Function **PInvF, Value **PCtx, 354 Value **PCtxLen = nullptr, Value **PCtxAlign = nullptr){ 355 Function *InvF = nullptr; 356 Value *Ctx = nullptr; 357 Value *CtxLen = nullptr; 358 Value *CtxAlign = nullptr; 359 if (auto CallBlkBind = dyn_cast<CallInst>(Blk)) { 360 assert(CallBlkBind->getCalledFunction()->getName() == 361 SPIR_INTRINSIC_BLOCK_BIND && "Invalid block"); 362 InvF = dyn_cast<Function>( 363 CallBlkBind->getArgOperand(0)->stripPointerCasts()); 364 CtxLen = CallBlkBind->getArgOperand(1); 365 CtxAlign = CallBlkBind->getArgOperand(2); 366 Ctx = CallBlkBind->getArgOperand(3); 367 } else if (auto F = dyn_cast<Function>(Blk->stripPointerCasts())) { 368 InvF = F; 369 Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext())); 370 } else if (auto Load = dyn_cast<LoadInst>(Blk)) { 371 auto Op = Load->getPointerOperand(); 372 if (auto GV = dyn_cast<GlobalVariable>(Op)) { 373 if (GV->isConstant()) { 374 InvF = cast<Function>(GV->getInitializer()->stripPointerCasts()); 375 Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext())); 376 } else { 377 llvm_unreachable("load non-constant block?"); 378 } 379 } else { 380 llvm_unreachable("Loading block from non global?"); 381 } 382 } else { 383 llvm_unreachable("Invalid block"); 384 } 385 DEBUG(dbgs() << " Block invocation func: " << InvF->getName() << '\n' << 386 " Block context: " << *Ctx << '\n'); 387 assert(InvF && Ctx && "Invalid block"); 388 if (PInvF) 389 *PInvF = InvF; 390 if (PCtx) 391 *PCtx = Ctx; 392 if (PCtxLen) 393 *PCtxLen = CtxLen; 394 if (PCtxAlign) 395 *PCtxAlign = CtxAlign; 396 } 397 void 398 erase(Instruction *I) { 399 if (!I) 400 return; 401 if (I->use_empty()) { 402 I->dropAllReferences(); 403 I->eraseFromParent(); 404 } 405 else 406 dumpUsers(I); 407 } 408 void 409 erase(ConstantExpr *I) { 410 if (!I) 411 return; 412 if (I->use_empty()) { 413 I->dropAllReferences(); 414 I->destroyConstant(); 415 } else 416 dumpUsers(I); 417 } 418 void 419 erase(Function *F) { 420 if (!F) 421 return; 422 if (!F->use_empty()) { 423 dumpUsers(F); 424 return; 425 } 426 F->dropAllReferences(); 427 auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); 428 CG.removeFunctionFromModule(new CallGraphNode(F)); 429 } 430 431 llvm::PointerType* getOCLClkEventType() { 432 return getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_CLK_EVENT_T, 433 SPIRAS_Global); 434 } 435 436 llvm::PointerType* getOCLClkEventPtrType() { 437 return PointerType::get(getOCLClkEventType(), SPIRAS_Generic); 438 } 439 440 bool isOCLClkEventPtrType(Type *T) { 441 if (auto PT = dyn_cast<PointerType>(T)) 442 return isPointerToOpaqueStructType( 443 PT->getElementType(), SPIR_TYPE_NAME_CLK_EVENT_T); 444 return false; 445 } 446 447 llvm::Constant* getOCLNullClkEventPtr() { 448 return Constant::getNullValue(getOCLClkEventPtrType()); 449 } 450 451 void dumpGetBlockInvokeUsers(StringRef Prompt) { 452 DEBUG(dbgs() << Prompt); 453 dumpUsers(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE)); 454 } 455 }; 456 457 char SPIRVLowerOCLBlocks::ID = 0; 458 } 459 460 INITIALIZE_PASS_BEGIN(SPIRVLowerOCLBlocks, "spvblocks", 461 "SPIR-V lower OCL blocks", false, false) 462 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 463 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 464 //INITIALIZE_AG_DEPENDENCY(AliasAnalysis) 465 INITIALIZE_PASS_END(SPIRVLowerOCLBlocks, "spvblocks", 466 "SPIR-V lower OCL blocks", false, false) 467 468 ModulePass *llvm::createSPIRVLowerOCLBlocks() { 469 return new SPIRVLowerOCLBlocks(); 470 } 471 472 #endif /* OCLLOWERBLOCKS_H_ */ 473