Home | History | Annotate | Download | only in PTX
      1 //===-- PTXAsmPrinter.cpp - PTX LLVM assembly writer ----------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This file contains a printer that converts from our internal representation
     11 // of machine-dependent LLVM code to PTX assembly language.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #define DEBUG_TYPE "ptx-asm-printer"
     16 
     17 #include "PTXAsmPrinter.h"
     18 #include "PTX.h"
     19 #include "PTXMachineFunctionInfo.h"
     20 #include "PTXParamManager.h"
     21 #include "PTXRegisterInfo.h"
     22 #include "PTXTargetMachine.h"
     23 #include "llvm/Argument.h"
     24 #include "llvm/DerivedTypes.h"
     25 #include "llvm/Function.h"
     26 #include "llvm/Module.h"
     27 #include "llvm/ADT/SmallString.h"
     28 #include "llvm/ADT/Twine.h"
     29 #include "llvm/Analysis/DebugInfo.h"
     30 #include "llvm/CodeGen/AsmPrinter.h"
     31 #include "llvm/CodeGen/MachineFrameInfo.h"
     32 #include "llvm/CodeGen/MachineInstr.h"
     33 #include "llvm/CodeGen/MachineRegisterInfo.h"
     34 #include "llvm/MC/MCContext.h"
     35 #include "llvm/MC/MCExpr.h"
     36 #include "llvm/MC/MCInst.h"
     37 #include "llvm/MC/MCStreamer.h"
     38 #include "llvm/MC/MCSymbol.h"
     39 #include "llvm/Target/Mangler.h"
     40 #include "llvm/Target/TargetLoweringObjectFile.h"
     41 #include "llvm/Support/CommandLine.h"
     42 #include "llvm/Support/Debug.h"
     43 #include "llvm/Support/ErrorHandling.h"
     44 #include "llvm/Support/MathExtras.h"
     45 #include "llvm/Support/Path.h"
     46 #include "llvm/Support/TargetRegistry.h"
     47 #include "llvm/Support/raw_ostream.h"
     48 
     49 using namespace llvm;
     50 
     51 static const char PARAM_PREFIX[] = "__param_";
     52 static const char RETURN_PREFIX[] = "__ret_";
     53 
     54 static const char *getRegisterTypeName(unsigned RegType) {
     55   switch (RegType) {
     56   default:
     57     llvm_unreachable("Unknown register type");
     58   case PTXRegisterType::Pred:
     59     return ".pred";
     60   case PTXRegisterType::B16:
     61     return ".b16";
     62   case PTXRegisterType::B32:
     63     return ".b32";
     64   case PTXRegisterType::B64:
     65     return ".b64";
     66   case PTXRegisterType::F32:
     67     return ".f32";
     68   case PTXRegisterType::F64:
     69     return ".f64";
     70   }
     71 }
     72 
     73 static const char *getStateSpaceName(unsigned addressSpace) {
     74   switch (addressSpace) {
     75   default: llvm_unreachable("Unknown state space");
     76   case PTXStateSpace::Global:    return "global";
     77   case PTXStateSpace::Constant:  return "const";
     78   case PTXStateSpace::Local:     return "local";
     79   case PTXStateSpace::Parameter: return "param";
     80   case PTXStateSpace::Shared:    return "shared";
     81   }
     82 }
     83 
     84 static const char *getTypeName(Type* type) {
     85   while (true) {
     86     switch (type->getTypeID()) {
     87       default: llvm_unreachable("Unknown type");
     88       case Type::FloatTyID: return ".f32";
     89       case Type::DoubleTyID: return ".f64";
     90       case Type::IntegerTyID:
     91         switch (type->getPrimitiveSizeInBits()) {
     92           default: llvm_unreachable("Unknown integer bit-width");
     93           case 16: return ".u16";
     94           case 32: return ".u32";
     95           case 64: return ".u64";
     96         }
     97       case Type::ArrayTyID:
     98       case Type::PointerTyID:
     99         type = dyn_cast<SequentialType>(type)->getElementType();
    100         break;
    101     }
    102   }
    103   return NULL;
    104 }
    105 
    106 bool PTXAsmPrinter::doFinalization(Module &M) {
    107   // XXX Temproarily remove global variables so that doFinalization() will not
    108   // emit them again (global variables are emitted at beginning).
    109 
    110   Module::GlobalListType &global_list = M.getGlobalList();
    111   int i, n = global_list.size();
    112   GlobalVariable **gv_array = new GlobalVariable* [n];
    113 
    114   // first, back-up GlobalVariable in gv_array
    115   i = 0;
    116   for (Module::global_iterator I = global_list.begin(), E = global_list.end();
    117        I != E; ++I)
    118     gv_array[i++] = &*I;
    119 
    120   // second, empty global_list
    121   while (!global_list.empty())
    122     global_list.remove(global_list.begin());
    123 
    124   // call doFinalization
    125   bool ret = AsmPrinter::doFinalization(M);
    126 
    127   // now we restore global variables
    128   for (i = 0; i < n; i ++)
    129     global_list.insert(global_list.end(), gv_array[i]);
    130 
    131   delete[] gv_array;
    132   return ret;
    133 }
    134 
    135 void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
    136 {
    137   const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
    138 
    139   // Emit the PTX .version and .target attributes
    140   OutStreamer.EmitRawText(Twine("\t.version ") + ST.getPTXVersionString());
    141   OutStreamer.EmitRawText(Twine("\t.target ") + ST.getTargetString() +
    142                                 (ST.supportsDouble() ? ""
    143                                                      : ", map_f64_to_f32"));
    144   // .address_size directive is optional, but it must immediately follow
    145   // the .target directive if present within a module
    146   if (ST.supportsPTX23()) {
    147     const char *addrSize = ST.is64Bit() ? "64" : "32";
    148     OutStreamer.EmitRawText(Twine("\t.address_size ") + addrSize);
    149   }
    150 
    151   OutStreamer.AddBlankLine();
    152 
    153   // Define any .file directives
    154   DebugInfoFinder DbgFinder;
    155   DbgFinder.processModule(M);
    156 
    157   for (DebugInfoFinder::iterator I = DbgFinder.compile_unit_begin(),
    158        E = DbgFinder.compile_unit_end(); I != E; ++I) {
    159     DICompileUnit DIUnit(*I);
    160     StringRef FN = DIUnit.getFilename();
    161     StringRef Dir = DIUnit.getDirectory();
    162     GetOrCreateSourceID(FN, Dir);
    163   }
    164 
    165   OutStreamer.AddBlankLine();
    166 
    167   // declare external functions
    168   for (Module::const_iterator i = M.begin(), e = M.end();
    169        i != e; ++i)
    170     EmitFunctionDeclaration(i);
    171 
    172   // declare global variables
    173   for (Module::const_global_iterator i = M.global_begin(), e = M.global_end();
    174        i != e; ++i)
    175     EmitVariableDeclaration(i);
    176 }
    177 
    178 void PTXAsmPrinter::EmitFunctionBodyStart() {
    179   OutStreamer.EmitRawText(Twine("{"));
    180 
    181   const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
    182   const PTXParamManager &PM = MFI->getParamManager();
    183 
    184   // Print register definitions
    185   SmallString<128> regDefs;
    186   raw_svector_ostream os(regDefs);
    187   unsigned numRegs;
    188 
    189   // pred
    190   numRegs = MFI->countRegisters(PTXRegisterType::Pred, PTXRegisterSpace::Reg);
    191   if(numRegs > 0)
    192     os << "\t.reg .pred %p<" << numRegs << ">;\n";
    193 
    194   // i16
    195   numRegs = MFI->countRegisters(PTXRegisterType::B16, PTXRegisterSpace::Reg);
    196   if(numRegs > 0)
    197     os << "\t.reg .b16 %rh<" << numRegs << ">;\n";
    198 
    199   // i32
    200   numRegs = MFI->countRegisters(PTXRegisterType::B32, PTXRegisterSpace::Reg);
    201   if(numRegs > 0)
    202     os << "\t.reg .b32 %r<" << numRegs << ">;\n";
    203 
    204   // i64
    205   numRegs = MFI->countRegisters(PTXRegisterType::B64, PTXRegisterSpace::Reg);
    206   if(numRegs > 0)
    207     os << "\t.reg .b64 %rd<" << numRegs << ">;\n";
    208 
    209   // f32
    210   numRegs = MFI->countRegisters(PTXRegisterType::F32, PTXRegisterSpace::Reg);
    211   if(numRegs > 0)
    212     os << "\t.reg .f32 %f<" << numRegs << ">;\n";
    213 
    214   // f64
    215   numRegs = MFI->countRegisters(PTXRegisterType::F64, PTXRegisterSpace::Reg);
    216   if(numRegs > 0)
    217     os << "\t.reg .f64 %fd<" << numRegs << ">;\n";
    218 
    219   // Local params
    220   for (PTXParamManager::param_iterator i = PM.local_begin(), e = PM.local_end();
    221        i != e; ++i)
    222     os << "\t.param .b" << PM.getParamSize(*i) << ' ' << PM.getParamName(*i)
    223        << ";\n";
    224 
    225   OutStreamer.EmitRawText(os.str());
    226 
    227 
    228   const MachineFrameInfo* FrameInfo = MF->getFrameInfo();
    229   DEBUG(dbgs() << "Have " << FrameInfo->getNumObjects()
    230                << " frame object(s)\n");
    231   for (unsigned i = 0, e = FrameInfo->getNumObjects(); i != e; ++i) {
    232     DEBUG(dbgs() << "Size of object: " << FrameInfo->getObjectSize(i) << "\n");
    233     if (FrameInfo->getObjectSize(i) > 0) {
    234       OutStreamer.EmitRawText("\t.local .align " +
    235                               Twine(FrameInfo->getObjectAlignment(i)) +
    236                               " .b8 __local" +
    237                               Twine(i) +
    238                               "[" +
    239                               Twine(FrameInfo->getObjectSize(i)) +
    240                               "];");
    241     }
    242   }
    243 
    244   //unsigned Index = 1;
    245   // Print parameter passing params
    246   //for (PTXMachineFunctionInfo::param_iterator
    247   //     i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) {
    248   //  std::string def = "\t.param .b";
    249   //  def += utostr(*i);
    250   //  def += " __ret_";
    251   //  def += utostr(Index);
    252   //  Index++;
    253   //  def += ";";
    254   //  OutStreamer.EmitRawText(Twine(def));
    255   //}
    256 }
    257 
    258 void PTXAsmPrinter::EmitFunctionBodyEnd() {
    259   OutStreamer.EmitRawText(Twine("}"));
    260 }
    261 
    262 void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
    263   MCInst TmpInst;
    264   LowerPTXMachineInstrToMCInst(MI, TmpInst, *this);
    265   OutStreamer.EmitInstruction(TmpInst);
    266 }
    267 
    268 void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
    269   // Check to see if this is a special global used by LLVM, if so, emit it.
    270   if (EmitSpecialLLVMGlobal(gv))
    271     return;
    272 
    273   MCSymbol *gvsym = Mang->getSymbol(gv);
    274 
    275   assert(gvsym->isUndefined() && "Cannot define a symbol twice!");
    276 
    277   SmallString<128> decl;
    278   raw_svector_ostream os(decl);
    279 
    280   // check if it is defined in some other translation unit
    281   if (gv->isDeclaration())
    282     os << ".extern ";
    283 
    284   // state space: e.g., .global
    285   os << '.' << getStateSpaceName(gv->getType()->getAddressSpace()) << ' ';
    286 
    287   // alignment (optional)
    288   unsigned alignment = gv->getAlignment();
    289   if (alignment != 0)
    290     os << ".align " << gv->getAlignment() << ' ';
    291 
    292 
    293   if (PointerType::classof(gv->getType())) {
    294     PointerType* pointerTy = dyn_cast<PointerType>(gv->getType());
    295     Type* elementTy = pointerTy->getElementType();
    296 
    297     if (elementTy->isArrayTy()) {
    298       assert(elementTy->isArrayTy() && "Only pointers to arrays are supported");
    299 
    300       ArrayType* arrayTy = dyn_cast<ArrayType>(elementTy);
    301       elementTy = arrayTy->getElementType();
    302 
    303       unsigned numElements = arrayTy->getNumElements();
    304 
    305       while (elementTy->isArrayTy()) {
    306         arrayTy = dyn_cast<ArrayType>(elementTy);
    307         elementTy = arrayTy->getElementType();
    308 
    309         numElements *= arrayTy->getNumElements();
    310       }
    311 
    312       // FIXME: isPrimitiveType() == false for i16?
    313       assert(elementTy->isSingleValueType() &&
    314              "Non-primitive types are not handled");
    315 
    316       // Find the size of the element in bits
    317       unsigned elementSize = elementTy->getPrimitiveSizeInBits();
    318 
    319       os << ".b" << elementSize << ' ' << gvsym->getName()
    320          << '[' << numElements << ']';
    321     } else {
    322       os << ".b8" << gvsym->getName() << "[]";
    323     }
    324 
    325     // handle string constants (assume ConstantArray means string)
    326     if (gv->hasInitializer()) {
    327       const Constant *C = gv->getInitializer();
    328       if (const ConstantArray *CA = dyn_cast<ConstantArray>(C)) {
    329         os << " = {";
    330 
    331         for (unsigned i = 0, e = C->getNumOperands(); i != e; ++i) {
    332           if (i > 0)
    333             os << ',';
    334 
    335           os << "0x";
    336           os.write_hex(cast<ConstantInt>(CA->getOperand(i))->getZExtValue());
    337         }
    338 
    339         os << '}';
    340       }
    341     }
    342   } else {
    343     // Note: this is currently the fall-through case and most likely generates
    344     //       incorrect code.
    345     os << getTypeName(gv->getType()) << ' ' << gvsym->getName();
    346 
    347     if (isa<ArrayType>(gv->getType()) || isa<PointerType>(gv->getType()))
    348       os << "[]";
    349   }
    350 
    351   os << ';';
    352 
    353   OutStreamer.EmitRawText(os.str());
    354   OutStreamer.AddBlankLine();
    355 }
    356 
    357 void PTXAsmPrinter::EmitFunctionEntryLabel() {
    358   // The function label could have already been emitted if two symbols end up
    359   // conflicting due to asm renaming.  Detect this and emit an error.
    360   if (!CurrentFnSym->isUndefined())
    361     report_fatal_error("'" + Twine(CurrentFnSym->getName()) +
    362                        "' label emitted multiple times to assembly file");
    363 
    364   const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
    365   const PTXParamManager &PM = MFI->getParamManager();
    366   const bool isKernel = MFI->isKernel();
    367   const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
    368 
    369   SmallString<128> decl;
    370   raw_svector_ostream os(decl);
    371   os << (isKernel ? ".entry" : ".func");
    372 
    373   if (!isKernel) {
    374     os << " (";
    375     if (ST.useParamSpaceForDeviceArgs()) {
    376       for (PTXParamManager::param_iterator i = PM.ret_begin(), e = PM.ret_end(),
    377            b = i; i != e; ++i) {
    378         if (i != b)
    379           os << ", ";
    380 
    381         os << ".param .b" << PM.getParamSize(*i) << ' ' << PM.getParamName(*i);
    382       }
    383     } else {
    384       for (PTXMachineFunctionInfo::reg_iterator
    385            i = MFI->retreg_begin(), e = MFI->retreg_end(), b = i;
    386            i != e; ++i) {
    387         if (i != b)
    388           os << ", ";
    389 
    390         os << ".reg " << getRegisterTypeName(MFI->getRegisterType(*i)) << ' '
    391            << MFI->getRegisterName(*i);
    392       }
    393     }
    394     os << ')';
    395   }
    396 
    397   // Print function name
    398   os << ' ' << CurrentFnSym->getName() << " (";
    399 
    400   const Function *F = MF->getFunction();
    401 
    402   // Print parameters
    403   if (isKernel || ST.useParamSpaceForDeviceArgs()) {
    404     /*for (PTXParamManager::param_iterator i = PM.arg_begin(), e = PM.arg_end(),
    405          b = i; i != e; ++i) {
    406       if (i != b)
    407         os << ", ";
    408 
    409       os << ".param .b" << PM.getParamSize(*i) << ' ' << PM.getParamName(*i);
    410     }*/
    411     int Counter = 1;
    412     for (Function::const_arg_iterator i = F->arg_begin(), e = F->arg_end(),
    413          b = i; i != e; ++i) {
    414       if (i != b)
    415         os << ", ";
    416       const Type *ArgType = (*i).getType();
    417       os << ".param .b";
    418       if (ArgType->isPointerTy()) {
    419         if (ST.is64Bit())
    420           os << "64";
    421         else
    422           os << "32";
    423       } else {
    424         os << ArgType->getPrimitiveSizeInBits();
    425       }
    426       if (ArgType->isPointerTy() && ST.emitPtrAttribute()) {
    427         const PointerType *PtrType = dyn_cast<const PointerType>(ArgType);
    428         os << " .ptr";
    429         switch (PtrType->getAddressSpace()) {
    430         default:
    431           llvm_unreachable("Unknown address space in argument");
    432         case PTXStateSpace::Global:
    433           os << " .global";
    434           break;
    435         case PTXStateSpace::Shared:
    436           os << " .shared";
    437           break;
    438         }
    439       }
    440       os << " __param_" << Counter++;
    441     }
    442   } else {
    443     for (PTXMachineFunctionInfo::reg_iterator
    444          i = MFI->argreg_begin(), e = MFI->argreg_end(), b = i;
    445          i != e; ++i) {
    446       if (i != b)
    447         os << ", ";
    448 
    449       os << ".reg " << getRegisterTypeName(MFI->getRegisterType(*i)) << ' '
    450          << MFI->getRegisterName(*i);
    451     }
    452   }
    453   os << ')';
    454 
    455   OutStreamer.EmitRawText(os.str());
    456 }
    457 
    458 void PTXAsmPrinter::EmitFunctionDeclaration(const Function* func)
    459 {
    460   const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
    461 
    462   std::string decl = "";
    463 
    464   // hard-coded emission of extern vprintf function
    465 
    466   if (func->getName() == "printf" || func->getName() == "puts") {
    467     decl += ".extern .func (.param .b32 __param_1) vprintf (.param .b";
    468     if (ST.is64Bit())
    469       decl += "64";
    470     else
    471       decl += "32";
    472     decl += " __param_2, .param .b";
    473     if (ST.is64Bit())
    474       decl += "64";
    475     else
    476       decl += "32";
    477     decl += " __param_3)\n";
    478   }
    479 
    480   OutStreamer.EmitRawText(Twine(decl));
    481 }
    482 
    483 unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName,
    484                                             StringRef DirName) {
    485   // If FE did not provide a file name, then assume stdin.
    486   if (FileName.empty())
    487     return GetOrCreateSourceID("<stdin>", StringRef());
    488 
    489   // MCStream expects full path name as filename.
    490   if (!DirName.empty() && !sys::path::is_absolute(FileName)) {
    491     SmallString<128> FullPathName = DirName;
    492     sys::path::append(FullPathName, FileName);
    493     // Here FullPathName will be copied into StringMap by GetOrCreateSourceID.
    494     return GetOrCreateSourceID(StringRef(FullPathName), StringRef());
    495   }
    496 
    497   StringMapEntry<unsigned> &Entry = SourceIdMap.GetOrCreateValue(FileName);
    498   if (Entry.getValue())
    499     return Entry.getValue();
    500 
    501   unsigned SrcId = SourceIdMap.size();
    502   Entry.setValue(SrcId);
    503 
    504   // Print out a .file directive to specify files for .loc directives.
    505   OutStreamer.EmitDwarfFileDirective(SrcId, "", Entry.getKey());
    506 
    507   return SrcId;
    508 }
    509 
    510 MCOperand PTXAsmPrinter::GetSymbolRef(const MachineOperand &MO,
    511                                       const MCSymbol *Symbol) {
    512   const MCExpr *Expr;
    513   Expr = MCSymbolRefExpr::Create(Symbol, MCSymbolRefExpr::VK_None, OutContext);
    514   return MCOperand::CreateExpr(Expr);
    515 }
    516 
    517 MCOperand PTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
    518   MCOperand MCOp;
    519   const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
    520   unsigned EncodedReg;
    521   switch (MO.getType()) {
    522   default:
    523     llvm_unreachable("Unknown operand type");
    524   case MachineOperand::MO_Register:
    525     if (MO.getReg() > 0) {
    526       // Encode the register
    527       EncodedReg = MFI->getEncodedRegister(MO.getReg());
    528     } else {
    529       EncodedReg = 0;
    530     }
    531     MCOp = MCOperand::CreateReg(EncodedReg);
    532     break;
    533   case MachineOperand::MO_Immediate:
    534     MCOp = MCOperand::CreateImm(MO.getImm());
    535     break;
    536   case MachineOperand::MO_MachineBasicBlock:
    537     MCOp = MCOperand::CreateExpr(MCSymbolRefExpr::Create(
    538                                  MO.getMBB()->getSymbol(), OutContext));
    539     break;
    540   case MachineOperand::MO_GlobalAddress:
    541     MCOp = GetSymbolRef(MO, Mang->getSymbol(MO.getGlobal()));
    542     break;
    543   case MachineOperand::MO_ExternalSymbol:
    544     MCOp = GetSymbolRef(MO, GetExternalSymbolSymbol(MO.getSymbolName()));
    545     break;
    546   case MachineOperand::MO_FPImmediate:
    547     APFloat Val = MO.getFPImm()->getValueAPF();
    548     bool ignored;
    549     Val.convert(APFloat::IEEEdouble, APFloat::rmTowardZero, &ignored);
    550     MCOp = MCOperand::CreateFPImm(Val.convertToDouble());
    551     break;
    552   }
    553 
    554   return MCOp;
    555 }
    556 
    557 // Force static initialization.
    558 extern "C" void LLVMInitializePTXAsmPrinter() {
    559   RegisterAsmPrinter<PTXAsmPrinter> X(ThePTX32Target);
    560   RegisterAsmPrinter<PTXAsmPrinter> Y(ThePTX64Target);
    561 }
    562