1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements the PTXTargetLowering class. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PTX.h" 15 #include "PTXISelLowering.h" 16 #include "PTXMachineFunctionInfo.h" 17 #include "PTXRegisterInfo.h" 18 #include "PTXSubtarget.h" 19 #include "llvm/Function.h" 20 #include "llvm/Support/ErrorHandling.h" 21 #include "llvm/CodeGen/CallingConvLower.h" 22 #include "llvm/CodeGen/MachineFunction.h" 23 #include "llvm/CodeGen/MachineRegisterInfo.h" 24 #include "llvm/CodeGen/SelectionDAG.h" 25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 29 using namespace llvm; 30 31 //===----------------------------------------------------------------------===// 32 // TargetLowering Implementation 33 //===----------------------------------------------------------------------===// 34 35 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM) 36 : TargetLowering(TM, new TargetLoweringObjectFileELF()) { 37 // Set up the register classes. 38 addRegisterClass(MVT::i1, PTX::RegPredRegisterClass); 39 addRegisterClass(MVT::i16, PTX::RegI16RegisterClass); 40 addRegisterClass(MVT::i32, PTX::RegI32RegisterClass); 41 addRegisterClass(MVT::i64, PTX::RegI64RegisterClass); 42 addRegisterClass(MVT::f32, PTX::RegF32RegisterClass); 43 addRegisterClass(MVT::f64, PTX::RegF64RegisterClass); 44 45 setBooleanContents(ZeroOrOneBooleanContent); 46 setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct? 47 setMinFunctionAlignment(2); 48 49 //////////////////////////////////// 50 /////////// Expansion ////////////// 51 //////////////////////////////////// 52 53 // (any/zero/sign) extload => load + (any/zero/sign) extend 54 55 setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand); 56 setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand); 57 setLoadExtAction(ISD::SEXTLOAD, MVT::i16, Expand); 58 59 // f32 extload => load + fextend 60 61 setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand); 62 63 // f64 truncstore => trunc + store 64 65 setTruncStoreAction(MVT::f64, MVT::f32, Expand); 66 67 // sign_extend_inreg => sign_extend 68 69 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand); 70 71 // br_cc => brcond 72 73 setOperationAction(ISD::BR_CC, MVT::Other, Expand); 74 75 // select_cc => setcc 76 77 setOperationAction(ISD::SELECT_CC, MVT::Other, Expand); 78 setOperationAction(ISD::SELECT_CC, MVT::f32, Expand); 79 setOperationAction(ISD::SELECT_CC, MVT::f64, Expand); 80 81 //////////////////////////////////// 82 //////////// Legal ///////////////// 83 //////////////////////////////////// 84 85 setOperationAction(ISD::ConstantFP, MVT::f32, Legal); 86 setOperationAction(ISD::ConstantFP, MVT::f64, Legal); 87 88 //////////////////////////////////// 89 //////////// Custom //////////////// 90 //////////////////////////////////// 91 92 // customise setcc to use bitwise logic if possible 93 94 setOperationAction(ISD::SETCC, MVT::i1, Custom); 95 96 // customize translation of memory addresses 97 98 setOperationAction(ISD::GlobalAddress, MVT::i32, Custom); 99 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom); 100 101 // Compute derived properties from the register classes 102 computeRegisterProperties(); 103 } 104 105 EVT PTXTargetLowering::getSetCCResultType(EVT VT) const { 106 return MVT::i1; 107 } 108 109 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { 110 switch (Op.getOpcode()) { 111 default: 112 llvm_unreachable("Unimplemented operand"); 113 case ISD::SETCC: 114 return LowerSETCC(Op, DAG); 115 case ISD::GlobalAddress: 116 return LowerGlobalAddress(Op, DAG); 117 } 118 } 119 120 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const { 121 switch (Opcode) { 122 default: 123 llvm_unreachable("Unknown opcode"); 124 case PTXISD::COPY_ADDRESS: 125 return "PTXISD::COPY_ADDRESS"; 126 case PTXISD::LOAD_PARAM: 127 return "PTXISD::LOAD_PARAM"; 128 case PTXISD::STORE_PARAM: 129 return "PTXISD::STORE_PARAM"; 130 case PTXISD::READ_PARAM: 131 return "PTXISD::READ_PARAM"; 132 case PTXISD::WRITE_PARAM: 133 return "PTXISD::WRITE_PARAM"; 134 case PTXISD::EXIT: 135 return "PTXISD::EXIT"; 136 case PTXISD::RET: 137 return "PTXISD::RET"; 138 case PTXISD::CALL: 139 return "PTXISD::CALL"; 140 } 141 } 142 143 //===----------------------------------------------------------------------===// 144 // Custom Lower Operation 145 //===----------------------------------------------------------------------===// 146 147 SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { 148 assert(Op.getValueType() == MVT::i1 && "SetCC type must be 1-bit integer"); 149 SDValue Op0 = Op.getOperand(0); 150 SDValue Op1 = Op.getOperand(1); 151 SDValue Op2 = Op.getOperand(2); 152 DebugLoc dl = Op.getDebugLoc(); 153 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get(); 154 155 // Look for X == 0, X == 1, X != 0, or X != 1 156 // We can simplify these to bitwise logic 157 158 if (Op1.getOpcode() == ISD::Constant && 159 (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 || 160 cast<ConstantSDNode>(Op1)->isNullValue()) && 161 (CC == ISD::SETEQ || CC == ISD::SETNE)) { 162 163 return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1); 164 } 165 166 return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2); 167 } 168 169 SDValue PTXTargetLowering:: 170 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { 171 EVT PtrVT = getPointerTy(); 172 DebugLoc dl = Op.getDebugLoc(); 173 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal(); 174 175 assert(PtrVT.isSimple() && "Pointer must be to primitive type."); 176 177 SDValue targetGlobal = DAG.getTargetGlobalAddress(GV, dl, PtrVT); 178 SDValue movInstr = DAG.getNode(PTXISD::COPY_ADDRESS, 179 dl, 180 PtrVT.getSimpleVT(), 181 targetGlobal); 182 183 return movInstr; 184 } 185 186 //===----------------------------------------------------------------------===// 187 // Calling Convention Implementation 188 //===----------------------------------------------------------------------===// 189 190 SDValue PTXTargetLowering:: 191 LowerFormalArguments(SDValue Chain, 192 CallingConv::ID CallConv, 193 bool isVarArg, 194 const SmallVectorImpl<ISD::InputArg> &Ins, 195 DebugLoc dl, 196 SelectionDAG &DAG, 197 SmallVectorImpl<SDValue> &InVals) const { 198 if (isVarArg) llvm_unreachable("PTX does not support varargs"); 199 200 MachineFunction &MF = DAG.getMachineFunction(); 201 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>(); 202 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); 203 PTXParamManager &PM = MFI->getParamManager(); 204 205 switch (CallConv) { 206 default: 207 llvm_unreachable("Unsupported calling convention"); 208 break; 209 case CallingConv::PTX_Kernel: 210 MFI->setKernel(true); 211 break; 212 case CallingConv::PTX_Device: 213 MFI->setKernel(false); 214 break; 215 } 216 217 // We do one of two things here: 218 // IsKernel || SM >= 2.0 -> Use param space for arguments 219 // SM < 2.0 -> Use registers for arguments 220 if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) { 221 // We just need to emit the proper LOAD_PARAM ISDs 222 for (unsigned i = 0, e = Ins.size(); i != e; ++i) { 223 assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) && 224 "Kernels cannot take pred operands"); 225 226 unsigned ParamSize = Ins[i].VT.getStoreSizeInBits(); 227 unsigned Param = PM.addArgumentParam(ParamSize); 228 const std::string &ParamName = PM.getParamName(Param); 229 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), 230 MVT::Other); 231 SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain, 232 ParamValue); 233 InVals.push_back(ArgValue); 234 } 235 } 236 else { 237 for (unsigned i = 0, e = Ins.size(); i != e; ++i) { 238 EVT RegVT = Ins[i].VT; 239 TargetRegisterClass* TRC = getRegClassFor(RegVT); 240 241 // Use a unique index in the instruction to prevent instruction folding. 242 // Yes, this is a hack. 243 SDValue Index = DAG.getTargetConstant(i, MVT::i32); 244 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC); 245 SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain, 246 Index); 247 248 InVals.push_back(ArgValue); 249 250 MFI->addArgReg(Reg); 251 } 252 } 253 254 return Chain; 255 } 256 257 SDValue PTXTargetLowering:: 258 LowerReturn(SDValue Chain, 259 CallingConv::ID CallConv, 260 bool isVarArg, 261 const SmallVectorImpl<ISD::OutputArg> &Outs, 262 const SmallVectorImpl<SDValue> &OutVals, 263 DebugLoc dl, 264 SelectionDAG &DAG) const { 265 if (isVarArg) llvm_unreachable("PTX does not support varargs"); 266 267 switch (CallConv) { 268 default: 269 llvm_unreachable("Unsupported calling convention."); 270 case CallingConv::PTX_Kernel: 271 assert(Outs.size() == 0 && "Kernel must return void."); 272 return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain); 273 case CallingConv::PTX_Device: 274 assert(Outs.size() <= 1 && "Can at most return one value."); 275 break; 276 } 277 278 MachineFunction& MF = DAG.getMachineFunction(); 279 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); 280 PTXParamManager &PM = MFI->getParamManager(); 281 282 SDValue Flag; 283 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>(); 284 285 if (ST.useParamSpaceForDeviceArgs()) { 286 assert(Outs.size() < 2 && "Device functions can return at most one value"); 287 288 if (Outs.size() == 1) { 289 unsigned ParamSize = OutVals[0].getValueType().getSizeInBits(); 290 unsigned Param = PM.addReturnParam(ParamSize); 291 const std::string &ParamName = PM.getParamName(Param); 292 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), 293 MVT::Other); 294 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, 295 ParamValue, OutVals[0]); 296 } 297 } else { 298 for (unsigned i = 0, e = Outs.size(); i != e; ++i) { 299 EVT RegVT = Outs[i].VT; 300 TargetRegisterClass* TRC = 0; 301 302 // Determine which register class we need 303 if (RegVT == MVT::i1) { 304 TRC = PTX::RegPredRegisterClass; 305 } 306 else if (RegVT == MVT::i16) { 307 TRC = PTX::RegI16RegisterClass; 308 } 309 else if (RegVT == MVT::i32) { 310 TRC = PTX::RegI32RegisterClass; 311 } 312 else if (RegVT == MVT::i64) { 313 TRC = PTX::RegI64RegisterClass; 314 } 315 else if (RegVT == MVT::f32) { 316 TRC = PTX::RegF32RegisterClass; 317 } 318 else if (RegVT == MVT::f64) { 319 TRC = PTX::RegF64RegisterClass; 320 } 321 else { 322 llvm_unreachable("Unknown parameter type"); 323 } 324 325 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC); 326 327 SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/); 328 SDValue OutReg = DAG.getRegister(Reg, RegVT); 329 330 Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg); 331 332 MFI->addRetReg(Reg); 333 } 334 } 335 336 if (Flag.getNode() == 0) { 337 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain); 338 } 339 else { 340 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag); 341 } 342 } 343 344 SDValue 345 PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, 346 CallingConv::ID CallConv, bool isVarArg, 347 bool &isTailCall, 348 const SmallVectorImpl<ISD::OutputArg> &Outs, 349 const SmallVectorImpl<SDValue> &OutVals, 350 const SmallVectorImpl<ISD::InputArg> &Ins, 351 DebugLoc dl, SelectionDAG &DAG, 352 SmallVectorImpl<SDValue> &InVals) const { 353 354 MachineFunction& MF = DAG.getMachineFunction(); 355 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); 356 PTXParamManager &PM = MFI->getParamManager(); 357 358 assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() && 359 "Calls are not handled for the target device"); 360 361 std::vector<SDValue> Ops; 362 // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs] 363 Ops.resize(Outs.size() + Ins.size() + 4); 364 365 Ops[0] = Chain; 366 367 // Identify the callee function 368 const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal(); 369 assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device && 370 "PTX function calls must be to PTX device functions"); 371 Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); 372 Ops[Ins.size()+2] = Callee; 373 374 // Generate STORE_PARAM nodes for each function argument. In PTX, function 375 // arguments are explicitly stored into .param variables and passed as 376 // arguments. There is no register/stack-based calling convention in PTX. 377 Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32); 378 for (unsigned i = 0; i != OutVals.size(); ++i) { 379 unsigned Size = OutVals[i].getValueType().getSizeInBits(); 380 unsigned Param = PM.addLocalParam(Size); 381 const std::string &ParamName = PM.getParamName(Param); 382 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), 383 MVT::Other); 384 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, 385 ParamValue, OutVals[i]); 386 Ops[i+Ins.size()+4] = ParamValue; 387 } 388 389 std::vector<SDValue> InParams; 390 391 // Generate list of .param variables to hold the return value(s). 392 Ops[1] = DAG.getTargetConstant(Ins.size(), MVT::i32); 393 for (unsigned i = 0; i < Ins.size(); ++i) { 394 unsigned Size = Ins[i].VT.getStoreSizeInBits(); 395 unsigned Param = PM.addLocalParam(Size); 396 const std::string &ParamName = PM.getParamName(Param); 397 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), 398 MVT::Other); 399 Ops[i+2] = ParamValue; 400 InParams.push_back(ParamValue); 401 } 402 403 Ops[0] = Chain; 404 405 // Create the CALL node. 406 Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size()); 407 408 // Create the LOAD_PARAM nodes that retrieve the function return value(s). 409 for (unsigned i = 0; i < Ins.size(); ++i) { 410 SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain, 411 InParams[i]); 412 InVals.push_back(Load); 413 } 414 415 return Chain; 416 } 417 418 unsigned PTXTargetLowering::getNumRegisters(LLVMContext &Context, EVT VT) { 419 // All arguments consist of one "register," regardless of the type. 420 return 1; 421 } 422 423