Home | History | Annotate | Download | only in Chapter4
      1 #include "llvm/Analysis/Passes.h"
      2 #include "llvm/ExecutionEngine/ExecutionEngine.h"
      3 #include "llvm/ExecutionEngine/JIT.h"
      4 #include "llvm/IR/DataLayout.h"
      5 #include "llvm/IR/DerivedTypes.h"
      6 #include "llvm/IR/IRBuilder.h"
      7 #include "llvm/IR/LLVMContext.h"
      8 #include "llvm/IR/Module.h"
      9 #include "llvm/IR/Verifier.h"
     10 #include "llvm/PassManager.h"
     11 #include "llvm/Support/TargetSelect.h"
     12 #include "llvm/Transforms/Scalar.h"
     13 #include <cctype>
     14 #include <cstdio>
     15 #include <map>
     16 #include <string>
     17 #include <vector>
     18 using namespace llvm;
     19 
     20 //===----------------------------------------------------------------------===//
     21 // Lexer
     22 //===----------------------------------------------------------------------===//
     23 
     24 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
     25 // of these for known things.
     26 enum Token {
     27   tok_eof = -1,
     28 
     29   // commands
     30   tok_def = -2, tok_extern = -3,
     31 
     32   // primary
     33   tok_identifier = -4, tok_number = -5
     34 };
     35 
     36 static std::string IdentifierStr;  // Filled in if tok_identifier
     37 static double NumVal;              // Filled in if tok_number
     38 
     39 /// gettok - Return the next token from standard input.
     40 static int gettok() {
     41   static int LastChar = ' ';
     42 
     43   // Skip any whitespace.
     44   while (isspace(LastChar))
     45     LastChar = getchar();
     46 
     47   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
     48     IdentifierStr = LastChar;
     49     while (isalnum((LastChar = getchar())))
     50       IdentifierStr += LastChar;
     51 
     52     if (IdentifierStr == "def") return tok_def;
     53     if (IdentifierStr == "extern") return tok_extern;
     54     return tok_identifier;
     55   }
     56 
     57   if (isdigit(LastChar) || LastChar == '.') {   // Number: [0-9.]+
     58     std::string NumStr;
     59     do {
     60       NumStr += LastChar;
     61       LastChar = getchar();
     62     } while (isdigit(LastChar) || LastChar == '.');
     63 
     64     NumVal = strtod(NumStr.c_str(), 0);
     65     return tok_number;
     66   }
     67 
     68   if (LastChar == '#') {
     69     // Comment until end of line.
     70     do LastChar = getchar();
     71     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
     72 
     73     if (LastChar != EOF)
     74       return gettok();
     75   }
     76 
     77   // Check for end of file.  Don't eat the EOF.
     78   if (LastChar == EOF)
     79     return tok_eof;
     80 
     81   // Otherwise, just return the character as its ascii value.
     82   int ThisChar = LastChar;
     83   LastChar = getchar();
     84   return ThisChar;
     85 }
     86 
     87 //===----------------------------------------------------------------------===//
     88 // Abstract Syntax Tree (aka Parse Tree)
     89 //===----------------------------------------------------------------------===//
     90 namespace {
     91 /// ExprAST - Base class for all expression nodes.
     92 class ExprAST {
     93 public:
     94   virtual ~ExprAST() {}
     95   virtual Value *Codegen() = 0;
     96 };
     97 
     98 /// NumberExprAST - Expression class for numeric literals like "1.0".
     99 class NumberExprAST : public ExprAST {
    100   double Val;
    101 public:
    102   NumberExprAST(double val) : Val(val) {}
    103   virtual Value *Codegen();
    104 };
    105 
    106 /// VariableExprAST - Expression class for referencing a variable, like "a".
    107 class VariableExprAST : public ExprAST {
    108   std::string Name;
    109 public:
    110   VariableExprAST(const std::string &name) : Name(name) {}
    111   virtual Value *Codegen();
    112 };
    113 
    114 /// BinaryExprAST - Expression class for a binary operator.
    115 class BinaryExprAST : public ExprAST {
    116   char Op;
    117   ExprAST *LHS, *RHS;
    118 public:
    119   BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs)
    120     : Op(op), LHS(lhs), RHS(rhs) {}
    121   virtual Value *Codegen();
    122 };
    123 
    124 /// CallExprAST - Expression class for function calls.
    125 class CallExprAST : public ExprAST {
    126   std::string Callee;
    127   std::vector<ExprAST*> Args;
    128 public:
    129   CallExprAST(const std::string &callee, std::vector<ExprAST*> &args)
    130     : Callee(callee), Args(args) {}
    131   virtual Value *Codegen();
    132 };
    133 
    134 /// PrototypeAST - This class represents the "prototype" for a function,
    135 /// which captures its name, and its argument names (thus implicitly the number
    136 /// of arguments the function takes).
    137 class PrototypeAST {
    138   std::string Name;
    139   std::vector<std::string> Args;
    140 public:
    141   PrototypeAST(const std::string &name, const std::vector<std::string> &args)
    142     : Name(name), Args(args) {}
    143 
    144   Function *Codegen();
    145 };
    146 
    147 /// FunctionAST - This class represents a function definition itself.
    148 class FunctionAST {
    149   PrototypeAST *Proto;
    150   ExprAST *Body;
    151 public:
    152   FunctionAST(PrototypeAST *proto, ExprAST *body)
    153     : Proto(proto), Body(body) {}
    154 
    155   Function *Codegen();
    156 };
    157 } // end anonymous namespace
    158 
    159 //===----------------------------------------------------------------------===//
    160 // Parser
    161 //===----------------------------------------------------------------------===//
    162 
    163 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
    164 /// token the parser is looking at.  getNextToken reads another token from the
    165 /// lexer and updates CurTok with its results.
    166 static int CurTok;
    167 static int getNextToken() {
    168   return CurTok = gettok();
    169 }
    170 
    171 /// BinopPrecedence - This holds the precedence for each binary operator that is
    172 /// defined.
    173 static std::map<char, int> BinopPrecedence;
    174 
    175 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
    176 static int GetTokPrecedence() {
    177   if (!isascii(CurTok))
    178     return -1;
    179 
    180   // Make sure it's a declared binop.
    181   int TokPrec = BinopPrecedence[CurTok];
    182   if (TokPrec <= 0) return -1;
    183   return TokPrec;
    184 }
    185 
    186 /// Error* - These are little helper functions for error handling.
    187 ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;}
    188 PrototypeAST *ErrorP(const char *Str) { Error(Str); return 0; }
    189 FunctionAST *ErrorF(const char *Str) { Error(Str); return 0; }
    190 
    191 static ExprAST *ParseExpression();
    192 
    193 /// identifierexpr
    194 ///   ::= identifier
    195 ///   ::= identifier '(' expression* ')'
    196 static ExprAST *ParseIdentifierExpr() {
    197   std::string IdName = IdentifierStr;
    198 
    199   getNextToken();  // eat identifier.
    200 
    201   if (CurTok != '(') // Simple variable ref.
    202     return new VariableExprAST(IdName);
    203 
    204   // Call.
    205   getNextToken();  // eat (
    206   std::vector<ExprAST*> Args;
    207   if (CurTok != ')') {
    208     while (1) {
    209       ExprAST *Arg = ParseExpression();
    210       if (!Arg) return 0;
    211       Args.push_back(Arg);
    212 
    213       if (CurTok == ')') break;
    214 
    215       if (CurTok != ',')
    216         return Error("Expected ')' or ',' in argument list");
    217       getNextToken();
    218     }
    219   }
    220 
    221   // Eat the ')'.
    222   getNextToken();
    223 
    224   return new CallExprAST(IdName, Args);
    225 }
    226 
    227 /// numberexpr ::= number
    228 static ExprAST *ParseNumberExpr() {
    229   ExprAST *Result = new NumberExprAST(NumVal);
    230   getNextToken(); // consume the number
    231   return Result;
    232 }
    233 
    234 /// parenexpr ::= '(' expression ')'
    235 static ExprAST *ParseParenExpr() {
    236   getNextToken();  // eat (.
    237   ExprAST *V = ParseExpression();
    238   if (!V) return 0;
    239 
    240   if (CurTok != ')')
    241     return Error("expected ')'");
    242   getNextToken();  // eat ).
    243   return V;
    244 }
    245 
    246 /// primary
    247 ///   ::= identifierexpr
    248 ///   ::= numberexpr
    249 ///   ::= parenexpr
    250 static ExprAST *ParsePrimary() {
    251   switch (CurTok) {
    252   default: return Error("unknown token when expecting an expression");
    253   case tok_identifier: return ParseIdentifierExpr();
    254   case tok_number:     return ParseNumberExpr();
    255   case '(':            return ParseParenExpr();
    256   }
    257 }
    258 
    259 /// binoprhs
    260 ///   ::= ('+' primary)*
    261 static ExprAST *ParseBinOpRHS(int ExprPrec, ExprAST *LHS) {
    262   // If this is a binop, find its precedence.
    263   while (1) {
    264     int TokPrec = GetTokPrecedence();
    265 
    266     // If this is a binop that binds at least as tightly as the current binop,
    267     // consume it, otherwise we are done.
    268     if (TokPrec < ExprPrec)
    269       return LHS;
    270 
    271     // Okay, we know this is a binop.
    272     int BinOp = CurTok;
    273     getNextToken();  // eat binop
    274 
    275     // Parse the primary expression after the binary operator.
    276     ExprAST *RHS = ParsePrimary();
    277     if (!RHS) return 0;
    278 
    279     // If BinOp binds less tightly with RHS than the operator after RHS, let
    280     // the pending operator take RHS as its LHS.
    281     int NextPrec = GetTokPrecedence();
    282     if (TokPrec < NextPrec) {
    283       RHS = ParseBinOpRHS(TokPrec+1, RHS);
    284       if (RHS == 0) return 0;
    285     }
    286 
    287     // Merge LHS/RHS.
    288     LHS = new BinaryExprAST(BinOp, LHS, RHS);
    289   }
    290 }
    291 
    292 /// expression
    293 ///   ::= primary binoprhs
    294 ///
    295 static ExprAST *ParseExpression() {
    296   ExprAST *LHS = ParsePrimary();
    297   if (!LHS) return 0;
    298 
    299   return ParseBinOpRHS(0, LHS);
    300 }
    301 
    302 /// prototype
    303 ///   ::= id '(' id* ')'
    304 static PrototypeAST *ParsePrototype() {
    305   if (CurTok != tok_identifier)
    306     return ErrorP("Expected function name in prototype");
    307 
    308   std::string FnName = IdentifierStr;
    309   getNextToken();
    310 
    311   if (CurTok != '(')
    312     return ErrorP("Expected '(' in prototype");
    313 
    314   std::vector<std::string> ArgNames;
    315   while (getNextToken() == tok_identifier)
    316     ArgNames.push_back(IdentifierStr);
    317   if (CurTok != ')')
    318     return ErrorP("Expected ')' in prototype");
    319 
    320   // success.
    321   getNextToken();  // eat ')'.
    322 
    323   return new PrototypeAST(FnName, ArgNames);
    324 }
    325 
    326 /// definition ::= 'def' prototype expression
    327 static FunctionAST *ParseDefinition() {
    328   getNextToken();  // eat def.
    329   PrototypeAST *Proto = ParsePrototype();
    330   if (Proto == 0) return 0;
    331 
    332   if (ExprAST *E = ParseExpression())
    333     return new FunctionAST(Proto, E);
    334   return 0;
    335 }
    336 
    337 /// toplevelexpr ::= expression
    338 static FunctionAST *ParseTopLevelExpr() {
    339   if (ExprAST *E = ParseExpression()) {
    340     // Make an anonymous proto.
    341     PrototypeAST *Proto = new PrototypeAST("", std::vector<std::string>());
    342     return new FunctionAST(Proto, E);
    343   }
    344   return 0;
    345 }
    346 
    347 /// external ::= 'extern' prototype
    348 static PrototypeAST *ParseExtern() {
    349   getNextToken();  // eat extern.
    350   return ParsePrototype();
    351 }
    352 
    353 //===----------------------------------------------------------------------===//
    354 // Code Generation
    355 //===----------------------------------------------------------------------===//
    356 
    357 static Module *TheModule;
    358 static IRBuilder<> Builder(getGlobalContext());
    359 static std::map<std::string, Value*> NamedValues;
    360 static FunctionPassManager *TheFPM;
    361 
    362 Value *ErrorV(const char *Str) { Error(Str); return 0; }
    363 
    364 Value *NumberExprAST::Codegen() {
    365   return ConstantFP::get(getGlobalContext(), APFloat(Val));
    366 }
    367 
    368 Value *VariableExprAST::Codegen() {
    369   // Look this variable up in the function.
    370   Value *V = NamedValues[Name];
    371   return V ? V : ErrorV("Unknown variable name");
    372 }
    373 
    374 Value *BinaryExprAST::Codegen() {
    375   Value *L = LHS->Codegen();
    376   Value *R = RHS->Codegen();
    377   if (L == 0 || R == 0) return 0;
    378 
    379   switch (Op) {
    380   case '+': return Builder.CreateFAdd(L, R, "addtmp");
    381   case '-': return Builder.CreateFSub(L, R, "subtmp");
    382   case '*': return Builder.CreateFMul(L, R, "multmp");
    383   case '<':
    384     L = Builder.CreateFCmpULT(L, R, "cmptmp");
    385     // Convert bool 0/1 to double 0.0 or 1.0
    386     return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
    387                                 "booltmp");
    388   default: return ErrorV("invalid binary operator");
    389   }
    390 }
    391 
    392 Value *CallExprAST::Codegen() {
    393   // Look up the name in the global module table.
    394   Function *CalleeF = TheModule->getFunction(Callee);
    395   if (CalleeF == 0)
    396     return ErrorV("Unknown function referenced");
    397 
    398   // If argument mismatch error.
    399   if (CalleeF->arg_size() != Args.size())
    400     return ErrorV("Incorrect # arguments passed");
    401 
    402   std::vector<Value*> ArgsV;
    403   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
    404     ArgsV.push_back(Args[i]->Codegen());
    405     if (ArgsV.back() == 0) return 0;
    406   }
    407 
    408   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
    409 }
    410 
    411 Function *PrototypeAST::Codegen() {
    412   // Make the function type:  double(double,double) etc.
    413   std::vector<Type*> Doubles(Args.size(),
    414                              Type::getDoubleTy(getGlobalContext()));
    415   FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()),
    416                                        Doubles, false);
    417 
    418   Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
    419 
    420   // If F conflicted, there was already something named 'Name'.  If it has a
    421   // body, don't allow redefinition or reextern.
    422   if (F->getName() != Name) {
    423     // Delete the one we just made and get the existing one.
    424     F->eraseFromParent();
    425     F = TheModule->getFunction(Name);
    426 
    427     // If F already has a body, reject this.
    428     if (!F->empty()) {
    429       ErrorF("redefinition of function");
    430       return 0;
    431     }
    432 
    433     // If F took a different number of args, reject.
    434     if (F->arg_size() != Args.size()) {
    435       ErrorF("redefinition of function with different # args");
    436       return 0;
    437     }
    438   }
    439 
    440   // Set names for all arguments.
    441   unsigned Idx = 0;
    442   for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
    443        ++AI, ++Idx) {
    444     AI->setName(Args[Idx]);
    445 
    446     // Add arguments to variable symbol table.
    447     NamedValues[Args[Idx]] = AI;
    448   }
    449 
    450   return F;
    451 }
    452 
    453 Function *FunctionAST::Codegen() {
    454   NamedValues.clear();
    455 
    456   Function *TheFunction = Proto->Codegen();
    457   if (TheFunction == 0)
    458     return 0;
    459 
    460   // Create a new basic block to start insertion into.
    461   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
    462   Builder.SetInsertPoint(BB);
    463 
    464   if (Value *RetVal = Body->Codegen()) {
    465     // Finish off the function.
    466     Builder.CreateRet(RetVal);
    467 
    468     // Validate the generated code, checking for consistency.
    469     verifyFunction(*TheFunction);
    470 
    471     // Optimize the function.
    472     TheFPM->run(*TheFunction);
    473 
    474     return TheFunction;
    475   }
    476 
    477   // Error reading body, remove function.
    478   TheFunction->eraseFromParent();
    479   return 0;
    480 }
    481 
    482 //===----------------------------------------------------------------------===//
    483 // Top-Level parsing and JIT Driver
    484 //===----------------------------------------------------------------------===//
    485 
    486 static ExecutionEngine *TheExecutionEngine;
    487 
    488 static void HandleDefinition() {
    489   if (FunctionAST *F = ParseDefinition()) {
    490     if (Function *LF = F->Codegen()) {
    491       fprintf(stderr, "Read function definition:");
    492       LF->dump();
    493     }
    494   } else {
    495     // Skip token for error recovery.
    496     getNextToken();
    497   }
    498 }
    499 
    500 static void HandleExtern() {
    501   if (PrototypeAST *P = ParseExtern()) {
    502     if (Function *F = P->Codegen()) {
    503       fprintf(stderr, "Read extern: ");
    504       F->dump();
    505     }
    506   } else {
    507     // Skip token for error recovery.
    508     getNextToken();
    509   }
    510 }
    511 
    512 static void HandleTopLevelExpression() {
    513   // Evaluate a top-level expression into an anonymous function.
    514   if (FunctionAST *F = ParseTopLevelExpr()) {
    515     if (Function *LF = F->Codegen()) {
    516       // JIT the function, returning a function pointer.
    517       void *FPtr = TheExecutionEngine->getPointerToFunction(LF);
    518 
    519       // Cast it to the right type (takes no arguments, returns a double) so we
    520       // can call it as a native function.
    521       double (*FP)() = (double (*)())(intptr_t)FPtr;
    522       fprintf(stderr, "Evaluated to %f\n", FP());
    523     }
    524   } else {
    525     // Skip token for error recovery.
    526     getNextToken();
    527   }
    528 }
    529 
    530 /// top ::= definition | external | expression | ';'
    531 static void MainLoop() {
    532   while (1) {
    533     fprintf(stderr, "ready> ");
    534     switch (CurTok) {
    535     case tok_eof:    return;
    536     case ';':        getNextToken(); break;  // ignore top-level semicolons.
    537     case tok_def:    HandleDefinition(); break;
    538     case tok_extern: HandleExtern(); break;
    539     default:         HandleTopLevelExpression(); break;
    540     }
    541   }
    542 }
    543 
    544 //===----------------------------------------------------------------------===//
    545 // "Library" functions that can be "extern'd" from user code.
    546 //===----------------------------------------------------------------------===//
    547 
    548 /// putchard - putchar that takes a double and returns 0.
    549 extern "C"
    550 double putchard(double X) {
    551   putchar((char)X);
    552   return 0;
    553 }
    554 
    555 //===----------------------------------------------------------------------===//
    556 // Main driver code.
    557 //===----------------------------------------------------------------------===//
    558 
    559 int main() {
    560   InitializeNativeTarget();
    561   LLVMContext &Context = getGlobalContext();
    562 
    563   // Install standard binary operators.
    564   // 1 is lowest precedence.
    565   BinopPrecedence['<'] = 10;
    566   BinopPrecedence['+'] = 20;
    567   BinopPrecedence['-'] = 20;
    568   BinopPrecedence['*'] = 40;  // highest.
    569 
    570   // Prime the first token.
    571   fprintf(stderr, "ready> ");
    572   getNextToken();
    573 
    574   // Make the module, which holds all the code.
    575   TheModule = new Module("my cool jit", Context);
    576 
    577   // Create the JIT.  This takes ownership of the module.
    578   std::string ErrStr;
    579   TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create();
    580   if (!TheExecutionEngine) {
    581     fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
    582     exit(1);
    583   }
    584 
    585   FunctionPassManager OurFPM(TheModule);
    586 
    587   // Set up the optimizer pipeline.  Start with registering info about how the
    588   // target lays out data structures.
    589   TheModule->setDataLayout(TheExecutionEngine->getDataLayout());
    590   OurFPM.add(new DataLayoutPass(TheModule));
    591   // Provide basic AliasAnalysis support for GVN.
    592   OurFPM.add(createBasicAliasAnalysisPass());
    593   // Do simple "peephole" optimizations and bit-twiddling optzns.
    594   OurFPM.add(createInstructionCombiningPass());
    595   // Reassociate expressions.
    596   OurFPM.add(createReassociatePass());
    597   // Eliminate Common SubExpressions.
    598   OurFPM.add(createGVNPass());
    599   // Simplify the control flow graph (deleting unreachable blocks, etc).
    600   OurFPM.add(createCFGSimplificationPass());
    601 
    602   OurFPM.doInitialization();
    603 
    604   // Set the global so the code gen can use this.
    605   TheFPM = &OurFPM;
    606 
    607   // Run the main "interpreter loop" now.
    608   MainLoop();
    609 
    610   TheFPM = 0;
    611 
    612   // Print out all of the generated code.
    613   TheModule->dump();
    614 
    615   return 0;
    616 }
    617