Home | History | Annotate | Download | only in Chapter3
      1 #include "llvm/DerivedTypes.h"
      2 #include "llvm/LLVMContext.h"
      3 #include "llvm/Module.h"
      4 #include "llvm/Analysis/Verifier.h"
      5 #include "llvm/Support/IRBuilder.h"
      6 #include <cstdio>
      7 #include <string>
      8 #include <map>
      9 #include <vector>
     10 using namespace llvm;
     11 
     12 //===----------------------------------------------------------------------===//
     13 // Lexer
     14 //===----------------------------------------------------------------------===//
     15 
     16 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
     17 // of these for known things.
     18 enum Token {
     19   tok_eof = -1,
     20 
     21   // commands
     22   tok_def = -2, tok_extern = -3,
     23 
     24   // primary
     25   tok_identifier = -4, tok_number = -5
     26 };
     27 
     28 static std::string IdentifierStr;  // Filled in if tok_identifier
     29 static double NumVal;              // Filled in if tok_number
     30 
     31 /// gettok - Return the next token from standard input.
     32 static int gettok() {
     33   static int LastChar = ' ';
     34 
     35   // Skip any whitespace.
     36   while (isspace(LastChar))
     37     LastChar = getchar();
     38 
     39   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
     40     IdentifierStr = LastChar;
     41     while (isalnum((LastChar = getchar())))
     42       IdentifierStr += LastChar;
     43 
     44     if (IdentifierStr == "def") return tok_def;
     45     if (IdentifierStr == "extern") return tok_extern;
     46     return tok_identifier;
     47   }
     48 
     49   if (isdigit(LastChar) || LastChar == '.') {   // Number: [0-9.]+
     50     std::string NumStr;
     51     do {
     52       NumStr += LastChar;
     53       LastChar = getchar();
     54     } while (isdigit(LastChar) || LastChar == '.');
     55 
     56     NumVal = strtod(NumStr.c_str(), 0);
     57     return tok_number;
     58   }
     59 
     60   if (LastChar == '#') {
     61     // Comment until end of line.
     62     do LastChar = getchar();
     63     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
     64 
     65     if (LastChar != EOF)
     66       return gettok();
     67   }
     68 
     69   // Check for end of file.  Don't eat the EOF.
     70   if (LastChar == EOF)
     71     return tok_eof;
     72 
     73   // Otherwise, just return the character as its ascii value.
     74   int ThisChar = LastChar;
     75   LastChar = getchar();
     76   return ThisChar;
     77 }
     78 
     79 //===----------------------------------------------------------------------===//
     80 // Abstract Syntax Tree (aka Parse Tree)
     81 //===----------------------------------------------------------------------===//
     82 
     83 /// ExprAST - Base class for all expression nodes.
     84 class ExprAST {
     85 public:
     86   virtual ~ExprAST() {}
     87   virtual Value *Codegen() = 0;
     88 };
     89 
     90 /// NumberExprAST - Expression class for numeric literals like "1.0".
     91 class NumberExprAST : public ExprAST {
     92   double Val;
     93 public:
     94   NumberExprAST(double val) : Val(val) {}
     95   virtual Value *Codegen();
     96 };
     97 
     98 /// VariableExprAST - Expression class for referencing a variable, like "a".
     99 class VariableExprAST : public ExprAST {
    100   std::string Name;
    101 public:
    102   VariableExprAST(const std::string &name) : Name(name) {}
    103   virtual Value *Codegen();
    104 };
    105 
    106 /// BinaryExprAST - Expression class for a binary operator.
    107 class BinaryExprAST : public ExprAST {
    108   char Op;
    109   ExprAST *LHS, *RHS;
    110 public:
    111   BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs)
    112     : Op(op), LHS(lhs), RHS(rhs) {}
    113   virtual Value *Codegen();
    114 };
    115 
    116 /// CallExprAST - Expression class for function calls.
    117 class CallExprAST : public ExprAST {
    118   std::string Callee;
    119   std::vector<ExprAST*> Args;
    120 public:
    121   CallExprAST(const std::string &callee, std::vector<ExprAST*> &args)
    122     : Callee(callee), Args(args) {}
    123   virtual Value *Codegen();
    124 };
    125 
    126 /// PrototypeAST - This class represents the "prototype" for a function,
    127 /// which captures its name, and its argument names (thus implicitly the number
    128 /// of arguments the function takes).
    129 class PrototypeAST {
    130   std::string Name;
    131   std::vector<std::string> Args;
    132 public:
    133   PrototypeAST(const std::string &name, const std::vector<std::string> &args)
    134     : Name(name), Args(args) {}
    135 
    136   Function *Codegen();
    137 };
    138 
    139 /// FunctionAST - This class represents a function definition itself.
    140 class FunctionAST {
    141   PrototypeAST *Proto;
    142   ExprAST *Body;
    143 public:
    144   FunctionAST(PrototypeAST *proto, ExprAST *body)
    145     : Proto(proto), Body(body) {}
    146 
    147   Function *Codegen();
    148 };
    149 
    150 //===----------------------------------------------------------------------===//
    151 // Parser
    152 //===----------------------------------------------------------------------===//
    153 
    154 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
    155 /// token the parser is looking at.  getNextToken reads another token from the
    156 /// lexer and updates CurTok with its results.
    157 static int CurTok;
    158 static int getNextToken() {
    159   return CurTok = gettok();
    160 }
    161 
    162 /// BinopPrecedence - This holds the precedence for each binary operator that is
    163 /// defined.
    164 static std::map<char, int> BinopPrecedence;
    165 
    166 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
    167 static int GetTokPrecedence() {
    168   if (!isascii(CurTok))
    169     return -1;
    170 
    171   // Make sure it's a declared binop.
    172   int TokPrec = BinopPrecedence[CurTok];
    173   if (TokPrec <= 0) return -1;
    174   return TokPrec;
    175 }
    176 
    177 /// Error* - These are little helper functions for error handling.
    178 ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;}
    179 PrototypeAST *ErrorP(const char *Str) { Error(Str); return 0; }
    180 FunctionAST *ErrorF(const char *Str) { Error(Str); return 0; }
    181 
    182 static ExprAST *ParseExpression();
    183 
    184 /// identifierexpr
    185 ///   ::= identifier
    186 ///   ::= identifier '(' expression* ')'
    187 static ExprAST *ParseIdentifierExpr() {
    188   std::string IdName = IdentifierStr;
    189 
    190   getNextToken();  // eat identifier.
    191 
    192   if (CurTok != '(') // Simple variable ref.
    193     return new VariableExprAST(IdName);
    194 
    195   // Call.
    196   getNextToken();  // eat (
    197   std::vector<ExprAST*> Args;
    198   if (CurTok != ')') {
    199     while (1) {
    200       ExprAST *Arg = ParseExpression();
    201       if (!Arg) return 0;
    202       Args.push_back(Arg);
    203 
    204       if (CurTok == ')') break;
    205 
    206       if (CurTok != ',')
    207         return Error("Expected ')' or ',' in argument list");
    208       getNextToken();
    209     }
    210   }
    211 
    212   // Eat the ')'.
    213   getNextToken();
    214 
    215   return new CallExprAST(IdName, Args);
    216 }
    217 
    218 /// numberexpr ::= number
    219 static ExprAST *ParseNumberExpr() {
    220   ExprAST *Result = new NumberExprAST(NumVal);
    221   getNextToken(); // consume the number
    222   return Result;
    223 }
    224 
    225 /// parenexpr ::= '(' expression ')'
    226 static ExprAST *ParseParenExpr() {
    227   getNextToken();  // eat (.
    228   ExprAST *V = ParseExpression();
    229   if (!V) return 0;
    230 
    231   if (CurTok != ')')
    232     return Error("expected ')'");
    233   getNextToken();  // eat ).
    234   return V;
    235 }
    236 
    237 /// primary
    238 ///   ::= identifierexpr
    239 ///   ::= numberexpr
    240 ///   ::= parenexpr
    241 static ExprAST *ParsePrimary() {
    242   switch (CurTok) {
    243   default: return Error("unknown token when expecting an expression");
    244   case tok_identifier: return ParseIdentifierExpr();
    245   case tok_number:     return ParseNumberExpr();
    246   case '(':            return ParseParenExpr();
    247   }
    248 }
    249 
    250 /// binoprhs
    251 ///   ::= ('+' primary)*
    252 static ExprAST *ParseBinOpRHS(int ExprPrec, ExprAST *LHS) {
    253   // If this is a binop, find its precedence.
    254   while (1) {
    255     int TokPrec = GetTokPrecedence();
    256 
    257     // If this is a binop that binds at least as tightly as the current binop,
    258     // consume it, otherwise we are done.
    259     if (TokPrec < ExprPrec)
    260       return LHS;
    261 
    262     // Okay, we know this is a binop.
    263     int BinOp = CurTok;
    264     getNextToken();  // eat binop
    265 
    266     // Parse the primary expression after the binary operator.
    267     ExprAST *RHS = ParsePrimary();
    268     if (!RHS) return 0;
    269 
    270     // If BinOp binds less tightly with RHS than the operator after RHS, let
    271     // the pending operator take RHS as its LHS.
    272     int NextPrec = GetTokPrecedence();
    273     if (TokPrec < NextPrec) {
    274       RHS = ParseBinOpRHS(TokPrec+1, RHS);
    275       if (RHS == 0) return 0;
    276     }
    277 
    278     // Merge LHS/RHS.
    279     LHS = new BinaryExprAST(BinOp, LHS, RHS);
    280   }
    281 }
    282 
    283 /// expression
    284 ///   ::= primary binoprhs
    285 ///
    286 static ExprAST *ParseExpression() {
    287   ExprAST *LHS = ParsePrimary();
    288   if (!LHS) return 0;
    289 
    290   return ParseBinOpRHS(0, LHS);
    291 }
    292 
    293 /// prototype
    294 ///   ::= id '(' id* ')'
    295 static PrototypeAST *ParsePrototype() {
    296   if (CurTok != tok_identifier)
    297     return ErrorP("Expected function name in prototype");
    298 
    299   std::string FnName = IdentifierStr;
    300   getNextToken();
    301 
    302   if (CurTok != '(')
    303     return ErrorP("Expected '(' in prototype");
    304 
    305   std::vector<std::string> ArgNames;
    306   while (getNextToken() == tok_identifier)
    307     ArgNames.push_back(IdentifierStr);
    308   if (CurTok != ')')
    309     return ErrorP("Expected ')' in prototype");
    310 
    311   // success.
    312   getNextToken();  // eat ')'.
    313 
    314   return new PrototypeAST(FnName, ArgNames);
    315 }
    316 
    317 /// definition ::= 'def' prototype expression
    318 static FunctionAST *ParseDefinition() {
    319   getNextToken();  // eat def.
    320   PrototypeAST *Proto = ParsePrototype();
    321   if (Proto == 0) return 0;
    322 
    323   if (ExprAST *E = ParseExpression())
    324     return new FunctionAST(Proto, E);
    325   return 0;
    326 }
    327 
    328 /// toplevelexpr ::= expression
    329 static FunctionAST *ParseTopLevelExpr() {
    330   if (ExprAST *E = ParseExpression()) {
    331     // Make an anonymous proto.
    332     PrototypeAST *Proto = new PrototypeAST("", std::vector<std::string>());
    333     return new FunctionAST(Proto, E);
    334   }
    335   return 0;
    336 }
    337 
    338 /// external ::= 'extern' prototype
    339 static PrototypeAST *ParseExtern() {
    340   getNextToken();  // eat extern.
    341   return ParsePrototype();
    342 }
    343 
    344 //===----------------------------------------------------------------------===//
    345 // Code Generation
    346 //===----------------------------------------------------------------------===//
    347 
    348 static Module *TheModule;
    349 static IRBuilder<> Builder(getGlobalContext());
    350 static std::map<std::string, Value*> NamedValues;
    351 
    352 Value *ErrorV(const char *Str) { Error(Str); return 0; }
    353 
    354 Value *NumberExprAST::Codegen() {
    355   return ConstantFP::get(getGlobalContext(), APFloat(Val));
    356 }
    357 
    358 Value *VariableExprAST::Codegen() {
    359   // Look this variable up in the function.
    360   Value *V = NamedValues[Name];
    361   return V ? V : ErrorV("Unknown variable name");
    362 }
    363 
    364 Value *BinaryExprAST::Codegen() {
    365   Value *L = LHS->Codegen();
    366   Value *R = RHS->Codegen();
    367   if (L == 0 || R == 0) return 0;
    368 
    369   switch (Op) {
    370   case '+': return Builder.CreateFAdd(L, R, "addtmp");
    371   case '-': return Builder.CreateFSub(L, R, "subtmp");
    372   case '*': return Builder.CreateFMul(L, R, "multmp");
    373   case '<':
    374     L = Builder.CreateFCmpULT(L, R, "cmptmp");
    375     // Convert bool 0/1 to double 0.0 or 1.0
    376     return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
    377                                 "booltmp");
    378   default: return ErrorV("invalid binary operator");
    379   }
    380 }
    381 
    382 Value *CallExprAST::Codegen() {
    383   // Look up the name in the global module table.
    384   Function *CalleeF = TheModule->getFunction(Callee);
    385   if (CalleeF == 0)
    386     return ErrorV("Unknown function referenced");
    387 
    388   // If argument mismatch error.
    389   if (CalleeF->arg_size() != Args.size())
    390     return ErrorV("Incorrect # arguments passed");
    391 
    392   std::vector<Value*> ArgsV;
    393   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
    394     ArgsV.push_back(Args[i]->Codegen());
    395     if (ArgsV.back() == 0) return 0;
    396   }
    397 
    398   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
    399 }
    400 
    401 Function *PrototypeAST::Codegen() {
    402   // Make the function type:  double(double,double) etc.
    403   std::vector<Type*> Doubles(Args.size(),
    404                              Type::getDoubleTy(getGlobalContext()));
    405   FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()),
    406                                        Doubles, false);
    407 
    408   Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
    409 
    410   // If F conflicted, there was already something named 'Name'.  If it has a
    411   // body, don't allow redefinition or reextern.
    412   if (F->getName() != Name) {
    413     // Delete the one we just made and get the existing one.
    414     F->eraseFromParent();
    415     F = TheModule->getFunction(Name);
    416 
    417     // If F already has a body, reject this.
    418     if (!F->empty()) {
    419       ErrorF("redefinition of function");
    420       return 0;
    421     }
    422 
    423     // If F took a different number of args, reject.
    424     if (F->arg_size() != Args.size()) {
    425       ErrorF("redefinition of function with different # args");
    426       return 0;
    427     }
    428   }
    429 
    430   // Set names for all arguments.
    431   unsigned Idx = 0;
    432   for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
    433        ++AI, ++Idx) {
    434     AI->setName(Args[Idx]);
    435 
    436     // Add arguments to variable symbol table.
    437     NamedValues[Args[Idx]] = AI;
    438   }
    439 
    440   return F;
    441 }
    442 
    443 Function *FunctionAST::Codegen() {
    444   NamedValues.clear();
    445 
    446   Function *TheFunction = Proto->Codegen();
    447   if (TheFunction == 0)
    448     return 0;
    449 
    450   // Create a new basic block to start insertion into.
    451   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
    452   Builder.SetInsertPoint(BB);
    453 
    454   if (Value *RetVal = Body->Codegen()) {
    455     // Finish off the function.
    456     Builder.CreateRet(RetVal);
    457 
    458     // Validate the generated code, checking for consistency.
    459     verifyFunction(*TheFunction);
    460 
    461     return TheFunction;
    462   }
    463 
    464   // Error reading body, remove function.
    465   TheFunction->eraseFromParent();
    466   return 0;
    467 }
    468 
    469 //===----------------------------------------------------------------------===//
    470 // Top-Level parsing and JIT Driver
    471 //===----------------------------------------------------------------------===//
    472 
    473 static void HandleDefinition() {
    474   if (FunctionAST *F = ParseDefinition()) {
    475     if (Function *LF = F->Codegen()) {
    476       fprintf(stderr, "Read function definition:");
    477       LF->dump();
    478     }
    479   } else {
    480     // Skip token for error recovery.
    481     getNextToken();
    482   }
    483 }
    484 
    485 static void HandleExtern() {
    486   if (PrototypeAST *P = ParseExtern()) {
    487     if (Function *F = P->Codegen()) {
    488       fprintf(stderr, "Read extern: ");
    489       F->dump();
    490     }
    491   } else {
    492     // Skip token for error recovery.
    493     getNextToken();
    494   }
    495 }
    496 
    497 static void HandleTopLevelExpression() {
    498   // Evaluate a top-level expression into an anonymous function.
    499   if (FunctionAST *F = ParseTopLevelExpr()) {
    500     if (Function *LF = F->Codegen()) {
    501       fprintf(stderr, "Read top-level expression:");
    502       LF->dump();
    503     }
    504   } else {
    505     // Skip token for error recovery.
    506     getNextToken();
    507   }
    508 }
    509 
    510 /// top ::= definition | external | expression | ';'
    511 static void MainLoop() {
    512   while (1) {
    513     fprintf(stderr, "ready> ");
    514     switch (CurTok) {
    515     case tok_eof:    return;
    516     case ';':        getNextToken(); break;  // ignore top-level semicolons.
    517     case tok_def:    HandleDefinition(); break;
    518     case tok_extern: HandleExtern(); break;
    519     default:         HandleTopLevelExpression(); break;
    520     }
    521   }
    522 }
    523 
    524 //===----------------------------------------------------------------------===//
    525 // "Library" functions that can be "extern'd" from user code.
    526 //===----------------------------------------------------------------------===//
    527 
    528 /// putchard - putchar that takes a double and returns 0.
    529 extern "C"
    530 double putchard(double X) {
    531   putchar((char)X);
    532   return 0;
    533 }
    534 
    535 //===----------------------------------------------------------------------===//
    536 // Main driver code.
    537 //===----------------------------------------------------------------------===//
    538 
    539 int main() {
    540   LLVMContext &Context = getGlobalContext();
    541 
    542   // Install standard binary operators.
    543   // 1 is lowest precedence.
    544   BinopPrecedence['<'] = 10;
    545   BinopPrecedence['+'] = 20;
    546   BinopPrecedence['-'] = 20;
    547   BinopPrecedence['*'] = 40;  // highest.
    548 
    549   // Prime the first token.
    550   fprintf(stderr, "ready> ");
    551   getNextToken();
    552 
    553   // Make the module, which holds all the code.
    554   TheModule = new Module("my cool jit", Context);
    555 
    556   // Run the main "interpreter loop" now.
    557   MainLoop();
    558 
    559   // Print out all of the generated code.
    560   TheModule->dump();
    561 
    562   return 0;
    563 }
    564