Home | History | Annotate | Download | only in Chapter4
      1 #include "llvm/ADT/STLExtras.h"
      2 #include "llvm/Analysis/Passes.h"
      3 #include "llvm/IR/IRBuilder.h"
      4 #include "llvm/IR/LLVMContext.h"
      5 #include "llvm/IR/LegacyPassManager.h"
      6 #include "llvm/IR/Module.h"
      7 #include "llvm/IR/Verifier.h"
      8 #include "llvm/Support/TargetSelect.h"
      9 #include "llvm/Transforms/Scalar.h"
     10 #include <cctype>
     11 #include <cstdio>
     12 #include <map>
     13 #include <string>
     14 #include <vector>
     15 #include "../include/KaleidoscopeJIT.h"
     16 
     17 using namespace llvm;
     18 using namespace llvm::orc;
     19 
     20 //===----------------------------------------------------------------------===//
     21 // Lexer
     22 //===----------------------------------------------------------------------===//
     23 
     24 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
     25 // of these for known things.
     26 enum Token {
     27   tok_eof = -1,
     28 
     29   // commands
     30   tok_def = -2,
     31   tok_extern = -3,
     32 
     33   // primary
     34   tok_identifier = -4,
     35   tok_number = -5
     36 };
     37 
     38 static std::string IdentifierStr; // Filled in if tok_identifier
     39 static double NumVal;             // Filled in if tok_number
     40 
     41 /// gettok - Return the next token from standard input.
     42 static int gettok() {
     43   static int LastChar = ' ';
     44 
     45   // Skip any whitespace.
     46   while (isspace(LastChar))
     47     LastChar = getchar();
     48 
     49   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
     50     IdentifierStr = LastChar;
     51     while (isalnum((LastChar = getchar())))
     52       IdentifierStr += LastChar;
     53 
     54     if (IdentifierStr == "def")
     55       return tok_def;
     56     if (IdentifierStr == "extern")
     57       return tok_extern;
     58     return tok_identifier;
     59   }
     60 
     61   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
     62     std::string NumStr;
     63     do {
     64       NumStr += LastChar;
     65       LastChar = getchar();
     66     } while (isdigit(LastChar) || LastChar == '.');
     67 
     68     NumVal = strtod(NumStr.c_str(), nullptr);
     69     return tok_number;
     70   }
     71 
     72   if (LastChar == '#') {
     73     // Comment until end of line.
     74     do
     75       LastChar = getchar();
     76     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
     77 
     78     if (LastChar != EOF)
     79       return gettok();
     80   }
     81 
     82   // Check for end of file.  Don't eat the EOF.
     83   if (LastChar == EOF)
     84     return tok_eof;
     85 
     86   // Otherwise, just return the character as its ascii value.
     87   int ThisChar = LastChar;
     88   LastChar = getchar();
     89   return ThisChar;
     90 }
     91 
     92 //===----------------------------------------------------------------------===//
     93 // Abstract Syntax Tree (aka Parse Tree)
     94 //===----------------------------------------------------------------------===//
     95 namespace {
     96 /// ExprAST - Base class for all expression nodes.
     97 class ExprAST {
     98 public:
     99   virtual ~ExprAST() {}
    100   virtual Value *codegen() = 0;
    101 };
    102 
    103 /// NumberExprAST - Expression class for numeric literals like "1.0".
    104 class NumberExprAST : public ExprAST {
    105   double Val;
    106 
    107 public:
    108   NumberExprAST(double Val) : Val(Val) {}
    109   Value *codegen() override;
    110 };
    111 
    112 /// VariableExprAST - Expression class for referencing a variable, like "a".
    113 class VariableExprAST : public ExprAST {
    114   std::string Name;
    115 
    116 public:
    117   VariableExprAST(const std::string &Name) : Name(Name) {}
    118   Value *codegen() override;
    119 };
    120 
    121 /// BinaryExprAST - Expression class for a binary operator.
    122 class BinaryExprAST : public ExprAST {
    123   char Op;
    124   std::unique_ptr<ExprAST> LHS, RHS;
    125 
    126 public:
    127   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
    128                 std::unique_ptr<ExprAST> RHS)
    129       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
    130   Value *codegen() override;
    131 };
    132 
    133 /// CallExprAST - Expression class for function calls.
    134 class CallExprAST : public ExprAST {
    135   std::string Callee;
    136   std::vector<std::unique_ptr<ExprAST>> Args;
    137 
    138 public:
    139   CallExprAST(const std::string &Callee,
    140               std::vector<std::unique_ptr<ExprAST>> Args)
    141       : Callee(Callee), Args(std::move(Args)) {}
    142   Value *codegen() override;
    143 };
    144 
    145 /// PrototypeAST - This class represents the "prototype" for a function,
    146 /// which captures its name, and its argument names (thus implicitly the number
    147 /// of arguments the function takes).
    148 class PrototypeAST {
    149   std::string Name;
    150   std::vector<std::string> Args;
    151 
    152 public:
    153   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
    154       : Name(Name), Args(std::move(Args)) {}
    155   Function *codegen();
    156   const std::string &getName() const { return Name; }
    157 };
    158 
    159 /// FunctionAST - This class represents a function definition itself.
    160 class FunctionAST {
    161   std::unique_ptr<PrototypeAST> Proto;
    162   std::unique_ptr<ExprAST> Body;
    163 
    164 public:
    165   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
    166               std::unique_ptr<ExprAST> Body)
    167       : Proto(std::move(Proto)), Body(std::move(Body)) {}
    168   Function *codegen();
    169 };
    170 } // end anonymous namespace
    171 
    172 //===----------------------------------------------------------------------===//
    173 // Parser
    174 //===----------------------------------------------------------------------===//
    175 
    176 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
    177 /// token the parser is looking at.  getNextToken reads another token from the
    178 /// lexer and updates CurTok with its results.
    179 static int CurTok;
    180 static int getNextToken() { return CurTok = gettok(); }
    181 
    182 /// BinopPrecedence - This holds the precedence for each binary operator that is
    183 /// defined.
    184 static std::map<char, int> BinopPrecedence;
    185 
    186 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
    187 static int GetTokPrecedence() {
    188   if (!isascii(CurTok))
    189     return -1;
    190 
    191   // Make sure it's a declared binop.
    192   int TokPrec = BinopPrecedence[CurTok];
    193   if (TokPrec <= 0)
    194     return -1;
    195   return TokPrec;
    196 }
    197 
    198 /// Error* - These are little helper functions for error handling.
    199 std::unique_ptr<ExprAST> Error(const char *Str) {
    200   fprintf(stderr, "Error: %s\n", Str);
    201   return nullptr;
    202 }
    203 
    204 std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
    205   Error(Str);
    206   return nullptr;
    207 }
    208 
    209 static std::unique_ptr<ExprAST> ParseExpression();
    210 
    211 /// numberexpr ::= number
    212 static std::unique_ptr<ExprAST> ParseNumberExpr() {
    213   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
    214   getNextToken(); // consume the number
    215   return std::move(Result);
    216 }
    217 
    218 /// parenexpr ::= '(' expression ')'
    219 static std::unique_ptr<ExprAST> ParseParenExpr() {
    220   getNextToken(); // eat (.
    221   auto V = ParseExpression();
    222   if (!V)
    223     return nullptr;
    224 
    225   if (CurTok != ')')
    226     return Error("expected ')'");
    227   getNextToken(); // eat ).
    228   return V;
    229 }
    230 
    231 /// identifierexpr
    232 ///   ::= identifier
    233 ///   ::= identifier '(' expression* ')'
    234 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
    235   std::string IdName = IdentifierStr;
    236 
    237   getNextToken(); // eat identifier.
    238 
    239   if (CurTok != '(') // Simple variable ref.
    240     return llvm::make_unique<VariableExprAST>(IdName);
    241 
    242   // Call.
    243   getNextToken(); // eat (
    244   std::vector<std::unique_ptr<ExprAST>> Args;
    245   if (CurTok != ')') {
    246     while (1) {
    247       if (auto Arg = ParseExpression())
    248         Args.push_back(std::move(Arg));
    249       else
    250         return nullptr;
    251 
    252       if (CurTok == ')')
    253         break;
    254 
    255       if (CurTok != ',')
    256         return Error("Expected ')' or ',' in argument list");
    257       getNextToken();
    258     }
    259   }
    260 
    261   // Eat the ')'.
    262   getNextToken();
    263 
    264   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
    265 }
    266 
    267 /// primary
    268 ///   ::= identifierexpr
    269 ///   ::= numberexpr
    270 ///   ::= parenexpr
    271 static std::unique_ptr<ExprAST> ParsePrimary() {
    272   switch (CurTok) {
    273   default:
    274     return Error("unknown token when expecting an expression");
    275   case tok_identifier:
    276     return ParseIdentifierExpr();
    277   case tok_number:
    278     return ParseNumberExpr();
    279   case '(':
    280     return ParseParenExpr();
    281   }
    282 }
    283 
    284 /// binoprhs
    285 ///   ::= ('+' primary)*
    286 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
    287                                               std::unique_ptr<ExprAST> LHS) {
    288   // If this is a binop, find its precedence.
    289   while (1) {
    290     int TokPrec = GetTokPrecedence();
    291 
    292     // If this is a binop that binds at least as tightly as the current binop,
    293     // consume it, otherwise we are done.
    294     if (TokPrec < ExprPrec)
    295       return LHS;
    296 
    297     // Okay, we know this is a binop.
    298     int BinOp = CurTok;
    299     getNextToken(); // eat binop
    300 
    301     // Parse the primary expression after the binary operator.
    302     auto RHS = ParsePrimary();
    303     if (!RHS)
    304       return nullptr;
    305 
    306     // If BinOp binds less tightly with RHS than the operator after RHS, let
    307     // the pending operator take RHS as its LHS.
    308     int NextPrec = GetTokPrecedence();
    309     if (TokPrec < NextPrec) {
    310       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
    311       if (!RHS)
    312         return nullptr;
    313     }
    314 
    315     // Merge LHS/RHS.
    316     LHS =
    317         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
    318   }
    319 }
    320 
    321 /// expression
    322 ///   ::= primary binoprhs
    323 ///
    324 static std::unique_ptr<ExprAST> ParseExpression() {
    325   auto LHS = ParsePrimary();
    326   if (!LHS)
    327     return nullptr;
    328 
    329   return ParseBinOpRHS(0, std::move(LHS));
    330 }
    331 
    332 /// prototype
    333 ///   ::= id '(' id* ')'
    334 static std::unique_ptr<PrototypeAST> ParsePrototype() {
    335   if (CurTok != tok_identifier)
    336     return ErrorP("Expected function name in prototype");
    337 
    338   std::string FnName = IdentifierStr;
    339   getNextToken();
    340 
    341   if (CurTok != '(')
    342     return ErrorP("Expected '(' in prototype");
    343 
    344   std::vector<std::string> ArgNames;
    345   while (getNextToken() == tok_identifier)
    346     ArgNames.push_back(IdentifierStr);
    347   if (CurTok != ')')
    348     return ErrorP("Expected ')' in prototype");
    349 
    350   // success.
    351   getNextToken(); // eat ')'.
    352 
    353   return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
    354 }
    355 
    356 /// definition ::= 'def' prototype expression
    357 static std::unique_ptr<FunctionAST> ParseDefinition() {
    358   getNextToken(); // eat def.
    359   auto Proto = ParsePrototype();
    360   if (!Proto)
    361     return nullptr;
    362 
    363   if (auto E = ParseExpression())
    364     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
    365   return nullptr;
    366 }
    367 
    368 /// toplevelexpr ::= expression
    369 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
    370   if (auto E = ParseExpression()) {
    371     // Make an anonymous proto.
    372     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
    373                                                  std::vector<std::string>());
    374     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
    375   }
    376   return nullptr;
    377 }
    378 
    379 /// external ::= 'extern' prototype
    380 static std::unique_ptr<PrototypeAST> ParseExtern() {
    381   getNextToken(); // eat extern.
    382   return ParsePrototype();
    383 }
    384 
    385 //===----------------------------------------------------------------------===//
    386 // Code Generation
    387 //===----------------------------------------------------------------------===//
    388 
    389 static std::unique_ptr<Module> TheModule;
    390 static IRBuilder<> Builder(getGlobalContext());
    391 static std::map<std::string, Value *> NamedValues;
    392 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
    393 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
    394 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
    395 
    396 Value *ErrorV(const char *Str) {
    397   Error(Str);
    398   return nullptr;
    399 }
    400 
    401 Function *getFunction(std::string Name) {
    402   // First, see if the function has already been added to the current module.
    403   if (auto *F = TheModule->getFunction(Name))
    404     return F;
    405 
    406   // If not, check whether we can codegen the declaration from some existing
    407   // prototype.
    408   auto FI = FunctionProtos.find(Name);
    409   if (FI != FunctionProtos.end())
    410     return FI->second->codegen();
    411 
    412   // If no existing prototype exists, return null.
    413   return nullptr;
    414 }
    415 
    416 Value *NumberExprAST::codegen() {
    417   return ConstantFP::get(getGlobalContext(), APFloat(Val));
    418 }
    419 
    420 Value *VariableExprAST::codegen() {
    421   // Look this variable up in the function.
    422   Value *V = NamedValues[Name];
    423   if (!V)
    424     return ErrorV("Unknown variable name");
    425   return V;
    426 }
    427 
    428 Value *BinaryExprAST::codegen() {
    429   Value *L = LHS->codegen();
    430   Value *R = RHS->codegen();
    431   if (!L || !R)
    432     return nullptr;
    433 
    434   switch (Op) {
    435   case '+':
    436     return Builder.CreateFAdd(L, R, "addtmp");
    437   case '-':
    438     return Builder.CreateFSub(L, R, "subtmp");
    439   case '*':
    440     return Builder.CreateFMul(L, R, "multmp");
    441   case '<':
    442     L = Builder.CreateFCmpULT(L, R, "cmptmp");
    443     // Convert bool 0/1 to double 0.0 or 1.0
    444     return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
    445                                 "booltmp");
    446   default:
    447     return ErrorV("invalid binary operator");
    448   }
    449 }
    450 
    451 Value *CallExprAST::codegen() {
    452   // Look up the name in the global module table.
    453   Function *CalleeF = getFunction(Callee);
    454   if (!CalleeF)
    455     return ErrorV("Unknown function referenced");
    456 
    457   // If argument mismatch error.
    458   if (CalleeF->arg_size() != Args.size())
    459     return ErrorV("Incorrect # arguments passed");
    460 
    461   std::vector<Value *> ArgsV;
    462   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
    463     ArgsV.push_back(Args[i]->codegen());
    464     if (!ArgsV.back())
    465       return nullptr;
    466   }
    467 
    468   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
    469 }
    470 
    471 Function *PrototypeAST::codegen() {
    472   // Make the function type:  double(double,double) etc.
    473   std::vector<Type *> Doubles(Args.size(),
    474                               Type::getDoubleTy(getGlobalContext()));
    475   FunctionType *FT =
    476       FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
    477 
    478   Function *F =
    479       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
    480 
    481   // Set names for all arguments.
    482   unsigned Idx = 0;
    483   for (auto &Arg : F->args())
    484     Arg.setName(Args[Idx++]);
    485 
    486   return F;
    487 }
    488 
    489 Function *FunctionAST::codegen() {
    490   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
    491   // reference to it for use below.
    492   auto &P = *Proto;
    493   FunctionProtos[Proto->getName()] = std::move(Proto);
    494   Function *TheFunction = getFunction(P.getName());
    495   if (!TheFunction)
    496     return nullptr;
    497 
    498   // Create a new basic block to start insertion into.
    499   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
    500   Builder.SetInsertPoint(BB);
    501 
    502   // Record the function arguments in the NamedValues map.
    503   NamedValues.clear();
    504   for (auto &Arg : TheFunction->args())
    505     NamedValues[Arg.getName()] = &Arg;
    506 
    507   if (Value *RetVal = Body->codegen()) {
    508     // Finish off the function.
    509     Builder.CreateRet(RetVal);
    510 
    511     // Validate the generated code, checking for consistency.
    512     verifyFunction(*TheFunction);
    513 
    514     // Run the optimizer on the function.
    515     TheFPM->run(*TheFunction);
    516 
    517     return TheFunction;
    518   }
    519 
    520   // Error reading body, remove function.
    521   TheFunction->eraseFromParent();
    522   return nullptr;
    523 }
    524 
    525 //===----------------------------------------------------------------------===//
    526 // Top-Level parsing and JIT Driver
    527 //===----------------------------------------------------------------------===//
    528 
    529 static void InitializeModuleAndPassManager() {
    530   // Open a new module.
    531   TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
    532   TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
    533 
    534   // Create a new pass manager attached to it.
    535   TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
    536 
    537   // Do simple "peephole" optimizations and bit-twiddling optzns.
    538   TheFPM->add(createInstructionCombiningPass());
    539   // Reassociate expressions.
    540   TheFPM->add(createReassociatePass());
    541   // Eliminate Common SubExpressions.
    542   TheFPM->add(createGVNPass());
    543   // Simplify the control flow graph (deleting unreachable blocks, etc).
    544   TheFPM->add(createCFGSimplificationPass());
    545 
    546   TheFPM->doInitialization();
    547 }
    548 
    549 static void HandleDefinition() {
    550   if (auto FnAST = ParseDefinition()) {
    551     if (auto *FnIR = FnAST->codegen()) {
    552       fprintf(stderr, "Read function definition:");
    553       FnIR->dump();
    554       TheJIT->addModule(std::move(TheModule));
    555       InitializeModuleAndPassManager();
    556     }
    557   } else {
    558     // Skip token for error recovery.
    559     getNextToken();
    560   }
    561 }
    562 
    563 static void HandleExtern() {
    564   if (auto ProtoAST = ParseExtern()) {
    565     if (auto *FnIR = ProtoAST->codegen()) {
    566       fprintf(stderr, "Read extern: ");
    567       FnIR->dump();
    568       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
    569     }
    570   } else {
    571     // Skip token for error recovery.
    572     getNextToken();
    573   }
    574 }
    575 
    576 static void HandleTopLevelExpression() {
    577   // Evaluate a top-level expression into an anonymous function.
    578   if (auto FnAST = ParseTopLevelExpr()) {
    579     if (FnAST->codegen()) {
    580 
    581       // JIT the module containing the anonymous expression, keeping a handle so
    582       // we can free it later.
    583       auto H = TheJIT->addModule(std::move(TheModule));
    584       InitializeModuleAndPassManager();
    585 
    586       // Search the JIT for the __anon_expr symbol.
    587       auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
    588       assert(ExprSymbol && "Function not found");
    589 
    590       // Get the symbol's address and cast it to the right type (takes no
    591       // arguments, returns a double) so we can call it as a native function.
    592       double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
    593       fprintf(stderr, "Evaluated to %f\n", FP());
    594 
    595       // Delete the anonymous expression module from the JIT.
    596       TheJIT->removeModule(H);
    597     }
    598   } else {
    599     // Skip token for error recovery.
    600     getNextToken();
    601   }
    602 }
    603 
    604 /// top ::= definition | external | expression | ';'
    605 static void MainLoop() {
    606   while (1) {
    607     fprintf(stderr, "ready> ");
    608     switch (CurTok) {
    609     case tok_eof:
    610       return;
    611     case ';': // ignore top-level semicolons.
    612       getNextToken();
    613       break;
    614     case tok_def:
    615       HandleDefinition();
    616       break;
    617     case tok_extern:
    618       HandleExtern();
    619       break;
    620     default:
    621       HandleTopLevelExpression();
    622       break;
    623     }
    624   }
    625 }
    626 
    627 //===----------------------------------------------------------------------===//
    628 // "Library" functions that can be "extern'd" from user code.
    629 //===----------------------------------------------------------------------===//
    630 
    631 /// putchard - putchar that takes a double and returns 0.
    632 extern "C" double putchard(double X) {
    633   fputc((char)X, stderr);
    634   return 0;
    635 }
    636 
    637 /// printd - printf that takes a double prints it as "%f\n", returning 0.
    638 extern "C" double printd(double X) {
    639   fprintf(stderr, "%f\n", X);
    640   return 0;
    641 }
    642 
    643 //===----------------------------------------------------------------------===//
    644 // Main driver code.
    645 //===----------------------------------------------------------------------===//
    646 
    647 int main() {
    648   InitializeNativeTarget();
    649   InitializeNativeTargetAsmPrinter();
    650   InitializeNativeTargetAsmParser();
    651 
    652   // Install standard binary operators.
    653   // 1 is lowest precedence.
    654   BinopPrecedence['<'] = 10;
    655   BinopPrecedence['+'] = 20;
    656   BinopPrecedence['-'] = 20;
    657   BinopPrecedence['*'] = 40; // highest.
    658 
    659   // Prime the first token.
    660   fprintf(stderr, "ready> ");
    661   getNextToken();
    662 
    663   TheJIT = llvm::make_unique<KaleidoscopeJIT>();
    664 
    665   InitializeModuleAndPassManager();
    666 
    667   // Run the main "interpreter loop" now.
    668   MainLoop();
    669 
    670   return 0;
    671 }
    672