Home | History | Annotate | Download | only in SPIRV
      1 //===- SPIRVLowerOCLBlocks.cpp - Lower OpenCL blocks ------------*- C++ -*-===//
      2 //
      3 //                     The LLVM/SPIR-V Translator
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
      9 //
     10 // Permission is hereby granted, free of charge, to any person obtaining a
     11 // copy of this software and associated documentation files (the "Software"),
     12 // to deal with the Software without restriction, including without limitation
     13 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
     14 // and/or sell copies of the Software, and to permit persons to whom the
     15 // Software is furnished to do so, subject to the following conditions:
     16 //
     17 // Redistributions of source code must retain the above copyright notice,
     18 // this list of conditions and the following disclaimers.
     19 // Redistributions in binary form must reproduce the above copyright notice,
     20 // this list of conditions and the following disclaimers in the documentation
     21 // and/or other materials provided with the distribution.
     22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its
     23 // contributors may be used to endorse or promote products derived from this
     24 // Software without specific prior written permission.
     25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
     31 // THE SOFTWARE.
     32 //
     33 //===----------------------------------------------------------------------===//
     34 /// \file
     35 ///
     36 /// This file implements lowering of OpenCL blocks to functions.
     37 ///
     38 //===----------------------------------------------------------------------===//
     39 
     40 #ifndef OCLLOWERBLOCKS_H_
     41 #define OCLLOWERBLOCKS_H_
     42 
     43 #include "SPIRVInternal.h"
     44 #include "OCLUtil.h"
     45 
     46 #include "llvm/ADT/DenseMap.h"
     47 #include "llvm/ADT/SetVector.h"
     48 #include "llvm/ADT/StringSwitch.h"
     49 #include "llvm/ADT/Triple.h"
     50 #include "llvm/Analysis/AliasAnalysis.h"
     51 #include "llvm/Analysis/AssumptionCache.h"
     52 #include "llvm/Analysis/CallGraph.h"
     53 #include "llvm/IR/Verifier.h"
     54 #include "llvm/Bitcode/ReaderWriter.h"
     55 #include "llvm/IR/Constants.h"
     56 #include "llvm/IR/DerivedTypes.h"
     57 #include "llvm/IR/Function.h"
     58 #include "llvm/IR/InstrTypes.h"
     59 #include "llvm/IR/Instructions.h"
     60 #include "llvm/IR/Module.h"
     61 #include "llvm/IR/Operator.h"
     62 #include "llvm/Pass.h"
     63 #include "llvm/PassSupport.h"
     64 #include "llvm/Support/Casting.h"
     65 #include "llvm/Support/Debug.h"
     66 #include "llvm/Support/raw_ostream.h"
     67 #include "llvm/Support/ToolOutputFile.h"
     68 #include "llvm/Transforms/Utils/Cloning.h"
     69 
     70 #include <iostream>
     71 #include <list>
     72 #include <memory>
     73 #include <set>
     74 #include <sstream>
     75 #include <vector>
     76 
     77 #define DEBUG_TYPE "spvblocks"
     78 
     79 using namespace llvm;
     80 using namespace SPIRV;
     81 using namespace OCLUtil;
     82 
     83 namespace SPIRV{
     84 
     85 /// Lower SPIR2 blocks to function calls.
     86 ///
     87 /// SPIR2 representation of blocks:
     88 ///
     89 /// block = spir_block_bind(bitcast(block_func), context_len, context_align,
     90 ///   context)
     91 /// block_func_ptr = bitcast(spir_get_block_invoke(block))
     92 /// context_ptr = spir_get_block_context(block)
     93 /// ret = block_func_ptr(context_ptr, args)
     94 ///
     95 /// Propagates block_func to each spir_get_block_invoke through def-use chain of
     96 /// spir_block_bind, so that
     97 /// ret = block_func(context, args)
     98 class SPIRVLowerOCLBlocks: public ModulePass {
     99 public:
    100   SPIRVLowerOCLBlocks():ModulePass(ID), M(nullptr){
    101     initializeSPIRVLowerOCLBlocksPass(*PassRegistry::getPassRegistry());
    102   }
    103 
    104   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
    105     AU.addRequired<CallGraphWrapperPass>();
    106     //AU.addRequired<AliasAnalysis>();
    107     AU.addRequired<AssumptionCacheTracker>();
    108   }
    109 
    110   virtual bool runOnModule(Module &Module) {
    111     M = &Module;
    112     lowerBlockBind();
    113     lowerGetBlockInvoke();
    114     lowerGetBlockContext();
    115     erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
    116     erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT));
    117     erase(M->getFunction(SPIR_INTRINSIC_BLOCK_BIND));
    118     DEBUG(dbgs() << "------- After OCLLowerBlocks ------------\n" <<
    119                     *M << '\n');
    120     return true;
    121   }
    122 
    123   static char ID;
    124 private:
    125   const static int MaxIter = 1000;
    126   Module *M;
    127 
    128   bool
    129   lowerBlockBind() {
    130     auto F = M->getFunction(SPIR_INTRINSIC_BLOCK_BIND);
    131     if (!F)
    132       return false;
    133     int Iter = MaxIter;
    134     while(lowerBlockBind(F) && Iter > 0){
    135       Iter--;
    136       DEBUG(dbgs() << "-------------- after iteration " << MaxIter - Iter <<
    137           " --------------\n" << *M << '\n');
    138     }
    139     assert(Iter > 0 && "Too many iterations");
    140     return true;
    141   }
    142 
    143   bool
    144   eraseUselessFunctions() {
    145     bool changed = false;
    146     for (auto I = M->begin(), E = M->end(); I != E;) {
    147       Function *F = static_cast<Function*>(I++);
    148       if (!GlobalValue::isInternalLinkage(F->getLinkage()) &&
    149           !F->isDeclaration())
    150         continue;
    151 
    152       dumpUsers(F, "[eraseUselessFunctions] ");
    153       for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
    154         auto U = *UI++;
    155         if (auto CE = dyn_cast<ConstantExpr>(U)){
    156           if (CE->use_empty()) {
    157             CE->dropAllReferences();
    158             changed = true;
    159           }
    160         }
    161       }
    162       if (F->use_empty()) {
    163         erase(F);
    164         changed = true;
    165       }
    166     }
    167     return changed;
    168   }
    169 
    170   void
    171   lowerGetBlockInvoke() {
    172     if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE)) {
    173       for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
    174         auto CI = dyn_cast<CallInst>(*UI++);
    175         assert(CI && "Invalid usage of spir_get_block_invoke");
    176         lowerGetBlockInvoke(CI);
    177       }
    178     }
    179   }
    180 
    181   void
    182   lowerGetBlockContext() {
    183     if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT)) {
    184       for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
    185         auto CI = dyn_cast<CallInst>(*UI++);
    186         assert(CI && "Invalid usage of spir_get_block_context");
    187         lowerGetBlockContext(CI);
    188       }
    189     }
    190   }
    191   /// Lower calls of spir_block_bind.
    192   /// Return true if the Module is changed.
    193   bool
    194   lowerBlockBind(Function *BlockBindFunc) {
    195     bool changed = false;
    196     for (auto I = BlockBindFunc->user_begin(), E = BlockBindFunc->user_end();
    197         I != E;) {
    198       DEBUG(dbgs() << "[lowerBlockBind] " << **I << '\n');
    199       // Handle spir_block_bind(bitcast(block_func), context_len,
    200       // context_align, context)
    201       auto CallBlkBind = cast<CallInst>(*I++);
    202       Function *InvF = nullptr;
    203       Value *Ctx = nullptr;
    204       Value *CtxLen = nullptr;
    205       Value *CtxAlign = nullptr;
    206       getBlockInvokeFuncAndContext(CallBlkBind, &InvF, &Ctx, &CtxLen,
    207           &CtxAlign);
    208       for (auto II = CallBlkBind->user_begin(), EE = CallBlkBind->user_end();
    209           II != EE;) {
    210         auto BlkUser = *II++;
    211         SPIRVDBG(dbgs() << "  Block user: " << *BlkUser << '\n');
    212         if (auto Ret = dyn_cast<ReturnInst>(BlkUser)) {
    213           bool Inlined = false;
    214           changed |= lowerReturnBlock(Ret, CallBlkBind, Inlined);
    215           if (Inlined)
    216             return true;
    217         } else if (auto CI = dyn_cast<CallInst>(BlkUser)){
    218           auto CallBindF = CI->getCalledFunction();
    219           auto Name = CallBindF->getName();
    220           std::string DemangledName;
    221           if (Name == SPIR_INTRINSIC_GET_BLOCK_INVOKE) {
    222             assert(CI->getArgOperand(0) == CallBlkBind);
    223             changed |= lowerGetBlockInvoke(CI, cast<Function>(InvF));
    224           } else if (Name == SPIR_INTRINSIC_GET_BLOCK_CONTEXT) {
    225             assert(CI->getArgOperand(0) == CallBlkBind);
    226             // Handle context_ptr = spir_get_block_context(block)
    227             lowerGetBlockContext(CI, Ctx);
    228             changed = true;
    229           } else if (oclIsBuiltin(Name, &DemangledName)) {
    230             lowerBlockBuiltin(CI, InvF, Ctx, CtxLen, CtxAlign, DemangledName);
    231             changed = true;
    232           } else
    233             llvm_unreachable("Invalid block user");
    234         }
    235       }
    236       erase(CallBlkBind);
    237     }
    238     changed |= eraseUselessFunctions();
    239     return changed;
    240   }
    241 
    242   void
    243   lowerGetBlockContext(CallInst *CallGetBlkCtx, Value *Ctx = nullptr) {
    244     if (!Ctx)
    245       getBlockInvokeFuncAndContext(CallGetBlkCtx->getArgOperand(0), nullptr,
    246           &Ctx);
    247     CallGetBlkCtx->replaceAllUsesWith(Ctx);
    248     DEBUG(dbgs() << "  [lowerGetBlockContext] " << *CallGetBlkCtx << " => " <<
    249         *Ctx << "\n\n");
    250     erase(CallGetBlkCtx);
    251   }
    252 
    253   bool
    254   lowerGetBlockInvoke(CallInst *CallGetBlkInvoke,
    255       Function *InvokeF = nullptr) {
    256     bool changed = false;
    257     for (auto UI = CallGetBlkInvoke->user_begin(),
    258         UE = CallGetBlkInvoke->user_end();
    259         UI != UE;) {
    260       // Handle block_func_ptr = bitcast(spir_get_block_invoke(block))
    261       auto CallInv = cast<Instruction>(*UI++);
    262       auto Cast = dyn_cast<BitCastInst>(CallInv);
    263       if (Cast)
    264         CallInv = dyn_cast<Instruction>(*CallInv->user_begin());
    265       DEBUG(dbgs() << "[lowerGetBlockInvoke]  " << *CallInv);
    266       // Handle ret = block_func_ptr(context_ptr, args)
    267       auto CI = cast<CallInst>(CallInv);
    268       auto F = CI->getCalledValue();
    269       if (InvokeF == nullptr) {
    270         getBlockInvokeFuncAndContext(CallGetBlkInvoke->getArgOperand(0),
    271             &InvokeF, nullptr);
    272         assert(InvokeF);
    273       }
    274       assert(F->getType() == InvokeF->getType());
    275       CI->replaceUsesOfWith(F, InvokeF);
    276       DEBUG(dbgs() << " => " << *CI << "\n\n");
    277       erase(Cast);
    278       changed = true;
    279     }
    280     erase(CallGetBlkInvoke);
    281     return changed;
    282   }
    283 
    284   void
    285   lowerBlockBuiltin(CallInst *CI, Function *InvF, Value *Ctx, Value *CtxLen,
    286       Value *CtxAlign, const std::string& DemangledName) {
    287     mutateCallInstSPIRV (M, CI, [=](CallInst *CI, std::vector<Value *> &Args) {
    288       size_t I = 0;
    289       size_t E = Args.size();
    290       for (; I != E; ++I) {
    291         if (isPointerToOpaqueStructType(Args[I]->getType(),
    292             SPIR_TYPE_NAME_BLOCK_T)) {
    293           break;
    294         }
    295       }
    296       assert (I < E);
    297       Args[I] = castToVoidFuncPtr(InvF);
    298       if (I + 1 == E) {
    299         Args.push_back(Ctx);
    300         Args.push_back(CtxLen);
    301         Args.push_back(CtxAlign);
    302       } else {
    303         Args.insert(Args.begin() + I + 1, CtxAlign);
    304         Args.insert(Args.begin() + I + 1, CtxLen);
    305         Args.insert(Args.begin() + I + 1, Ctx);
    306       }
    307       if (DemangledName == kOCLBuiltinName::EnqueueKernel) {
    308         // Insert event arguments if there are not.
    309         if (!isa<IntegerType>(Args[3]->getType())) {
    310           Args.insert(Args.begin() + 3, getInt32(M, 0));
    311           Args.insert(Args.begin() + 4, getOCLNullClkEventPtr());
    312         }
    313         if (!isOCLClkEventPtrType(Args[5]->getType()))
    314           Args.insert(Args.begin() + 5, getOCLNullClkEventPtr());
    315       }
    316       return getSPIRVFuncName(OCLSPIRVBuiltinMap::map(DemangledName));
    317     });
    318   }
    319   /// Transform return of a block.
    320   /// The function returning a block is inlined since the context cannot be
    321   /// passed to another function.
    322   /// Returns true of module is changed.
    323   bool
    324   lowerReturnBlock(ReturnInst *Ret, Value *CallBlkBind, bool &Inlined) {
    325     auto F = Ret->getParent()->getParent();
    326     auto changed = false;
    327     for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
    328       auto U = *UI++;
    329       dumpUsers(U);
    330       auto Inst = dyn_cast<Instruction>(U);
    331       if (Inst && Inst->use_empty()) {
    332         erase(Inst);
    333         changed = true;
    334         continue;
    335       }
    336       auto CI = dyn_cast<CallInst>(U);
    337       if(!CI || CI->getCalledFunction() != F)
    338         continue;
    339 
    340       DEBUG(dbgs() << "[lowerReturnBlock] inline " << F->getName() << '\n');
    341       auto CG = &getAnalysis<CallGraphWrapperPass>().getCallGraph();
    342       auto ACT = &getAnalysis<AssumptionCacheTracker>();
    343       //auto AA = &getAnalysis<AliasAnalysis>();
    344       //InlineFunctionInfo IFI(CG, M->getDataLayout(), AA, ACT);
    345       InlineFunctionInfo IFI(CG, ACT);
    346       InlineFunction(CI, IFI);
    347       Inlined = true;
    348     }
    349     return changed || Inlined;
    350   }
    351 
    352   void
    353   getBlockInvokeFuncAndContext(Value *Blk, Function **PInvF, Value **PCtx,
    354       Value **PCtxLen = nullptr, Value **PCtxAlign = nullptr){
    355     Function *InvF = nullptr;
    356     Value *Ctx = nullptr;
    357     Value *CtxLen = nullptr;
    358     Value *CtxAlign = nullptr;
    359     if (auto CallBlkBind = dyn_cast<CallInst>(Blk)) {
    360       assert(CallBlkBind->getCalledFunction()->getName() ==
    361           SPIR_INTRINSIC_BLOCK_BIND && "Invalid block");
    362       InvF = dyn_cast<Function>(
    363           CallBlkBind->getArgOperand(0)->stripPointerCasts());
    364       CtxLen = CallBlkBind->getArgOperand(1);
    365       CtxAlign = CallBlkBind->getArgOperand(2);
    366       Ctx = CallBlkBind->getArgOperand(3);
    367     } else if (auto F = dyn_cast<Function>(Blk->stripPointerCasts())) {
    368       InvF = F;
    369       Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
    370     } else if (auto Load = dyn_cast<LoadInst>(Blk)) {
    371       auto Op = Load->getPointerOperand();
    372       if (auto GV = dyn_cast<GlobalVariable>(Op)) {
    373         if (GV->isConstant()) {
    374           InvF = cast<Function>(GV->getInitializer()->stripPointerCasts());
    375           Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
    376         } else {
    377           llvm_unreachable("load non-constant block?");
    378         }
    379       } else {
    380         llvm_unreachable("Loading block from non global?");
    381       }
    382     } else {
    383       llvm_unreachable("Invalid block");
    384     }
    385     DEBUG(dbgs() << "  Block invocation func: " << InvF->getName() << '\n' <<
    386         "  Block context: " << *Ctx << '\n');
    387     assert(InvF && Ctx && "Invalid block");
    388     if (PInvF)
    389       *PInvF = InvF;
    390     if (PCtx)
    391       *PCtx = Ctx;
    392     if (PCtxLen)
    393       *PCtxLen = CtxLen;
    394     if (PCtxAlign)
    395       *PCtxAlign = CtxAlign;
    396   }
    397   void
    398   erase(Instruction *I) {
    399     if (!I)
    400       return;
    401     if (I->use_empty()) {
    402       I->dropAllReferences();
    403       I->eraseFromParent();
    404     }
    405     else
    406       dumpUsers(I);
    407   }
    408   void
    409   erase(ConstantExpr *I) {
    410     if (!I)
    411       return;
    412     if (I->use_empty()) {
    413       I->dropAllReferences();
    414       I->destroyConstant();
    415     } else
    416       dumpUsers(I);
    417   }
    418   void
    419   erase(Function *F) {
    420     if (!F)
    421       return;
    422     if (!F->use_empty()) {
    423       dumpUsers(F);
    424       return;
    425     }
    426     F->dropAllReferences();
    427     auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
    428     CG.removeFunctionFromModule(new CallGraphNode(F));
    429   }
    430 
    431   llvm::PointerType* getOCLClkEventType() {
    432     return getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_CLK_EVENT_T,
    433         SPIRAS_Global);
    434   }
    435 
    436   llvm::PointerType* getOCLClkEventPtrType() {
    437     return PointerType::get(getOCLClkEventType(), SPIRAS_Generic);
    438   }
    439 
    440   bool isOCLClkEventPtrType(Type *T) {
    441     if (auto PT = dyn_cast<PointerType>(T))
    442       return isPointerToOpaqueStructType(
    443         PT->getElementType(), SPIR_TYPE_NAME_CLK_EVENT_T);
    444     return false;
    445   }
    446 
    447   llvm::Constant* getOCLNullClkEventPtr() {
    448     return Constant::getNullValue(getOCLClkEventPtrType());
    449   }
    450 
    451   void dumpGetBlockInvokeUsers(StringRef Prompt) {
    452     DEBUG(dbgs() << Prompt);
    453     dumpUsers(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
    454   }
    455 };
    456 
    457 char SPIRVLowerOCLBlocks::ID = 0;
    458 }
    459 
    460 INITIALIZE_PASS_BEGIN(SPIRVLowerOCLBlocks, "spvblocks",
    461     "SPIR-V lower OCL blocks", false, false)
    462 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
    463 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
    464 //INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
    465 INITIALIZE_PASS_END(SPIRVLowerOCLBlocks, "spvblocks",
    466     "SPIR-V lower OCL blocks", false, false)
    467 
    468 ModulePass *llvm::createSPIRVLowerOCLBlocks() {
    469   return new SPIRVLowerOCLBlocks();
    470 }
    471 
    472 #endif /* OCLLOWERBLOCKS_H_ */
    473