Home | History | Annotate | Download | only in NVPTX
      1 //
      2 //                     The LLVM Compiler Infrastructure
      3 //
      4 // This file is distributed under the University of Illinois Open Source
      5 // License. See LICENSE.TXT for details.
      6 //
      7 //===----------------------------------------------------------------------===//
      8 //
      9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
     10 // selection DAG.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 
     15 #include "NVPTX.h"
     16 #include "NVPTXISelLowering.h"
     17 #include "NVPTXTargetMachine.h"
     18 #include "NVPTXTargetObjectFile.h"
     19 #include "NVPTXUtilities.h"
     20 #include "llvm/Intrinsics.h"
     21 #include "llvm/IntrinsicInst.h"
     22 #include "llvm/Support/CommandLine.h"
     23 #include "llvm/DerivedTypes.h"
     24 #include "llvm/GlobalValue.h"
     25 #include "llvm/Module.h"
     26 #include "llvm/Function.h"
     27 #include "llvm/CodeGen/Analysis.h"
     28 #include "llvm/CodeGen/MachineFrameInfo.h"
     29 #include "llvm/CodeGen/MachineFunction.h"
     30 #include "llvm/CodeGen/MachineInstrBuilder.h"
     31 #include "llvm/CodeGen/MachineRegisterInfo.h"
     32 #include "llvm/Support/CallSite.h"
     33 #include "llvm/Support/ErrorHandling.h"
     34 #include "llvm/Support/Debug.h"
     35 #include "llvm/Support/raw_ostream.h"
     36 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
     37 #include "llvm/MC/MCSectionELF.h"
     38 #include <sstream>
     39 
     40 #undef DEBUG_TYPE
     41 #define DEBUG_TYPE "nvptx-lower"
     42 
     43 using namespace llvm;
     44 
     45 static unsigned int uniqueCallSite = 0;
     46 
     47 static cl::opt<bool>
     48 RetainVectorOperands("nvptx-codegen-vectors",
     49      cl::desc("NVPTX Specific: Retain LLVM's vectors and generate PTX vectors"),
     50                      cl::init(true));
     51 
     52 static cl::opt<bool>
     53 sched4reg("nvptx-sched4reg",
     54           cl::desc("NVPTX Specific: schedule for register pressue"),
     55           cl::init(false));
     56 
     57 // NVPTXTargetLowering Constructor.
     58 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
     59 : TargetLowering(TM, new NVPTXTargetObjectFile()),
     60   nvTM(&TM),
     61   nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
     62 
     63   // always lower memset, memcpy, and memmove intrinsics to load/store
     64   // instructions, rather
     65   // then generating calls to memset, mempcy or memmove.
     66   maxStoresPerMemset = (unsigned)0xFFFFFFFF;
     67   maxStoresPerMemcpy = (unsigned)0xFFFFFFFF;
     68   maxStoresPerMemmove = (unsigned)0xFFFFFFFF;
     69 
     70   setBooleanContents(ZeroOrNegativeOneBooleanContent);
     71 
     72   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
     73   // condition branches.
     74   setJumpIsExpensive(true);
     75 
     76   // By default, use the Source scheduling
     77   if (sched4reg)
     78     setSchedulingPreference(Sched::RegPressure);
     79   else
     80     setSchedulingPreference(Sched::Source);
     81 
     82   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
     83   addRegisterClass(MVT::i8, &NVPTX::Int8RegsRegClass);
     84   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
     85   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
     86   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
     87   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
     88   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
     89 
     90   if (RetainVectorOperands) {
     91     addRegisterClass(MVT::v2f32, &NVPTX::V2F32RegsRegClass);
     92     addRegisterClass(MVT::v4f32, &NVPTX::V4F32RegsRegClass);
     93     addRegisterClass(MVT::v2i32, &NVPTX::V2I32RegsRegClass);
     94     addRegisterClass(MVT::v4i32, &NVPTX::V4I32RegsRegClass);
     95     addRegisterClass(MVT::v2f64, &NVPTX::V2F64RegsRegClass);
     96     addRegisterClass(MVT::v2i64, &NVPTX::V2I64RegsRegClass);
     97     addRegisterClass(MVT::v2i16, &NVPTX::V2I16RegsRegClass);
     98     addRegisterClass(MVT::v4i16, &NVPTX::V4I16RegsRegClass);
     99     addRegisterClass(MVT::v2i8, &NVPTX::V2I8RegsRegClass);
    100     addRegisterClass(MVT::v4i8, &NVPTX::V4I8RegsRegClass);
    101 
    102     setOperationAction(ISD::BUILD_VECTOR, MVT::v4i32  , Custom);
    103     setOperationAction(ISD::BUILD_VECTOR, MVT::v4f32  , Custom);
    104     setOperationAction(ISD::BUILD_VECTOR, MVT::v4i16  , Custom);
    105     setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8   , Custom);
    106     setOperationAction(ISD::BUILD_VECTOR, MVT::v2i64  , Custom);
    107     setOperationAction(ISD::BUILD_VECTOR, MVT::v2f64  , Custom);
    108     setOperationAction(ISD::BUILD_VECTOR, MVT::v2i32  , Custom);
    109     setOperationAction(ISD::BUILD_VECTOR, MVT::v2f32  , Custom);
    110     setOperationAction(ISD::BUILD_VECTOR, MVT::v2i16  , Custom);
    111     setOperationAction(ISD::BUILD_VECTOR, MVT::v2i8   , Custom);
    112 
    113     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v4i32  , Custom);
    114     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v4f32  , Custom);
    115     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v4i16  , Custom);
    116     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v4i8   , Custom);
    117     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i64  , Custom);
    118     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2f64  , Custom);
    119     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i32  , Custom);
    120     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2f32  , Custom);
    121     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i16  , Custom);
    122     setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i8   , Custom);
    123   }
    124 
    125   // Operations not directly supported by NVPTX.
    126   setOperationAction(ISD::SELECT_CC,         MVT::Other, Expand);
    127   setOperationAction(ISD::BR_CC,             MVT::Other, Expand);
    128   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
    129   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
    130   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
    131   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Expand);
    132   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1 , Expand);
    133 
    134   if (nvptxSubtarget.hasROT64()) {
    135     setOperationAction(ISD::ROTL , MVT::i64, Legal);
    136     setOperationAction(ISD::ROTR , MVT::i64, Legal);
    137   }
    138   else {
    139     setOperationAction(ISD::ROTL , MVT::i64, Expand);
    140     setOperationAction(ISD::ROTR , MVT::i64, Expand);
    141   }
    142   if (nvptxSubtarget.hasROT32()) {
    143     setOperationAction(ISD::ROTL , MVT::i32, Legal);
    144     setOperationAction(ISD::ROTR , MVT::i32, Legal);
    145   }
    146   else {
    147     setOperationAction(ISD::ROTL , MVT::i32, Expand);
    148     setOperationAction(ISD::ROTR , MVT::i32, Expand);
    149   }
    150 
    151   setOperationAction(ISD::ROTL , MVT::i16, Expand);
    152   setOperationAction(ISD::ROTR , MVT::i16, Expand);
    153   setOperationAction(ISD::ROTL , MVT::i8, Expand);
    154   setOperationAction(ISD::ROTR , MVT::i8, Expand);
    155   setOperationAction(ISD::BSWAP , MVT::i16, Expand);
    156   setOperationAction(ISD::BSWAP , MVT::i32, Expand);
    157   setOperationAction(ISD::BSWAP , MVT::i64, Expand);
    158 
    159   // Indirect branch is not supported.
    160   // This also disables Jump Table creation.
    161   setOperationAction(ISD::BR_JT,             MVT::Other, Expand);
    162   setOperationAction(ISD::BRIND,             MVT::Other, Expand);
    163 
    164   setOperationAction(ISD::GlobalAddress   , MVT::i32  , Custom);
    165   setOperationAction(ISD::GlobalAddress   , MVT::i64  , Custom);
    166 
    167   // We want to legalize constant related memmove and memcopy
    168   // intrinsics.
    169   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
    170 
    171   // Turn FP extload into load/fextend
    172   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
    173   // Turn FP truncstore into trunc + store.
    174   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
    175 
    176   // PTX does not support load / store predicate registers
    177   setOperationAction(ISD::LOAD, MVT::i1, Expand);
    178   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
    179   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
    180   setOperationAction(ISD::STORE, MVT::i1, Expand);
    181   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
    182   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
    183   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
    184   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
    185 
    186   // This is legal in NVPTX
    187   setOperationAction(ISD::ConstantFP,         MVT::f64, Legal);
    188   setOperationAction(ISD::ConstantFP,         MVT::f32, Legal);
    189 
    190   // TRAP can be lowered to PTX trap
    191   setOperationAction(ISD::TRAP,               MVT::Other, Legal);
    192 
    193   // By default, CONCAT_VECTORS is implemented via store/load
    194   // through stack. It is slow and uses local memory. We need
    195   // to custom-lowering them.
    196   setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i32  , Custom);
    197   setOperationAction(ISD::CONCAT_VECTORS, MVT::v4f32  , Custom);
    198   setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i16  , Custom);
    199   setOperationAction(ISD::CONCAT_VECTORS, MVT::v4i8   , Custom);
    200   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2i64  , Custom);
    201   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2f64  , Custom);
    202   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2i32  , Custom);
    203   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2f32  , Custom);
    204   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2i16  , Custom);
    205   setOperationAction(ISD::CONCAT_VECTORS, MVT::v2i8   , Custom);
    206 
    207   // Expand vector int to float and float to int conversions
    208   // - For SINT_TO_FP and UINT_TO_FP, the src type
    209   //   (Node->getOperand(0).getValueType())
    210   //   is used to determine the action, while for FP_TO_UINT and FP_TO_SINT,
    211   //   the dest type (Node->getValueType(0)) is used.
    212   //
    213   //   See VectorLegalizer::LegalizeOp() (LegalizeVectorOps.cpp) for the vector
    214   //   case, and
    215   //   SelectionDAGLegalize::LegalizeOp() (LegalizeDAG.cpp) for the scalar case.
    216   //
    217   //   That is why v4i32 or v2i32 are used here.
    218   //
    219   //   The expansion for vectors happens in VectorLegalizer::LegalizeOp()
    220   //   (LegalizeVectorOps.cpp).
    221   setOperationAction(ISD::SINT_TO_FP, MVT::v4i32, Expand);
    222   setOperationAction(ISD::SINT_TO_FP, MVT::v2i32, Expand);
    223   setOperationAction(ISD::UINT_TO_FP, MVT::v4i32, Expand);
    224   setOperationAction(ISD::UINT_TO_FP, MVT::v2i32, Expand);
    225   setOperationAction(ISD::FP_TO_SINT, MVT::v2i32, Expand);
    226   setOperationAction(ISD::FP_TO_SINT, MVT::v4i32, Expand);
    227   setOperationAction(ISD::FP_TO_UINT, MVT::v2i32, Expand);
    228   setOperationAction(ISD::FP_TO_UINT, MVT::v4i32, Expand);
    229 
    230   // Now deduce the information based on the above mentioned
    231   // actions
    232   computeRegisterProperties();
    233 }
    234 
    235 
    236 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
    237   switch (Opcode) {
    238   default: return 0;
    239   case NVPTXISD::CALL:            return "NVPTXISD::CALL";
    240   case NVPTXISD::RET_FLAG:        return "NVPTXISD::RET_FLAG";
    241   case NVPTXISD::Wrapper:         return "NVPTXISD::Wrapper";
    242   case NVPTXISD::NVBuiltin:       return "NVPTXISD::NVBuiltin";
    243   case NVPTXISD::DeclareParam:    return "NVPTXISD::DeclareParam";
    244   case NVPTXISD::DeclareScalarParam:
    245     return "NVPTXISD::DeclareScalarParam";
    246   case NVPTXISD::DeclareRet:      return "NVPTXISD::DeclareRet";
    247   case NVPTXISD::DeclareRetParam: return "NVPTXISD::DeclareRetParam";
    248   case NVPTXISD::PrintCall:       return "NVPTXISD::PrintCall";
    249   case NVPTXISD::LoadParam:       return "NVPTXISD::LoadParam";
    250   case NVPTXISD::StoreParam:      return "NVPTXISD::StoreParam";
    251   case NVPTXISD::StoreParamS32:   return "NVPTXISD::StoreParamS32";
    252   case NVPTXISD::StoreParamU32:   return "NVPTXISD::StoreParamU32";
    253   case NVPTXISD::MoveToParam:     return "NVPTXISD::MoveToParam";
    254   case NVPTXISD::CallArgBegin:    return "NVPTXISD::CallArgBegin";
    255   case NVPTXISD::CallArg:         return "NVPTXISD::CallArg";
    256   case NVPTXISD::LastCallArg:     return "NVPTXISD::LastCallArg";
    257   case NVPTXISD::CallArgEnd:      return "NVPTXISD::CallArgEnd";
    258   case NVPTXISD::CallVoid:        return "NVPTXISD::CallVoid";
    259   case NVPTXISD::CallVal:         return "NVPTXISD::CallVal";
    260   case NVPTXISD::CallSymbol:      return "NVPTXISD::CallSymbol";
    261   case NVPTXISD::Prototype:       return "NVPTXISD::Prototype";
    262   case NVPTXISD::MoveParam:       return "NVPTXISD::MoveParam";
    263   case NVPTXISD::MoveRetval:      return "NVPTXISD::MoveRetval";
    264   case NVPTXISD::MoveToRetval:    return "NVPTXISD::MoveToRetval";
    265   case NVPTXISD::StoreRetval:     return "NVPTXISD::StoreRetval";
    266   case NVPTXISD::PseudoUseParam:  return "NVPTXISD::PseudoUseParam";
    267   case NVPTXISD::RETURN:          return "NVPTXISD::RETURN";
    268   case NVPTXISD::CallSeqBegin:    return "NVPTXISD::CallSeqBegin";
    269   case NVPTXISD::CallSeqEnd:      return "NVPTXISD::CallSeqEnd";
    270   }
    271 }
    272 
    273 
    274 SDValue
    275 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
    276   DebugLoc dl = Op.getDebugLoc();
    277   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
    278   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
    279   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
    280 }
    281 
    282 std::string NVPTXTargetLowering::getPrototype(Type *retTy,
    283                                               const ArgListTy &Args,
    284                                     const SmallVectorImpl<ISD::OutputArg> &Outs,
    285                                               unsigned retAlignment) const {
    286 
    287   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
    288 
    289   std::stringstream O;
    290   O << "prototype_" << uniqueCallSite << " : .callprototype ";
    291 
    292   if (retTy->getTypeID() == Type::VoidTyID)
    293     O << "()";
    294   else {
    295     O << "(";
    296     if (isABI) {
    297       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
    298         unsigned size = 0;
    299         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
    300           size = ITy->getBitWidth();
    301           if (size < 32) size = 32;
    302         }
    303         else {
    304           assert(retTy->isFloatingPointTy() &&
    305                  "Floating point type expected here");
    306           size = retTy->getPrimitiveSizeInBits();
    307         }
    308 
    309         O << ".param .b" << size << " _";
    310       }
    311       else if (isa<PointerType>(retTy))
    312         O << ".param .b" << getPointerTy().getSizeInBits()
    313         << " _";
    314       else {
    315         if ((retTy->getTypeID() == Type::StructTyID) ||
    316             isa<VectorType>(retTy)) {
    317           SmallVector<EVT, 16> vtparts;
    318           ComputeValueVTs(*this, retTy, vtparts);
    319           unsigned totalsz = 0;
    320           for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
    321             unsigned elems = 1;
    322             EVT elemtype = vtparts[i];
    323             if (vtparts[i].isVector()) {
    324               elems = vtparts[i].getVectorNumElements();
    325               elemtype = vtparts[i].getVectorElementType();
    326             }
    327             for (unsigned j=0, je=elems; j!=je; ++j) {
    328               unsigned sz = elemtype.getSizeInBits();
    329               if (elemtype.isInteger() && (sz < 8)) sz = 8;
    330               totalsz += sz/8;
    331             }
    332           }
    333           O << ".param .align "
    334               << retAlignment
    335               << " .b8 _["
    336               << totalsz << "]";
    337         }
    338         else {
    339           assert(false &&
    340                  "Unknown return type");
    341         }
    342       }
    343     }
    344     else {
    345       SmallVector<EVT, 16> vtparts;
    346       ComputeValueVTs(*this, retTy, vtparts);
    347       unsigned idx = 0;
    348       for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
    349         unsigned elems = 1;
    350         EVT elemtype = vtparts[i];
    351         if (vtparts[i].isVector()) {
    352           elems = vtparts[i].getVectorNumElements();
    353           elemtype = vtparts[i].getVectorElementType();
    354         }
    355 
    356         for (unsigned j=0, je=elems; j!=je; ++j) {
    357           unsigned sz = elemtype.getSizeInBits();
    358           if (elemtype.isInteger() && (sz < 32)) sz = 32;
    359           O << ".reg .b" << sz << " _";
    360           if (j<je-1) O << ", ";
    361           ++idx;
    362         }
    363         if (i < e-1)
    364           O << ", ";
    365       }
    366     }
    367     O << ") ";
    368   }
    369   O << "_ (";
    370 
    371   bool first = true;
    372   MVT thePointerTy = getPointerTy();
    373 
    374   for (unsigned i=0,e=Args.size(); i!=e; ++i) {
    375     const Type *Ty = Args[i].Ty;
    376     if (!first) {
    377       O << ", ";
    378     }
    379     first = false;
    380 
    381     if (Outs[i].Flags.isByVal() == false) {
    382       unsigned sz = 0;
    383       if (isa<IntegerType>(Ty)) {
    384         sz = cast<IntegerType>(Ty)->getBitWidth();
    385         if (sz < 32) sz = 32;
    386       }
    387       else if (isa<PointerType>(Ty))
    388         sz = thePointerTy.getSizeInBits();
    389       else
    390         sz = Ty->getPrimitiveSizeInBits();
    391       if (isABI)
    392         O << ".param .b" << sz << " ";
    393       else
    394         O << ".reg .b" << sz << " ";
    395       O << "_";
    396       continue;
    397     }
    398     const PointerType *PTy = dyn_cast<PointerType>(Ty);
    399     assert(PTy &&
    400            "Param with byval attribute should be a pointer type");
    401     Type *ETy = PTy->getElementType();
    402 
    403     if (isABI) {
    404       unsigned align = Outs[i].Flags.getByValAlign();
    405       unsigned sz = getTargetData()->getTypeAllocSize(ETy);
    406       O << ".param .align " << align
    407           << " .b8 ";
    408       O << "_";
    409       O << "[" << sz << "]";
    410       continue;
    411     }
    412     else {
    413       SmallVector<EVT, 16> vtparts;
    414       ComputeValueVTs(*this, ETy, vtparts);
    415       for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
    416         unsigned elems = 1;
    417         EVT elemtype = vtparts[i];
    418         if (vtparts[i].isVector()) {
    419           elems = vtparts[i].getVectorNumElements();
    420           elemtype = vtparts[i].getVectorElementType();
    421         }
    422 
    423         for (unsigned j=0,je=elems; j!=je; ++j) {
    424           unsigned sz = elemtype.getSizeInBits();
    425           if (elemtype.isInteger() && (sz < 32)) sz = 32;
    426           O << ".reg .b" << sz << " ";
    427           O << "_";
    428           if (j<je-1) O << ", ";
    429         }
    430         if (i<e-1)
    431           O << ", ";
    432       }
    433       continue;
    434     }
    435   }
    436   O << ");";
    437   return O.str();
    438 }
    439 
    440 
    441 SDValue
    442 NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
    443                                SmallVectorImpl<SDValue> &InVals) const {
    444   SelectionDAG &DAG                     = CLI.DAG;
    445   DebugLoc &dl                          = CLI.DL;
    446   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
    447   SmallVector<SDValue, 32> &OutVals     = CLI.OutVals;
    448   SmallVector<ISD::InputArg, 32> &Ins   = CLI.Ins;
    449   SDValue Chain                         = CLI.Chain;
    450   SDValue Callee                        = CLI.Callee;
    451   bool &isTailCall                      = CLI.IsTailCall;
    452   ArgListTy &Args                       = CLI.Args;
    453   Type *retTy                           = CLI.RetTy;
    454   ImmutableCallSite *CS                 = CLI.CS;
    455 
    456   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
    457 
    458   SDValue tempChain = Chain;
    459   Chain = DAG.getCALLSEQ_START(Chain,
    460                                DAG.getIntPtrConstant(uniqueCallSite, true));
    461   SDValue InFlag = Chain.getValue(1);
    462 
    463   assert((Outs.size() == Args.size()) &&
    464          "Unexpected number of arguments to function call");
    465   unsigned paramCount = 0;
    466   // Declare the .params or .reg need to pass values
    467   // to the function
    468   for (unsigned i=0, e=Outs.size(); i!=e; ++i) {
    469     EVT VT = Outs[i].VT;
    470 
    471     if (Outs[i].Flags.isByVal() == false) {
    472       // Plain scalar
    473       // for ABI,    declare .param .b<size> .param<n>;
    474       // for nonABI, declare .reg .b<size> .param<n>;
    475       unsigned isReg = 1;
    476       if (isABI)
    477         isReg = 0;
    478       unsigned sz = VT.getSizeInBits();
    479       if (VT.isInteger() && (sz < 32)) sz = 32;
    480       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    481       SDValue DeclareParamOps[] = { Chain,
    482                                     DAG.getConstant(paramCount, MVT::i32),
    483                                     DAG.getConstant(sz, MVT::i32),
    484                                     DAG.getConstant(isReg, MVT::i32),
    485                                     InFlag };
    486       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
    487                           DeclareParamOps, 5);
    488       InFlag = Chain.getValue(1);
    489       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    490       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
    491                              DAG.getConstant(0, MVT::i32), OutVals[i], InFlag };
    492 
    493       unsigned opcode = NVPTXISD::StoreParam;
    494       if (isReg)
    495         opcode = NVPTXISD::MoveToParam;
    496       else {
    497         if (Outs[i].Flags.isZExt())
    498           opcode = NVPTXISD::StoreParamU32;
    499         else if (Outs[i].Flags.isSExt())
    500           opcode = NVPTXISD::StoreParamS32;
    501       }
    502       Chain = DAG.getNode(opcode, dl, CopyParamVTs, CopyParamOps, 5);
    503 
    504       InFlag = Chain.getValue(1);
    505       ++paramCount;
    506       continue;
    507     }
    508     // struct or vector
    509     SmallVector<EVT, 16> vtparts;
    510     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
    511     assert(PTy &&
    512            "Type of a byval parameter should be pointer");
    513     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
    514 
    515     if (isABI) {
    516       // declare .param .align 16 .b8 .param<n>[<size>];
    517       unsigned sz = Outs[i].Flags.getByValSize();
    518       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    519       // The ByValAlign in the Outs[i].Flags is alway set at this point, so we
    520       // don't need to
    521       // worry about natural alignment or not. See TargetLowering::LowerCallTo()
    522       SDValue DeclareParamOps[] = { Chain,
    523                        DAG.getConstant(Outs[i].Flags.getByValAlign(), MVT::i32),
    524                                     DAG.getConstant(paramCount, MVT::i32),
    525                                     DAG.getConstant(sz, MVT::i32),
    526                                     InFlag };
    527       Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
    528                           DeclareParamOps, 5);
    529       InFlag = Chain.getValue(1);
    530       unsigned curOffset = 0;
    531       for (unsigned j=0,je=vtparts.size(); j!=je; ++j) {
    532         unsigned elems = 1;
    533         EVT elemtype = vtparts[j];
    534         if (vtparts[j].isVector()) {
    535           elems = vtparts[j].getVectorNumElements();
    536           elemtype = vtparts[j].getVectorElementType();
    537         }
    538         for (unsigned k=0,ke=elems; k!=ke; ++k) {
    539           unsigned sz = elemtype.getSizeInBits();
    540           if (elemtype.isInteger() && (sz < 8)) sz = 8;
    541           SDValue srcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
    542                                         OutVals[i],
    543                                         DAG.getConstant(curOffset,
    544                                                         getPointerTy()));
    545           SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
    546                                 MachinePointerInfo(), false, false, false, 0);
    547           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    548           SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount,
    549                                                             MVT::i32),
    550                                            DAG.getConstant(curOffset, MVT::i32),
    551                                                             theVal, InFlag };
    552           Chain = DAG.getNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
    553                               CopyParamOps, 5);
    554           InFlag = Chain.getValue(1);
    555           curOffset += sz/8;
    556         }
    557       }
    558       ++paramCount;
    559       continue;
    560     }
    561     // Non-abi, struct or vector
    562     // Declare a bunch or .reg .b<size> .param<n>
    563     unsigned curOffset = 0;
    564     for (unsigned j=0,je=vtparts.size(); j!=je; ++j) {
    565       unsigned elems = 1;
    566       EVT elemtype = vtparts[j];
    567       if (vtparts[j].isVector()) {
    568         elems = vtparts[j].getVectorNumElements();
    569         elemtype = vtparts[j].getVectorElementType();
    570       }
    571       for (unsigned k=0,ke=elems; k!=ke; ++k) {
    572         unsigned sz = elemtype.getSizeInBits();
    573         if (elemtype.isInteger() && (sz < 32)) sz = 32;
    574         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    575         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(paramCount,
    576                                                              MVT::i32),
    577                                                   DAG.getConstant(sz, MVT::i32),
    578                                                    DAG.getConstant(1, MVT::i32),
    579                                                              InFlag };
    580         Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
    581                             DeclareParamOps, 5);
    582         InFlag = Chain.getValue(1);
    583         SDValue srcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
    584                                       DAG.getConstant(curOffset,
    585                                                       getPointerTy()));
    586         SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
    587                                   MachinePointerInfo(), false, false, false, 0);
    588         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    589         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
    590                                    DAG.getConstant(0, MVT::i32), theVal,
    591                                    InFlag };
    592         Chain = DAG.getNode(NVPTXISD::MoveToParam, dl, CopyParamVTs,
    593                             CopyParamOps, 5);
    594         InFlag = Chain.getValue(1);
    595         ++paramCount;
    596       }
    597     }
    598   }
    599 
    600   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
    601   unsigned retAlignment = 0;
    602 
    603   // Handle Result
    604   unsigned retCount = 0;
    605   if (Ins.size() > 0) {
    606     SmallVector<EVT, 16> resvtparts;
    607     ComputeValueVTs(*this, retTy, resvtparts);
    608 
    609     // Declare one .param .align 16 .b8 func_retval0[<size>] for ABI or
    610     // individual .reg .b<size> func_retval<0..> for non ABI
    611     unsigned resultsz = 0;
    612     for (unsigned i=0,e=resvtparts.size(); i!=e; ++i) {
    613       unsigned elems = 1;
    614       EVT elemtype = resvtparts[i];
    615       if (resvtparts[i].isVector()) {
    616         elems = resvtparts[i].getVectorNumElements();
    617         elemtype = resvtparts[i].getVectorElementType();
    618       }
    619       for (unsigned j=0,je=elems; j!=je; ++j) {
    620         unsigned sz = elemtype.getSizeInBits();
    621         if (isABI == false) {
    622           if (elemtype.isInteger() && (sz < 32)) sz = 32;
    623         }
    624         else {
    625           if (elemtype.isInteger() && (sz < 8)) sz = 8;
    626         }
    627         if (isABI == false) {
    628           SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    629           SDValue DeclareRetOps[] = { Chain, DAG.getConstant(2, MVT::i32),
    630                                       DAG.getConstant(sz, MVT::i32),
    631                                       DAG.getConstant(retCount, MVT::i32),
    632                                       InFlag };
    633           Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
    634                               DeclareRetOps, 5);
    635           InFlag = Chain.getValue(1);
    636           ++retCount;
    637         }
    638         resultsz += sz;
    639       }
    640     }
    641     if (isABI) {
    642       if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
    643           retTy->isPointerTy() ) {
    644         // Scalar needs to be at least 32bit wide
    645         if (resultsz < 32)
    646           resultsz = 32;
    647         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    648         SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
    649                                     DAG.getConstant(resultsz, MVT::i32),
    650                                     DAG.getConstant(0, MVT::i32), InFlag };
    651         Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
    652                             DeclareRetOps, 5);
    653         InFlag = Chain.getValue(1);
    654       }
    655       else {
    656         if (Func) { // direct call
    657           if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
    658             retAlignment = getTargetData()->getABITypeAlignment(retTy);
    659         } else { // indirect call
    660           const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
    661           if (!llvm::getAlign(*CallI, 0, retAlignment))
    662             retAlignment = getTargetData()->getABITypeAlignment(retTy);
    663         }
    664         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    665         SDValue DeclareRetOps[] = { Chain, DAG.getConstant(retAlignment,
    666                                                            MVT::i32),
    667                                           DAG.getConstant(resultsz/8, MVT::i32),
    668                                          DAG.getConstant(0, MVT::i32), InFlag };
    669         Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
    670                             DeclareRetOps, 5);
    671         InFlag = Chain.getValue(1);
    672       }
    673     }
    674   }
    675 
    676   if (!Func) {
    677     // This is indirect function call case : PTX requires a prototype of the
    678     // form
    679     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
    680     // to be emitted, and the label has to used as the last arg of call
    681     // instruction.
    682     // The prototype is embedded in a string and put as the operand for an
    683     // INLINEASM SDNode.
    684     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    685     std::string proto_string = getPrototype(retTy, Args, Outs, retAlignment);
    686     const char *asmstr = nvTM->getManagedStrPool()->
    687         getManagedString(proto_string.c_str())->c_str();
    688     SDValue InlineAsmOps[] = { Chain,
    689                                DAG.getTargetExternalSymbol(asmstr,
    690                                                            getPointerTy()),
    691                                                            DAG.getMDNode(0),
    692                                    DAG.getTargetConstant(0, MVT::i32), InFlag };
    693     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
    694     InFlag = Chain.getValue(1);
    695   }
    696   // Op to just print "call"
    697   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    698   SDValue PrintCallOps[] = { Chain,
    699                              DAG.getConstant(isABI ? ((Ins.size()==0) ? 0 : 1)
    700                                  : retCount, MVT::i32),
    701                                    InFlag };
    702   Chain = DAG.getNode(Func?(NVPTXISD::PrintCallUni):(NVPTXISD::PrintCall), dl,
    703       PrintCallVTs, PrintCallOps, 3);
    704   InFlag = Chain.getValue(1);
    705 
    706   // Ops to print out the function name
    707   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    708   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
    709   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
    710   InFlag = Chain.getValue(1);
    711 
    712   // Ops to print out the param list
    713   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    714   SDValue CallArgBeginOps[] = { Chain, InFlag };
    715   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
    716                       CallArgBeginOps, 2);
    717   InFlag = Chain.getValue(1);
    718 
    719   for (unsigned i=0, e=paramCount; i!=e; ++i) {
    720     unsigned opcode;
    721     if (i==(e-1))
    722       opcode = NVPTXISD::LastCallArg;
    723     else
    724       opcode = NVPTXISD::CallArg;
    725     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    726     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
    727                              DAG.getConstant(i, MVT::i32),
    728                              InFlag };
    729     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
    730     InFlag = Chain.getValue(1);
    731   }
    732   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    733   SDValue CallArgEndOps[] = { Chain,
    734                               DAG.getConstant(Func ? 1 : 0, MVT::i32),
    735                               InFlag };
    736   Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps,
    737                       3);
    738   InFlag = Chain.getValue(1);
    739 
    740   if (!Func) {
    741     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
    742     SDValue PrototypeOps[] = { Chain,
    743                                DAG.getConstant(uniqueCallSite, MVT::i32),
    744                                InFlag };
    745     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
    746     InFlag = Chain.getValue(1);
    747   }
    748 
    749   // Generate loads from param memory/moves from registers for result
    750   if (Ins.size() > 0) {
    751     if (isABI) {
    752       unsigned resoffset = 0;
    753       for (unsigned i=0,e=Ins.size(); i!=e; ++i) {
    754         unsigned sz = Ins[i].VT.getSizeInBits();
    755         if (Ins[i].VT.isInteger() && (sz < 8)) sz = 8;
    756         std::vector<EVT> LoadRetVTs;
    757         LoadRetVTs.push_back(Ins[i].VT);
    758         LoadRetVTs.push_back(MVT::Other); LoadRetVTs.push_back(MVT::Glue);
    759         std::vector<SDValue> LoadRetOps;
    760         LoadRetOps.push_back(Chain);
    761         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
    762         LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
    763         LoadRetOps.push_back(InFlag);
    764         SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, LoadRetVTs,
    765                                      &LoadRetOps[0], LoadRetOps.size());
    766         Chain = retval.getValue(1);
    767         InFlag = retval.getValue(2);
    768         InVals.push_back(retval);
    769         resoffset += sz/8;
    770       }
    771     }
    772     else {
    773       SmallVector<EVT, 16> resvtparts;
    774       ComputeValueVTs(*this, retTy, resvtparts);
    775 
    776       assert(Ins.size() == resvtparts.size() &&
    777              "Unexpected number of return values in non-ABI case");
    778       unsigned paramNum = 0;
    779       for (unsigned i=0,e=Ins.size(); i!=e; ++i) {
    780         assert(EVT(Ins[i].VT) == resvtparts[i] &&
    781                "Unexpected EVT type in non-ABI case");
    782         unsigned numelems = 1;
    783         EVT elemtype = Ins[i].VT;
    784         if (Ins[i].VT.isVector()) {
    785           numelems = Ins[i].VT.getVectorNumElements();
    786           elemtype = Ins[i].VT.getVectorElementType();
    787         }
    788         std::vector<SDValue> tempRetVals;
    789         for (unsigned j=0; j<numelems; ++j) {
    790           std::vector<EVT> MoveRetVTs;
    791           MoveRetVTs.push_back(elemtype);
    792           MoveRetVTs.push_back(MVT::Other); MoveRetVTs.push_back(MVT::Glue);
    793           std::vector<SDValue> MoveRetOps;
    794           MoveRetOps.push_back(Chain);
    795           MoveRetOps.push_back(DAG.getConstant(0, MVT::i32));
    796           MoveRetOps.push_back(DAG.getConstant(paramNum, MVT::i32));
    797           MoveRetOps.push_back(InFlag);
    798           SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, MoveRetVTs,
    799                                        &MoveRetOps[0], MoveRetOps.size());
    800           Chain = retval.getValue(1);
    801           InFlag = retval.getValue(2);
    802           tempRetVals.push_back(retval);
    803           ++paramNum;
    804         }
    805         if (Ins[i].VT.isVector())
    806           InVals.push_back(DAG.getNode(ISD::BUILD_VECTOR, dl, Ins[i].VT,
    807                                        &tempRetVals[0], tempRetVals.size()));
    808         else
    809           InVals.push_back(tempRetVals[0]);
    810       }
    811     }
    812   }
    813   Chain = DAG.getCALLSEQ_END(Chain,
    814                              DAG.getIntPtrConstant(uniqueCallSite, true),
    815                              DAG.getIntPtrConstant(uniqueCallSite+1, true),
    816                              InFlag);
    817   uniqueCallSite++;
    818 
    819   // set isTailCall to false for now, until we figure out how to express
    820   // tail call optimization in PTX
    821   isTailCall = false;
    822   return Chain;
    823 }
    824 
    825 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
    826 // (see LegalizeDAG.cpp). This is slow and uses local memory.
    827 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
    828 SDValue NVPTXTargetLowering::
    829 LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
    830   SDNode *Node = Op.getNode();
    831   DebugLoc dl = Node->getDebugLoc();
    832   SmallVector<SDValue, 8> Ops;
    833   unsigned NumOperands = Node->getNumOperands();
    834   for (unsigned i=0; i < NumOperands; ++i) {
    835     SDValue SubOp = Node->getOperand(i);
    836     EVT VVT = SubOp.getNode()->getValueType(0);
    837     EVT EltVT = VVT.getVectorElementType();
    838     unsigned NumSubElem = VVT.getVectorNumElements();
    839     for (unsigned j=0; j < NumSubElem; ++j) {
    840       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
    841                                 DAG.getIntPtrConstant(j)));
    842     }
    843   }
    844   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0),
    845                      &Ops[0], Ops.size());
    846 }
    847 
    848 SDValue NVPTXTargetLowering::
    849 LowerOperation(SDValue Op, SelectionDAG &DAG) const {
    850   switch (Op.getOpcode()) {
    851   case ISD::RETURNADDR: return SDValue();
    852   case ISD::FRAMEADDR:  return SDValue();
    853   case ISD::GlobalAddress:      return LowerGlobalAddress(Op, DAG);
    854   case ISD::INTRINSIC_W_CHAIN: return Op;
    855   case ISD::BUILD_VECTOR:
    856   case ISD::EXTRACT_SUBVECTOR:
    857     return Op;
    858   case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, DAG);
    859   default:
    860     llvm_unreachable("Custom lowering not defined for operation");
    861   }
    862 }
    863 
    864 SDValue
    865 NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname, int idx,
    866                                 EVT v) const {
    867   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
    868   std::stringstream suffix;
    869   suffix << idx;
    870   *name += suffix.str();
    871   return DAG.getTargetExternalSymbol(name->c_str(), v);
    872 }
    873 
    874 SDValue
    875 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
    876   return getExtSymb(DAG, ".PARAM", idx, v);
    877 }
    878 
    879 SDValue
    880 NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
    881   return getExtSymb(DAG, ".HLPPARAM", idx);
    882 }
    883 
    884 // Check to see if the kernel argument is image*_t or sampler_t
    885 
    886 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
    887   static const char *const specialTypes[] = {
    888                                              "struct._image2d_t",
    889                                              "struct._image3d_t",
    890                                              "struct._sampler_t"
    891   };
    892 
    893   const Type *Ty = arg->getType();
    894   const PointerType *PTy = dyn_cast<PointerType>(Ty);
    895 
    896   if (!PTy)
    897     return false;
    898 
    899   if (!context)
    900     return false;
    901 
    902   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
    903   const std::string TypeName = STy ? STy->getName() : "";
    904 
    905   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
    906     if (TypeName == specialTypes[i])
    907       return true;
    908 
    909   return false;
    910 }
    911 
    912 SDValue
    913 NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
    914                                         CallingConv::ID CallConv, bool isVarArg,
    915                                       const SmallVectorImpl<ISD::InputArg> &Ins,
    916                                           DebugLoc dl, SelectionDAG &DAG,
    917                                        SmallVectorImpl<SDValue> &InVals) const {
    918   MachineFunction &MF = DAG.getMachineFunction();
    919   const TargetData *TD = getTargetData();
    920 
    921   const Function *F = MF.getFunction();
    922   const AttrListPtr &PAL = F->getAttributes();
    923 
    924   SDValue Root = DAG.getRoot();
    925   std::vector<SDValue> OutChains;
    926 
    927   bool isKernel = llvm::isKernelFunction(*F);
    928   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
    929 
    930   std::vector<Type *> argTypes;
    931   std::vector<const Argument *> theArgs;
    932   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
    933       I != E; ++I) {
    934     theArgs.push_back(I);
    935     argTypes.push_back(I->getType());
    936   }
    937   assert(argTypes.size() == Ins.size() &&
    938          "Ins types and function types did not match");
    939 
    940   int idx = 0;
    941   for (unsigned i=0, e=Ins.size(); i!=e; ++i, ++idx) {
    942     Type *Ty = argTypes[i];
    943     EVT ObjectVT = getValueType(Ty);
    944     assert(ObjectVT == Ins[i].VT &&
    945            "Ins type did not match function type");
    946 
    947     // If the kernel argument is image*_t or sampler_t, convert it to
    948     // a i32 constant holding the parameter position. This can later
    949     // matched in the AsmPrinter to output the correct mangled name.
    950     if (isImageOrSamplerVal(theArgs[i],
    951                            (theArgs[i]->getParent() ?
    952                                theArgs[i]->getParent()->getParent() : 0))) {
    953       assert(isKernel && "Only kernels can have image/sampler params");
    954       InVals.push_back(DAG.getConstant(i+1, MVT::i32));
    955       continue;
    956     }
    957 
    958     if (theArgs[i]->use_empty()) {
    959       // argument is dead
    960       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
    961       continue;
    962     }
    963 
    964     // In the following cases, assign a node order of "idx+1"
    965     // to newly created nodes. The SDNOdes for params have to
    966     // appear in the same order as their order of appearance
    967     // in the original function. "idx+1" holds that order.
    968     if (PAL.paramHasAttr(i+1, Attribute::ByVal) == false) {
    969       // A plain scalar.
    970       if (isABI || isKernel) {
    971         // If ABI, load from the param symbol
    972         SDValue Arg = getParamSymbol(DAG, idx);
    973         Value *srcValue = new Argument(PointerType::get(ObjectVT.getTypeForEVT(
    974             F->getContext()),
    975             llvm::ADDRESS_SPACE_PARAM));
    976         SDValue p = DAG.getLoad(ObjectVT, dl, Root, Arg,
    977                                 MachinePointerInfo(srcValue), false, false,
    978                                 false,
    979                                 TD->getABITypeAlignment(ObjectVT.getTypeForEVT(
    980                                   F->getContext())));
    981         if (p.getNode())
    982           DAG.AssignOrdering(p.getNode(), idx+1);
    983         InVals.push_back(p);
    984       }
    985       else {
    986         // If no ABI, just move the param symbol
    987         SDValue Arg = getParamSymbol(DAG, idx, ObjectVT);
    988         SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
    989         if (p.getNode())
    990           DAG.AssignOrdering(p.getNode(), idx+1);
    991         InVals.push_back(p);
    992       }
    993       continue;
    994     }
    995 
    996     // Param has ByVal attribute
    997     if (isABI || isKernel) {
    998       // Return MoveParam(param symbol).
    999       // Ideally, the param symbol can be returned directly,
   1000       // but when SDNode builder decides to use it in a CopyToReg(),
   1001       // machine instruction fails because TargetExternalSymbol
   1002       // (not lowered) is target dependent, and CopyToReg assumes
   1003       // the source is lowered.
   1004       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
   1005       SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
   1006       if (p.getNode())
   1007         DAG.AssignOrdering(p.getNode(), idx+1);
   1008       if (isKernel)
   1009         InVals.push_back(p);
   1010       else {
   1011         SDValue p2 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
   1012                     DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32),
   1013                                  p);
   1014         InVals.push_back(p2);
   1015       }
   1016     } else {
   1017       // Have to move a set of param symbols to registers and
   1018       // store them locally and return the local pointer in InVals
   1019       const PointerType *elemPtrType = dyn_cast<PointerType>(argTypes[i]);
   1020       assert(elemPtrType &&
   1021              "Byval parameter should be a pointer type");
   1022       Type *elemType = elemPtrType->getElementType();
   1023       // Compute the constituent parts
   1024       SmallVector<EVT, 16> vtparts;
   1025       SmallVector<uint64_t, 16> offsets;
   1026       ComputeValueVTs(*this, elemType, vtparts, &offsets, 0);
   1027       unsigned totalsize = 0;
   1028       for (unsigned j=0, je=vtparts.size(); j!=je; ++j)
   1029         totalsize += vtparts[j].getStoreSizeInBits();
   1030       SDValue localcopy =  DAG.getFrameIndex(MF.getFrameInfo()->
   1031                                       CreateStackObject(totalsize/8, 16, false),
   1032                                              getPointerTy());
   1033       unsigned sizesofar = 0;
   1034       std::vector<SDValue> theChains;
   1035       for (unsigned j=0, je=vtparts.size(); j!=je; ++j) {
   1036         unsigned numElems = 1;
   1037         if (vtparts[j].isVector()) numElems = vtparts[j].getVectorNumElements();
   1038         for (unsigned k=0, ke=numElems; k!=ke; ++k) {
   1039           EVT tmpvt = vtparts[j];
   1040           if (tmpvt.isVector()) tmpvt = tmpvt.getVectorElementType();
   1041           SDValue arg = DAG.getNode(NVPTXISD::MoveParam, dl, tmpvt,
   1042                                     getParamSymbol(DAG, idx, tmpvt));
   1043           SDValue addr = DAG.getNode(ISD::ADD, dl, getPointerTy(), localcopy,
   1044                                     DAG.getConstant(sizesofar, getPointerTy()));
   1045           theChains.push_back(DAG.getStore(Chain, dl, arg, addr,
   1046                                         MachinePointerInfo(), false, false, 0));
   1047           sizesofar += tmpvt.getStoreSizeInBits()/8;
   1048           ++idx;
   1049         }
   1050       }
   1051       --idx;
   1052       Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &theChains[0],
   1053                           theChains.size());
   1054       InVals.push_back(localcopy);
   1055     }
   1056   }
   1057 
   1058   // Clang will check explicit VarArg and issue error if any. However, Clang
   1059   // will let code with
   1060   // implicit var arg like f() pass.
   1061   // We treat this case as if the arg list is empty.
   1062   //if (F.isVarArg()) {
   1063   // assert(0 && "VarArg not supported yet!");
   1064   //}
   1065 
   1066   if (!OutChains.empty())
   1067     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
   1068                             &OutChains[0], OutChains.size()));
   1069 
   1070   return Chain;
   1071 }
   1072 
   1073 SDValue
   1074 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   1075                                  bool isVarArg,
   1076                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
   1077                                  const SmallVectorImpl<SDValue> &OutVals,
   1078                                  DebugLoc dl, SelectionDAG &DAG) const {
   1079 
   1080   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
   1081 
   1082   unsigned sizesofar = 0;
   1083   unsigned idx = 0;
   1084   for (unsigned i=0, e=Outs.size(); i!=e; ++i) {
   1085     SDValue theVal = OutVals[i];
   1086     EVT theValType = theVal.getValueType();
   1087     unsigned numElems = 1;
   1088     if (theValType.isVector()) numElems = theValType.getVectorNumElements();
   1089     for (unsigned j=0,je=numElems; j!=je; ++j) {
   1090       SDValue tmpval = theVal;
   1091       if (theValType.isVector())
   1092         tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
   1093                              theValType.getVectorElementType(),
   1094                              tmpval, DAG.getIntPtrConstant(j));
   1095       Chain = DAG.getNode(isABI ? NVPTXISD::StoreRetval :NVPTXISD::MoveToRetval,
   1096           dl, MVT::Other,
   1097           Chain,
   1098           DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
   1099           tmpval);
   1100       if (theValType.isVector())
   1101         sizesofar += theValType.getVectorElementType().getStoreSizeInBits()/8;
   1102       else
   1103         sizesofar += theValType.getStoreSizeInBits()/8;
   1104       ++idx;
   1105     }
   1106   }
   1107 
   1108   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
   1109 }
   1110 
   1111 void
   1112 NVPTXTargetLowering::LowerAsmOperandForConstraint(SDValue Op,
   1113                                                   std::string &Constraint,
   1114                                                   std::vector<SDValue> &Ops,
   1115                                                   SelectionDAG &DAG) const
   1116 {
   1117   if (Constraint.length() > 1)
   1118     return;
   1119   else
   1120     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
   1121 }
   1122 
   1123 // NVPTX suuport vector of legal types of any length in Intrinsics because the
   1124 // NVPTX specific type legalizer
   1125 // will legalize them to the PTX supported length.
   1126 bool
   1127 NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
   1128   if (isTypeLegal(VT))
   1129     return true;
   1130   if (VT.isVector()) {
   1131     MVT eVT = VT.getVectorElementType();
   1132     if (isTypeLegal(eVT))
   1133       return true;
   1134   }
   1135   return false;
   1136 }
   1137 
   1138 
   1139 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
   1140 // TgtMemIntrinsic
   1141 // because we need the information that is only available in the "Value" type
   1142 // of destination
   1143 // pointer. In particular, the address space information.
   1144 bool
   1145 NVPTXTargetLowering::getTgtMemIntrinsic(IntrinsicInfo& Info, const CallInst &I,
   1146                                         unsigned Intrinsic) const {
   1147   switch (Intrinsic) {
   1148   default:
   1149     return false;
   1150 
   1151   case Intrinsic::nvvm_atomic_load_add_f32:
   1152     Info.opc = ISD::INTRINSIC_W_CHAIN;
   1153     Info.memVT = MVT::f32;
   1154     Info.ptrVal = I.getArgOperand(0);
   1155     Info.offset = 0;
   1156     Info.vol = 0;
   1157     Info.readMem = true;
   1158     Info.writeMem = true;
   1159     Info.align = 0;
   1160     return true;
   1161 
   1162   case Intrinsic::nvvm_atomic_load_inc_32:
   1163   case Intrinsic::nvvm_atomic_load_dec_32:
   1164     Info.opc = ISD::INTRINSIC_W_CHAIN;
   1165     Info.memVT = MVT::i32;
   1166     Info.ptrVal = I.getArgOperand(0);
   1167     Info.offset = 0;
   1168     Info.vol = 0;
   1169     Info.readMem = true;
   1170     Info.writeMem = true;
   1171     Info.align = 0;
   1172     return true;
   1173 
   1174   case Intrinsic::nvvm_ldu_global_i:
   1175   case Intrinsic::nvvm_ldu_global_f:
   1176   case Intrinsic::nvvm_ldu_global_p:
   1177 
   1178     Info.opc = ISD::INTRINSIC_W_CHAIN;
   1179     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
   1180       Info.memVT = MVT::i32;
   1181     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
   1182       Info.memVT = getPointerTy();
   1183     else
   1184       Info.memVT = MVT::f32;
   1185     Info.ptrVal = I.getArgOperand(0);
   1186     Info.offset = 0;
   1187     Info.vol = 0;
   1188     Info.readMem = true;
   1189     Info.writeMem = false;
   1190     Info.align = 0;
   1191     return true;
   1192 
   1193   }
   1194   return false;
   1195 }
   1196 
   1197 /// isLegalAddressingMode - Return true if the addressing mode represented
   1198 /// by AM is legal for this target, for a load/store of the specified type.
   1199 /// Used to guide target specific optimizations, like loop strength reduction
   1200 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
   1201 /// (CodeGenPrepare.cpp)
   1202 bool
   1203 NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
   1204                                            Type *Ty) const {
   1205 
   1206   // AddrMode - This represents an addressing mode of:
   1207   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
   1208   //
   1209   // The legal address modes are
   1210   // - [avar]
   1211   // - [areg]
   1212   // - [areg+immoff]
   1213   // - [immAddr]
   1214 
   1215   if (AM.BaseGV) {
   1216     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
   1217       return false;
   1218     return true;
   1219   }
   1220 
   1221   switch (AM.Scale) {
   1222   case 0:  // "r", "r+i" or "i" is allowed
   1223     break;
   1224   case 1:
   1225     if (AM.HasBaseReg)  // "r+r+i" or "r+r" is not allowed.
   1226       return false;
   1227     // Otherwise we have r+i.
   1228     break;
   1229   default:
   1230     // No scale > 1 is allowed
   1231     return false;
   1232   }
   1233   return true;
   1234 }
   1235 
   1236 //===----------------------------------------------------------------------===//
   1237 //                         NVPTX Inline Assembly Support
   1238 //===----------------------------------------------------------------------===//
   1239 
   1240 /// getConstraintType - Given a constraint letter, return the type of
   1241 /// constraint it is for this target.
   1242 NVPTXTargetLowering::ConstraintType
   1243 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
   1244   if (Constraint.size() == 1) {
   1245     switch (Constraint[0]) {
   1246     default:
   1247       break;
   1248     case 'r':
   1249     case 'h':
   1250     case 'c':
   1251     case 'l':
   1252     case 'f':
   1253     case 'd':
   1254     case '0':
   1255     case 'N':
   1256       return C_RegisterClass;
   1257     }
   1258   }
   1259   return TargetLowering::getConstraintType(Constraint);
   1260 }
   1261 
   1262 
   1263 std::pair<unsigned, const TargetRegisterClass*>
   1264 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
   1265                                                   EVT VT) const {
   1266   if (Constraint.size() == 1) {
   1267     switch (Constraint[0]) {
   1268     case 'c':
   1269       return std::make_pair(0U, &NVPTX::Int8RegsRegClass);
   1270     case 'h':
   1271       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
   1272     case 'r':
   1273       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
   1274     case 'l':
   1275     case 'N':
   1276       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
   1277     case 'f':
   1278       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
   1279     case 'd':
   1280       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
   1281     }
   1282   }
   1283   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
   1284 }
   1285 
   1286 
   1287 
   1288 /// getFunctionAlignment - Return the Log2 alignment of this function.
   1289 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
   1290   return 4;
   1291 }
   1292