Home | History | Annotate | Download | only in CodeGen
      1 //===----- CGCUDANV.cpp - Interface to NVIDIA CUDA Runtime ----------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This provides a class for CUDA code generation targeting the NVIDIA CUDA
     11 // runtime library.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #include "CGCUDARuntime.h"
     16 #include "CodeGenFunction.h"
     17 #include "CodeGenModule.h"
     18 #include "clang/AST/Decl.h"
     19 #include "llvm/IR/BasicBlock.h"
     20 #include "llvm/IR/Constants.h"
     21 #include "llvm/IR/DerivedTypes.h"
     22 #include "llvm/Support/CallSite.h"
     23 #include <vector>
     24 
     25 using namespace clang;
     26 using namespace CodeGen;
     27 
     28 namespace {
     29 
     30 class CGNVCUDARuntime : public CGCUDARuntime {
     31 
     32 private:
     33   llvm::Type *IntTy, *SizeTy;
     34   llvm::PointerType *CharPtrTy, *VoidPtrTy;
     35 
     36   llvm::Constant *getSetupArgumentFn() const;
     37   llvm::Constant *getLaunchFn() const;
     38 
     39 public:
     40   CGNVCUDARuntime(CodeGenModule &CGM);
     41 
     42   void EmitDeviceStubBody(CodeGenFunction &CGF, FunctionArgList &Args);
     43 };
     44 
     45 }
     46 
     47 CGNVCUDARuntime::CGNVCUDARuntime(CodeGenModule &CGM) : CGCUDARuntime(CGM) {
     48   CodeGen::CodeGenTypes &Types = CGM.getTypes();
     49   ASTContext &Ctx = CGM.getContext();
     50 
     51   IntTy = Types.ConvertType(Ctx.IntTy);
     52   SizeTy = Types.ConvertType(Ctx.getSizeType());
     53 
     54   CharPtrTy = llvm::PointerType::getUnqual(Types.ConvertType(Ctx.CharTy));
     55   VoidPtrTy = cast<llvm::PointerType>(Types.ConvertType(Ctx.VoidPtrTy));
     56 }
     57 
     58 llvm::Constant *CGNVCUDARuntime::getSetupArgumentFn() const {
     59   // cudaError_t cudaSetupArgument(void *, size_t, size_t)
     60   std::vector<llvm::Type*> Params;
     61   Params.push_back(VoidPtrTy);
     62   Params.push_back(SizeTy);
     63   Params.push_back(SizeTy);
     64   return CGM.CreateRuntimeFunction(llvm::FunctionType::get(IntTy,
     65                                                            Params, false),
     66                                    "cudaSetupArgument");
     67 }
     68 
     69 llvm::Constant *CGNVCUDARuntime::getLaunchFn() const {
     70   // cudaError_t cudaLaunch(char *)
     71   std::vector<llvm::Type*> Params;
     72   Params.push_back(CharPtrTy);
     73   return CGM.CreateRuntimeFunction(llvm::FunctionType::get(IntTy,
     74                                                            Params, false),
     75                                    "cudaLaunch");
     76 }
     77 
     78 void CGNVCUDARuntime::EmitDeviceStubBody(CodeGenFunction &CGF,
     79                                          FunctionArgList &Args) {
     80   // Build the argument value list and the argument stack struct type.
     81   SmallVector<llvm::Value *, 16> ArgValues;
     82   std::vector<llvm::Type *> ArgTypes;
     83   for (FunctionArgList::const_iterator I = Args.begin(), E = Args.end();
     84        I != E; ++I) {
     85     llvm::Value *V = CGF.GetAddrOfLocalVar(*I);
     86     ArgValues.push_back(V);
     87     assert(isa<llvm::PointerType>(V->getType()) && "Arg type not PointerType");
     88     ArgTypes.push_back(cast<llvm::PointerType>(V->getType())->getElementType());
     89   }
     90   llvm::StructType *ArgStackTy = llvm::StructType::get(
     91       CGF.getLLVMContext(), ArgTypes);
     92 
     93   llvm::BasicBlock *EndBlock = CGF.createBasicBlock("setup.end");
     94 
     95   // Emit the calls to cudaSetupArgument
     96   llvm::Constant *cudaSetupArgFn = getSetupArgumentFn();
     97   for (unsigned I = 0, E = Args.size(); I != E; ++I) {
     98     llvm::Value *Args[3];
     99     llvm::BasicBlock *NextBlock = CGF.createBasicBlock("setup.next");
    100     Args[0] = CGF.Builder.CreatePointerCast(ArgValues[I], VoidPtrTy);
    101     Args[1] = CGF.Builder.CreateIntCast(
    102         llvm::ConstantExpr::getSizeOf(ArgTypes[I]),
    103         SizeTy, false);
    104     Args[2] = CGF.Builder.CreateIntCast(
    105         llvm::ConstantExpr::getOffsetOf(ArgStackTy, I),
    106         SizeTy, false);
    107     llvm::CallSite CS = CGF.EmitRuntimeCallOrInvoke(cudaSetupArgFn, Args);
    108     llvm::Constant *Zero = llvm::ConstantInt::get(IntTy, 0);
    109     llvm::Value *CSZero = CGF.Builder.CreateICmpEQ(CS.getInstruction(), Zero);
    110     CGF.Builder.CreateCondBr(CSZero, NextBlock, EndBlock);
    111     CGF.EmitBlock(NextBlock);
    112   }
    113 
    114   // Emit the call to cudaLaunch
    115   llvm::Constant *cudaLaunchFn = getLaunchFn();
    116   llvm::Value *Arg = CGF.Builder.CreatePointerCast(CGF.CurFn, CharPtrTy);
    117   CGF.EmitRuntimeCallOrInvoke(cudaLaunchFn, Arg);
    118   CGF.EmitBranch(EndBlock);
    119 
    120   CGF.EmitBlock(EndBlock);
    121 }
    122 
    123 CGCUDARuntime *CodeGen::CreateNVCUDARuntime(CodeGenModule &CGM) {
    124   return new CGNVCUDARuntime(CGM);
    125 }
    126