Home | History | Annotate | Download | only in NVPTX
      1 //===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
     10 // the size is large or is not a compile-time constant.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "NVPTXLowerAggrCopies.h"
     15 #include "llvm/IR/Constants.h"
     16 #include "llvm/IR/DataLayout.h"
     17 #include "llvm/IR/Function.h"
     18 #include "llvm/IR/IRBuilder.h"
     19 #include "llvm/IR/InstIterator.h"
     20 #include "llvm/IR/Instructions.h"
     21 #include "llvm/IR/IntrinsicInst.h"
     22 #include "llvm/IR/Intrinsics.h"
     23 #include "llvm/IR/LLVMContext.h"
     24 #include "llvm/IR/Module.h"
     25 
     26 using namespace llvm;
     27 
     28 namespace llvm { FunctionPass *createLowerAggrCopies(); }
     29 
     30 char NVPTXLowerAggrCopies::ID = 0;
     31 
     32 // Lower MemTransferInst or load-store pair to loop
     33 static void convertTransferToLoop(
     34     Instruction *splitAt, Value *srcAddr, Value *dstAddr, Value *len,
     35     //unsigned numLoads,
     36     bool srcVolatile, bool dstVolatile, LLVMContext &Context, Function &F) {
     37   Type *indType = len->getType();
     38 
     39   BasicBlock *origBB = splitAt->getParent();
     40   BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split");
     41   BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB);
     42 
     43   origBB->getTerminator()->setSuccessor(0, loopBB);
     44   IRBuilder<> builder(origBB, origBB->getTerminator());
     45 
     46   // srcAddr and dstAddr are expected to be pointer types,
     47   // so no check is made here.
     48   unsigned srcAS = dyn_cast<PointerType>(srcAddr->getType())->getAddressSpace();
     49   unsigned dstAS = dyn_cast<PointerType>(dstAddr->getType())->getAddressSpace();
     50 
     51   // Cast pointers to (char *)
     52   srcAddr = builder.CreateBitCast(srcAddr, Type::getInt8PtrTy(Context, srcAS));
     53   dstAddr = builder.CreateBitCast(dstAddr, Type::getInt8PtrTy(Context, dstAS));
     54 
     55   IRBuilder<> loop(loopBB);
     56   // The loop index (ind) is a phi node.
     57   PHINode *ind = loop.CreatePHI(indType, 0);
     58   // Incoming value for ind is 0
     59   ind->addIncoming(ConstantInt::get(indType, 0), origBB);
     60 
     61   // load from srcAddr+ind
     62   Value *val = loop.CreateLoad(loop.CreateGEP(srcAddr, ind), srcVolatile);
     63   // store at dstAddr+ind
     64   loop.CreateStore(val, loop.CreateGEP(dstAddr, ind), dstVolatile);
     65 
     66   // The value for ind coming from backedge is (ind + 1)
     67   Value *newind = loop.CreateAdd(ind, ConstantInt::get(indType, 1));
     68   ind->addIncoming(newind, loopBB);
     69 
     70   loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB);
     71 }
     72 
     73 // Lower MemSetInst to loop
     74 static void convertMemSetToLoop(Instruction *splitAt, Value *dstAddr,
     75                                 Value *len, Value *val, LLVMContext &Context,
     76                                 Function &F) {
     77   BasicBlock *origBB = splitAt->getParent();
     78   BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split");
     79   BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB);
     80 
     81   origBB->getTerminator()->setSuccessor(0, loopBB);
     82   IRBuilder<> builder(origBB, origBB->getTerminator());
     83 
     84   unsigned dstAS = dyn_cast<PointerType>(dstAddr->getType())->getAddressSpace();
     85 
     86   // Cast pointer to the type of value getting stored
     87   dstAddr =
     88       builder.CreateBitCast(dstAddr, PointerType::get(val->getType(), dstAS));
     89 
     90   IRBuilder<> loop(loopBB);
     91   PHINode *ind = loop.CreatePHI(len->getType(), 0);
     92   ind->addIncoming(ConstantInt::get(len->getType(), 0), origBB);
     93 
     94   loop.CreateStore(val, loop.CreateGEP(dstAddr, ind), false);
     95 
     96   Value *newind = loop.CreateAdd(ind, ConstantInt::get(len->getType(), 1));
     97   ind->addIncoming(newind, loopBB);
     98 
     99   loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB);
    100 }
    101 
    102 bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
    103   SmallVector<LoadInst *, 4> aggrLoads;
    104   SmallVector<MemTransferInst *, 4> aggrMemcpys;
    105   SmallVector<MemSetInst *, 4> aggrMemsets;
    106 
    107   const DataLayout *DL = &getAnalysis<DataLayoutPass>().getDataLayout();
    108   LLVMContext &Context = F.getParent()->getContext();
    109 
    110   //
    111   // Collect all the aggrLoads, aggrMemcpys and addrMemsets.
    112   //
    113   //const BasicBlock *firstBB = &F.front();  // first BB in F
    114   for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
    115     //BasicBlock *bb = BI;
    116     for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE;
    117          ++II) {
    118       if (LoadInst *load = dyn_cast<LoadInst>(II)) {
    119 
    120         if (load->hasOneUse() == false)
    121           continue;
    122 
    123         if (DL->getTypeStoreSize(load->getType()) < MaxAggrCopySize)
    124           continue;
    125 
    126         User *use = load->user_back();
    127         if (StoreInst *store = dyn_cast<StoreInst>(use)) {
    128           if (store->getOperand(0) != load) //getValueOperand
    129             continue;
    130           aggrLoads.push_back(load);
    131         }
    132       } else if (MemTransferInst *intr = dyn_cast<MemTransferInst>(II)) {
    133         Value *len = intr->getLength();
    134         // If the number of elements being copied is greater
    135         // than MaxAggrCopySize, lower it to a loop
    136         if (ConstantInt *len_int = dyn_cast<ConstantInt>(len)) {
    137           if (len_int->getZExtValue() >= MaxAggrCopySize) {
    138             aggrMemcpys.push_back(intr);
    139           }
    140         } else {
    141           // turn variable length memcpy/memmov into loop
    142           aggrMemcpys.push_back(intr);
    143         }
    144       } else if (MemSetInst *memsetintr = dyn_cast<MemSetInst>(II)) {
    145         Value *len = memsetintr->getLength();
    146         if (ConstantInt *len_int = dyn_cast<ConstantInt>(len)) {
    147           if (len_int->getZExtValue() >= MaxAggrCopySize) {
    148             aggrMemsets.push_back(memsetintr);
    149           }
    150         } else {
    151           // turn variable length memset into loop
    152           aggrMemsets.push_back(memsetintr);
    153         }
    154       }
    155     }
    156   }
    157   if ((aggrLoads.size() == 0) && (aggrMemcpys.size() == 0) &&
    158       (aggrMemsets.size() == 0))
    159     return false;
    160 
    161   //
    162   // Do the transformation of an aggr load/copy/set to a loop
    163   //
    164   for (unsigned i = 0, e = aggrLoads.size(); i != e; ++i) {
    165     LoadInst *load = aggrLoads[i];
    166     StoreInst *store = dyn_cast<StoreInst>(*load->user_begin());
    167     Value *srcAddr = load->getOperand(0);
    168     Value *dstAddr = store->getOperand(1);
    169     unsigned numLoads = DL->getTypeStoreSize(load->getType());
    170     Value *len = ConstantInt::get(Type::getInt32Ty(Context), numLoads);
    171 
    172     convertTransferToLoop(store, srcAddr, dstAddr, len, load->isVolatile(),
    173                           store->isVolatile(), Context, F);
    174 
    175     store->eraseFromParent();
    176     load->eraseFromParent();
    177   }
    178 
    179   for (unsigned i = 0, e = aggrMemcpys.size(); i != e; ++i) {
    180     MemTransferInst *cpy = aggrMemcpys[i];
    181     Value *len = cpy->getLength();
    182     // llvm 2.7 version of memcpy does not have volatile
    183     // operand yet. So always making it non-volatile
    184     // optimistically, so that we don't see unnecessary
    185     // st.volatile in ptx
    186     convertTransferToLoop(cpy, cpy->getSource(), cpy->getDest(), len, false,
    187                           false, Context, F);
    188     cpy->eraseFromParent();
    189   }
    190 
    191   for (unsigned i = 0, e = aggrMemsets.size(); i != e; ++i) {
    192     MemSetInst *memsetinst = aggrMemsets[i];
    193     Value *len = memsetinst->getLength();
    194     Value *val = memsetinst->getValue();
    195     convertMemSetToLoop(memsetinst, memsetinst->getDest(), len, val, Context,
    196                         F);
    197     memsetinst->eraseFromParent();
    198   }
    199 
    200   return true;
    201 }
    202 
    203 FunctionPass *llvm::createLowerAggrCopies() {
    204   return new NVPTXLowerAggrCopies();
    205 }
    206