Home | History | Annotate | Download | only in Chapter3
      1 #include "llvm/ADT/APFloat.h"
      2 #include "llvm/ADT/STLExtras.h"
      3 #include "llvm/IR/BasicBlock.h"
      4 #include "llvm/IR/Constants.h"
      5 #include "llvm/IR/DerivedTypes.h"
      6 #include "llvm/IR/Function.h"
      7 #include "llvm/IR/IRBuilder.h"
      8 #include "llvm/IR/LLVMContext.h"
      9 #include "llvm/IR/Module.h"
     10 #include "llvm/IR/Type.h"
     11 #include "llvm/IR/Verifier.h"
     12 #include <cctype>
     13 #include <cstdio>
     14 #include <cstdlib>
     15 #include <map>
     16 #include <memory>
     17 #include <string>
     18 #include <vector>
     19 
     20 using namespace llvm;
     21 
     22 //===----------------------------------------------------------------------===//
     23 // Lexer
     24 //===----------------------------------------------------------------------===//
     25 
     26 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
     27 // of these for known things.
     28 enum Token {
     29   tok_eof = -1,
     30 
     31   // commands
     32   tok_def = -2,
     33   tok_extern = -3,
     34 
     35   // primary
     36   tok_identifier = -4,
     37   tok_number = -5
     38 };
     39 
     40 static std::string IdentifierStr; // Filled in if tok_identifier
     41 static double NumVal;             // Filled in if tok_number
     42 
     43 /// gettok - Return the next token from standard input.
     44 static int gettok() {
     45   static int LastChar = ' ';
     46 
     47   // Skip any whitespace.
     48   while (isspace(LastChar))
     49     LastChar = getchar();
     50 
     51   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
     52     IdentifierStr = LastChar;
     53     while (isalnum((LastChar = getchar())))
     54       IdentifierStr += LastChar;
     55 
     56     if (IdentifierStr == "def")
     57       return tok_def;
     58     if (IdentifierStr == "extern")
     59       return tok_extern;
     60     return tok_identifier;
     61   }
     62 
     63   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
     64     std::string NumStr;
     65     do {
     66       NumStr += LastChar;
     67       LastChar = getchar();
     68     } while (isdigit(LastChar) || LastChar == '.');
     69 
     70     NumVal = strtod(NumStr.c_str(), nullptr);
     71     return tok_number;
     72   }
     73 
     74   if (LastChar == '#') {
     75     // Comment until end of line.
     76     do
     77       LastChar = getchar();
     78     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
     79 
     80     if (LastChar != EOF)
     81       return gettok();
     82   }
     83 
     84   // Check for end of file.  Don't eat the EOF.
     85   if (LastChar == EOF)
     86     return tok_eof;
     87 
     88   // Otherwise, just return the character as its ascii value.
     89   int ThisChar = LastChar;
     90   LastChar = getchar();
     91   return ThisChar;
     92 }
     93 
     94 //===----------------------------------------------------------------------===//
     95 // Abstract Syntax Tree (aka Parse Tree)
     96 //===----------------------------------------------------------------------===//
     97 namespace {
     98 /// ExprAST - Base class for all expression nodes.
     99 class ExprAST {
    100 public:
    101   virtual ~ExprAST() {}
    102   virtual Value *codegen() = 0;
    103 };
    104 
    105 /// NumberExprAST - Expression class for numeric literals like "1.0".
    106 class NumberExprAST : public ExprAST {
    107   double Val;
    108 
    109 public:
    110   NumberExprAST(double Val) : Val(Val) {}
    111   Value *codegen() override;
    112 };
    113 
    114 /// VariableExprAST - Expression class for referencing a variable, like "a".
    115 class VariableExprAST : public ExprAST {
    116   std::string Name;
    117 
    118 public:
    119   VariableExprAST(const std::string &Name) : Name(Name) {}
    120   Value *codegen() override;
    121 };
    122 
    123 /// BinaryExprAST - Expression class for a binary operator.
    124 class BinaryExprAST : public ExprAST {
    125   char Op;
    126   std::unique_ptr<ExprAST> LHS, RHS;
    127 
    128 public:
    129   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
    130                 std::unique_ptr<ExprAST> RHS)
    131       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
    132   Value *codegen() override;
    133 };
    134 
    135 /// CallExprAST - Expression class for function calls.
    136 class CallExprAST : public ExprAST {
    137   std::string Callee;
    138   std::vector<std::unique_ptr<ExprAST>> Args;
    139 
    140 public:
    141   CallExprAST(const std::string &Callee,
    142               std::vector<std::unique_ptr<ExprAST>> Args)
    143       : Callee(Callee), Args(std::move(Args)) {}
    144   Value *codegen() override;
    145 };
    146 
    147 /// PrototypeAST - This class represents the "prototype" for a function,
    148 /// which captures its name, and its argument names (thus implicitly the number
    149 /// of arguments the function takes).
    150 class PrototypeAST {
    151   std::string Name;
    152   std::vector<std::string> Args;
    153 
    154 public:
    155   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
    156       : Name(Name), Args(std::move(Args)) {}
    157   Function *codegen();
    158   const std::string &getName() const { return Name; }
    159 };
    160 
    161 /// FunctionAST - This class represents a function definition itself.
    162 class FunctionAST {
    163   std::unique_ptr<PrototypeAST> Proto;
    164   std::unique_ptr<ExprAST> Body;
    165 
    166 public:
    167   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
    168               std::unique_ptr<ExprAST> Body)
    169       : Proto(std::move(Proto)), Body(std::move(Body)) {}
    170   Function *codegen();
    171 };
    172 } // end anonymous namespace
    173 
    174 //===----------------------------------------------------------------------===//
    175 // Parser
    176 //===----------------------------------------------------------------------===//
    177 
    178 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
    179 /// token the parser is looking at.  getNextToken reads another token from the
    180 /// lexer and updates CurTok with its results.
    181 static int CurTok;
    182 static int getNextToken() { return CurTok = gettok(); }
    183 
    184 /// BinopPrecedence - This holds the precedence for each binary operator that is
    185 /// defined.
    186 static std::map<char, int> BinopPrecedence;
    187 
    188 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
    189 static int GetTokPrecedence() {
    190   if (!isascii(CurTok))
    191     return -1;
    192 
    193   // Make sure it's a declared binop.
    194   int TokPrec = BinopPrecedence[CurTok];
    195   if (TokPrec <= 0)
    196     return -1;
    197   return TokPrec;
    198 }
    199 
    200 /// LogError* - These are little helper functions for error handling.
    201 std::unique_ptr<ExprAST> LogError(const char *Str) {
    202   fprintf(stderr, "Error: %s\n", Str);
    203   return nullptr;
    204 }
    205 
    206 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
    207   LogError(Str);
    208   return nullptr;
    209 }
    210 
    211 static std::unique_ptr<ExprAST> ParseExpression();
    212 
    213 /// numberexpr ::= number
    214 static std::unique_ptr<ExprAST> ParseNumberExpr() {
    215   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
    216   getNextToken(); // consume the number
    217   return std::move(Result);
    218 }
    219 
    220 /// parenexpr ::= '(' expression ')'
    221 static std::unique_ptr<ExprAST> ParseParenExpr() {
    222   getNextToken(); // eat (.
    223   auto V = ParseExpression();
    224   if (!V)
    225     return nullptr;
    226 
    227   if (CurTok != ')')
    228     return LogError("expected ')'");
    229   getNextToken(); // eat ).
    230   return V;
    231 }
    232 
    233 /// identifierexpr
    234 ///   ::= identifier
    235 ///   ::= identifier '(' expression* ')'
    236 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
    237   std::string IdName = IdentifierStr;
    238 
    239   getNextToken(); // eat identifier.
    240 
    241   if (CurTok != '(') // Simple variable ref.
    242     return llvm::make_unique<VariableExprAST>(IdName);
    243 
    244   // Call.
    245   getNextToken(); // eat (
    246   std::vector<std::unique_ptr<ExprAST>> Args;
    247   if (CurTok != ')') {
    248     while (true) {
    249       if (auto Arg = ParseExpression())
    250         Args.push_back(std::move(Arg));
    251       else
    252         return nullptr;
    253 
    254       if (CurTok == ')')
    255         break;
    256 
    257       if (CurTok != ',')
    258         return LogError("Expected ')' or ',' in argument list");
    259       getNextToken();
    260     }
    261   }
    262 
    263   // Eat the ')'.
    264   getNextToken();
    265 
    266   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
    267 }
    268 
    269 /// primary
    270 ///   ::= identifierexpr
    271 ///   ::= numberexpr
    272 ///   ::= parenexpr
    273 static std::unique_ptr<ExprAST> ParsePrimary() {
    274   switch (CurTok) {
    275   default:
    276     return LogError("unknown token when expecting an expression");
    277   case tok_identifier:
    278     return ParseIdentifierExpr();
    279   case tok_number:
    280     return ParseNumberExpr();
    281   case '(':
    282     return ParseParenExpr();
    283   }
    284 }
    285 
    286 /// binoprhs
    287 ///   ::= ('+' primary)*
    288 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
    289                                               std::unique_ptr<ExprAST> LHS) {
    290   // If this is a binop, find its precedence.
    291   while (true) {
    292     int TokPrec = GetTokPrecedence();
    293 
    294     // If this is a binop that binds at least as tightly as the current binop,
    295     // consume it, otherwise we are done.
    296     if (TokPrec < ExprPrec)
    297       return LHS;
    298 
    299     // Okay, we know this is a binop.
    300     int BinOp = CurTok;
    301     getNextToken(); // eat binop
    302 
    303     // Parse the primary expression after the binary operator.
    304     auto RHS = ParsePrimary();
    305     if (!RHS)
    306       return nullptr;
    307 
    308     // If BinOp binds less tightly with RHS than the operator after RHS, let
    309     // the pending operator take RHS as its LHS.
    310     int NextPrec = GetTokPrecedence();
    311     if (TokPrec < NextPrec) {
    312       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
    313       if (!RHS)
    314         return nullptr;
    315     }
    316 
    317     // Merge LHS/RHS.
    318     LHS =
    319         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
    320   }
    321 }
    322 
    323 /// expression
    324 ///   ::= primary binoprhs
    325 ///
    326 static std::unique_ptr<ExprAST> ParseExpression() {
    327   auto LHS = ParsePrimary();
    328   if (!LHS)
    329     return nullptr;
    330 
    331   return ParseBinOpRHS(0, std::move(LHS));
    332 }
    333 
    334 /// prototype
    335 ///   ::= id '(' id* ')'
    336 static std::unique_ptr<PrototypeAST> ParsePrototype() {
    337   if (CurTok != tok_identifier)
    338     return LogErrorP("Expected function name in prototype");
    339 
    340   std::string FnName = IdentifierStr;
    341   getNextToken();
    342 
    343   if (CurTok != '(')
    344     return LogErrorP("Expected '(' in prototype");
    345 
    346   std::vector<std::string> ArgNames;
    347   while (getNextToken() == tok_identifier)
    348     ArgNames.push_back(IdentifierStr);
    349   if (CurTok != ')')
    350     return LogErrorP("Expected ')' in prototype");
    351 
    352   // success.
    353   getNextToken(); // eat ')'.
    354 
    355   return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
    356 }
    357 
    358 /// definition ::= 'def' prototype expression
    359 static std::unique_ptr<FunctionAST> ParseDefinition() {
    360   getNextToken(); // eat def.
    361   auto Proto = ParsePrototype();
    362   if (!Proto)
    363     return nullptr;
    364 
    365   if (auto E = ParseExpression())
    366     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
    367   return nullptr;
    368 }
    369 
    370 /// toplevelexpr ::= expression
    371 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
    372   if (auto E = ParseExpression()) {
    373     // Make an anonymous proto.
    374     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
    375                                                  std::vector<std::string>());
    376     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
    377   }
    378   return nullptr;
    379 }
    380 
    381 /// external ::= 'extern' prototype
    382 static std::unique_ptr<PrototypeAST> ParseExtern() {
    383   getNextToken(); // eat extern.
    384   return ParsePrototype();
    385 }
    386 
    387 //===----------------------------------------------------------------------===//
    388 // Code Generation
    389 //===----------------------------------------------------------------------===//
    390 
    391 static LLVMContext TheContext;
    392 static IRBuilder<> Builder(TheContext);
    393 static std::unique_ptr<Module> TheModule;
    394 static std::map<std::string, Value *> NamedValues;
    395 
    396 Value *LogErrorV(const char *Str) {
    397   LogError(Str);
    398   return nullptr;
    399 }
    400 
    401 Value *NumberExprAST::codegen() {
    402   return ConstantFP::get(TheContext, APFloat(Val));
    403 }
    404 
    405 Value *VariableExprAST::codegen() {
    406   // Look this variable up in the function.
    407   Value *V = NamedValues[Name];
    408   if (!V)
    409     return LogErrorV("Unknown variable name");
    410   return V;
    411 }
    412 
    413 Value *BinaryExprAST::codegen() {
    414   Value *L = LHS->codegen();
    415   Value *R = RHS->codegen();
    416   if (!L || !R)
    417     return nullptr;
    418 
    419   switch (Op) {
    420   case '+':
    421     return Builder.CreateFAdd(L, R, "addtmp");
    422   case '-':
    423     return Builder.CreateFSub(L, R, "subtmp");
    424   case '*':
    425     return Builder.CreateFMul(L, R, "multmp");
    426   case '<':
    427     L = Builder.CreateFCmpULT(L, R, "cmptmp");
    428     // Convert bool 0/1 to double 0.0 or 1.0
    429     return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
    430   default:
    431     return LogErrorV("invalid binary operator");
    432   }
    433 }
    434 
    435 Value *CallExprAST::codegen() {
    436   // Look up the name in the global module table.
    437   Function *CalleeF = TheModule->getFunction(Callee);
    438   if (!CalleeF)
    439     return LogErrorV("Unknown function referenced");
    440 
    441   // If argument mismatch error.
    442   if (CalleeF->arg_size() != Args.size())
    443     return LogErrorV("Incorrect # arguments passed");
    444 
    445   std::vector<Value *> ArgsV;
    446   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
    447     ArgsV.push_back(Args[i]->codegen());
    448     if (!ArgsV.back())
    449       return nullptr;
    450   }
    451 
    452   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
    453 }
    454 
    455 Function *PrototypeAST::codegen() {
    456   // Make the function type:  double(double,double) etc.
    457   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
    458   FunctionType *FT =
    459       FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
    460 
    461   Function *F =
    462       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
    463 
    464   // Set names for all arguments.
    465   unsigned Idx = 0;
    466   for (auto &Arg : F->args())
    467     Arg.setName(Args[Idx++]);
    468 
    469   return F;
    470 }
    471 
    472 Function *FunctionAST::codegen() {
    473   // First, check for an existing function from a previous 'extern' declaration.
    474   Function *TheFunction = TheModule->getFunction(Proto->getName());
    475 
    476   if (!TheFunction)
    477     TheFunction = Proto->codegen();
    478 
    479   if (!TheFunction)
    480     return nullptr;
    481 
    482   // Create a new basic block to start insertion into.
    483   BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
    484   Builder.SetInsertPoint(BB);
    485 
    486   // Record the function arguments in the NamedValues map.
    487   NamedValues.clear();
    488   for (auto &Arg : TheFunction->args())
    489     NamedValues[Arg.getName()] = &Arg;
    490 
    491   if (Value *RetVal = Body->codegen()) {
    492     // Finish off the function.
    493     Builder.CreateRet(RetVal);
    494 
    495     // Validate the generated code, checking for consistency.
    496     verifyFunction(*TheFunction);
    497 
    498     return TheFunction;
    499   }
    500 
    501   // Error reading body, remove function.
    502   TheFunction->eraseFromParent();
    503   return nullptr;
    504 }
    505 
    506 //===----------------------------------------------------------------------===//
    507 // Top-Level parsing and JIT Driver
    508 //===----------------------------------------------------------------------===//
    509 
    510 static void HandleDefinition() {
    511   if (auto FnAST = ParseDefinition()) {
    512     if (auto *FnIR = FnAST->codegen()) {
    513       fprintf(stderr, "Read function definition:");
    514       FnIR->dump();
    515     }
    516   } else {
    517     // Skip token for error recovery.
    518     getNextToken();
    519   }
    520 }
    521 
    522 static void HandleExtern() {
    523   if (auto ProtoAST = ParseExtern()) {
    524     if (auto *FnIR = ProtoAST->codegen()) {
    525       fprintf(stderr, "Read extern: ");
    526       FnIR->dump();
    527     }
    528   } else {
    529     // Skip token for error recovery.
    530     getNextToken();
    531   }
    532 }
    533 
    534 static void HandleTopLevelExpression() {
    535   // Evaluate a top-level expression into an anonymous function.
    536   if (auto FnAST = ParseTopLevelExpr()) {
    537     if (auto *FnIR = FnAST->codegen()) {
    538       fprintf(stderr, "Read top-level expression:");
    539       FnIR->dump();
    540     }
    541   } else {
    542     // Skip token for error recovery.
    543     getNextToken();
    544   }
    545 }
    546 
    547 /// top ::= definition | external | expression | ';'
    548 static void MainLoop() {
    549   while (true) {
    550     fprintf(stderr, "ready> ");
    551     switch (CurTok) {
    552     case tok_eof:
    553       return;
    554     case ';': // ignore top-level semicolons.
    555       getNextToken();
    556       break;
    557     case tok_def:
    558       HandleDefinition();
    559       break;
    560     case tok_extern:
    561       HandleExtern();
    562       break;
    563     default:
    564       HandleTopLevelExpression();
    565       break;
    566     }
    567   }
    568 }
    569 
    570 //===----------------------------------------------------------------------===//
    571 // Main driver code.
    572 //===----------------------------------------------------------------------===//
    573 
    574 int main() {
    575   // Install standard binary operators.
    576   // 1 is lowest precedence.
    577   BinopPrecedence['<'] = 10;
    578   BinopPrecedence['+'] = 20;
    579   BinopPrecedence['-'] = 20;
    580   BinopPrecedence['*'] = 40; // highest.
    581 
    582   // Prime the first token.
    583   fprintf(stderr, "ready> ");
    584   getNextToken();
    585 
    586   // Make the module, which holds all the code.
    587   TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
    588 
    589   // Run the main "interpreter loop" now.
    590   MainLoop();
    591 
    592   // Print out all of the generated code.
    593   TheModule->dump();
    594 
    595   return 0;
    596 }
    597