Home | History | Annotate | Download | only in PTX
      1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This file implements the PTXTargetLowering class.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "PTX.h"
     15 #include "PTXISelLowering.h"
     16 #include "PTXMachineFunctionInfo.h"
     17 #include "PTXRegisterInfo.h"
     18 #include "PTXSubtarget.h"
     19 #include "llvm/Function.h"
     20 #include "llvm/Support/ErrorHandling.h"
     21 #include "llvm/CodeGen/CallingConvLower.h"
     22 #include "llvm/CodeGen/MachineFunction.h"
     23 #include "llvm/CodeGen/MachineRegisterInfo.h"
     24 #include "llvm/CodeGen/SelectionDAG.h"
     25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
     26 #include "llvm/Support/Debug.h"
     27 #include "llvm/Support/raw_ostream.h"
     28 
     29 using namespace llvm;
     30 
     31 //===----------------------------------------------------------------------===//
     32 // TargetLowering Implementation
     33 //===----------------------------------------------------------------------===//
     34 
     35 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
     36   : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
     37   // Set up the register classes.
     38   addRegisterClass(MVT::i1,  PTX::RegPredRegisterClass);
     39   addRegisterClass(MVT::i16, PTX::RegI16RegisterClass);
     40   addRegisterClass(MVT::i32, PTX::RegI32RegisterClass);
     41   addRegisterClass(MVT::i64, PTX::RegI64RegisterClass);
     42   addRegisterClass(MVT::f32, PTX::RegF32RegisterClass);
     43   addRegisterClass(MVT::f64, PTX::RegF64RegisterClass);
     44 
     45   setBooleanContents(ZeroOrOneBooleanContent);
     46   setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct?
     47   setMinFunctionAlignment(2);
     48 
     49   ////////////////////////////////////
     50   /////////// Expansion //////////////
     51   ////////////////////////////////////
     52 
     53   // (any/zero/sign) extload => load + (any/zero/sign) extend
     54 
     55   setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand);
     56   setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand);
     57   setLoadExtAction(ISD::SEXTLOAD, MVT::i16, Expand);
     58 
     59   // f32 extload => load + fextend
     60 
     61   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
     62 
     63   // f64 truncstore => trunc + store
     64 
     65   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
     66 
     67   // sign_extend_inreg => sign_extend
     68 
     69   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
     70 
     71   // br_cc => brcond
     72 
     73   setOperationAction(ISD::BR_CC, MVT::Other, Expand);
     74 
     75   // select_cc => setcc
     76 
     77   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
     78   setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
     79   setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
     80 
     81   ////////////////////////////////////
     82   //////////// Legal /////////////////
     83   ////////////////////////////////////
     84 
     85   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
     86   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
     87 
     88   ////////////////////////////////////
     89   //////////// Custom ////////////////
     90   ////////////////////////////////////
     91 
     92   // customise setcc to use bitwise logic if possible
     93 
     94   setOperationAction(ISD::SETCC, MVT::i1, Custom);
     95 
     96   // customize translation of memory addresses
     97 
     98   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
     99   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
    100 
    101   // Compute derived properties from the register classes
    102   computeRegisterProperties();
    103 }
    104 
    105 EVT PTXTargetLowering::getSetCCResultType(EVT VT) const {
    106   return MVT::i1;
    107 }
    108 
    109 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
    110   switch (Op.getOpcode()) {
    111     default:
    112       llvm_unreachable("Unimplemented operand");
    113     case ISD::SETCC:
    114       return LowerSETCC(Op, DAG);
    115     case ISD::GlobalAddress:
    116       return LowerGlobalAddress(Op, DAG);
    117   }
    118 }
    119 
    120 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
    121   switch (Opcode) {
    122     default:
    123       llvm_unreachable("Unknown opcode");
    124     case PTXISD::COPY_ADDRESS:
    125       return "PTXISD::COPY_ADDRESS";
    126     case PTXISD::LOAD_PARAM:
    127       return "PTXISD::LOAD_PARAM";
    128     case PTXISD::STORE_PARAM:
    129       return "PTXISD::STORE_PARAM";
    130     case PTXISD::READ_PARAM:
    131       return "PTXISD::READ_PARAM";
    132     case PTXISD::WRITE_PARAM:
    133       return "PTXISD::WRITE_PARAM";
    134     case PTXISD::EXIT:
    135       return "PTXISD::EXIT";
    136     case PTXISD::RET:
    137       return "PTXISD::RET";
    138     case PTXISD::CALL:
    139       return "PTXISD::CALL";
    140   }
    141 }
    142 
    143 //===----------------------------------------------------------------------===//
    144 //                      Custom Lower Operation
    145 //===----------------------------------------------------------------------===//
    146 
    147 SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
    148   assert(Op.getValueType() == MVT::i1 && "SetCC type must be 1-bit integer");
    149   SDValue Op0 = Op.getOperand(0);
    150   SDValue Op1 = Op.getOperand(1);
    151   SDValue Op2 = Op.getOperand(2);
    152   DebugLoc dl = Op.getDebugLoc();
    153   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
    154 
    155   // Look for X == 0, X == 1, X != 0, or X != 1
    156   // We can simplify these to bitwise logic
    157 
    158   if (Op1.getOpcode() == ISD::Constant &&
    159       (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 ||
    160        cast<ConstantSDNode>(Op1)->isNullValue()) &&
    161       (CC == ISD::SETEQ || CC == ISD::SETNE)) {
    162 
    163     return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1);
    164   }
    165 
    166   return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2);
    167 }
    168 
    169 SDValue PTXTargetLowering::
    170 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
    171   EVT PtrVT = getPointerTy();
    172   DebugLoc dl = Op.getDebugLoc();
    173   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
    174 
    175   assert(PtrVT.isSimple() && "Pointer must be to primitive type.");
    176 
    177   SDValue targetGlobal = DAG.getTargetGlobalAddress(GV, dl, PtrVT);
    178   SDValue movInstr = DAG.getNode(PTXISD::COPY_ADDRESS,
    179                                  dl,
    180                                  PtrVT.getSimpleVT(),
    181                                  targetGlobal);
    182 
    183   return movInstr;
    184 }
    185 
    186 //===----------------------------------------------------------------------===//
    187 //                      Calling Convention Implementation
    188 //===----------------------------------------------------------------------===//
    189 
    190 SDValue PTXTargetLowering::
    191   LowerFormalArguments(SDValue Chain,
    192                        CallingConv::ID CallConv,
    193                        bool isVarArg,
    194                        const SmallVectorImpl<ISD::InputArg> &Ins,
    195                        DebugLoc dl,
    196                        SelectionDAG &DAG,
    197                        SmallVectorImpl<SDValue> &InVals) const {
    198   if (isVarArg) llvm_unreachable("PTX does not support varargs");
    199 
    200   MachineFunction &MF = DAG.getMachineFunction();
    201   const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
    202   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
    203   PTXParamManager &PM = MFI->getParamManager();
    204 
    205   switch (CallConv) {
    206     default:
    207       llvm_unreachable("Unsupported calling convention");
    208       break;
    209     case CallingConv::PTX_Kernel:
    210       MFI->setKernel(true);
    211       break;
    212     case CallingConv::PTX_Device:
    213       MFI->setKernel(false);
    214       break;
    215   }
    216 
    217   // We do one of two things here:
    218   // IsKernel || SM >= 2.0  ->  Use param space for arguments
    219   // SM < 2.0               ->  Use registers for arguments
    220   if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) {
    221     // We just need to emit the proper LOAD_PARAM ISDs
    222     for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
    223       assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
    224              "Kernels cannot take pred operands");
    225 
    226       unsigned ParamSize = Ins[i].VT.getStoreSizeInBits();
    227       unsigned Param = PM.addArgumentParam(ParamSize);
    228       const std::string &ParamName = PM.getParamName(Param);
    229       SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
    230                                                        MVT::Other);
    231       SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
    232                                      ParamValue);
    233       InVals.push_back(ArgValue);
    234     }
    235   }
    236   else {
    237     for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
    238       EVT                  RegVT = Ins[i].VT;
    239       TargetRegisterClass* TRC   = getRegClassFor(RegVT);
    240 
    241       // Use a unique index in the instruction to prevent instruction folding.
    242       // Yes, this is a hack.
    243       SDValue Index = DAG.getTargetConstant(i, MVT::i32);
    244       unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
    245       SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain,
    246                                      Index);
    247 
    248       InVals.push_back(ArgValue);
    249 
    250       MFI->addArgReg(Reg);
    251     }
    252   }
    253 
    254   return Chain;
    255 }
    256 
    257 SDValue PTXTargetLowering::
    258   LowerReturn(SDValue Chain,
    259               CallingConv::ID CallConv,
    260               bool isVarArg,
    261               const SmallVectorImpl<ISD::OutputArg> &Outs,
    262               const SmallVectorImpl<SDValue> &OutVals,
    263               DebugLoc dl,
    264               SelectionDAG &DAG) const {
    265   if (isVarArg) llvm_unreachable("PTX does not support varargs");
    266 
    267   switch (CallConv) {
    268     default:
    269       llvm_unreachable("Unsupported calling convention.");
    270     case CallingConv::PTX_Kernel:
    271       assert(Outs.size() == 0 && "Kernel must return void.");
    272       return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
    273     case CallingConv::PTX_Device:
    274       assert(Outs.size() <= 1 && "Can at most return one value.");
    275       break;
    276   }
    277 
    278   MachineFunction& MF = DAG.getMachineFunction();
    279   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
    280   PTXParamManager &PM = MFI->getParamManager();
    281 
    282   SDValue Flag;
    283   const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
    284 
    285   if (ST.useParamSpaceForDeviceArgs()) {
    286     assert(Outs.size() < 2 && "Device functions can return at most one value");
    287 
    288     if (Outs.size() == 1) {
    289       unsigned ParamSize = OutVals[0].getValueType().getSizeInBits();
    290       unsigned Param = PM.addReturnParam(ParamSize);
    291       const std::string &ParamName = PM.getParamName(Param);
    292       SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
    293                                                        MVT::Other);
    294       Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
    295                           ParamValue, OutVals[0]);
    296     }
    297   } else {
    298     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
    299       EVT                  RegVT = Outs[i].VT;
    300       TargetRegisterClass* TRC = 0;
    301 
    302       // Determine which register class we need
    303       if (RegVT == MVT::i1) {
    304         TRC = PTX::RegPredRegisterClass;
    305       }
    306       else if (RegVT == MVT::i16) {
    307         TRC = PTX::RegI16RegisterClass;
    308       }
    309       else if (RegVT == MVT::i32) {
    310         TRC = PTX::RegI32RegisterClass;
    311       }
    312       else if (RegVT == MVT::i64) {
    313         TRC = PTX::RegI64RegisterClass;
    314       }
    315       else if (RegVT == MVT::f32) {
    316         TRC = PTX::RegF32RegisterClass;
    317       }
    318       else if (RegVT == MVT::f64) {
    319         TRC = PTX::RegF64RegisterClass;
    320       }
    321       else {
    322         llvm_unreachable("Unknown parameter type");
    323       }
    324 
    325       unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
    326 
    327       SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
    328       SDValue OutReg = DAG.getRegister(Reg, RegVT);
    329 
    330       Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg);
    331 
    332       MFI->addRetReg(Reg);
    333     }
    334   }
    335 
    336   if (Flag.getNode() == 0) {
    337     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
    338   }
    339   else {
    340     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
    341   }
    342 }
    343 
    344 SDValue
    345 PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
    346                              CallingConv::ID CallConv, bool isVarArg,
    347                              bool &isTailCall,
    348                              const SmallVectorImpl<ISD::OutputArg> &Outs,
    349                              const SmallVectorImpl<SDValue> &OutVals,
    350                              const SmallVectorImpl<ISD::InputArg> &Ins,
    351                              DebugLoc dl, SelectionDAG &DAG,
    352                              SmallVectorImpl<SDValue> &InVals) const {
    353 
    354   MachineFunction& MF = DAG.getMachineFunction();
    355   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
    356   PTXParamManager &PM = MFI->getParamManager();
    357 
    358   assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
    359          "Calls are not handled for the target device");
    360 
    361   std::vector<SDValue> Ops;
    362   // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
    363   Ops.resize(Outs.size() + Ins.size() + 4);
    364 
    365   Ops[0] = Chain;
    366 
    367   // Identify the callee function
    368   const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
    369   assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
    370          "PTX function calls must be to PTX device functions");
    371   Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
    372   Ops[Ins.size()+2] = Callee;
    373 
    374   // Generate STORE_PARAM nodes for each function argument.  In PTX, function
    375   // arguments are explicitly stored into .param variables and passed as
    376   // arguments. There is no register/stack-based calling convention in PTX.
    377   Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
    378   for (unsigned i = 0; i != OutVals.size(); ++i) {
    379     unsigned Size = OutVals[i].getValueType().getSizeInBits();
    380     unsigned Param = PM.addLocalParam(Size);
    381     const std::string &ParamName = PM.getParamName(Param);
    382     SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
    383                                                      MVT::Other);
    384     Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
    385                         ParamValue, OutVals[i]);
    386     Ops[i+Ins.size()+4] = ParamValue;
    387   }
    388 
    389   std::vector<SDValue> InParams;
    390 
    391   // Generate list of .param variables to hold the return value(s).
    392   Ops[1] = DAG.getTargetConstant(Ins.size(), MVT::i32);
    393   for (unsigned i = 0; i < Ins.size(); ++i) {
    394     unsigned Size = Ins[i].VT.getStoreSizeInBits();
    395     unsigned Param = PM.addLocalParam(Size);
    396     const std::string &ParamName = PM.getParamName(Param);
    397     SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
    398                                                      MVT::Other);
    399     Ops[i+2] = ParamValue;
    400     InParams.push_back(ParamValue);
    401   }
    402 
    403   Ops[0] = Chain;
    404 
    405   // Create the CALL node.
    406   Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size());
    407 
    408   // Create the LOAD_PARAM nodes that retrieve the function return value(s).
    409   for (unsigned i = 0; i < Ins.size(); ++i) {
    410     SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
    411                                InParams[i]);
    412     InVals.push_back(Load);
    413   }
    414 
    415   return Chain;
    416 }
    417 
    418 unsigned PTXTargetLowering::getNumRegisters(LLVMContext &Context, EVT VT) {
    419   // All arguments consist of one "register," regardless of the type.
    420   return 1;
    421 }
    422 
    423