Home | History | Annotate | Download | only in Chapter4
      1 #include "llvm/DerivedTypes.h"
      2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
      3 #include "llvm/ExecutionEngine/JIT.h"
      4 #include "llvm/LLVMContext.h"
      5 #include "llvm/Module.h"
      6 #include "llvm/PassManager.h"
      7 #include "llvm/Analysis/Verifier.h"
      8 #include "llvm/Analysis/Passes.h"
      9 #include "llvm/Target/TargetData.h"
     10 #include "llvm/Transforms/Scalar.h"
     11 #include "llvm/Support/IRBuilder.h"
     12 #include "llvm/Support/TargetSelect.h"
     13 #include <cstdio>
     14 #include <string>
     15 #include <map>
     16 #include <vector>
     17 using namespace llvm;
     18 
     19 //===----------------------------------------------------------------------===//
     20 // Lexer
     21 //===----------------------------------------------------------------------===//
     22 
     23 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
     24 // of these for known things.
     25 enum Token {
     26   tok_eof = -1,
     27 
     28   // commands
     29   tok_def = -2, tok_extern = -3,
     30 
     31   // primary
     32   tok_identifier = -4, tok_number = -5
     33 };
     34 
     35 static std::string IdentifierStr;  // Filled in if tok_identifier
     36 static double NumVal;              // Filled in if tok_number
     37 
     38 /// gettok - Return the next token from standard input.
     39 static int gettok() {
     40   static int LastChar = ' ';
     41 
     42   // Skip any whitespace.
     43   while (isspace(LastChar))
     44     LastChar = getchar();
     45 
     46   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
     47     IdentifierStr = LastChar;
     48     while (isalnum((LastChar = getchar())))
     49       IdentifierStr += LastChar;
     50 
     51     if (IdentifierStr == "def") return tok_def;
     52     if (IdentifierStr == "extern") return tok_extern;
     53     return tok_identifier;
     54   }
     55 
     56   if (isdigit(LastChar) || LastChar == '.') {   // Number: [0-9.]+
     57     std::string NumStr;
     58     do {
     59       NumStr += LastChar;
     60       LastChar = getchar();
     61     } while (isdigit(LastChar) || LastChar == '.');
     62 
     63     NumVal = strtod(NumStr.c_str(), 0);
     64     return tok_number;
     65   }
     66 
     67   if (LastChar == '#') {
     68     // Comment until end of line.
     69     do LastChar = getchar();
     70     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
     71 
     72     if (LastChar != EOF)
     73       return gettok();
     74   }
     75 
     76   // Check for end of file.  Don't eat the EOF.
     77   if (LastChar == EOF)
     78     return tok_eof;
     79 
     80   // Otherwise, just return the character as its ascii value.
     81   int ThisChar = LastChar;
     82   LastChar = getchar();
     83   return ThisChar;
     84 }
     85 
     86 //===----------------------------------------------------------------------===//
     87 // Abstract Syntax Tree (aka Parse Tree)
     88 //===----------------------------------------------------------------------===//
     89 
     90 /// ExprAST - Base class for all expression nodes.
     91 class ExprAST {
     92 public:
     93   virtual ~ExprAST() {}
     94   virtual Value *Codegen() = 0;
     95 };
     96 
     97 /// NumberExprAST - Expression class for numeric literals like "1.0".
     98 class NumberExprAST : public ExprAST {
     99   double Val;
    100 public:
    101   NumberExprAST(double val) : Val(val) {}
    102   virtual Value *Codegen();
    103 };
    104 
    105 /// VariableExprAST - Expression class for referencing a variable, like "a".
    106 class VariableExprAST : public ExprAST {
    107   std::string Name;
    108 public:
    109   VariableExprAST(const std::string &name) : Name(name) {}
    110   virtual Value *Codegen();
    111 };
    112 
    113 /// BinaryExprAST - Expression class for a binary operator.
    114 class BinaryExprAST : public ExprAST {
    115   char Op;
    116   ExprAST *LHS, *RHS;
    117 public:
    118   BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs)
    119     : Op(op), LHS(lhs), RHS(rhs) {}
    120   virtual Value *Codegen();
    121 };
    122 
    123 /// CallExprAST - Expression class for function calls.
    124 class CallExprAST : public ExprAST {
    125   std::string Callee;
    126   std::vector<ExprAST*> Args;
    127 public:
    128   CallExprAST(const std::string &callee, std::vector<ExprAST*> &args)
    129     : Callee(callee), Args(args) {}
    130   virtual Value *Codegen();
    131 };
    132 
    133 /// PrototypeAST - This class represents the "prototype" for a function,
    134 /// which captures its name, and its argument names (thus implicitly the number
    135 /// of arguments the function takes).
    136 class PrototypeAST {
    137   std::string Name;
    138   std::vector<std::string> Args;
    139 public:
    140   PrototypeAST(const std::string &name, const std::vector<std::string> &args)
    141     : Name(name), Args(args) {}
    142 
    143   Function *Codegen();
    144 };
    145 
    146 /// FunctionAST - This class represents a function definition itself.
    147 class FunctionAST {
    148   PrototypeAST *Proto;
    149   ExprAST *Body;
    150 public:
    151   FunctionAST(PrototypeAST *proto, ExprAST *body)
    152     : Proto(proto), Body(body) {}
    153 
    154   Function *Codegen();
    155 };
    156 
    157 //===----------------------------------------------------------------------===//
    158 // Parser
    159 //===----------------------------------------------------------------------===//
    160 
    161 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
    162 /// token the parser is looking at.  getNextToken reads another token from the
    163 /// lexer and updates CurTok with its results.
    164 static int CurTok;
    165 static int getNextToken() {
    166   return CurTok = gettok();
    167 }
    168 
    169 /// BinopPrecedence - This holds the precedence for each binary operator that is
    170 /// defined.
    171 static std::map<char, int> BinopPrecedence;
    172 
    173 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
    174 static int GetTokPrecedence() {
    175   if (!isascii(CurTok))
    176     return -1;
    177 
    178   // Make sure it's a declared binop.
    179   int TokPrec = BinopPrecedence[CurTok];
    180   if (TokPrec <= 0) return -1;
    181   return TokPrec;
    182 }
    183 
    184 /// Error* - These are little helper functions for error handling.
    185 ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;}
    186 PrototypeAST *ErrorP(const char *Str) { Error(Str); return 0; }
    187 FunctionAST *ErrorF(const char *Str) { Error(Str); return 0; }
    188 
    189 static ExprAST *ParseExpression();
    190 
    191 /// identifierexpr
    192 ///   ::= identifier
    193 ///   ::= identifier '(' expression* ')'
    194 static ExprAST *ParseIdentifierExpr() {
    195   std::string IdName = IdentifierStr;
    196 
    197   getNextToken();  // eat identifier.
    198 
    199   if (CurTok != '(') // Simple variable ref.
    200     return new VariableExprAST(IdName);
    201 
    202   // Call.
    203   getNextToken();  // eat (
    204   std::vector<ExprAST*> Args;
    205   if (CurTok != ')') {
    206     while (1) {
    207       ExprAST *Arg = ParseExpression();
    208       if (!Arg) return 0;
    209       Args.push_back(Arg);
    210 
    211       if (CurTok == ')') break;
    212 
    213       if (CurTok != ',')
    214         return Error("Expected ')' or ',' in argument list");
    215       getNextToken();
    216     }
    217   }
    218 
    219   // Eat the ')'.
    220   getNextToken();
    221 
    222   return new CallExprAST(IdName, Args);
    223 }
    224 
    225 /// numberexpr ::= number
    226 static ExprAST *ParseNumberExpr() {
    227   ExprAST *Result = new NumberExprAST(NumVal);
    228   getNextToken(); // consume the number
    229   return Result;
    230 }
    231 
    232 /// parenexpr ::= '(' expression ')'
    233 static ExprAST *ParseParenExpr() {
    234   getNextToken();  // eat (.
    235   ExprAST *V = ParseExpression();
    236   if (!V) return 0;
    237 
    238   if (CurTok != ')')
    239     return Error("expected ')'");
    240   getNextToken();  // eat ).
    241   return V;
    242 }
    243 
    244 /// primary
    245 ///   ::= identifierexpr
    246 ///   ::= numberexpr
    247 ///   ::= parenexpr
    248 static ExprAST *ParsePrimary() {
    249   switch (CurTok) {
    250   default: return Error("unknown token when expecting an expression");
    251   case tok_identifier: return ParseIdentifierExpr();
    252   case tok_number:     return ParseNumberExpr();
    253   case '(':            return ParseParenExpr();
    254   }
    255 }
    256 
    257 /// binoprhs
    258 ///   ::= ('+' primary)*
    259 static ExprAST *ParseBinOpRHS(int ExprPrec, ExprAST *LHS) {
    260   // If this is a binop, find its precedence.
    261   while (1) {
    262     int TokPrec = GetTokPrecedence();
    263 
    264     // If this is a binop that binds at least as tightly as the current binop,
    265     // consume it, otherwise we are done.
    266     if (TokPrec < ExprPrec)
    267       return LHS;
    268 
    269     // Okay, we know this is a binop.
    270     int BinOp = CurTok;
    271     getNextToken();  // eat binop
    272 
    273     // Parse the primary expression after the binary operator.
    274     ExprAST *RHS = ParsePrimary();
    275     if (!RHS) return 0;
    276 
    277     // If BinOp binds less tightly with RHS than the operator after RHS, let
    278     // the pending operator take RHS as its LHS.
    279     int NextPrec = GetTokPrecedence();
    280     if (TokPrec < NextPrec) {
    281       RHS = ParseBinOpRHS(TokPrec+1, RHS);
    282       if (RHS == 0) return 0;
    283     }
    284 
    285     // Merge LHS/RHS.
    286     LHS = new BinaryExprAST(BinOp, LHS, RHS);
    287   }
    288 }
    289 
    290 /// expression
    291 ///   ::= primary binoprhs
    292 ///
    293 static ExprAST *ParseExpression() {
    294   ExprAST *LHS = ParsePrimary();
    295   if (!LHS) return 0;
    296 
    297   return ParseBinOpRHS(0, LHS);
    298 }
    299 
    300 /// prototype
    301 ///   ::= id '(' id* ')'
    302 static PrototypeAST *ParsePrototype() {
    303   if (CurTok != tok_identifier)
    304     return ErrorP("Expected function name in prototype");
    305 
    306   std::string FnName = IdentifierStr;
    307   getNextToken();
    308 
    309   if (CurTok != '(')
    310     return ErrorP("Expected '(' in prototype");
    311 
    312   std::vector<std::string> ArgNames;
    313   while (getNextToken() == tok_identifier)
    314     ArgNames.push_back(IdentifierStr);
    315   if (CurTok != ')')
    316     return ErrorP("Expected ')' in prototype");
    317 
    318   // success.
    319   getNextToken();  // eat ')'.
    320 
    321   return new PrototypeAST(FnName, ArgNames);
    322 }
    323 
    324 /// definition ::= 'def' prototype expression
    325 static FunctionAST *ParseDefinition() {
    326   getNextToken();  // eat def.
    327   PrototypeAST *Proto = ParsePrototype();
    328   if (Proto == 0) return 0;
    329 
    330   if (ExprAST *E = ParseExpression())
    331     return new FunctionAST(Proto, E);
    332   return 0;
    333 }
    334 
    335 /// toplevelexpr ::= expression
    336 static FunctionAST *ParseTopLevelExpr() {
    337   if (ExprAST *E = ParseExpression()) {
    338     // Make an anonymous proto.
    339     PrototypeAST *Proto = new PrototypeAST("", std::vector<std::string>());
    340     return new FunctionAST(Proto, E);
    341   }
    342   return 0;
    343 }
    344 
    345 /// external ::= 'extern' prototype
    346 static PrototypeAST *ParseExtern() {
    347   getNextToken();  // eat extern.
    348   return ParsePrototype();
    349 }
    350 
    351 //===----------------------------------------------------------------------===//
    352 // Code Generation
    353 //===----------------------------------------------------------------------===//
    354 
    355 static Module *TheModule;
    356 static IRBuilder<> Builder(getGlobalContext());
    357 static std::map<std::string, Value*> NamedValues;
    358 static FunctionPassManager *TheFPM;
    359 
    360 Value *ErrorV(const char *Str) { Error(Str); return 0; }
    361 
    362 Value *NumberExprAST::Codegen() {
    363   return ConstantFP::get(getGlobalContext(), APFloat(Val));
    364 }
    365 
    366 Value *VariableExprAST::Codegen() {
    367   // Look this variable up in the function.
    368   Value *V = NamedValues[Name];
    369   return V ? V : ErrorV("Unknown variable name");
    370 }
    371 
    372 Value *BinaryExprAST::Codegen() {
    373   Value *L = LHS->Codegen();
    374   Value *R = RHS->Codegen();
    375   if (L == 0 || R == 0) return 0;
    376 
    377   switch (Op) {
    378   case '+': return Builder.CreateFAdd(L, R, "addtmp");
    379   case '-': return Builder.CreateFSub(L, R, "subtmp");
    380   case '*': return Builder.CreateFMul(L, R, "multmp");
    381   case '<':
    382     L = Builder.CreateFCmpULT(L, R, "cmptmp");
    383     // Convert bool 0/1 to double 0.0 or 1.0
    384     return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
    385                                 "booltmp");
    386   default: return ErrorV("invalid binary operator");
    387   }
    388 }
    389 
    390 Value *CallExprAST::Codegen() {
    391   // Look up the name in the global module table.
    392   Function *CalleeF = TheModule->getFunction(Callee);
    393   if (CalleeF == 0)
    394     return ErrorV("Unknown function referenced");
    395 
    396   // If argument mismatch error.
    397   if (CalleeF->arg_size() != Args.size())
    398     return ErrorV("Incorrect # arguments passed");
    399 
    400   std::vector<Value*> ArgsV;
    401   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
    402     ArgsV.push_back(Args[i]->Codegen());
    403     if (ArgsV.back() == 0) return 0;
    404   }
    405 
    406   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
    407 }
    408 
    409 Function *PrototypeAST::Codegen() {
    410   // Make the function type:  double(double,double) etc.
    411   std::vector<Type*> Doubles(Args.size(),
    412                              Type::getDoubleTy(getGlobalContext()));
    413   FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()),
    414                                        Doubles, false);
    415 
    416   Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
    417 
    418   // If F conflicted, there was already something named 'Name'.  If it has a
    419   // body, don't allow redefinition or reextern.
    420   if (F->getName() != Name) {
    421     // Delete the one we just made and get the existing one.
    422     F->eraseFromParent();
    423     F = TheModule->getFunction(Name);
    424 
    425     // If F already has a body, reject this.
    426     if (!F->empty()) {
    427       ErrorF("redefinition of function");
    428       return 0;
    429     }
    430 
    431     // If F took a different number of args, reject.
    432     if (F->arg_size() != Args.size()) {
    433       ErrorF("redefinition of function with different # args");
    434       return 0;
    435     }
    436   }
    437 
    438   // Set names for all arguments.
    439   unsigned Idx = 0;
    440   for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
    441        ++AI, ++Idx) {
    442     AI->setName(Args[Idx]);
    443 
    444     // Add arguments to variable symbol table.
    445     NamedValues[Args[Idx]] = AI;
    446   }
    447 
    448   return F;
    449 }
    450 
    451 Function *FunctionAST::Codegen() {
    452   NamedValues.clear();
    453 
    454   Function *TheFunction = Proto->Codegen();
    455   if (TheFunction == 0)
    456     return 0;
    457 
    458   // Create a new basic block to start insertion into.
    459   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
    460   Builder.SetInsertPoint(BB);
    461 
    462   if (Value *RetVal = Body->Codegen()) {
    463     // Finish off the function.
    464     Builder.CreateRet(RetVal);
    465 
    466     // Validate the generated code, checking for consistency.
    467     verifyFunction(*TheFunction);
    468 
    469     // Optimize the function.
    470     TheFPM->run(*TheFunction);
    471 
    472     return TheFunction;
    473   }
    474 
    475   // Error reading body, remove function.
    476   TheFunction->eraseFromParent();
    477   return 0;
    478 }
    479 
    480 //===----------------------------------------------------------------------===//
    481 // Top-Level parsing and JIT Driver
    482 //===----------------------------------------------------------------------===//
    483 
    484 static ExecutionEngine *TheExecutionEngine;
    485 
    486 static void HandleDefinition() {
    487   if (FunctionAST *F = ParseDefinition()) {
    488     if (Function *LF = F->Codegen()) {
    489       fprintf(stderr, "Read function definition:");
    490       LF->dump();
    491     }
    492   } else {
    493     // Skip token for error recovery.
    494     getNextToken();
    495   }
    496 }
    497 
    498 static void HandleExtern() {
    499   if (PrototypeAST *P = ParseExtern()) {
    500     if (Function *F = P->Codegen()) {
    501       fprintf(stderr, "Read extern: ");
    502       F->dump();
    503     }
    504   } else {
    505     // Skip token for error recovery.
    506     getNextToken();
    507   }
    508 }
    509 
    510 static void HandleTopLevelExpression() {
    511   // Evaluate a top-level expression into an anonymous function.
    512   if (FunctionAST *F = ParseTopLevelExpr()) {
    513     if (Function *LF = F->Codegen()) {
    514       // JIT the function, returning a function pointer.
    515       void *FPtr = TheExecutionEngine->getPointerToFunction(LF);
    516 
    517       // Cast it to the right type (takes no arguments, returns a double) so we
    518       // can call it as a native function.
    519       double (*FP)() = (double (*)())(intptr_t)FPtr;
    520       fprintf(stderr, "Evaluated to %f\n", FP());
    521     }
    522   } else {
    523     // Skip token for error recovery.
    524     getNextToken();
    525   }
    526 }
    527 
    528 /// top ::= definition | external | expression | ';'
    529 static void MainLoop() {
    530   while (1) {
    531     fprintf(stderr, "ready> ");
    532     switch (CurTok) {
    533     case tok_eof:    return;
    534     case ';':        getNextToken(); break;  // ignore top-level semicolons.
    535     case tok_def:    HandleDefinition(); break;
    536     case tok_extern: HandleExtern(); break;
    537     default:         HandleTopLevelExpression(); break;
    538     }
    539   }
    540 }
    541 
    542 //===----------------------------------------------------------------------===//
    543 // "Library" functions that can be "extern'd" from user code.
    544 //===----------------------------------------------------------------------===//
    545 
    546 /// putchard - putchar that takes a double and returns 0.
    547 extern "C"
    548 double putchard(double X) {
    549   putchar((char)X);
    550   return 0;
    551 }
    552 
    553 //===----------------------------------------------------------------------===//
    554 // Main driver code.
    555 //===----------------------------------------------------------------------===//
    556 
    557 int main() {
    558   InitializeNativeTarget();
    559   LLVMContext &Context = getGlobalContext();
    560 
    561   // Install standard binary operators.
    562   // 1 is lowest precedence.
    563   BinopPrecedence['<'] = 10;
    564   BinopPrecedence['+'] = 20;
    565   BinopPrecedence['-'] = 20;
    566   BinopPrecedence['*'] = 40;  // highest.
    567 
    568   // Prime the first token.
    569   fprintf(stderr, "ready> ");
    570   getNextToken();
    571 
    572   // Make the module, which holds all the code.
    573   TheModule = new Module("my cool jit", Context);
    574 
    575   // Create the JIT.  This takes ownership of the module.
    576   std::string ErrStr;
    577   TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create();
    578   if (!TheExecutionEngine) {
    579     fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
    580     exit(1);
    581   }
    582 
    583   FunctionPassManager OurFPM(TheModule);
    584 
    585   // Set up the optimizer pipeline.  Start with registering info about how the
    586   // target lays out data structures.
    587   OurFPM.add(new TargetData(*TheExecutionEngine->getTargetData()));
    588   // Provide basic AliasAnalysis support for GVN.
    589   OurFPM.add(createBasicAliasAnalysisPass());
    590   // Do simple "peephole" optimizations and bit-twiddling optzns.
    591   OurFPM.add(createInstructionCombiningPass());
    592   // Reassociate expressions.
    593   OurFPM.add(createReassociatePass());
    594   // Eliminate Common SubExpressions.
    595   OurFPM.add(createGVNPass());
    596   // Simplify the control flow graph (deleting unreachable blocks, etc).
    597   OurFPM.add(createCFGSimplificationPass());
    598 
    599   OurFPM.doInitialization();
    600 
    601   // Set the global so the code gen can use this.
    602   TheFPM = &OurFPM;
    603 
    604   // Run the main "interpreter loop" now.
    605   MainLoop();
    606 
    607   TheFPM = 0;
    608 
    609   // Print out all of the generated code.
    610   TheModule->dump();
    611 
    612   return 0;
    613 }
    614