1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements the PTXTargetLowering class. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PTXISelLowering.h" 15 #include "PTX.h" 16 #include "PTXMachineFunctionInfo.h" 17 #include "PTXRegisterInfo.h" 18 #include "PTXSubtarget.h" 19 #include "llvm/Function.h" 20 #include "llvm/Support/ErrorHandling.h" 21 #include "llvm/CodeGen/CallingConvLower.h" 22 #include "llvm/CodeGen/MachineFunction.h" 23 #include "llvm/CodeGen/MachineFrameInfo.h" 24 #include "llvm/CodeGen/MachineRegisterInfo.h" 25 #include "llvm/CodeGen/SelectionDAG.h" 26 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 using namespace llvm; 31 32 //===----------------------------------------------------------------------===// 33 // TargetLowering Implementation 34 //===----------------------------------------------------------------------===// 35 36 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM) 37 : TargetLowering(TM, new TargetLoweringObjectFileELF()) { 38 // Set up the register classes. 39 addRegisterClass(MVT::i1, PTX::RegPredRegisterClass); 40 addRegisterClass(MVT::i16, PTX::RegI16RegisterClass); 41 addRegisterClass(MVT::i32, PTX::RegI32RegisterClass); 42 addRegisterClass(MVT::i64, PTX::RegI64RegisterClass); 43 addRegisterClass(MVT::f32, PTX::RegF32RegisterClass); 44 addRegisterClass(MVT::f64, PTX::RegF64RegisterClass); 45 46 setBooleanContents(ZeroOrOneBooleanContent); 47 setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct? 48 setMinFunctionAlignment(2); 49 50 // Let LLVM use loads/stores for all mem* operations 51 maxStoresPerMemcpy = 4096; 52 maxStoresPerMemmove = 4096; 53 maxStoresPerMemset = 4096; 54 55 //////////////////////////////////// 56 /////////// Expansion ////////////// 57 //////////////////////////////////// 58 59 // (any/zero/sign) extload => load + (any/zero/sign) extend 60 61 setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand); 62 setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand); 63 setLoadExtAction(ISD::SEXTLOAD, MVT::i16, Expand); 64 65 // f32 extload => load + fextend 66 67 setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand); 68 69 // f64 truncstore => trunc + store 70 71 setTruncStoreAction(MVT::f64, MVT::f32, Expand); 72 73 // sign_extend_inreg => sign_extend 74 75 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand); 76 77 // br_cc => brcond 78 79 setOperationAction(ISD::BR_CC, MVT::Other, Expand); 80 81 // select_cc => setcc 82 83 setOperationAction(ISD::SELECT_CC, MVT::Other, Expand); 84 setOperationAction(ISD::SELECT_CC, MVT::f32, Expand); 85 setOperationAction(ISD::SELECT_CC, MVT::f64, Expand); 86 87 //////////////////////////////////// 88 //////////// Legal ///////////////// 89 //////////////////////////////////// 90 91 setOperationAction(ISD::ConstantFP, MVT::f32, Legal); 92 setOperationAction(ISD::ConstantFP, MVT::f64, Legal); 93 94 //////////////////////////////////// 95 //////////// Custom //////////////// 96 //////////////////////////////////// 97 98 // customise setcc to use bitwise logic if possible 99 100 //setOperationAction(ISD::SETCC, MVT::i1, Custom); 101 setOperationAction(ISD::SETCC, MVT::i1, Legal); 102 103 // customize translation of memory addresses 104 105 setOperationAction(ISD::GlobalAddress, MVT::i32, Custom); 106 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom); 107 108 // Compute derived properties from the register classes 109 computeRegisterProperties(); 110 } 111 112 EVT PTXTargetLowering::getSetCCResultType(EVT VT) const { 113 return MVT::i1; 114 } 115 116 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { 117 switch (Op.getOpcode()) { 118 default: 119 llvm_unreachable("Unimplemented operand"); 120 case ISD::SETCC: 121 return LowerSETCC(Op, DAG); 122 case ISD::GlobalAddress: 123 return LowerGlobalAddress(Op, DAG); 124 } 125 } 126 127 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const { 128 switch (Opcode) { 129 default: 130 llvm_unreachable("Unknown opcode"); 131 case PTXISD::COPY_ADDRESS: 132 return "PTXISD::COPY_ADDRESS"; 133 case PTXISD::LOAD_PARAM: 134 return "PTXISD::LOAD_PARAM"; 135 case PTXISD::STORE_PARAM: 136 return "PTXISD::STORE_PARAM"; 137 case PTXISD::READ_PARAM: 138 return "PTXISD::READ_PARAM"; 139 case PTXISD::WRITE_PARAM: 140 return "PTXISD::WRITE_PARAM"; 141 case PTXISD::EXIT: 142 return "PTXISD::EXIT"; 143 case PTXISD::RET: 144 return "PTXISD::RET"; 145 case PTXISD::CALL: 146 return "PTXISD::CALL"; 147 } 148 } 149 150 //===----------------------------------------------------------------------===// 151 // Custom Lower Operation 152 //===----------------------------------------------------------------------===// 153 154 SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { 155 assert(Op.getValueType() == MVT::i1 && "SetCC type must be 1-bit integer"); 156 SDValue Op0 = Op.getOperand(0); 157 SDValue Op1 = Op.getOperand(1); 158 SDValue Op2 = Op.getOperand(2); 159 DebugLoc dl = Op.getDebugLoc(); 160 //ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get(); 161 162 // Look for X == 0, X == 1, X != 0, or X != 1 163 // We can simplify these to bitwise logic 164 165 //if (Op1.getOpcode() == ISD::Constant && 166 // (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 || 167 // cast<ConstantSDNode>(Op1)->isNullValue()) && 168 // (CC == ISD::SETEQ || CC == ISD::SETNE)) { 169 // 170 // return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1); 171 //} 172 173 //ConstantSDNode* COp1 = cast<ConstantSDNode>(Op1); 174 //if(COp1 && COp1->getZExtValue() == 1) { 175 // if(CC == ISD::SETNE) { 176 // return DAG.getNode(PTX::XORripreds, dl, MVT::i1, Op0); 177 // } 178 //} 179 180 llvm_unreachable("setcc was not matched by a pattern!"); 181 182 return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2); 183 } 184 185 SDValue PTXTargetLowering:: 186 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { 187 EVT PtrVT = getPointerTy(); 188 DebugLoc dl = Op.getDebugLoc(); 189 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal(); 190 191 assert(PtrVT.isSimple() && "Pointer must be to primitive type."); 192 193 SDValue targetGlobal = DAG.getTargetGlobalAddress(GV, dl, PtrVT); 194 SDValue movInstr = DAG.getNode(PTXISD::COPY_ADDRESS, 195 dl, 196 PtrVT.getSimpleVT(), 197 targetGlobal); 198 199 return movInstr; 200 } 201 202 //===----------------------------------------------------------------------===// 203 // Calling Convention Implementation 204 //===----------------------------------------------------------------------===// 205 206 SDValue PTXTargetLowering:: 207 LowerFormalArguments(SDValue Chain, 208 CallingConv::ID CallConv, 209 bool isVarArg, 210 const SmallVectorImpl<ISD::InputArg> &Ins, 211 DebugLoc dl, 212 SelectionDAG &DAG, 213 SmallVectorImpl<SDValue> &InVals) const { 214 if (isVarArg) llvm_unreachable("PTX does not support varargs"); 215 216 MachineFunction &MF = DAG.getMachineFunction(); 217 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>(); 218 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); 219 PTXParamManager &PM = MFI->getParamManager(); 220 221 switch (CallConv) { 222 default: 223 llvm_unreachable("Unsupported calling convention"); 224 case CallingConv::PTX_Kernel: 225 MFI->setKernel(true); 226 break; 227 case CallingConv::PTX_Device: 228 MFI->setKernel(false); 229 break; 230 } 231 232 // We do one of two things here: 233 // IsKernel || SM >= 2.0 -> Use param space for arguments 234 // SM < 2.0 -> Use registers for arguments 235 if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) { 236 // We just need to emit the proper LOAD_PARAM ISDs 237 for (unsigned i = 0, e = Ins.size(); i != e; ++i) { 238 assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) && 239 "Kernels cannot take pred operands"); 240 241 unsigned ParamSize = Ins[i].VT.getStoreSizeInBits(); 242 unsigned Param = PM.addArgumentParam(ParamSize); 243 const std::string &ParamName = PM.getParamName(Param); 244 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), 245 MVT::Other); 246 SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain, 247 ParamValue); 248 InVals.push_back(ArgValue); 249 } 250 } 251 else { 252 for (unsigned i = 0, e = Ins.size(); i != e; ++i) { 253 EVT RegVT = Ins[i].VT; 254 const TargetRegisterClass* TRC = getRegClassFor(RegVT); 255 unsigned RegType; 256 257 // Determine which register class we need 258 if (RegVT == MVT::i1) 259 RegType = PTXRegisterType::Pred; 260 else if (RegVT == MVT::i16) 261 RegType = PTXRegisterType::B16; 262 else if (RegVT == MVT::i32) 263 RegType = PTXRegisterType::B32; 264 else if (RegVT == MVT::i64) 265 RegType = PTXRegisterType::B64; 266 else if (RegVT == MVT::f32) 267 RegType = PTXRegisterType::F32; 268 else if (RegVT == MVT::f64) 269 RegType = PTXRegisterType::F64; 270 else 271 llvm_unreachable("Unknown parameter type"); 272 273 // Use a unique index in the instruction to prevent instruction folding. 274 // Yes, this is a hack. 275 SDValue Index = DAG.getTargetConstant(i, MVT::i32); 276 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC); 277 SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain, 278 Index); 279 280 InVals.push_back(ArgValue); 281 282 MFI->addRegister(Reg, RegType, PTXRegisterSpace::Argument); 283 } 284 } 285 286 return Chain; 287 } 288 289 SDValue PTXTargetLowering:: 290 LowerReturn(SDValue Chain, 291 CallingConv::ID CallConv, 292 bool isVarArg, 293 const SmallVectorImpl<ISD::OutputArg> &Outs, 294 const SmallVectorImpl<SDValue> &OutVals, 295 DebugLoc dl, 296 SelectionDAG &DAG) const { 297 if (isVarArg) llvm_unreachable("PTX does not support varargs"); 298 299 switch (CallConv) { 300 default: 301 llvm_unreachable("Unsupported calling convention."); 302 case CallingConv::PTX_Kernel: 303 assert(Outs.size() == 0 && "Kernel must return void."); 304 return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain); 305 case CallingConv::PTX_Device: 306 assert(Outs.size() <= 1 && "Can at most return one value."); 307 break; 308 } 309 310 MachineFunction& MF = DAG.getMachineFunction(); 311 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); 312 PTXParamManager &PM = MFI->getParamManager(); 313 314 SDValue Flag; 315 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>(); 316 317 if (ST.useParamSpaceForDeviceArgs()) { 318 assert(Outs.size() < 2 && "Device functions can return at most one value"); 319 320 if (Outs.size() == 1) { 321 unsigned ParamSize = OutVals[0].getValueType().getSizeInBits(); 322 unsigned Param = PM.addReturnParam(ParamSize); 323 const std::string &ParamName = PM.getParamName(Param); 324 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), 325 MVT::Other); 326 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, 327 ParamValue, OutVals[0]); 328 } 329 } else { 330 for (unsigned i = 0, e = Outs.size(); i != e; ++i) { 331 EVT RegVT = Outs[i].VT; 332 const TargetRegisterClass* TRC; 333 unsigned RegType; 334 335 // Determine which register class we need 336 if (RegVT == MVT::i1) { 337 TRC = PTX::RegPredRegisterClass; 338 RegType = PTXRegisterType::Pred; 339 } 340 else if (RegVT == MVT::i16) { 341 TRC = PTX::RegI16RegisterClass; 342 RegType = PTXRegisterType::B16; 343 } 344 else if (RegVT == MVT::i32) { 345 TRC = PTX::RegI32RegisterClass; 346 RegType = PTXRegisterType::B32; 347 } 348 else if (RegVT == MVT::i64) { 349 TRC = PTX::RegI64RegisterClass; 350 RegType = PTXRegisterType::B64; 351 } 352 else if (RegVT == MVT::f32) { 353 TRC = PTX::RegF32RegisterClass; 354 RegType = PTXRegisterType::F32; 355 } 356 else if (RegVT == MVT::f64) { 357 TRC = PTX::RegF64RegisterClass; 358 RegType = PTXRegisterType::F64; 359 } 360 else { 361 llvm_unreachable("Unknown parameter type"); 362 } 363 364 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC); 365 366 SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/); 367 SDValue OutReg = DAG.getRegister(Reg, RegVT); 368 369 Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg); 370 371 MFI->addRegister(Reg, RegType, PTXRegisterSpace::Return); 372 } 373 } 374 375 if (Flag.getNode() == 0) { 376 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain); 377 } 378 else { 379 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag); 380 } 381 } 382 383 SDValue 384 PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, 385 CallingConv::ID CallConv, bool isVarArg, 386 bool doesNotRet, bool &isTailCall, 387 const SmallVectorImpl<ISD::OutputArg> &Outs, 388 const SmallVectorImpl<SDValue> &OutVals, 389 const SmallVectorImpl<ISD::InputArg> &Ins, 390 DebugLoc dl, SelectionDAG &DAG, 391 SmallVectorImpl<SDValue> &InVals) const { 392 393 MachineFunction& MF = DAG.getMachineFunction(); 394 PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>(); 395 PTXParamManager &PM = PTXMFI->getParamManager(); 396 MachineFrameInfo *MFI = MF.getFrameInfo(); 397 398 assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() && 399 "Calls are not handled for the target device"); 400 401 // Identify the callee function 402 const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal(); 403 const Function *function = cast<Function>(GV); 404 405 // allow non-device calls only for printf 406 bool isPrintf = function->getName() == "printf" || function->getName() == "puts"; 407 408 assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) && 409 "PTX function calls must be to PTX device functions"); 410 411 unsigned outSize = isPrintf ? 2 : Outs.size(); 412 413 std::vector<SDValue> Ops; 414 // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs] 415 Ops.resize(outSize + Ins.size() + 4); 416 417 Ops[0] = Chain; 418 419 // Identify the callee function 420 Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); 421 Ops[Ins.size()+2] = Callee; 422 423 // #Outs 424 Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32); 425 426 if (isPrintf) { 427 // first argument is the address of the global string variable in memory 428 unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits()); 429 SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(), 430 MVT::Other); 431 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, 432 ParamValue0, OutVals[0]); 433 Ops[Ins.size()+4] = ParamValue0; 434 435 // alignment is the maximum size of all the arguments 436 unsigned alignment = 0; 437 for (unsigned i = 1; i < OutVals.size(); ++i) { 438 alignment = std::max(alignment, 439 OutVals[i].getValueType().getSizeInBits()); 440 } 441 442 // size is the alignment multiplied by the number of arguments 443 unsigned size = alignment * (OutVals.size() - 1); 444 445 // second argument is the address of the stack object (unless no arguments) 446 unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits()); 447 SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(), 448 MVT::Other); 449 Ops[Ins.size()+5] = ParamValue1; 450 451 if (size > 0) 452 { 453 // create a local stack object to store the arguments 454 unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false); 455 SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy()); 456 457 // store each of the arguments to the stack in turn 458 for (unsigned int i = 1; i != OutVals.size(); i++) { 459 SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy())); 460 Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr, 461 MachinePointerInfo(), 462 false, false, 0); 463 } 464 465 // copy the address of the local frame index to get the address in non-local space 466 SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex); 467 468 // store this address in the second argument 469 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr); 470 } 471 } 472 else 473 { 474 // Generate STORE_PARAM nodes for each function argument. In PTX, function 475 // arguments are explicitly stored into .param variables and passed as 476 // arguments. There is no register/stack-based calling convention in PTX. 477 for (unsigned i = 0; i != OutVals.size(); ++i) { 478 unsigned Size = OutVals[i].getValueType().getSizeInBits(); 479 unsigned Param = PM.addLocalParam(Size); 480 const std::string &ParamName = PM.getParamName(Param); 481 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), 482 MVT::Other); 483 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, 484 ParamValue, OutVals[i]); 485 Ops[i+Ins.size()+4] = ParamValue; 486 } 487 } 488 489 std::vector<SDValue> InParams; 490 491 // Generate list of .param variables to hold the return value(s). 492 Ops[1] = DAG.getTargetConstant(Ins.size(), MVT::i32); 493 for (unsigned i = 0; i < Ins.size(); ++i) { 494 unsigned Size = Ins[i].VT.getStoreSizeInBits(); 495 unsigned Param = PM.addLocalParam(Size); 496 const std::string &ParamName = PM.getParamName(Param); 497 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), 498 MVT::Other); 499 Ops[i+2] = ParamValue; 500 InParams.push_back(ParamValue); 501 } 502 503 Ops[0] = Chain; 504 505 // Create the CALL node. 506 Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size()); 507 508 // Create the LOAD_PARAM nodes that retrieve the function return value(s). 509 for (unsigned i = 0; i < Ins.size(); ++i) { 510 SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain, 511 InParams[i]); 512 InVals.push_back(Load); 513 } 514 515 return Chain; 516 } 517 518 unsigned PTXTargetLowering::getNumRegisters(LLVMContext &Context, EVT VT) { 519 // All arguments consist of one "register," regardless of the type. 520 return 1; 521 } 522 523